1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
18#include "mlir/IR/Attributes.h"
19#include "mlir/IR/BuiltinAttributes.h"
20#include "mlir/IR/DialectImplementation.h"
21#include "mlir/IR/OpImplementation.h"
22#include "mlir/IR/OperationSupport.h"
23#include "mlir/Interfaces/FoldInterfaces.h"
24
25#include "llvm/ADT/ArrayRef.h"
26#include "llvm/ADT/BitVector.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/ADT/STLForwardCompat.h"
29#include "llvm/ADT/SmallString.h"
30#include "llvm/ADT/StringExtras.h"
31#include "llvm/ADT/StringRef.h"
32#include "llvm/ADT/TypeSwitch.h"
33#include "llvm/Frontend/OpenMP/OMPConstants.h"
34#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
35#include <cstddef>
36#include <iterator>
37#include <optional>
38#include <variant>
39
40#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
41#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
42#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
44
45using namespace mlir;
46using namespace mlir::omp;
47
48static ArrayAttr makeArrayAttr(MLIRContext *context,
49 llvm::ArrayRef<Attribute> attrs) {
50 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
51}
52
53static DenseBoolArrayAttr
54makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
55 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
56}
57
58namespace {
59struct MemRefPointerLikeModel
60 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
61 MemRefType> {
62 Type getElementType(Type pointer) const {
63 return llvm::cast<MemRefType>(pointer).getElementType();
64 }
65};
66
67struct LLVMPointerPointerLikeModel
68 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
69 LLVM::LLVMPointerType> {
70 Type getElementType(Type pointer) const { return Type(); }
71};
72} // namespace
73
74void OpenMPDialect::initialize() {
75 addOperations<
76#define GET_OP_LIST
77#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78 >();
79 addAttributes<
80#define GET_ATTRDEF_LIST
81#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82 >();
83 addTypes<
84#define GET_TYPEDEF_LIST
85#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86 >();
87
88 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
89
90 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
91 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
92 *getContext());
93
94 // Attach default offload module interface to module op to access
95 // offload functionality through
96 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
97 *getContext());
98
99 // Attach default declare target interfaces to operations which can be marked
100 // as declare target (Global Operations and Functions/Subroutines in dialects
101 // that Fortran (or other languages that lower to MLIR) translates too
102 mlir::LLVM::GlobalOp::attachInterface<
103 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::GlobalOp>>(
104 *getContext());
105 mlir::LLVM::LLVMFuncOp::attachInterface<
106 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::LLVMFuncOp>>(
107 *getContext());
108 mlir::func::FuncOp::attachInterface<
109 mlir::omp::DeclareTargetDefaultModel<mlir::func::FuncOp>>(*getContext());
110}
111
112//===----------------------------------------------------------------------===//
113// Parser and printer for Allocate Clause
114//===----------------------------------------------------------------------===//
115
116/// Parse an allocate clause with allocators and a list of operands with types.
117///
118/// allocate-operand-list :: = allocate-operand |
119/// allocator-operand `,` allocate-operand-list
120/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
121/// ssa-id-and-type ::= ssa-id `:` type
122static ParseResult parseAllocateAndAllocator(
123 OpAsmParser &parser,
124 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocateVars,
125 SmallVectorImpl<Type> &allocateTypes,
126 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocatorVars,
127 SmallVectorImpl<Type> &allocatorTypes) {
128
129 return parser.parseCommaSeparatedList([&]() {
130 OpAsmParser::UnresolvedOperand operand;
131 Type type;
132 if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type))
133 return failure();
134 allocatorVars.push_back(operand);
135 allocatorTypes.push_back(Elt: type);
136 if (parser.parseArrow())
137 return failure();
138 if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type))
139 return failure();
140
141 allocateVars.push_back(operand);
142 allocateTypes.push_back(Elt: type);
143 return success();
144 });
145}
146
147/// Print allocate clause
148static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
149 OperandRange allocateVars,
150 TypeRange allocateTypes,
151 OperandRange allocatorVars,
152 TypeRange allocatorTypes) {
153 for (unsigned i = 0; i < allocateVars.size(); ++i) {
154 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
155 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
156 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
157 }
158}
159
160//===----------------------------------------------------------------------===//
161// Parser and printer for a clause attribute (StringEnumAttr)
162//===----------------------------------------------------------------------===//
163
164template <typename ClauseAttr>
165static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
166 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
167 StringRef enumStr;
168 SMLoc loc = parser.getCurrentLocation();
169 if (parser.parseKeyword(keyword: &enumStr))
170 return failure();
171 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
172 attr = ClauseAttr::get(parser.getContext(), *enumValue);
173 return success();
174 }
175 return parser.emitError(loc, message: "invalid clause value: '") << enumStr << "'";
176}
177
178template <typename ClauseAttr>
179void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
180 p << stringifyEnum(attr.getValue());
181}
182
183//===----------------------------------------------------------------------===//
184// Parser and printer for Linear Clause
185//===----------------------------------------------------------------------===//
186
187/// linear ::= `linear` `(` linear-list `)`
188/// linear-list := linear-val | linear-val linear-list
189/// linear-val := ssa-id-and-type `=` ssa-id-and-type
190static ParseResult parseLinearClause(
191 OpAsmParser &parser,
192 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearVars,
193 SmallVectorImpl<Type> &linearTypes,
194 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearStepVars) {
195 return parser.parseCommaSeparatedList(parseElementFn: [&]() {
196 OpAsmParser::UnresolvedOperand var;
197 Type type;
198 OpAsmParser::UnresolvedOperand stepVar;
199 if (parser.parseOperand(result&: var) || parser.parseEqual() ||
200 parser.parseOperand(result&: stepVar) || parser.parseColonType(result&: type))
201 return failure();
202
203 linearVars.push_back(Elt: var);
204 linearTypes.push_back(Elt: type);
205 linearStepVars.push_back(Elt: stepVar);
206 return success();
207 });
208}
209
210/// Print Linear Clause
211static void printLinearClause(OpAsmPrinter &p, Operation *op,
212 ValueRange linearVars, TypeRange linearTypes,
213 ValueRange linearStepVars) {
214 size_t linearVarsSize = linearVars.size();
215 for (unsigned i = 0; i < linearVarsSize; ++i) {
216 std::string separator = i == linearVarsSize - 1 ? "" : ", ";
217 p << linearVars[i];
218 if (linearStepVars.size() > i)
219 p << " = " << linearStepVars[i];
220 p << " : " << linearVars[i].getType() << separator;
221 }
222}
223
224//===----------------------------------------------------------------------===//
225// Verifier for Nontemporal Clause
226//===----------------------------------------------------------------------===//
227
228static LogicalResult verifyNontemporalClause(Operation *op,
229 OperandRange nontemporalVars) {
230
231 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
232 DenseSet<Value> nontemporalItems;
233 for (const auto &it : nontemporalVars)
234 if (!nontemporalItems.insert(V: it).second)
235 return op->emitOpError() << "nontemporal variable used more than once";
236
237 return success();
238}
239
240//===----------------------------------------------------------------------===//
241// Parser, verifier and printer for Aligned Clause
242//===----------------------------------------------------------------------===//
243static LogicalResult verifyAlignedClause(Operation *op,
244 std::optional<ArrayAttr> alignments,
245 OperandRange alignedVars) {
246 // Check if number of alignment values equals to number of aligned variables
247 if (!alignedVars.empty()) {
248 if (!alignments || alignments->size() != alignedVars.size())
249 return op->emitOpError()
250 << "expected as many alignment values as aligned variables";
251 } else {
252 if (alignments)
253 return op->emitOpError() << "unexpected alignment values attribute";
254 return success();
255 }
256
257 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
258 DenseSet<Value> alignedItems;
259 for (auto it : alignedVars)
260 if (!alignedItems.insert(V: it).second)
261 return op->emitOpError() << "aligned variable used more than once";
262
263 if (!alignments)
264 return success();
265
266 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
267 for (unsigned i = 0; i < (*alignments).size(); ++i) {
268 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
269 if (intAttr.getValue().sle(0))
270 return op->emitOpError() << "alignment should be greater than 0";
271 } else {
272 return op->emitOpError() << "expected integer alignment";
273 }
274 }
275
276 return success();
277}
278
279/// aligned ::= `aligned` `(` aligned-list `)`
280/// aligned-list := aligned-val | aligned-val aligned-list
281/// aligned-val := ssa-id-and-type `->` alignment
282static ParseResult
283parseAlignedClause(OpAsmParser &parser,
284 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &alignedVars,
285 SmallVectorImpl<Type> &alignedTypes,
286 ArrayAttr &alignmentsAttr) {
287 SmallVector<Attribute> alignmentVec;
288 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
289 if (parser.parseOperand(result&: alignedVars.emplace_back()) ||
290 parser.parseColonType(result&: alignedTypes.emplace_back()) ||
291 parser.parseArrow() ||
292 parser.parseAttribute(result&: alignmentVec.emplace_back())) {
293 return failure();
294 }
295 return success();
296 })))
297 return failure();
298 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
299 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
300 return success();
301}
302
303/// Print Aligned Clause
304static void printAlignedClause(OpAsmPrinter &p, Operation *op,
305 ValueRange alignedVars, TypeRange alignedTypes,
306 std::optional<ArrayAttr> alignments) {
307 for (unsigned i = 0; i < alignedVars.size(); ++i) {
308 if (i != 0)
309 p << ", ";
310 p << alignedVars[i] << " : " << alignedVars[i].getType();
311 p << " -> " << (*alignments)[i];
312 }
313}
314
315//===----------------------------------------------------------------------===//
316// Parser, printer and verifier for Schedule Clause
317//===----------------------------------------------------------------------===//
318
319static ParseResult
320verifyScheduleModifiers(OpAsmParser &parser,
321 SmallVectorImpl<SmallString<12>> &modifiers) {
322 if (modifiers.size() > 2)
323 return parser.emitError(loc: parser.getNameLoc()) << " unexpected modifier(s)";
324 for (const auto &mod : modifiers) {
325 // Translate the string. If it has no value, then it was not a valid
326 // modifier!
327 auto symbol = symbolizeScheduleModifier(mod);
328 if (!symbol)
329 return parser.emitError(loc: parser.getNameLoc())
330 << " unknown modifier type: " << mod;
331 }
332
333 // If we have one modifier that is "simd", then stick a "none" modiifer in
334 // index 0.
335 if (modifiers.size() == 1) {
336 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
337 modifiers.push_back(Elt: modifiers[0]);
338 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
339 }
340 } else if (modifiers.size() == 2) {
341 // If there are two modifier:
342 // First modifier should not be simd, second one should be simd
343 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
344 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
345 return parser.emitError(loc: parser.getNameLoc())
346 << " incorrect modifier order";
347 }
348 return success();
349}
350
351/// schedule ::= `schedule` `(` sched-list `)`
352/// sched-list ::= sched-val | sched-val sched-list |
353/// sched-val `,` sched-modifier
354/// sched-val ::= sched-with-chunk | sched-wo-chunk
355/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
356/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
357/// sched-wo-chunk ::= `auto` | `runtime`
358/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
359/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
360static ParseResult
361parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
362 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
363 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
364 Type &chunkType) {
365 StringRef keyword;
366 if (parser.parseKeyword(keyword: &keyword))
367 return failure();
368 std::optional<mlir::omp::ClauseScheduleKind> schedule =
369 symbolizeClauseScheduleKind(keyword);
370 if (!schedule)
371 return parser.emitError(loc: parser.getNameLoc()) << " expected schedule kind";
372
373 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
374 switch (*schedule) {
375 case ClauseScheduleKind::Static:
376 case ClauseScheduleKind::Dynamic:
377 case ClauseScheduleKind::Guided:
378 if (succeeded(Result: parser.parseOptionalEqual())) {
379 chunkSize = OpAsmParser::UnresolvedOperand{};
380 if (parser.parseOperand(result&: *chunkSize) || parser.parseColonType(result&: chunkType))
381 return failure();
382 } else {
383 chunkSize = std::nullopt;
384 }
385 break;
386 case ClauseScheduleKind::Auto:
387 case ClauseScheduleKind::Runtime:
388 chunkSize = std::nullopt;
389 }
390
391 // If there is a comma, we have one or more modifiers..
392 SmallVector<SmallString<12>> modifiers;
393 while (succeeded(Result: parser.parseOptionalComma())) {
394 StringRef mod;
395 if (parser.parseKeyword(keyword: &mod))
396 return failure();
397 modifiers.push_back(Elt: mod);
398 }
399
400 if (verifyScheduleModifiers(parser, modifiers))
401 return failure();
402
403 if (!modifiers.empty()) {
404 SMLoc loc = parser.getCurrentLocation();
405 if (std::optional<ScheduleModifier> mod =
406 symbolizeScheduleModifier(modifiers[0])) {
407 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
408 } else {
409 return parser.emitError(loc, message: "invalid schedule modifier");
410 }
411 // Only SIMD attribute is allowed here!
412 if (modifiers.size() > 1) {
413 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
414 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
415 }
416 }
417
418 return success();
419}
420
421/// Print schedule clause
422static void printScheduleClause(OpAsmPrinter &p, Operation *op,
423 ClauseScheduleKindAttr scheduleKind,
424 ScheduleModifierAttr scheduleMod,
425 UnitAttr scheduleSimd, Value scheduleChunk,
426 Type scheduleChunkType) {
427 p << stringifyClauseScheduleKind(scheduleKind.getValue());
428 if (scheduleChunk)
429 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
430 if (scheduleMod)
431 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
432 if (scheduleSimd)
433 p << ", simd";
434}
435
436//===----------------------------------------------------------------------===//
437// Parser and printer for Order Clause
438//===----------------------------------------------------------------------===//
439
440// order ::= `order` `(` [order-modifier ':'] concurrent `)`
441// order-modifier ::= reproducible | unconstrained
442static ParseResult parseOrderClause(OpAsmParser &parser,
443 ClauseOrderKindAttr &order,
444 OrderModifierAttr &orderMod) {
445 StringRef enumStr;
446 SMLoc loc = parser.getCurrentLocation();
447 if (parser.parseKeyword(keyword: &enumStr))
448 return failure();
449 if (std::optional<OrderModifier> enumValue =
450 symbolizeOrderModifier(enumStr)) {
451 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
452 if (parser.parseOptionalColon())
453 return failure();
454 loc = parser.getCurrentLocation();
455 if (parser.parseKeyword(keyword: &enumStr))
456 return failure();
457 }
458 if (std::optional<ClauseOrderKind> enumValue =
459 symbolizeClauseOrderKind(enumStr)) {
460 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
461 return success();
462 }
463 return parser.emitError(loc, message: "invalid clause value: '") << enumStr << "'";
464}
465
466static void printOrderClause(OpAsmPrinter &p, Operation *op,
467 ClauseOrderKindAttr order,
468 OrderModifierAttr orderMod) {
469 if (orderMod)
470 p << stringifyOrderModifier(orderMod.getValue()) << ":";
471 if (order)
472 p << stringifyClauseOrderKind(order.getValue());
473}
474
475template <typename ClauseTypeAttr, typename ClauseType>
476static ParseResult
477parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
478 std::optional<OpAsmParser::UnresolvedOperand> &operand,
479 Type &operandType,
480 std::optional<ClauseType> (*symbolizeClause)(StringRef),
481 StringRef clauseName) {
482 StringRef enumStr;
483 if (succeeded(Result: parser.parseOptionalKeyword(keyword: &enumStr))) {
484 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
485 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
486 if (parser.parseComma())
487 return failure();
488 } else {
489 return parser.emitError(loc: parser.getCurrentLocation())
490 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
491 ;
492 }
493 }
494
495 OpAsmParser::UnresolvedOperand var;
496 if (succeeded(Result: parser.parseOperand(result&: var))) {
497 operand = var;
498 } else {
499 return parser.emitError(loc: parser.getCurrentLocation())
500 << "expected " << clauseName << " operand";
501 }
502
503 if (operand.has_value()) {
504 if (parser.parseColonType(result&: operandType))
505 return failure();
506 }
507
508 return success();
509}
510
511template <typename ClauseTypeAttr, typename ClauseType>
512static void
513printGranularityClause(OpAsmPrinter &p, Operation *op,
514 ClauseTypeAttr prescriptiveness, Value operand,
515 mlir::Type operandType,
516 StringRef (*stringifyClauseType)(ClauseType)) {
517
518 if (prescriptiveness)
519 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
520
521 if (operand)
522 p << operand << ": " << operandType;
523}
524
525//===----------------------------------------------------------------------===//
526// Parser and printer for grainsize Clause
527//===----------------------------------------------------------------------===//
528
529// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
530static ParseResult
531parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
532 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
533 Type &grainsizeType) {
534 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
535 parser, grainsizeMod, grainsize, grainsizeType,
536 &symbolizeClauseGrainsizeType, "grainsize");
537}
538
539static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
540 ClauseGrainsizeTypeAttr grainsizeMod,
541 Value grainsize, mlir::Type grainsizeType) {
542 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
543 p, op, grainsizeMod, grainsize, grainsizeType,
544 &stringifyClauseGrainsizeType);
545}
546
547//===----------------------------------------------------------------------===//
548// Parser and printer for num_tasks Clause
549//===----------------------------------------------------------------------===//
550
551// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
552static ParseResult
553parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
554 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
555 Type &numTasksType) {
556 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
557 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
558 "num_tasks");
559}
560
561static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
562 ClauseNumTasksTypeAttr numTasksMod,
563 Value numTasks, mlir::Type numTasksType) {
564 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
565 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
566}
567
568//===----------------------------------------------------------------------===//
569// Parsers for operations including clauses that define entry block arguments.
570//===----------------------------------------------------------------------===//
571
572namespace {
573struct MapParseArgs {
574 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
575 SmallVectorImpl<Type> &types;
576 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
577 SmallVectorImpl<Type> &types)
578 : vars(vars), types(types) {}
579};
580struct PrivateParseArgs {
581 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
582 llvm::SmallVectorImpl<Type> &types;
583 ArrayAttr &syms;
584 UnitAttr &needsBarrier;
585 DenseI64ArrayAttr *mapIndices;
586 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
587 SmallVectorImpl<Type> &types, ArrayAttr &syms,
588 UnitAttr &needsBarrier,
589 DenseI64ArrayAttr *mapIndices = nullptr)
590 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
591 mapIndices(mapIndices) {}
592};
593
594struct ReductionParseArgs {
595 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
596 SmallVectorImpl<Type> &types;
597 DenseBoolArrayAttr &byref;
598 ArrayAttr &syms;
599 ReductionModifierAttr *modifier;
600 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
601 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
602 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
603 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
604};
605
606struct AllRegionParseArgs {
607 std::optional<MapParseArgs> hasDeviceAddrArgs;
608 std::optional<MapParseArgs> hostEvalArgs;
609 std::optional<ReductionParseArgs> inReductionArgs;
610 std::optional<MapParseArgs> mapArgs;
611 std::optional<PrivateParseArgs> privateArgs;
612 std::optional<ReductionParseArgs> reductionArgs;
613 std::optional<ReductionParseArgs> taskReductionArgs;
614 std::optional<MapParseArgs> useDeviceAddrArgs;
615 std::optional<MapParseArgs> useDevicePtrArgs;
616};
617} // namespace
618
619static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
620 return "private_barrier";
621}
622
623static ParseResult parseClauseWithRegionArgs(
624 OpAsmParser &parser,
625 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
626 SmallVectorImpl<Type> &types,
627 SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
628 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
629 DenseBoolArrayAttr *byref = nullptr,
630 ReductionModifierAttr *modifier = nullptr,
631 UnitAttr *needsBarrier = nullptr) {
632 SmallVector<SymbolRefAttr> symbolVec;
633 SmallVector<int64_t> mapIndicesVec;
634 SmallVector<bool> isByRefVec;
635 unsigned regionArgOffset = regionPrivateArgs.size();
636
637 if (parser.parseLParen())
638 return failure();
639
640 if (modifier && succeeded(Result: parser.parseOptionalKeyword(keyword: "mod"))) {
641 StringRef enumStr;
642 if (parser.parseColon() || parser.parseKeyword(keyword: &enumStr) ||
643 parser.parseComma())
644 return failure();
645 std::optional<ReductionModifier> enumValue =
646 symbolizeReductionModifier(enumStr);
647 if (!enumValue.has_value())
648 return failure();
649 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
650 if (!*modifier)
651 return failure();
652 }
653
654 if (parser.parseCommaSeparatedList(parseElementFn: [&]() {
655 if (byref)
656 isByRefVec.push_back(
657 Elt: parser.parseOptionalKeyword(keyword: "byref").succeeded());
658
659 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
660 return failure();
661
662 if (parser.parseOperand(result&: operands.emplace_back()) ||
663 parser.parseArrow() ||
664 parser.parseArgument(result&: regionPrivateArgs.emplace_back()))
665 return failure();
666
667 if (mapIndices) {
668 if (parser.parseOptionalLSquare().succeeded()) {
669 if (parser.parseKeyword(keyword: "map_idx") || parser.parseEqual() ||
670 parser.parseInteger(result&: mapIndicesVec.emplace_back()) ||
671 parser.parseRSquare())
672 return failure();
673 } else {
674 mapIndicesVec.push_back(Elt: -1);
675 }
676 }
677
678 return success();
679 }))
680 return failure();
681
682 if (parser.parseColon())
683 return failure();
684
685 if (parser.parseCommaSeparatedList(parseElementFn: [&]() {
686 if (parser.parseType(result&: types.emplace_back()))
687 return failure();
688
689 return success();
690 }))
691 return failure();
692
693 if (operands.size() != types.size())
694 return failure();
695
696 if (parser.parseRParen())
697 return failure();
698
699 if (needsBarrier) {
700 if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling())
701 .succeeded())
702 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
703 }
704
705 auto *argsBegin = regionPrivateArgs.begin();
706 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
707 argsBegin + regionArgOffset + types.size());
708 for (auto [prv, type] : llvm::zip_equal(t&: argsSubrange, u&: types)) {
709 prv.type = type;
710 }
711
712 if (symbols) {
713 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
714 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
715 }
716
717 if (!mapIndicesVec.empty())
718 *mapIndices =
719 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
720
721 if (byref)
722 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
723
724 return success();
725}
726
727static ParseResult parseBlockArgClause(
728 OpAsmParser &parser,
729 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
730 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
731 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
732 if (!mapArgs)
733 return failure();
734
735 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
736 entryBlockArgs)))
737 return failure();
738 }
739 return success();
740}
741
742static ParseResult parseBlockArgClause(
743 OpAsmParser &parser,
744 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
745 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
746 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
747 if (!privateArgs)
748 return failure();
749
750 if (failed(parseClauseWithRegionArgs(
751 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
752 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
753 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
754 return failure();
755 }
756 return success();
757}
758
759static ParseResult parseBlockArgClause(
760 OpAsmParser &parser,
761 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
762 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
763 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
764 if (!reductionArgs)
765 return failure();
766 if (failed(parseClauseWithRegionArgs(
767 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
768 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
769 reductionArgs->modifier)))
770 return failure();
771 }
772 return success();
773}
774
775static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
776 AllRegionParseArgs args) {
777 llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
778
779 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "has_device_addr",
780 mapArgs: args.hasDeviceAddrArgs)))
781 return parser.emitError(loc: parser.getCurrentLocation())
782 << "invalid `has_device_addr` format";
783
784 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "host_eval",
785 mapArgs: args.hostEvalArgs)))
786 return parser.emitError(loc: parser.getCurrentLocation())
787 << "invalid `host_eval` format";
788
789 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
790 args.inReductionArgs)))
791 return parser.emitError(loc: parser.getCurrentLocation())
792 << "invalid `in_reduction` format";
793
794 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "map_entries",
795 mapArgs: args.mapArgs)))
796 return parser.emitError(loc: parser.getCurrentLocation())
797 << "invalid `map_entries` format";
798
799 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "private",
800 privateArgs: args.privateArgs)))
801 return parser.emitError(loc: parser.getCurrentLocation())
802 << "invalid `private` format";
803
804 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
805 args.reductionArgs)))
806 return parser.emitError(loc: parser.getCurrentLocation())
807 << "invalid `reduction` format";
808
809 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
810 args.taskReductionArgs)))
811 return parser.emitError(loc: parser.getCurrentLocation())
812 << "invalid `task_reduction` format";
813
814 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "use_device_addr",
815 mapArgs: args.useDeviceAddrArgs)))
816 return parser.emitError(loc: parser.getCurrentLocation())
817 << "invalid `use_device_addr` format";
818
819 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "use_device_ptr",
820 mapArgs: args.useDevicePtrArgs)))
821 return parser.emitError(loc: parser.getCurrentLocation())
822 << "invalid `use_device_addr` format";
823
824 return parser.parseRegion(region, arguments: entryBlockArgs);
825}
826
827// These parseXyz functions correspond to the custom<Xyz> definitions
828// in the .td file(s).
829static ParseResult parseTargetOpRegion(
830 OpAsmParser &parser, Region &region,
831 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hasDeviceAddrVars,
832 SmallVectorImpl<Type> &hasDeviceAddrTypes,
833 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
834 SmallVectorImpl<Type> &hostEvalTypes,
835 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
836 SmallVectorImpl<Type> &inReductionTypes,
837 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
838 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
839 SmallVectorImpl<Type> &mapTypes,
840 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
841 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
842 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
843 AllRegionParseArgs args;
844 args.hasDeviceAddrArgs.emplace(args&: hasDeviceAddrVars, args&: hasDeviceAddrTypes);
845 args.hostEvalArgs.emplace(args&: hostEvalVars, args&: hostEvalTypes);
846 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
847 inReductionByref, inReductionSyms);
848 args.mapArgs.emplace(args&: mapVars, args&: mapTypes);
849 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
850 args&: privateNeedsBarrier, args: &privateMaps);
851 return parseBlockArgRegion(parser, region, args);
852}
853
854static ParseResult parseInReductionPrivateRegion(
855 OpAsmParser &parser, Region &region,
856 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
857 SmallVectorImpl<Type> &inReductionTypes,
858 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
859 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
860 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
861 UnitAttr &privateNeedsBarrier) {
862 AllRegionParseArgs args;
863 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
864 inReductionByref, inReductionSyms);
865 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
866 args&: privateNeedsBarrier);
867 return parseBlockArgRegion(parser, region, args);
868}
869
870static ParseResult parseInReductionPrivateReductionRegion(
871 OpAsmParser &parser, Region &region,
872 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
873 SmallVectorImpl<Type> &inReductionTypes,
874 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
875 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
876 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
877 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
878 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
879 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
880 ArrayAttr &reductionSyms) {
881 AllRegionParseArgs args;
882 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
883 inReductionByref, inReductionSyms);
884 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
885 args&: privateNeedsBarrier);
886 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
887 reductionSyms, &reductionMod);
888 return parseBlockArgRegion(parser, region, args);
889}
890
891static ParseResult parsePrivateRegion(
892 OpAsmParser &parser, Region &region,
893 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
894 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
895 UnitAttr &privateNeedsBarrier) {
896 AllRegionParseArgs args;
897 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
898 args&: privateNeedsBarrier);
899 return parseBlockArgRegion(parser, region, args);
900}
901
902static ParseResult parsePrivateReductionRegion(
903 OpAsmParser &parser, Region &region,
904 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
905 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
906 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
907 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
908 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
909 ArrayAttr &reductionSyms) {
910 AllRegionParseArgs args;
911 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
912 args&: privateNeedsBarrier);
913 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
914 reductionSyms, &reductionMod);
915 return parseBlockArgRegion(parser, region, args);
916}
917
918static ParseResult parseTaskReductionRegion(
919 OpAsmParser &parser, Region &region,
920 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
921 SmallVectorImpl<Type> &taskReductionTypes,
922 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
923 AllRegionParseArgs args;
924 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
925 taskReductionByref, taskReductionSyms);
926 return parseBlockArgRegion(parser, region, args);
927}
928
929static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
930 OpAsmParser &parser, Region &region,
931 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
932 SmallVectorImpl<Type> &useDeviceAddrTypes,
933 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
934 SmallVectorImpl<Type> &useDevicePtrTypes) {
935 AllRegionParseArgs args;
936 args.useDeviceAddrArgs.emplace(args&: useDeviceAddrVars, args&: useDeviceAddrTypes);
937 args.useDevicePtrArgs.emplace(args&: useDevicePtrVars, args&: useDevicePtrTypes);
938 return parseBlockArgRegion(parser, region, args);
939}
940
941//===----------------------------------------------------------------------===//
942// Printers for operations including clauses that define entry block arguments.
943//===----------------------------------------------------------------------===//
944
945namespace {
946struct MapPrintArgs {
947 ValueRange vars;
948 TypeRange types;
949 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
950};
951struct PrivatePrintArgs {
952 ValueRange vars;
953 TypeRange types;
954 ArrayAttr syms;
955 UnitAttr needsBarrier;
956 DenseI64ArrayAttr mapIndices;
957 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
958 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
959 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
960 mapIndices(mapIndices) {}
961};
962struct ReductionPrintArgs {
963 ValueRange vars;
964 TypeRange types;
965 DenseBoolArrayAttr byref;
966 ArrayAttr syms;
967 ReductionModifierAttr modifier;
968 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
969 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
970 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
971};
972struct AllRegionPrintArgs {
973 std::optional<MapPrintArgs> hasDeviceAddrArgs;
974 std::optional<MapPrintArgs> hostEvalArgs;
975 std::optional<ReductionPrintArgs> inReductionArgs;
976 std::optional<MapPrintArgs> mapArgs;
977 std::optional<PrivatePrintArgs> privateArgs;
978 std::optional<ReductionPrintArgs> reductionArgs;
979 std::optional<ReductionPrintArgs> taskReductionArgs;
980 std::optional<MapPrintArgs> useDeviceAddrArgs;
981 std::optional<MapPrintArgs> useDevicePtrArgs;
982};
983} // namespace
984
985static void printClauseWithRegionArgs(
986 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
987 ValueRange argsSubrange, ValueRange operands, TypeRange types,
988 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
989 DenseBoolArrayAttr byref = nullptr,
990 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
991 if (argsSubrange.empty())
992 return;
993
994 p << clauseName << "(";
995
996 if (modifier)
997 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
998
999 if (!symbols) {
1000 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1001 symbols = ArrayAttr::get(ctx, values);
1002 }
1003
1004 if (!mapIndices) {
1005 llvm::SmallVector<int64_t> values(operands.size(), -1);
1006 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1007 }
1008
1009 if (!byref) {
1010 mlir::SmallVector<bool> values(operands.size(), false);
1011 byref = DenseBoolArrayAttr::get(ctx, values);
1012 }
1013
1014 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1015 mapIndices.asArrayRef(),
1016 byref.asArrayRef()),
1017 p, [&p](auto t) {
1018 auto [op, arg, sym, map, isByRef] = t;
1019 if (isByRef)
1020 p << "byref ";
1021 if (sym)
1022 p << sym << " ";
1023
1024 p << op << " -> " << arg;
1025
1026 if (map != -1)
1027 p << " [map_idx=" << map << "]";
1028 });
1029 p << " : ";
1030 llvm::interleaveComma(c: types, os&: p);
1031 p << ") ";
1032
1033 if (needsBarrier)
1034 p << getPrivateNeedsBarrierSpelling() << " ";
1035}
1036
1037static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
1038 StringRef clauseName, ValueRange argsSubrange,
1039 std::optional<MapPrintArgs> mapArgs) {
1040 if (mapArgs)
1041 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, operands: mapArgs->vars,
1042 types: mapArgs->types);
1043}
1044
1045static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
1046 StringRef clauseName, ValueRange argsSubrange,
1047 std::optional<PrivatePrintArgs> privateArgs) {
1048 if (privateArgs)
1049 printClauseWithRegionArgs(
1050 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1051 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1052 /*modifier=*/nullptr, privateArgs->needsBarrier);
1053}
1054
1055static void
1056printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1057 ValueRange argsSubrange,
1058 std::optional<ReductionPrintArgs> reductionArgs) {
1059 if (reductionArgs)
1060 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1061 reductionArgs->vars, reductionArgs->types,
1062 reductionArgs->syms, /*mapIndices=*/nullptr,
1063 reductionArgs->byref, reductionArgs->modifier);
1064}
1065
1066static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1067 const AllRegionPrintArgs &args) {
1068 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1069 MLIRContext *ctx = op->getContext();
1070
1071 printBlockArgClause(p, ctx, "has_device_addr",
1072 iface.getHasDeviceAddrBlockArgs(),
1073 args.hasDeviceAddrArgs);
1074 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1075 args.hostEvalArgs);
1076 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1077 args.inReductionArgs);
1078 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1079 args.mapArgs);
1080 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1081 args.privateArgs);
1082 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1083 args.reductionArgs);
1084 printBlockArgClause(p, ctx, "task_reduction",
1085 iface.getTaskReductionBlockArgs(),
1086 args.taskReductionArgs);
1087 printBlockArgClause(p, ctx, "use_device_addr",
1088 iface.getUseDeviceAddrBlockArgs(),
1089 args.useDeviceAddrArgs);
1090 printBlockArgClause(p, ctx, "use_device_ptr",
1091 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1092
1093 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
1094}
1095
1096// These parseXyz functions correspond to the custom<Xyz> definitions
1097// in the .td file(s).
1098static void printTargetOpRegion(
1099 OpAsmPrinter &p, Operation *op, Region &region,
1100 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1101 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1102 ValueRange inReductionVars, TypeRange inReductionTypes,
1103 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1104 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1105 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1106 DenseI64ArrayAttr privateMaps) {
1107 AllRegionPrintArgs args;
1108 args.hasDeviceAddrArgs.emplace(args&: hasDeviceAddrVars, args&: hasDeviceAddrTypes);
1109 args.hostEvalArgs.emplace(args&: hostEvalVars, args&: hostEvalTypes);
1110 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1111 inReductionByref, inReductionSyms);
1112 args.mapArgs.emplace(args&: mapVars, args&: mapTypes);
1113 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1114 privateNeedsBarrier, privateMaps);
1115 printBlockArgRegion(p, op, region, args);
1116}
1117
1118static void printInReductionPrivateRegion(
1119 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1120 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1121 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1122 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1123 AllRegionPrintArgs args;
1124 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1125 inReductionByref, inReductionSyms);
1126 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1127 privateNeedsBarrier,
1128 /*mapIndices=*/nullptr);
1129 printBlockArgRegion(p, op, region, args);
1130}
1131
1132static void printInReductionPrivateReductionRegion(
1133 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1134 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1135 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1136 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1137 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1138 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1139 ArrayAttr reductionSyms) {
1140 AllRegionPrintArgs args;
1141 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1142 inReductionByref, inReductionSyms);
1143 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1144 privateNeedsBarrier,
1145 /*mapIndices=*/nullptr);
1146 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1147 reductionSyms, reductionMod);
1148 printBlockArgRegion(p, op, region, args);
1149}
1150
1151static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1152 ValueRange privateVars, TypeRange privateTypes,
1153 ArrayAttr privateSyms,
1154 UnitAttr privateNeedsBarrier) {
1155 AllRegionPrintArgs args;
1156 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1157 privateNeedsBarrier,
1158 /*mapIndices=*/nullptr);
1159 printBlockArgRegion(p, op, region, args);
1160}
1161
1162static void printPrivateReductionRegion(
1163 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1164 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1165 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1166 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1167 ArrayAttr reductionSyms) {
1168 AllRegionPrintArgs args;
1169 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1170 privateNeedsBarrier,
1171 /*mapIndices=*/nullptr);
1172 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1173 reductionSyms, reductionMod);
1174 printBlockArgRegion(p, op, region, args);
1175}
1176
1177static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
1178 Region &region,
1179 ValueRange taskReductionVars,
1180 TypeRange taskReductionTypes,
1181 DenseBoolArrayAttr taskReductionByref,
1182 ArrayAttr taskReductionSyms) {
1183 AllRegionPrintArgs args;
1184 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1185 taskReductionByref, taskReductionSyms);
1186 printBlockArgRegion(p, op, region, args);
1187}
1188
1189static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
1190 Region &region,
1191 ValueRange useDeviceAddrVars,
1192 TypeRange useDeviceAddrTypes,
1193 ValueRange useDevicePtrVars,
1194 TypeRange useDevicePtrTypes) {
1195 AllRegionPrintArgs args;
1196 args.useDeviceAddrArgs.emplace(args&: useDeviceAddrVars, args&: useDeviceAddrTypes);
1197 args.useDevicePtrArgs.emplace(args&: useDevicePtrVars, args&: useDevicePtrTypes);
1198 printBlockArgRegion(p, op, region, args);
1199}
1200
1201/// Verifies Reduction Clause
1202static LogicalResult
1203verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1204 OperandRange reductionVars,
1205 std::optional<ArrayRef<bool>> reductionByref) {
1206 if (!reductionVars.empty()) {
1207 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1208 return op->emitOpError()
1209 << "expected as many reduction symbol references "
1210 "as reduction variables";
1211 if (reductionByref && reductionByref->size() != reductionVars.size())
1212 return op->emitError() << "expected as many reduction variable by "
1213 "reference attributes as reduction variables";
1214 } else {
1215 if (reductionSyms)
1216 return op->emitOpError() << "unexpected reduction symbol references";
1217 return success();
1218 }
1219
1220 // TODO: The followings should be done in
1221 // SymbolUserOpInterface::verifySymbolUses.
1222 DenseSet<Value> accumulators;
1223 for (auto args : llvm::zip(t&: reductionVars, u: *reductionSyms)) {
1224 Value accum = std::get<0>(args);
1225
1226 if (!accumulators.insert(V: accum).second)
1227 return op->emitOpError() << "accumulator variable used more than once";
1228
1229 Type varType = accum.getType();
1230 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1231 auto decl =
1232 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
1233 if (!decl)
1234 return op->emitOpError() << "expected symbol reference " << symbolRef
1235 << " to point to a reduction declaration";
1236
1237 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1238 return op->emitOpError()
1239 << "expected accumulator (" << varType
1240 << ") to be the same type as reduction declaration ("
1241 << decl.getAccumulatorType() << ")";
1242 }
1243
1244 return success();
1245}
1246
1247//===----------------------------------------------------------------------===//
1248// Parser, printer and verifier for Copyprivate
1249//===----------------------------------------------------------------------===//
1250
1251/// copyprivate-entry-list ::= copyprivate-entry
1252/// | copyprivate-entry-list `,` copyprivate-entry
1253/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1254static ParseResult parseCopyprivate(
1255 OpAsmParser &parser,
1256 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &copyprivateVars,
1257 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1258 SmallVector<SymbolRefAttr> symsVec;
1259 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1260 if (parser.parseOperand(result&: copyprivateVars.emplace_back()) ||
1261 parser.parseArrow() ||
1262 parser.parseAttribute(symsVec.emplace_back()) ||
1263 parser.parseColonType(result&: copyprivateTypes.emplace_back()))
1264 return failure();
1265 return success();
1266 })))
1267 return failure();
1268 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1269 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1270 return success();
1271}
1272
1273/// Print Copyprivate clause
1274static void printCopyprivate(OpAsmPrinter &p, Operation *op,
1275 OperandRange copyprivateVars,
1276 TypeRange copyprivateTypes,
1277 std::optional<ArrayAttr> copyprivateSyms) {
1278 if (!copyprivateSyms.has_value())
1279 return;
1280 llvm::interleaveComma(
1281 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1282 [&](const auto &args) {
1283 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1284 << std::get<2>(args);
1285 });
1286}
1287
1288/// Verifies CopyPrivate Clause
1289static LogicalResult
1290verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars,
1291 std::optional<ArrayAttr> copyprivateSyms) {
1292 size_t copyprivateSymsSize =
1293 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1294 if (copyprivateSymsSize != copyprivateVars.size())
1295 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1296 << copyprivateVars.size()
1297 << ") and functions (= " << copyprivateSymsSize
1298 << "), both must be equal";
1299 if (!copyprivateSyms.has_value())
1300 return success();
1301
1302 for (auto copyprivateVarAndSym :
1303 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1304 auto symbolRef =
1305 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1306 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1307 funcOp;
1308 if (mlir::func::FuncOp mlirFuncOp =
1309 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
1310 symbolRef))
1311 funcOp = mlirFuncOp;
1312 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1313 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1314 op, symbolRef))
1315 funcOp = llvmFuncOp;
1316
1317 auto getNumArguments = [&] {
1318 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1319 };
1320
1321 auto getArgumentType = [&](unsigned i) {
1322 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1323 *funcOp);
1324 };
1325
1326 if (!funcOp)
1327 return op->emitOpError() << "expected symbol reference " << symbolRef
1328 << " to point to a copy function";
1329
1330 if (getNumArguments() != 2)
1331 return op->emitOpError()
1332 << "expected copy function " << symbolRef << " to have 2 operands";
1333
1334 Type argTy = getArgumentType(0);
1335 if (argTy != getArgumentType(1))
1336 return op->emitOpError() << "expected copy function " << symbolRef
1337 << " arguments to have the same type";
1338
1339 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1340 if (argTy != varType)
1341 return op->emitOpError()
1342 << "expected copy function arguments' type (" << argTy
1343 << ") to be the same as copyprivate variable's type (" << varType
1344 << ")";
1345 }
1346
1347 return success();
1348}
1349
1350//===----------------------------------------------------------------------===//
1351// Parser, printer and verifier for DependVarList
1352//===----------------------------------------------------------------------===//
1353
1354/// depend-entry-list ::= depend-entry
1355/// | depend-entry-list `,` depend-entry
1356/// depend-entry ::= depend-kind `->` ssa-id `:` type
1357static ParseResult
1358parseDependVarList(OpAsmParser &parser,
1359 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars,
1360 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1361 SmallVector<ClauseTaskDependAttr> kindsVec;
1362 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1363 StringRef keyword;
1364 if (parser.parseKeyword(keyword: &keyword) || parser.parseArrow() ||
1365 parser.parseOperand(result&: dependVars.emplace_back()) ||
1366 parser.parseColonType(result&: dependTypes.emplace_back()))
1367 return failure();
1368 if (std::optional<ClauseTaskDepend> keywordDepend =
1369 (symbolizeClauseTaskDepend(keyword)))
1370 kindsVec.emplace_back(
1371 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1372 else
1373 return failure();
1374 return success();
1375 })))
1376 return failure();
1377 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1378 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1379 return success();
1380}
1381
1382/// Print Depend clause
1383static void printDependVarList(OpAsmPrinter &p, Operation *op,
1384 OperandRange dependVars, TypeRange dependTypes,
1385 std::optional<ArrayAttr> dependKinds) {
1386
1387 for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1388 if (i != 0)
1389 p << ", ";
1390 p << stringifyClauseTaskDepend(
1391 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1392 .getValue())
1393 << " -> " << dependVars[i] << " : " << dependTypes[i];
1394 }
1395}
1396
1397/// Verifies Depend clause
1398static LogicalResult verifyDependVarList(Operation *op,
1399 std::optional<ArrayAttr> dependKinds,
1400 OperandRange dependVars) {
1401 if (!dependVars.empty()) {
1402 if (!dependKinds || dependKinds->size() != dependVars.size())
1403 return op->emitOpError() << "expected as many depend values"
1404 " as depend variables";
1405 } else {
1406 if (dependKinds && !dependKinds->empty())
1407 return op->emitOpError() << "unexpected depend values";
1408 return success();
1409 }
1410
1411 return success();
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// Parser, printer and verifier for Synchronization Hint (2.17.12)
1416//===----------------------------------------------------------------------===//
1417
1418/// Parses a Synchronization Hint clause. The value of hint is an integer
1419/// which is a combination of different hints from `omp_sync_hint_t`.
1420///
1421/// hint-clause = `hint` `(` hint-value `)`
1422static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1423 IntegerAttr &hintAttr) {
1424 StringRef hintKeyword;
1425 int64_t hint = 0;
1426 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "none"))) {
1427 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1428 return success();
1429 }
1430 auto parseKeyword = [&]() -> ParseResult {
1431 if (failed(Result: parser.parseKeyword(keyword: &hintKeyword)))
1432 return failure();
1433 if (hintKeyword == "uncontended")
1434 hint |= 1;
1435 else if (hintKeyword == "contended")
1436 hint |= 2;
1437 else if (hintKeyword == "nonspeculative")
1438 hint |= 4;
1439 else if (hintKeyword == "speculative")
1440 hint |= 8;
1441 else
1442 return parser.emitError(loc: parser.getCurrentLocation())
1443 << hintKeyword << " is not a valid hint";
1444 return success();
1445 };
1446 if (parser.parseCommaSeparatedList(parseElementFn: parseKeyword))
1447 return failure();
1448 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1449 return success();
1450}
1451
1452/// Prints a Synchronization Hint clause
1453static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
1454 IntegerAttr hintAttr) {
1455 int64_t hint = hintAttr.getInt();
1456
1457 if (hint == 0) {
1458 p << "none";
1459 return;
1460 }
1461
1462 // Helper function to get n-th bit from the right end of `value`
1463 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1464
1465 bool uncontended = bitn(hint, 0);
1466 bool contended = bitn(hint, 1);
1467 bool nonspeculative = bitn(hint, 2);
1468 bool speculative = bitn(hint, 3);
1469
1470 SmallVector<StringRef> hints;
1471 if (uncontended)
1472 hints.push_back(Elt: "uncontended");
1473 if (contended)
1474 hints.push_back(Elt: "contended");
1475 if (nonspeculative)
1476 hints.push_back(Elt: "nonspeculative");
1477 if (speculative)
1478 hints.push_back(Elt: "speculative");
1479
1480 llvm::interleaveComma(c: hints, os&: p);
1481}
1482
1483/// Verifies a synchronization hint clause
1484static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1485
1486 // Helper function to get n-th bit from the right end of `value`
1487 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1488
1489 bool uncontended = bitn(hint, 0);
1490 bool contended = bitn(hint, 1);
1491 bool nonspeculative = bitn(hint, 2);
1492 bool speculative = bitn(hint, 3);
1493
1494 if (uncontended && contended)
1495 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1496 "omp_sync_hint_contended cannot be combined";
1497 if (nonspeculative && speculative)
1498 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1499 "omp_sync_hint_speculative cannot be combined.";
1500 return success();
1501}
1502
1503//===----------------------------------------------------------------------===//
1504// Parser, printer and verifier for Target
1505//===----------------------------------------------------------------------===//
1506
1507// Helper function to get bitwise AND of `value` and 'flag'
1508uint64_t mapTypeToBitFlag(uint64_t value,
1509 llvm::omp::OpenMPOffloadMappingFlags flag) {
1510 return value & llvm::to_underlying(E: flag);
1511}
1512
1513/// Parses a map_entries map type from a string format back into its numeric
1514/// value.
1515///
1516/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1517/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1518static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1519 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1520 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1521
1522 // This simply verifies the correct keyword is read in, the
1523 // keyword itself is stored inside of the operation
1524 auto parseTypeAndMod = [&]() -> ParseResult {
1525 StringRef mapTypeMod;
1526 if (parser.parseKeyword(keyword: &mapTypeMod))
1527 return failure();
1528
1529 if (mapTypeMod == "always")
1530 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1531
1532 if (mapTypeMod == "implicit")
1533 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1534
1535 if (mapTypeMod == "ompx_hold")
1536 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1537
1538 if (mapTypeMod == "close")
1539 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1540
1541 if (mapTypeMod == "present")
1542 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1543
1544 if (mapTypeMod == "to")
1545 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1546
1547 if (mapTypeMod == "from")
1548 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1549
1550 if (mapTypeMod == "tofrom")
1551 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1552 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1553
1554 if (mapTypeMod == "delete")
1555 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1556
1557 if (mapTypeMod == "return_param")
1558 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1559
1560 return success();
1561 };
1562
1563 if (parser.parseCommaSeparatedList(parseElementFn: parseTypeAndMod))
1564 return failure();
1565
1566 mapType = parser.getBuilder().getIntegerAttr(
1567 parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
1568 llvm::to_underlying(E: mapTypeBits));
1569
1570 return success();
1571}
1572
1573/// Prints a map_entries map type from its numeric value out into its string
1574/// format.
1575static void printMapClause(OpAsmPrinter &p, Operation *op,
1576 IntegerAttr mapType) {
1577 uint64_t mapTypeBits = mapType.getUInt();
1578
1579 bool emitAllocRelease = true;
1580 llvm::SmallVector<std::string, 4> mapTypeStrs;
1581
1582 // handling of always, close, present placed at the beginning of the string
1583 // to aid readability
1584 if (mapTypeToBitFlag(value: mapTypeBits,
1585 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1586 mapTypeStrs.push_back(Elt: "always");
1587 if (mapTypeToBitFlag(value: mapTypeBits,
1588 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1589 mapTypeStrs.push_back(Elt: "implicit");
1590 if (mapTypeToBitFlag(value: mapTypeBits,
1591 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1592 mapTypeStrs.push_back(Elt: "ompx_hold");
1593 if (mapTypeToBitFlag(value: mapTypeBits,
1594 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1595 mapTypeStrs.push_back(Elt: "close");
1596 if (mapTypeToBitFlag(value: mapTypeBits,
1597 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1598 mapTypeStrs.push_back(Elt: "present");
1599
1600 // special handling of to/from/tofrom/delete and release/alloc, release +
1601 // alloc are the abscense of one of the other flags, whereas tofrom requires
1602 // both the to and from flag to be set.
1603 bool to = mapTypeToBitFlag(value: mapTypeBits,
1604 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1605 bool from = mapTypeToBitFlag(
1606 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1607 if (to && from) {
1608 emitAllocRelease = false;
1609 mapTypeStrs.push_back(Elt: "tofrom");
1610 } else if (from) {
1611 emitAllocRelease = false;
1612 mapTypeStrs.push_back(Elt: "from");
1613 } else if (to) {
1614 emitAllocRelease = false;
1615 mapTypeStrs.push_back(Elt: "to");
1616 }
1617 if (mapTypeToBitFlag(value: mapTypeBits,
1618 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1619 emitAllocRelease = false;
1620 mapTypeStrs.push_back(Elt: "delete");
1621 }
1622 if (mapTypeToBitFlag(
1623 value: mapTypeBits,
1624 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1625 emitAllocRelease = false;
1626 mapTypeStrs.push_back(Elt: "return_param");
1627 }
1628 if (emitAllocRelease)
1629 mapTypeStrs.push_back(Elt: "exit_release_or_enter_alloc");
1630
1631 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1632 p << mapTypeStrs[i];
1633 if (i + 1 < mapTypeStrs.size()) {
1634 p << ", ";
1635 }
1636 }
1637}
1638
1639static ParseResult parseMembersIndex(OpAsmParser &parser,
1640 ArrayAttr &membersIdx) {
1641 SmallVector<Attribute> values, memberIdxs;
1642
1643 auto parseIndices = [&]() -> ParseResult {
1644 int64_t value;
1645 if (parser.parseInteger(result&: value))
1646 return failure();
1647 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1648 APInt(64, value, /*isSigned=*/false)));
1649 return success();
1650 };
1651
1652 do {
1653 if (failed(Result: parser.parseLSquare()))
1654 return failure();
1655
1656 if (parser.parseCommaSeparatedList(parseElementFn: parseIndices))
1657 return failure();
1658
1659 if (failed(Result: parser.parseRSquare()))
1660 return failure();
1661
1662 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1663 values.clear();
1664 } while (succeeded(Result: parser.parseOptionalComma()));
1665
1666 if (!memberIdxs.empty())
1667 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1668
1669 return success();
1670}
1671
1672static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1673 ArrayAttr membersIdx) {
1674 if (!membersIdx)
1675 return;
1676
1677 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1678 p << "[";
1679 auto memberIdx = cast<ArrayAttr>(v);
1680 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1681 p << cast<IntegerAttr>(v2).getInt();
1682 });
1683 p << "]";
1684 });
1685}
1686
1687static void printCaptureType(OpAsmPrinter &p, Operation *op,
1688 VariableCaptureKindAttr mapCaptureType) {
1689 std::string typeCapStr;
1690 llvm::raw_string_ostream typeCap(typeCapStr);
1691 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1692 typeCap << "ByRef";
1693 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1694 typeCap << "ByCopy";
1695 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1696 typeCap << "VLAType";
1697 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1698 typeCap << "This";
1699 p << typeCapStr;
1700}
1701
1702static ParseResult parseCaptureType(OpAsmParser &parser,
1703 VariableCaptureKindAttr &mapCaptureType) {
1704 StringRef mapCaptureKey;
1705 if (parser.parseKeyword(keyword: &mapCaptureKey))
1706 return failure();
1707
1708 if (mapCaptureKey == "This")
1709 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1710 parser.getContext(), mlir::omp::VariableCaptureKind::This);
1711 if (mapCaptureKey == "ByRef")
1712 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1713 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1714 if (mapCaptureKey == "ByCopy")
1715 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1716 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1717 if (mapCaptureKey == "VLAType")
1718 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1719 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1720
1721 return success();
1722}
1723
1724static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1725 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars;
1726 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars;
1727
1728 for (auto mapOp : mapVars) {
1729 if (!mapOp.getDefiningOp())
1730 return emitError(loc: op->getLoc(), message: "missing map operation");
1731
1732 if (auto mapInfoOp =
1733 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1734 uint64_t mapTypeBits = mapInfoOp.getMapType();
1735
1736 bool to = mapTypeToBitFlag(
1737 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1738 bool from = mapTypeToBitFlag(
1739 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1740 bool del = mapTypeToBitFlag(
1741 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1742
1743 bool always = mapTypeToBitFlag(
1744 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1745 bool close = mapTypeToBitFlag(
1746 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1747 bool implicit = mapTypeToBitFlag(
1748 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1749
1750 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1751 return emitError(loc: op->getLoc(),
1752 message: "to, from, tofrom and alloc map types are permitted");
1753
1754 if (isa<TargetEnterDataOp>(op) && (from || del))
1755 return emitError(loc: op->getLoc(), message: "to and alloc map types are permitted");
1756
1757 if (isa<TargetExitDataOp>(op) && to)
1758 return emitError(loc: op->getLoc(),
1759 message: "from, release and delete map types are permitted");
1760
1761 if (isa<TargetUpdateOp>(op)) {
1762 if (del) {
1763 return emitError(loc: op->getLoc(),
1764 message: "at least one of to or from map types must be "
1765 "specified, other map types are not permitted");
1766 }
1767
1768 if (!to && !from) {
1769 return emitError(loc: op->getLoc(),
1770 message: "at least one of to or from map types must be "
1771 "specified, other map types are not permitted");
1772 }
1773
1774 auto updateVar = mapInfoOp.getVarPtr();
1775
1776 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1777 (from && updateToVars.contains(updateVar))) {
1778 return emitError(
1779 loc: op->getLoc(),
1780 message: "either to or from map types can be specified, not both");
1781 }
1782
1783 if (always || close || implicit) {
1784 return emitError(
1785 loc: op->getLoc(),
1786 message: "present, mapper and iterator map type modifiers are permitted");
1787 }
1788
1789 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1790 }
1791 } else if (!isa<DeclareMapperInfoOp>(op)) {
1792 return emitError(loc: op->getLoc(),
1793 message: "map argument is not a map entry operation");
1794 }
1795 }
1796
1797 return success();
1798}
1799
1800static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1801 std::optional<DenseI64ArrayAttr> privateMapIndices =
1802 targetOp.getPrivateMapsAttr();
1803
1804 // None of the private operands are mapped.
1805 if (!privateMapIndices.has_value() || !privateMapIndices.value())
1806 return success();
1807
1808 OperandRange privateVars = targetOp.getPrivateVars();
1809
1810 if (privateMapIndices.value().size() !=
1811 static_cast<int64_t>(privateVars.size()))
1812 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
1813 "`private_maps` attribute mismatch");
1814
1815 return success();
1816}
1817
1818//===----------------------------------------------------------------------===//
1819// MapInfoOp
1820//===----------------------------------------------------------------------===//
1821
1822static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1823 StringRef clauseName,
1824 OperandRange vars) {
1825 for (Value var : vars)
1826 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
1827 return op->emitOpError()
1828 << "'" << clauseName
1829 << "' arguments must be defined by 'omp.map.info' ops";
1830 return success();
1831}
1832
1833LogicalResult MapInfoOp::verify() {
1834 if (getMapperId() &&
1835 !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1836 *this, getMapperIdAttr())) {
1837 return emitError("invalid mapper id");
1838 }
1839
1840 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
1841 return failure();
1842
1843 return success();
1844}
1845
1846//===----------------------------------------------------------------------===//
1847// TargetDataOp
1848//===----------------------------------------------------------------------===//
1849
1850void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1851 const TargetDataOperands &clauses) {
1852 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
1853 clauses.mapVars, clauses.useDeviceAddrVars,
1854 clauses.useDevicePtrVars);
1855}
1856
1857LogicalResult TargetDataOp::verify() {
1858 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1859 getUseDeviceAddrVars().empty()) {
1860 return ::emitError(this->getLoc(),
1861 "At least one of map, use_device_ptr_vars, or "
1862 "use_device_addr_vars operand must be present");
1863 }
1864
1865 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
1866 getUseDevicePtrVars())))
1867 return failure();
1868
1869 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
1870 getUseDeviceAddrVars())))
1871 return failure();
1872
1873 return verifyMapClause(*this, getMapVars());
1874}
1875
1876//===----------------------------------------------------------------------===//
1877// TargetEnterDataOp
1878//===----------------------------------------------------------------------===//
1879
1880void TargetEnterDataOp::build(
1881 OpBuilder &builder, OperationState &state,
1882 const TargetEnterExitUpdateDataOperands &clauses) {
1883 MLIRContext *ctx = builder.getContext();
1884 TargetEnterDataOp::build(builder, state,
1885 makeArrayAttr(ctx, clauses.dependKinds),
1886 clauses.dependVars, clauses.device, clauses.ifExpr,
1887 clauses.mapVars, clauses.nowait);
1888}
1889
1890LogicalResult TargetEnterDataOp::verify() {
1891 LogicalResult verifyDependVars =
1892 verifyDependVarList(*this, getDependKinds(), getDependVars());
1893 return failed(verifyDependVars) ? verifyDependVars
1894 : verifyMapClause(*this, getMapVars());
1895}
1896
1897//===----------------------------------------------------------------------===//
1898// TargetExitDataOp
1899//===----------------------------------------------------------------------===//
1900
1901void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1902 const TargetEnterExitUpdateDataOperands &clauses) {
1903 MLIRContext *ctx = builder.getContext();
1904 TargetExitDataOp::build(builder, state,
1905 makeArrayAttr(ctx, clauses.dependKinds),
1906 clauses.dependVars, clauses.device, clauses.ifExpr,
1907 clauses.mapVars, clauses.nowait);
1908}
1909
1910LogicalResult TargetExitDataOp::verify() {
1911 LogicalResult verifyDependVars =
1912 verifyDependVarList(*this, getDependKinds(), getDependVars());
1913 return failed(verifyDependVars) ? verifyDependVars
1914 : verifyMapClause(*this, getMapVars());
1915}
1916
1917//===----------------------------------------------------------------------===//
1918// TargetUpdateOp
1919//===----------------------------------------------------------------------===//
1920
1921void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1922 const TargetEnterExitUpdateDataOperands &clauses) {
1923 MLIRContext *ctx = builder.getContext();
1924 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
1925 clauses.dependVars, clauses.device, clauses.ifExpr,
1926 clauses.mapVars, clauses.nowait);
1927}
1928
1929LogicalResult TargetUpdateOp::verify() {
1930 LogicalResult verifyDependVars =
1931 verifyDependVarList(*this, getDependKinds(), getDependVars());
1932 return failed(verifyDependVars) ? verifyDependVars
1933 : verifyMapClause(*this, getMapVars());
1934}
1935
1936//===----------------------------------------------------------------------===//
1937// TargetOp
1938//===----------------------------------------------------------------------===//
1939
1940void TargetOp::build(OpBuilder &builder, OperationState &state,
1941 const TargetOperands &clauses) {
1942 MLIRContext *ctx = builder.getContext();
1943 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1944 // inReductionByref, inReductionSyms.
1945 TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1946 clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
1947 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1948 clauses.hostEvalVars, clauses.ifExpr,
1949 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1950 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1951 clauses.mapVars, clauses.nowait, clauses.privateVars,
1952 makeArrayAttr(ctx, clauses.privateSyms),
1953 clauses.privateNeedsBarrier, clauses.threadLimit,
1954 /*private_maps=*/nullptr);
1955}
1956
1957LogicalResult TargetOp::verify() {
1958 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
1959 return failure();
1960
1961 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
1962 getHasDeviceAddrVars())))
1963 return failure();
1964
1965 if (failed(verifyMapClause(*this, getMapVars())))
1966 return failure();
1967
1968 return verifyPrivateVarsMapping(*this);
1969}
1970
1971LogicalResult TargetOp::verifyRegions() {
1972 auto teamsOps = getOps<TeamsOp>();
1973 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1974 return emitError("target containing multiple 'omp.teams' nested ops");
1975
1976 // Check that host_eval values are only used in legal ways.
1977 Operation *capturedOp = getInnermostCapturedOmpOp();
1978 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1979 for (Value hostEvalArg :
1980 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1981 for (Operation *user : hostEvalArg.getUsers()) {
1982 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1983 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1984 teamsOp.getNumTeamsUpper(),
1985 teamsOp.getThreadLimit()},
1986 hostEvalArg))
1987 continue;
1988
1989 return emitOpError() << "host_eval argument only legal as 'num_teams' "
1990 "and 'thread_limit' in 'omp.teams'";
1991 }
1992 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1993 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1994 parallelOp->isAncestor(capturedOp) &&
1995 hostEvalArg == parallelOp.getNumThreads())
1996 continue;
1997
1998 return emitOpError()
1999 << "host_eval argument only legal as 'num_threads' in "
2000 "'omp.parallel' when representing target SPMD";
2001 }
2002 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2003 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2004 loopNestOp.getOperation() == capturedOp &&
2005 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2006 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2007 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2008 continue;
2009
2010 return emitOpError() << "host_eval argument only legal as loop bounds "
2011 "and steps in 'omp.loop_nest' when trip count "
2012 "must be evaluated in the host";
2013 }
2014
2015 return emitOpError() << "host_eval argument illegal use in '"
2016 << user->getName() << "' operation";
2017 }
2018 }
2019 return success();
2020}
2021
2022static Operation *
2023findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2024 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2025 assert(rootOp && "expected valid operation");
2026
2027 Dialect *ompDialect = rootOp->getDialect();
2028 Operation *capturedOp = nullptr;
2029 DominanceInfo domInfo;
2030
2031 // Process in pre-order to check operations from outermost to innermost,
2032 // ensuring we only enter the region of an operation if it meets the criteria
2033 // for being captured. We stop the exploration of nested operations as soon as
2034 // we process a region holding no operations to be captured.
2035 rootOp->walk<WalkOrder::PreOrder>(callback: [&](Operation *op) {
2036 if (op == rootOp)
2037 return WalkResult::advance();
2038
2039 // Ignore operations of other dialects or omp operations with no regions,
2040 // because these will only be checked if they are siblings of an omp
2041 // operation that can potentially be captured.
2042 bool isOmpDialect = op->getDialect() == ompDialect;
2043 bool hasRegions = op->getNumRegions() > 0;
2044 if (!isOmpDialect || !hasRegions)
2045 return WalkResult::skip();
2046
2047 // This operation cannot be captured if it can be executed more than once
2048 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2049 // be executed before all exits of the region (i.e. it doesn't dominate all
2050 // blocks with no successors reachable from the entry block).
2051 if (checkSingleMandatoryExec) {
2052 Region *parentRegion = op->getParentRegion();
2053 Block *parentBlock = op->getBlock();
2054
2055 for (Block *successor : parentBlock->getSuccessors())
2056 if (successor->isReachable(other: parentBlock))
2057 return WalkResult::interrupt();
2058
2059 for (Block &block : *parentRegion)
2060 if (domInfo.isReachableFromEntry(a: &block) && block.hasNoSuccessors() &&
2061 !domInfo.dominates(a: parentBlock, b: &block))
2062 return WalkResult::interrupt();
2063 }
2064
2065 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2066 // into nested operations.
2067 for (Operation &sibling : op->getParentRegion()->getOps())
2068 if (&sibling != op && !siblingAllowedFn(&sibling))
2069 return WalkResult::interrupt();
2070
2071 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2072 // Otherwise, process the contents of this operation.
2073 capturedOp = op;
2074 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2075 : WalkResult::advance();
2076 });
2077
2078 return capturedOp;
2079}
2080
2081Operation *TargetOp::getInnermostCapturedOmpOp() {
2082 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2083
2084 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2085 // effects, but don't include a memory write effect.
2086 return findCapturedOmpOp(
2087 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2088 if (!sibling)
2089 return false;
2090
2091 if (ompDialect == sibling->getDialect())
2092 return sibling->hasTrait<OpTrait::IsTerminator>();
2093
2094 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2095 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2096 effects;
2097 memOp.getEffects(effects);
2098 return !llvm::any_of(
2099 effects, [&](MemoryEffects::EffectInstance &effect) {
2100 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2101 isa<SideEffects::AutomaticAllocationScopeResource>(
2102 effect.getResource());
2103 });
2104 }
2105 return true;
2106 });
2107}
2108
2109TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2110 // A non-null captured op is only valid if it resides inside of a TargetOp
2111 // and is the result of calling getInnermostCapturedOmpOp() on it.
2112 TargetOp targetOp =
2113 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2114 assert((!capturedOp ||
2115 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2116 "unexpected captured op");
2117
2118 // If it's not capturing a loop, it's a default target region.
2119 if (!isa_and_present<LoopNestOp>(capturedOp))
2120 return TargetRegionFlags::generic;
2121
2122 // Get the innermost non-simd loop wrapper.
2123 SmallVector<LoopWrapperInterface> loopWrappers;
2124 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2125 assert(!loopWrappers.empty());
2126
2127 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2128 if (isa<SimdOp>(innermostWrapper))
2129 innermostWrapper = std::next(innermostWrapper);
2130
2131 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2132 if (numWrappers != 1 && numWrappers != 2)
2133 return TargetRegionFlags::generic;
2134
2135 // Detect target-teams-distribute-parallel-wsloop[-simd].
2136 if (numWrappers == 2) {
2137 if (!isa<WsloopOp>(innermostWrapper))
2138 return TargetRegionFlags::generic;
2139
2140 innermostWrapper = std::next(innermostWrapper);
2141 if (!isa<DistributeOp>(innermostWrapper))
2142 return TargetRegionFlags::generic;
2143
2144 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2145 if (!isa_and_present<ParallelOp>(parallelOp))
2146 return TargetRegionFlags::generic;
2147
2148 Operation *teamsOp = parallelOp->getParentOp();
2149 if (!isa_and_present<TeamsOp>(teamsOp))
2150 return TargetRegionFlags::generic;
2151
2152 if (teamsOp->getParentOp() == targetOp.getOperation())
2153 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2154 }
2155 // Detect target-teams-distribute[-simd] and target-teams-loop.
2156 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2157 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2158 if (!isa_and_present<TeamsOp>(teamsOp))
2159 return TargetRegionFlags::generic;
2160
2161 if (teamsOp->getParentOp() != targetOp.getOperation())
2162 return TargetRegionFlags::generic;
2163
2164 if (isa<LoopOp>(innermostWrapper))
2165 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2166
2167 // Find single immediately nested captured omp.parallel and add spmd flag
2168 // (generic-spmd case).
2169 //
2170 // TODO: This shouldn't have to be done here, as it is too easy to break.
2171 // The openmp-opt pass should be updated to be able to promote kernels like
2172 // this from "Generic" to "Generic-SPMD". However, the use of the
2173 // `kmpc_distribute_static_loop` family of functions produced by the
2174 // OMPIRBuilder for these kernels prevents that from working.
2175 Dialect *ompDialect = targetOp->getDialect();
2176 Operation *nestedCapture = findCapturedOmpOp(
2177 capturedOp, /*checkSingleMandatoryExec=*/false,
2178 [&](Operation *sibling) {
2179 return sibling && (ompDialect != sibling->getDialect() ||
2180 sibling->hasTrait<OpTrait::IsTerminator>());
2181 });
2182
2183 TargetRegionFlags result =
2184 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2185
2186 if (!nestedCapture)
2187 return result;
2188
2189 while (nestedCapture->getParentOp() != capturedOp)
2190 nestedCapture = nestedCapture->getParentOp();
2191
2192 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2193 : result;
2194 }
2195 // Detect target-parallel-wsloop[-simd].
2196 else if (isa<WsloopOp>(innermostWrapper)) {
2197 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2198 if (!isa_and_present<ParallelOp>(parallelOp))
2199 return TargetRegionFlags::generic;
2200
2201 if (parallelOp->getParentOp() == targetOp.getOperation())
2202 return TargetRegionFlags::spmd;
2203 }
2204
2205 return TargetRegionFlags::generic;
2206}
2207
2208//===----------------------------------------------------------------------===//
2209// ParallelOp
2210//===----------------------------------------------------------------------===//
2211
2212void ParallelOp::build(OpBuilder &builder, OperationState &state,
2213 ArrayRef<NamedAttribute> attributes) {
2214 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2215 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2216 /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2217 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2218 /*proc_bind_kind=*/nullptr,
2219 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2220 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2221 state.addAttributes(attributes);
2222}
2223
2224void ParallelOp::build(OpBuilder &builder, OperationState &state,
2225 const ParallelOperands &clauses) {
2226 MLIRContext *ctx = builder.getContext();
2227 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2228 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2229 makeArrayAttr(ctx, clauses.privateSyms),
2230 clauses.privateNeedsBarrier, clauses.procBindKind,
2231 clauses.reductionMod, clauses.reductionVars,
2232 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2233 makeArrayAttr(ctx, clauses.reductionSyms));
2234}
2235
2236template <typename OpType>
2237static LogicalResult verifyPrivateVarList(OpType &op) {
2238 auto privateVars = op.getPrivateVars();
2239 auto privateSyms = op.getPrivateSymsAttr();
2240
2241 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2242 return success();
2243
2244 auto numPrivateVars = privateVars.size();
2245 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2246
2247 if (numPrivateVars != numPrivateSyms)
2248 return op.emitError() << "inconsistent number of private variables and "
2249 "privatizer op symbols, private vars: "
2250 << numPrivateVars
2251 << " vs. privatizer op symbols: " << numPrivateSyms;
2252
2253 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2254 Type varType = std::get<0>(privateVarInfo).getType();
2255 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2256 PrivateClauseOp privatizerOp =
2257 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2258
2259 if (privatizerOp == nullptr)
2260 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2261 << privateSym << "'";
2262
2263 Type privatizerType = privatizerOp.getArgType();
2264
2265 if (privatizerType && (varType != privatizerType))
2266 return op.emitError()
2267 << "type mismatch between a "
2268 << (privatizerOp.getDataSharingType() ==
2269 DataSharingClauseType::Private
2270 ? "private"
2271 : "firstprivate")
2272 << " variable and its privatizer op, var type: " << varType
2273 << " vs. privatizer op type: " << privatizerType;
2274 }
2275
2276 return success();
2277}
2278
2279LogicalResult ParallelOp::verify() {
2280 if (getAllocateVars().size() != getAllocatorVars().size())
2281 return emitError(
2282 "expected equal sizes for allocate and allocator variables");
2283
2284 if (failed(verifyPrivateVarList(*this)))
2285 return failure();
2286
2287 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2288 getReductionByref());
2289}
2290
2291LogicalResult ParallelOp::verifyRegions() {
2292 auto distChildOps = getOps<DistributeOp>();
2293 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2294 if (numDistChildOps > 1)
2295 return emitError()
2296 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2297
2298 if (numDistChildOps == 1) {
2299 if (!isComposite())
2300 return emitError()
2301 << "'omp.composite' attribute missing from composite operation";
2302
2303 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2304 Operation &distributeOp = **distChildOps.begin();
2305 for (Operation &childOp : getOps()) {
2306 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2307 continue;
2308
2309 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2310 return emitError() << "unexpected OpenMP operation inside of composite "
2311 "'omp.parallel': "
2312 << childOp.getName();
2313 }
2314 } else if (isComposite()) {
2315 return emitError()
2316 << "'omp.composite' attribute present in non-composite operation";
2317 }
2318 return success();
2319}
2320
2321//===----------------------------------------------------------------------===//
2322// TeamsOp
2323//===----------------------------------------------------------------------===//
2324
2325static bool opInGlobalImplicitParallelRegion(Operation *op) {
2326 while ((op = op->getParentOp()))
2327 if (isa<OpenMPDialect>(op->getDialect()))
2328 return false;
2329 return true;
2330}
2331
2332void TeamsOp::build(OpBuilder &builder, OperationState &state,
2333 const TeamsOperands &clauses) {
2334 MLIRContext *ctx = builder.getContext();
2335 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2336 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2337 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2338 /*private_vars=*/{}, /*private_syms=*/nullptr,
2339 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2340 clauses.reductionVars,
2341 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2342 makeArrayAttr(ctx, clauses.reductionSyms),
2343 clauses.threadLimit);
2344}
2345
2346LogicalResult TeamsOp::verify() {
2347 // Check parent region
2348 // TODO If nested inside of a target region, also check that it does not
2349 // contain any statements, declarations or directives other than this
2350 // omp.teams construct. The issue is how to support the initialization of
2351 // this operation's own arguments (allow SSA values across omp.target?).
2352 Operation *op = getOperation();
2353 if (!isa<TargetOp>(op->getParentOp()) &&
2354 !opInGlobalImplicitParallelRegion(op))
2355 return emitError("expected to be nested inside of omp.target or not nested "
2356 "in any OpenMP dialect operations");
2357
2358 // Check for num_teams clause restrictions
2359 if (auto numTeamsLowerBound = getNumTeamsLower()) {
2360 auto numTeamsUpperBound = getNumTeamsUpper();
2361 if (!numTeamsUpperBound)
2362 return emitError("expected num_teams upper bound to be defined if the "
2363 "lower bound is defined");
2364 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2365 return emitError(
2366 "expected num_teams upper bound and lower bound to be the same type");
2367 }
2368
2369 // Check for allocate clause restrictions
2370 if (getAllocateVars().size() != getAllocatorVars().size())
2371 return emitError(
2372 "expected equal sizes for allocate and allocator variables");
2373
2374 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2375 getReductionByref());
2376}
2377
2378//===----------------------------------------------------------------------===//
2379// SectionOp
2380//===----------------------------------------------------------------------===//
2381
2382OperandRange SectionOp::getPrivateVars() {
2383 return getParentOp().getPrivateVars();
2384}
2385
2386OperandRange SectionOp::getReductionVars() {
2387 return getParentOp().getReductionVars();
2388}
2389
2390//===----------------------------------------------------------------------===//
2391// SectionsOp
2392//===----------------------------------------------------------------------===//
2393
2394void SectionsOp::build(OpBuilder &builder, OperationState &state,
2395 const SectionsOperands &clauses) {
2396 MLIRContext *ctx = builder.getContext();
2397 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2398 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2399 clauses.nowait, /*private_vars=*/{},
2400 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2401 clauses.reductionMod, clauses.reductionVars,
2402 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2403 makeArrayAttr(ctx, clauses.reductionSyms));
2404}
2405
2406LogicalResult SectionsOp::verify() {
2407 if (getAllocateVars().size() != getAllocatorVars().size())
2408 return emitError(
2409 "expected equal sizes for allocate and allocator variables");
2410
2411 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2412 getReductionByref());
2413}
2414
2415LogicalResult SectionsOp::verifyRegions() {
2416 for (auto &inst : *getRegion().begin()) {
2417 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2418 return emitOpError()
2419 << "expected omp.section op or terminator op inside region";
2420 }
2421 }
2422
2423 return success();
2424}
2425
2426//===----------------------------------------------------------------------===//
2427// SingleOp
2428//===----------------------------------------------------------------------===//
2429
2430void SingleOp::build(OpBuilder &builder, OperationState &state,
2431 const SingleOperands &clauses) {
2432 MLIRContext *ctx = builder.getContext();
2433 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2434 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2435 clauses.copyprivateVars,
2436 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2437 /*private_vars=*/{}, /*private_syms=*/nullptr,
2438 /*private_needs_barrier=*/nullptr);
2439}
2440
2441LogicalResult SingleOp::verify() {
2442 // Check for allocate clause restrictions
2443 if (getAllocateVars().size() != getAllocatorVars().size())
2444 return emitError(
2445 "expected equal sizes for allocate and allocator variables");
2446
2447 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2448 getCopyprivateSyms());
2449}
2450
2451//===----------------------------------------------------------------------===//
2452// WorkshareOp
2453//===----------------------------------------------------------------------===//
2454
2455void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2456 const WorkshareOperands &clauses) {
2457 WorkshareOp::build(builder, state, clauses.nowait);
2458}
2459
2460//===----------------------------------------------------------------------===//
2461// WorkshareLoopWrapperOp
2462//===----------------------------------------------------------------------===//
2463
2464LogicalResult WorkshareLoopWrapperOp::verify() {
2465 if (!(*this)->getParentOfType<WorkshareOp>())
2466 return emitOpError() << "must be nested in an omp.workshare";
2467 return success();
2468}
2469
2470LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2471 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2472 getNestedWrapper())
2473 return emitOpError() << "expected to be a standalone loop wrapper";
2474
2475 return success();
2476}
2477
2478//===----------------------------------------------------------------------===//
2479// LoopWrapperInterface
2480//===----------------------------------------------------------------------===//
2481
2482LogicalResult LoopWrapperInterface::verifyImpl() {
2483 Operation *op = this->getOperation();
2484 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2485 !op->hasTrait<OpTrait::SingleBlock>())
2486 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2487 "and `SingleBlock` traits";
2488
2489 if (op->getNumRegions() != 1)
2490 return emitOpError() << "loop wrapper does not contain exactly one region";
2491
2492 Region &region = op->getRegion(0);
2493 if (range_size(region.getOps()) != 1)
2494 return emitOpError()
2495 << "loop wrapper does not contain exactly one nested op";
2496
2497 Operation &firstOp = *region.op_begin();
2498 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2499 return emitOpError() << "nested in loop wrapper is not another loop "
2500 "wrapper or `omp.loop_nest`";
2501
2502 return success();
2503}
2504
2505//===----------------------------------------------------------------------===//
2506// LoopOp
2507//===----------------------------------------------------------------------===//
2508
2509void LoopOp::build(OpBuilder &builder, OperationState &state,
2510 const LoopOperands &clauses) {
2511 MLIRContext *ctx = builder.getContext();
2512
2513 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2514 makeArrayAttr(ctx, clauses.privateSyms),
2515 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2516 clauses.reductionMod, clauses.reductionVars,
2517 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2518 makeArrayAttr(ctx, clauses.reductionSyms));
2519}
2520
2521LogicalResult LoopOp::verify() {
2522 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2523 getReductionByref());
2524}
2525
2526LogicalResult LoopOp::verifyRegions() {
2527 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2528 getNestedWrapper())
2529 return emitOpError() << "expected to be a standalone loop wrapper";
2530
2531 return success();
2532}
2533
2534//===----------------------------------------------------------------------===//
2535// WsloopOp
2536//===----------------------------------------------------------------------===//
2537
2538void WsloopOp::build(OpBuilder &builder, OperationState &state,
2539 ArrayRef<NamedAttribute> attributes) {
2540 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2541 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2542 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2543 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2544 /*private_needs_barrier=*/false,
2545 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2546 /*reduction_byref=*/nullptr,
2547 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2548 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2549 /*schedule_simd=*/false);
2550 state.addAttributes(attributes);
2551}
2552
2553void WsloopOp::build(OpBuilder &builder, OperationState &state,
2554 const WsloopOperands &clauses) {
2555 MLIRContext *ctx = builder.getContext();
2556 // TODO: Store clauses in op: allocateVars, allocatorVars
2557 WsloopOp::build(
2558 builder, state,
2559 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2560 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2561 clauses.ordered, clauses.privateVars,
2562 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2563 clauses.reductionMod, clauses.reductionVars,
2564 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2565 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2566 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2567}
2568
2569LogicalResult WsloopOp::verify() {
2570 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2571 getReductionByref());
2572}
2573
2574LogicalResult WsloopOp::verifyRegions() {
2575 bool isCompositeChildLeaf =
2576 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2577
2578 if (LoopWrapperInterface nested = getNestedWrapper()) {
2579 if (!isComposite())
2580 return emitError()
2581 << "'omp.composite' attribute missing from composite wrapper";
2582
2583 // Check for the allowed leaf constructs that may appear in a composite
2584 // construct directly after DO/FOR.
2585 if (!isa<SimdOp>(nested))
2586 return emitError() << "only supported nested wrapper is 'omp.simd'";
2587
2588 } else if (isComposite() && !isCompositeChildLeaf) {
2589 return emitError()
2590 << "'omp.composite' attribute present in non-composite wrapper";
2591 } else if (!isComposite() && isCompositeChildLeaf) {
2592 return emitError()
2593 << "'omp.composite' attribute missing from composite wrapper";
2594 }
2595
2596 return success();
2597}
2598
2599//===----------------------------------------------------------------------===//
2600// Simd construct [2.9.3.1]
2601//===----------------------------------------------------------------------===//
2602
2603void SimdOp::build(OpBuilder &builder, OperationState &state,
2604 const SimdOperands &clauses) {
2605 MLIRContext *ctx = builder.getContext();
2606 // TODO Store clauses in op: linearVars, linearStepVars
2607 SimdOp::build(builder, state, clauses.alignedVars,
2608 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2609 /*linear_vars=*/{}, /*linear_step_vars=*/{},
2610 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2611 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2612 clauses.privateNeedsBarrier, clauses.reductionMod,
2613 clauses.reductionVars,
2614 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2615 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2616 clauses.simdlen);
2617}
2618
2619LogicalResult SimdOp::verify() {
2620 if (getSimdlen().has_value() && getSafelen().has_value() &&
2621 getSimdlen().value() > getSafelen().value())
2622 return emitOpError()
2623 << "simdlen clause and safelen clause are both present, but the "
2624 "simdlen value is not less than or equal to safelen value";
2625
2626 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2627 return failure();
2628
2629 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2630 return failure();
2631
2632 bool isCompositeChildLeaf =
2633 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2634
2635 if (!isComposite() && isCompositeChildLeaf)
2636 return emitError()
2637 << "'omp.composite' attribute missing from composite wrapper";
2638
2639 if (isComposite() && !isCompositeChildLeaf)
2640 return emitError()
2641 << "'omp.composite' attribute present in non-composite wrapper";
2642
2643 return success();
2644}
2645
2646LogicalResult SimdOp::verifyRegions() {
2647 if (getNestedWrapper())
2648 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2649
2650 return success();
2651}
2652
2653//===----------------------------------------------------------------------===//
2654// Distribute construct [2.9.4.1]
2655//===----------------------------------------------------------------------===//
2656
2657void DistributeOp::build(OpBuilder &builder, OperationState &state,
2658 const DistributeOperands &clauses) {
2659 DistributeOp::build(builder, state, clauses.allocateVars,
2660 clauses.allocatorVars, clauses.distScheduleStatic,
2661 clauses.distScheduleChunkSize, clauses.order,
2662 clauses.orderMod, clauses.privateVars,
2663 makeArrayAttr(builder.getContext(), clauses.privateSyms),
2664 clauses.privateNeedsBarrier);
2665}
2666
2667LogicalResult DistributeOp::verify() {
2668 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2669 return emitOpError() << "chunk size set without "
2670 "dist_schedule_static being present";
2671
2672 if (getAllocateVars().size() != getAllocatorVars().size())
2673 return emitError(
2674 "expected equal sizes for allocate and allocator variables");
2675
2676 return success();
2677}
2678
2679LogicalResult DistributeOp::verifyRegions() {
2680 if (LoopWrapperInterface nested = getNestedWrapper()) {
2681 if (!isComposite())
2682 return emitError()
2683 << "'omp.composite' attribute missing from composite wrapper";
2684 // Check for the allowed leaf constructs that may appear in a composite
2685 // construct directly after DISTRIBUTE.
2686 if (isa<WsloopOp>(nested)) {
2687 Operation *parentOp = (*this)->getParentOp();
2688 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2689 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2690 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2691 "when a composite 'omp.parallel' is the direct "
2692 "parent";
2693 }
2694 } else if (!isa<SimdOp>(nested))
2695 return emitError() << "only supported nested wrappers are 'omp.simd' and "
2696 "'omp.wsloop'";
2697 } else if (isComposite()) {
2698 return emitError()
2699 << "'omp.composite' attribute present in non-composite wrapper";
2700 }
2701
2702 return success();
2703}
2704
2705//===----------------------------------------------------------------------===//
2706// DeclareMapperOp / DeclareMapperInfoOp
2707//===----------------------------------------------------------------------===//
2708
2709LogicalResult DeclareMapperInfoOp::verify() {
2710 return verifyMapClause(*this, getMapVars());
2711}
2712
2713LogicalResult DeclareMapperOp::verifyRegions() {
2714 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2715 getRegion().getBlocks().front().getTerminator()))
2716 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2717
2718 return success();
2719}
2720
2721//===----------------------------------------------------------------------===//
2722// DeclareReductionOp
2723//===----------------------------------------------------------------------===//
2724
2725LogicalResult DeclareReductionOp::verifyRegions() {
2726 if (!getAllocRegion().empty()) {
2727 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2728 if (yieldOp.getResults().size() != 1 ||
2729 yieldOp.getResults().getTypes()[0] != getType())
2730 return emitOpError() << "expects alloc region to yield a value "
2731 "of the reduction type";
2732 }
2733 }
2734
2735 if (getInitializerRegion().empty())
2736 return emitOpError() << "expects non-empty initializer region";
2737 Block &initializerEntryBlock = getInitializerRegion().front();
2738
2739 if (initializerEntryBlock.getNumArguments() == 1) {
2740 if (!getAllocRegion().empty())
2741 return emitOpError() << "expects two arguments to the initializer region "
2742 "when an allocation region is used";
2743 } else if (initializerEntryBlock.getNumArguments() == 2) {
2744 if (getAllocRegion().empty())
2745 return emitOpError() << "expects one argument to the initializer region "
2746 "when no allocation region is used";
2747 } else {
2748 return emitOpError()
2749 << "expects one or two arguments to the initializer region";
2750 }
2751
2752 for (mlir::Value arg : initializerEntryBlock.getArguments())
2753 if (arg.getType() != getType())
2754 return emitOpError() << "expects initializer region argument to match "
2755 "the reduction type";
2756
2757 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2758 if (yieldOp.getResults().size() != 1 ||
2759 yieldOp.getResults().getTypes()[0] != getType())
2760 return emitOpError() << "expects initializer region to yield a value "
2761 "of the reduction type";
2762 }
2763
2764 if (getReductionRegion().empty())
2765 return emitOpError() << "expects non-empty reduction region";
2766 Block &reductionEntryBlock = getReductionRegion().front();
2767 if (reductionEntryBlock.getNumArguments() != 2 ||
2768 reductionEntryBlock.getArgumentTypes()[0] !=
2769 reductionEntryBlock.getArgumentTypes()[1] ||
2770 reductionEntryBlock.getArgumentTypes()[0] != getType())
2771 return emitOpError() << "expects reduction region with two arguments of "
2772 "the reduction type";
2773 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2774 if (yieldOp.getResults().size() != 1 ||
2775 yieldOp.getResults().getTypes()[0] != getType())
2776 return emitOpError() << "expects reduction region to yield a value "
2777 "of the reduction type";
2778 }
2779
2780 if (!getAtomicReductionRegion().empty()) {
2781 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2782 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2783 atomicReductionEntryBlock.getArgumentTypes()[0] !=
2784 atomicReductionEntryBlock.getArgumentTypes()[1])
2785 return emitOpError() << "expects atomic reduction region with two "
2786 "arguments of the same type";
2787 auto ptrType = llvm::dyn_cast<PointerLikeType>(
2788 atomicReductionEntryBlock.getArgumentTypes()[0]);
2789 if (!ptrType ||
2790 (ptrType.getElementType() && ptrType.getElementType() != getType()))
2791 return emitOpError() << "expects atomic reduction region arguments to "
2792 "be accumulators containing the reduction type";
2793 }
2794
2795 if (getCleanupRegion().empty())
2796 return success();
2797 Block &cleanupEntryBlock = getCleanupRegion().front();
2798 if (cleanupEntryBlock.getNumArguments() != 1 ||
2799 cleanupEntryBlock.getArgument(0).getType() != getType())
2800 return emitOpError() << "expects cleanup region with one argument "
2801 "of the reduction type";
2802
2803 return success();
2804}
2805
2806//===----------------------------------------------------------------------===//
2807// TaskOp
2808//===----------------------------------------------------------------------===//
2809
2810void TaskOp::build(OpBuilder &builder, OperationState &state,
2811 const TaskOperands &clauses) {
2812 MLIRContext *ctx = builder.getContext();
2813 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2814 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2815 clauses.final, clauses.ifExpr, clauses.inReductionVars,
2816 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2817 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2818 clauses.priority, /*private_vars=*/clauses.privateVars,
2819 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2820 clauses.privateNeedsBarrier, clauses.untied,
2821 clauses.eventHandle);
2822}
2823
2824LogicalResult TaskOp::verify() {
2825 LogicalResult verifyDependVars =
2826 verifyDependVarList(*this, getDependKinds(), getDependVars());
2827 return failed(verifyDependVars)
2828 ? verifyDependVars
2829 : verifyReductionVarList(*this, getInReductionSyms(),
2830 getInReductionVars(),
2831 getInReductionByref());
2832}
2833
2834//===----------------------------------------------------------------------===//
2835// TaskgroupOp
2836//===----------------------------------------------------------------------===//
2837
2838void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2839 const TaskgroupOperands &clauses) {
2840 MLIRContext *ctx = builder.getContext();
2841 TaskgroupOp::build(builder, state, clauses.allocateVars,
2842 clauses.allocatorVars, clauses.taskReductionVars,
2843 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
2844 makeArrayAttr(ctx, clauses.taskReductionSyms));
2845}
2846
2847LogicalResult TaskgroupOp::verify() {
2848 return verifyReductionVarList(*this, getTaskReductionSyms(),
2849 getTaskReductionVars(),
2850 getTaskReductionByref());
2851}
2852
2853//===----------------------------------------------------------------------===//
2854// TaskloopOp
2855//===----------------------------------------------------------------------===//
2856
2857void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2858 const TaskloopOperands &clauses) {
2859 MLIRContext *ctx = builder.getContext();
2860 TaskloopOp::build(
2861 builder, state, clauses.allocateVars, clauses.allocatorVars,
2862 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
2863 clauses.inReductionVars,
2864 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
2865 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
2866 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
2867 /*private_vars=*/clauses.privateVars,
2868 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
2869 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2870 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2871 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
2872}
2873
2874LogicalResult TaskloopOp::verify() {
2875 if (getAllocateVars().size() != getAllocatorVars().size())
2876 return emitError(
2877 "expected equal sizes for allocate and allocator variables");
2878 if (failed(verifyReductionVarList(*this, getReductionSyms(),
2879 getReductionVars(), getReductionByref())) ||
2880 failed(verifyReductionVarList(*this, getInReductionSyms(),
2881 getInReductionVars(),
2882 getInReductionByref())))
2883 return failure();
2884
2885 if (!getReductionVars().empty() && getNogroup())
2886 return emitError("if a reduction clause is present on the taskloop "
2887 "directive, the nogroup clause must not be specified");
2888 for (auto var : getReductionVars()) {
2889 if (llvm::is_contained(getInReductionVars(), var))
2890 return emitError("the same list item cannot appear in both a reduction "
2891 "and an in_reduction clause");
2892 }
2893
2894 if (getGrainsize() && getNumTasks()) {
2895 return emitError(
2896 "the grainsize clause and num_tasks clause are mutually exclusive and "
2897 "may not appear on the same taskloop directive");
2898 }
2899
2900 return success();
2901}
2902
2903LogicalResult TaskloopOp::verifyRegions() {
2904 if (LoopWrapperInterface nested = getNestedWrapper()) {
2905 if (!isComposite())
2906 return emitError()
2907 << "'omp.composite' attribute missing from composite wrapper";
2908
2909 // Check for the allowed leaf constructs that may appear in a composite
2910 // construct directly after TASKLOOP.
2911 if (!isa<SimdOp>(nested))
2912 return emitError() << "only supported nested wrapper is 'omp.simd'";
2913 } else if (isComposite()) {
2914 return emitError()
2915 << "'omp.composite' attribute present in non-composite wrapper";
2916 }
2917
2918 return success();
2919}
2920
2921//===----------------------------------------------------------------------===//
2922// LoopNestOp
2923//===----------------------------------------------------------------------===//
2924
2925ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2926 // Parse an opening `(` followed by induction variables followed by `)`
2927 SmallVector<OpAsmParser::Argument> ivs;
2928 SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs;
2929 Type loopVarType;
2930 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
2931 parser.parseColonType(loopVarType) ||
2932 // Parse loop bounds.
2933 parser.parseEqual() ||
2934 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
2935 parser.parseKeyword("to") ||
2936 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
2937 return failure();
2938
2939 for (auto &iv : ivs)
2940 iv.type = loopVarType;
2941
2942 // Parse "inclusive" flag.
2943 if (succeeded(parser.parseOptionalKeyword("inclusive")))
2944 result.addAttribute("loop_inclusive",
2945 UnitAttr::get(parser.getBuilder().getContext()));
2946
2947 // Parse step values.
2948 SmallVector<OpAsmParser::UnresolvedOperand> steps;
2949 if (parser.parseKeyword("step") ||
2950 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
2951 return failure();
2952
2953 // Parse the body.
2954 Region *region = result.addRegion();
2955 if (parser.parseRegion(*region, ivs))
2956 return failure();
2957
2958 // Resolve operands.
2959 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
2960 parser.resolveOperands(ubs, loopVarType, result.operands) ||
2961 parser.resolveOperands(steps, loopVarType, result.operands))
2962 return failure();
2963
2964 // Parse the optional attribute list.
2965 return parser.parseOptionalAttrDict(result.attributes);
2966}
2967
2968void LoopNestOp::print(OpAsmPrinter &p) {
2969 Region &region = getRegion();
2970 auto args = region.getArguments();
2971 p << " (" << args << ") : " << args[0].getType() << " = ("
2972 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2973 if (getLoopInclusive())
2974 p << "inclusive ";
2975 p << "step (" << getLoopSteps() << ") ";
2976 p.printRegion(region, /*printEntryBlockArgs=*/false);
2977}
2978
2979void LoopNestOp::build(OpBuilder &builder, OperationState &state,
2980 const LoopNestOperands &clauses) {
2981 LoopNestOp::build(builder, state, clauses.loopLowerBounds,
2982 clauses.loopUpperBounds, clauses.loopSteps,
2983 clauses.loopInclusive);
2984}
2985
2986LogicalResult LoopNestOp::verify() {
2987 if (getLoopLowerBounds().empty())
2988 return emitOpError() << "must represent at least one loop";
2989
2990 if (getLoopLowerBounds().size() != getIVs().size())
2991 return emitOpError() << "number of range arguments and IVs do not match";
2992
2993 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
2994 if (lb.getType() != iv.getType())
2995 return emitOpError()
2996 << "range argument type does not match corresponding IV type";
2997 }
2998
2999 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3000 return emitOpError() << "expects parent op to be a loop wrapper";
3001
3002 return success();
3003}
3004
3005void LoopNestOp::gatherWrappers(
3006 SmallVectorImpl<LoopWrapperInterface> &wrappers) {
3007 Operation *parent = (*this)->getParentOp();
3008 while (auto wrapper =
3009 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3010 wrappers.push_back(wrapper);
3011 parent = parent->getParentOp();
3012 }
3013}
3014
3015//===----------------------------------------------------------------------===//
3016// Critical construct (2.17.1)
3017//===----------------------------------------------------------------------===//
3018
3019void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3020 const CriticalDeclareOperands &clauses) {
3021 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3022}
3023
3024LogicalResult CriticalDeclareOp::verify() {
3025 return verifySynchronizationHint(*this, getHint());
3026}
3027
3028LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3029 if (getNameAttr()) {
3030 SymbolRefAttr symbolRef = getNameAttr();
3031 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3032 *this, symbolRef);
3033 if (!decl) {
3034 return emitOpError() << "expected symbol reference " << symbolRef
3035 << " to point to a critical declaration";
3036 }
3037 }
3038
3039 return success();
3040}
3041
3042//===----------------------------------------------------------------------===//
3043// Ordered construct
3044//===----------------------------------------------------------------------===//
3045
3046static LogicalResult verifyOrderedParent(Operation &op) {
3047 bool hasRegion = op.getNumRegions() > 0;
3048 auto loopOp = op.getParentOfType<LoopNestOp>();
3049 if (!loopOp) {
3050 if (hasRegion)
3051 return success();
3052
3053 // TODO: Consider if this needs to be the case only for the standalone
3054 // variant of the ordered construct.
3055 return op.emitOpError() << "must be nested inside of a loop";
3056 }
3057
3058 Operation *wrapper = loopOp->getParentOp();
3059 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3060 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3061 if (!orderedAttr)
3062 return op.emitOpError() << "the enclosing worksharing-loop region must "
3063 "have an ordered clause";
3064
3065 if (hasRegion && orderedAttr.getInt() != 0)
3066 return op.emitOpError() << "the enclosing loop's ordered clause must not "
3067 "have a parameter present";
3068
3069 if (!hasRegion && orderedAttr.getInt() == 0)
3070 return op.emitOpError() << "the enclosing loop's ordered clause must "
3071 "have a parameter present";
3072 } else if (!isa<SimdOp>(wrapper)) {
3073 return op.emitOpError() << "must be nested inside of a worksharing, simd "
3074 "or worksharing simd loop";
3075 }
3076 return success();
3077}
3078
3079void OrderedOp::build(OpBuilder &builder, OperationState &state,
3080 const OrderedOperands &clauses) {
3081 OrderedOp::build(builder, state, clauses.doacrossDependType,
3082 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3083}
3084
3085LogicalResult OrderedOp::verify() {
3086 if (failed(verifyOrderedParent(**this)))
3087 return failure();
3088
3089 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3090 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3091 return emitOpError() << "number of variables in depend clause does not "
3092 << "match number of iteration variables in the "
3093 << "doacross loop";
3094
3095 return success();
3096}
3097
3098void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3099 const OrderedRegionOperands &clauses) {
3100 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3101}
3102
3103LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3104
3105//===----------------------------------------------------------------------===//
3106// TaskwaitOp
3107//===----------------------------------------------------------------------===//
3108
3109void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3110 const TaskwaitOperands &clauses) {
3111 // TODO Store clauses in op: dependKinds, dependVars, nowait.
3112 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3113 /*depend_vars=*/{}, /*nowait=*/nullptr);
3114}
3115
3116//===----------------------------------------------------------------------===//
3117// Verifier for AtomicReadOp
3118//===----------------------------------------------------------------------===//
3119
3120LogicalResult AtomicReadOp::verify() {
3121 if (verifyCommon().failed())
3122 return mlir::failure();
3123
3124 if (auto mo = getMemoryOrder()) {
3125 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3126 *mo == ClauseMemoryOrderKind::Release) {
3127 return emitError(
3128 "memory-order must not be acq_rel or release for atomic reads");
3129 }
3130 }
3131 return verifySynchronizationHint(*this, getHint());
3132}
3133
3134//===----------------------------------------------------------------------===//
3135// Verifier for AtomicWriteOp
3136//===----------------------------------------------------------------------===//
3137
3138LogicalResult AtomicWriteOp::verify() {
3139 if (verifyCommon().failed())
3140 return mlir::failure();
3141
3142 if (auto mo = getMemoryOrder()) {
3143 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3144 *mo == ClauseMemoryOrderKind::Acquire) {
3145 return emitError(
3146 "memory-order must not be acq_rel or acquire for atomic writes");
3147 }
3148 }
3149 return verifySynchronizationHint(*this, getHint());
3150}
3151
3152//===----------------------------------------------------------------------===//
3153// Verifier for AtomicUpdateOp
3154//===----------------------------------------------------------------------===//
3155
3156LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3157 PatternRewriter &rewriter) {
3158 if (op.isNoOp()) {
3159 rewriter.eraseOp(op);
3160 return success();
3161 }
3162 if (Value writeVal = op.getWriteOpVal()) {
3163 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3164 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3165 return success();
3166 }
3167 return failure();
3168}
3169
3170LogicalResult AtomicUpdateOp::verify() {
3171 if (verifyCommon().failed())
3172 return mlir::failure();
3173
3174 if (auto mo = getMemoryOrder()) {
3175 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3176 *mo == ClauseMemoryOrderKind::Acquire) {
3177 return emitError(
3178 "memory-order must not be acq_rel or acquire for atomic updates");
3179 }
3180 }
3181
3182 return verifySynchronizationHint(*this, getHint());
3183}
3184
3185LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3186
3187//===----------------------------------------------------------------------===//
3188// Verifier for AtomicCaptureOp
3189//===----------------------------------------------------------------------===//
3190
3191AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3192 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3193 return op;
3194 return dyn_cast<AtomicReadOp>(getSecondOp());
3195}
3196
3197AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3198 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3199 return op;
3200 return dyn_cast<AtomicWriteOp>(getSecondOp());
3201}
3202
3203AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3204 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3205 return op;
3206 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3207}
3208
3209LogicalResult AtomicCaptureOp::verify() {
3210 return verifySynchronizationHint(*this, getHint());
3211}
3212
3213LogicalResult AtomicCaptureOp::verifyRegions() {
3214 if (verifyRegionsCommon().failed())
3215 return mlir::failure();
3216
3217 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
3218 return emitOpError(
3219 "operations inside capture region must not have hint clause");
3220
3221 if (getFirstOp()->getAttr("memory_order") ||
3222 getSecondOp()->getAttr("memory_order"))
3223 return emitOpError(
3224 "operations inside capture region must not have memory_order clause");
3225 return success();
3226}
3227
3228//===----------------------------------------------------------------------===//
3229// CancelOp
3230//===----------------------------------------------------------------------===//
3231
3232void CancelOp::build(OpBuilder &builder, OperationState &state,
3233 const CancelOperands &clauses) {
3234 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
3235}
3236
3237static Operation *getParentInSameDialect(Operation *thisOp) {
3238 Operation *parent = thisOp->getParentOp();
3239 while (parent) {
3240 if (parent->getDialect() == thisOp->getDialect())
3241 return parent;
3242 parent = parent->getParentOp();
3243 }
3244 return nullptr;
3245}
3246
3247LogicalResult CancelOp::verify() {
3248 ClauseCancellationConstructType cct = getCancelDirective();
3249 // The next OpenMP operation in the chain of parents
3250 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3251 if (!structuralParent)
3252 return emitOpError() << "Orphaned cancel construct";
3253
3254 if ((cct == ClauseCancellationConstructType::Parallel) &&
3255 !mlir::isa<ParallelOp>(structuralParent)) {
3256 return emitOpError() << "cancel parallel must appear "
3257 << "inside a parallel region";
3258 }
3259 if (cct == ClauseCancellationConstructType::Loop) {
3260 // structural parent will be omp.loop_nest, directly nested inside
3261 // omp.wsloop
3262 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
3263
3264 if (!wsloopOp) {
3265 return emitOpError()
3266 << "cancel loop must appear inside a worksharing-loop region";
3267 }
3268 if (wsloopOp.getNowaitAttr()) {
3269 return emitError() << "A worksharing construct that is canceled "
3270 << "must not have a nowait clause";
3271 }
3272 if (wsloopOp.getOrderedAttr()) {
3273 return emitError() << "A worksharing construct that is canceled "
3274 << "must not have an ordered clause";
3275 }
3276
3277 } else if (cct == ClauseCancellationConstructType::Sections) {
3278 // structural parent will be an omp.section, directly nested inside
3279 // omp.sections
3280 auto sectionsOp =
3281 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
3282 if (!sectionsOp) {
3283 return emitOpError() << "cancel sections must appear "
3284 << "inside a sections region";
3285 }
3286 if (sectionsOp.getNowait()) {
3287 return emitError() << "A sections construct that is canceled "
3288 << "must not have a nowait clause";
3289 }
3290 }
3291 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3292 (!mlir::isa<omp::TaskOp>(structuralParent) &&
3293 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
3294 return emitOpError() << "cancel taskgroup must appear "
3295 << "inside a task region";
3296 }
3297 return success();
3298}
3299
3300//===----------------------------------------------------------------------===//
3301// CancellationPointOp
3302//===----------------------------------------------------------------------===//
3303
3304void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3305 const CancellationPointOperands &clauses) {
3306 CancellationPointOp::build(builder, state, clauses.cancelDirective);
3307}
3308
3309LogicalResult CancellationPointOp::verify() {
3310 ClauseCancellationConstructType cct = getCancelDirective();
3311 // The next OpenMP operation in the chain of parents
3312 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3313 if (!structuralParent)
3314 return emitOpError() << "Orphaned cancellation point";
3315
3316 if ((cct == ClauseCancellationConstructType::Parallel) &&
3317 !mlir::isa<ParallelOp>(structuralParent)) {
3318 return emitOpError() << "cancellation point parallel must appear "
3319 << "inside a parallel region";
3320 }
3321 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3322 // find the wsloop
3323 if ((cct == ClauseCancellationConstructType::Loop) &&
3324 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
3325 return emitOpError() << "cancellation point loop must appear "
3326 << "inside a worksharing-loop region";
3327 }
3328 if ((cct == ClauseCancellationConstructType::Sections) &&
3329 !mlir::isa<omp::SectionOp>(structuralParent)) {
3330 return emitOpError() << "cancellation point sections must appear "
3331 << "inside a sections region";
3332 }
3333 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3334 !mlir::isa<omp::TaskOp>(structuralParent)) {
3335 return emitOpError() << "cancellation point taskgroup must appear "
3336 << "inside a task region";
3337 }
3338 return success();
3339}
3340
3341//===----------------------------------------------------------------------===//
3342// MapBoundsOp
3343//===----------------------------------------------------------------------===//
3344
3345LogicalResult MapBoundsOp::verify() {
3346 auto extent = getExtent();
3347 auto upperbound = getUpperBound();
3348 if (!extent && !upperbound)
3349 return emitError("expected extent or upperbound.");
3350 return success();
3351}
3352
3353void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3354 TypeRange /*result_types*/, StringAttr symName,
3355 TypeAttr type) {
3356 PrivateClauseOp::build(
3357 odsBuilder, odsState, symName, type,
3358 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
3359 DataSharingClauseType::Private));
3360}
3361
3362LogicalResult PrivateClauseOp::verifyRegions() {
3363 Type argType = getArgType();
3364 auto verifyTerminator = [&](Operation *terminator,
3365 bool yieldsValue) -> LogicalResult {
3366 if (!terminator->getBlock()->getSuccessors().empty())
3367 return success();
3368
3369 if (!llvm::isa<YieldOp>(terminator))
3370 return mlir::emitError(terminator->getLoc())
3371 << "expected exit block terminator to be an `omp.yield` op.";
3372
3373 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
3374 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3375
3376 if (!yieldsValue) {
3377 if (yieldedTypes.empty())
3378 return success();
3379
3380 return mlir::emitError(terminator->getLoc())
3381 << "Did not expect any values to be yielded.";
3382 }
3383
3384 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3385 return success();
3386
3387 auto error = mlir::emitError(yieldOp.getLoc())
3388 << "Invalid yielded value. Expected type: " << argType
3389 << ", got: ";
3390
3391 if (yieldedTypes.empty())
3392 error << "None";
3393 else
3394 error << yieldedTypes;
3395
3396 return error;
3397 };
3398
3399 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3400 StringRef regionName,
3401 bool yieldsValue) -> LogicalResult {
3402 assert(!region.empty());
3403
3404 if (region.getNumArguments() != expectedNumArgs)
3405 return mlir::emitError(region.getLoc())
3406 << "`" << regionName << "`: "
3407 << "expected " << expectedNumArgs
3408 << " region arguments, got: " << region.getNumArguments();
3409
3410 for (Block &block : region) {
3411 // MLIR will verify the absence of the terminator for us.
3412 if (!block.mightHaveTerminator())
3413 continue;
3414
3415 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
3416 return failure();
3417 }
3418
3419 return success();
3420 };
3421
3422 // Ensure all of the region arguments have the same type
3423 for (Region *region : getRegions())
3424 for (Type ty : region->getArgumentTypes())
3425 if (ty != argType)
3426 return emitError() << "Region argument type mismatch: got " << ty
3427 << " expected " << argType << ".";
3428
3429 mlir::Region &initRegion = getInitRegion();
3430 if (!initRegion.empty() &&
3431 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3432 /*yieldsValue=*/true)))
3433 return failure();
3434
3435 DataSharingClauseType dsType = getDataSharingType();
3436
3437 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3438 return emitError("`private` clauses do not require a `copy` region.");
3439
3440 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3441 return emitError(
3442 "`firstprivate` clauses require at least a `copy` region.");
3443
3444 if (dsType == DataSharingClauseType::FirstPrivate &&
3445 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3446 /*yieldsValue=*/true)))
3447 return failure();
3448
3449 if (!getDeallocRegion().empty() &&
3450 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3451 /*yieldsValue=*/false)))
3452 return failure();
3453
3454 return success();
3455}
3456
3457//===----------------------------------------------------------------------===//
3458// Spec 5.2: Masked construct (10.5)
3459//===----------------------------------------------------------------------===//
3460
3461void MaskedOp::build(OpBuilder &builder, OperationState &state,
3462 const MaskedOperands &clauses) {
3463 MaskedOp::build(builder, state, clauses.filteredThreadId);
3464}
3465
3466//===----------------------------------------------------------------------===//
3467// Spec 5.2: Scan construct (5.6)
3468//===----------------------------------------------------------------------===//
3469
3470void ScanOp::build(OpBuilder &builder, OperationState &state,
3471 const ScanOperands &clauses) {
3472 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
3473}
3474
3475LogicalResult ScanOp::verify() {
3476 if (hasExclusiveVars() == hasInclusiveVars())
3477 return emitError(
3478 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3479 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3480 if (parentWsLoopOp.getReductionModAttr() &&
3481 parentWsLoopOp.getReductionModAttr().getValue() ==
3482 ReductionModifier::inscan)
3483 return success();
3484 }
3485 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3486 if (parentSimdOp.getReductionModAttr() &&
3487 parentSimdOp.getReductionModAttr().getValue() ==
3488 ReductionModifier::inscan)
3489 return success();
3490 }
3491 return emitError("SCAN directive needs to be enclosed within a parent "
3492 "worksharing loop construct or SIMD construct with INSCAN "
3493 "reduction modifier");
3494}
3495
3496#define GET_ATTRDEF_CLASSES
3497#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3498
3499#define GET_OP_CLASSES
3500#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3501
3502#define GET_TYPEDEF_CLASSES
3503#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
3504

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp