1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
18#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h"
19#include "mlir/IR/Attributes.h"
20#include "mlir/IR/BuiltinAttributes.h"
21#include "mlir/IR/DialectImplementation.h"
22#include "mlir/IR/OpImplementation.h"
23#include "mlir/IR/OperationSupport.h"
24#include "mlir/IR/SymbolTable.h"
25#include "mlir/Interfaces/FoldInterfaces.h"
26
27#include "llvm/ADT/ArrayRef.h"
28#include "llvm/ADT/BitVector.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/Frontend/OpenMP/OMPConstants.h"
37#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
38#include <cstddef>
39#include <iterator>
40#include <optional>
41#include <variant>
42
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
47
48using namespace mlir;
49using namespace mlir::omp;
50
51static ArrayAttr makeArrayAttr(MLIRContext *context,
52 llvm::ArrayRef<Attribute> attrs) {
53 return attrs.empty() ? nullptr : ArrayAttr::get(context, value: attrs);
54}
55
56static DenseBoolArrayAttr
57makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
58 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(context: ctx, content: boolArray);
59}
60
61namespace {
62struct MemRefPointerLikeModel
63 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
64 MemRefType> {
65 Type getElementType(Type pointer) const {
66 return llvm::cast<MemRefType>(Val&: pointer).getElementType();
67 }
68};
69
70struct LLVMPointerPointerLikeModel
71 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
72 LLVM::LLVMPointerType> {
73 Type getElementType(Type pointer) const { return Type(); }
74};
75} // namespace
76
77void OpenMPDialect::initialize() {
78 addOperations<
79#define GET_OP_LIST
80#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
81 >();
82 addAttributes<
83#define GET_ATTRDEF_LIST
84#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
85 >();
86 addTypes<
87#define GET_TYPEDEF_LIST
88#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
89 >();
90
91 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
92
93 MemRefType::attachInterface<MemRefPointerLikeModel>(context&: *getContext());
94 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
95 context&: *getContext());
96
97 // Attach default offload module interface to module op to access
98 // offload functionality through
99 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
100 context&: *getContext());
101
102 // Attach default declare target interfaces to operations which can be marked
103 // as declare target (Global Operations and Functions/Subroutines in dialects
104 // that Fortran (or other languages that lower to MLIR) translates too
105 mlir::LLVM::GlobalOp::attachInterface<
106 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::GlobalOp>>(
107 context&: *getContext());
108 mlir::LLVM::LLVMFuncOp::attachInterface<
109 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::LLVMFuncOp>>(
110 context&: *getContext());
111 mlir::func::FuncOp::attachInterface<
112 mlir::omp::DeclareTargetDefaultModel<mlir::func::FuncOp>>(context&: *getContext());
113}
114
115//===----------------------------------------------------------------------===//
116// Parser and printer for Allocate Clause
117//===----------------------------------------------------------------------===//
118
119/// Parse an allocate clause with allocators and a list of operands with types.
120///
121/// allocate-operand-list :: = allocate-operand |
122/// allocator-operand `,` allocate-operand-list
123/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
124/// ssa-id-and-type ::= ssa-id `:` type
125static ParseResult parseAllocateAndAllocator(
126 OpAsmParser &parser,
127 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocateVars,
128 SmallVectorImpl<Type> &allocateTypes,
129 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocatorVars,
130 SmallVectorImpl<Type> &allocatorTypes) {
131
132 return parser.parseCommaSeparatedList(parseElementFn: [&]() {
133 OpAsmParser::UnresolvedOperand operand;
134 Type type;
135 if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type))
136 return failure();
137 allocatorVars.push_back(Elt: operand);
138 allocatorTypes.push_back(Elt: type);
139 if (parser.parseArrow())
140 return failure();
141 if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type))
142 return failure();
143
144 allocateVars.push_back(Elt: operand);
145 allocateTypes.push_back(Elt: type);
146 return success();
147 });
148}
149
150/// Print allocate clause
151static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
152 OperandRange allocateVars,
153 TypeRange allocateTypes,
154 OperandRange allocatorVars,
155 TypeRange allocatorTypes) {
156 for (unsigned i = 0; i < allocateVars.size(); ++i) {
157 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
158 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
159 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
160 }
161}
162
163//===----------------------------------------------------------------------===//
164// Parser and printer for a clause attribute (StringEnumAttr)
165//===----------------------------------------------------------------------===//
166
167template <typename ClauseAttr>
168static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
169 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
170 StringRef enumStr;
171 SMLoc loc = parser.getCurrentLocation();
172 if (parser.parseKeyword(keyword: &enumStr))
173 return failure();
174 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
175 attr = ClauseAttr::get(parser.getContext(), *enumValue);
176 return success();
177 }
178 return parser.emitError(loc, message: "invalid clause value: '") << enumStr << "'";
179}
180
181template <typename ClauseAttr>
182void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
183 p << stringifyEnum(attr.getValue());
184}
185
186//===----------------------------------------------------------------------===//
187// Parser and printer for Linear Clause
188//===----------------------------------------------------------------------===//
189
190/// linear ::= `linear` `(` linear-list `)`
191/// linear-list := linear-val | linear-val linear-list
192/// linear-val := ssa-id-and-type `=` ssa-id-and-type
193static ParseResult parseLinearClause(
194 OpAsmParser &parser,
195 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearVars,
196 SmallVectorImpl<Type> &linearTypes,
197 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearStepVars) {
198 return parser.parseCommaSeparatedList(parseElementFn: [&]() {
199 OpAsmParser::UnresolvedOperand var;
200 Type type;
201 OpAsmParser::UnresolvedOperand stepVar;
202 if (parser.parseOperand(result&: var) || parser.parseEqual() ||
203 parser.parseOperand(result&: stepVar) || parser.parseColonType(result&: type))
204 return failure();
205
206 linearVars.push_back(Elt: var);
207 linearTypes.push_back(Elt: type);
208 linearStepVars.push_back(Elt: stepVar);
209 return success();
210 });
211}
212
213/// Print Linear Clause
214static void printLinearClause(OpAsmPrinter &p, Operation *op,
215 ValueRange linearVars, TypeRange linearTypes,
216 ValueRange linearStepVars) {
217 size_t linearVarsSize = linearVars.size();
218 for (unsigned i = 0; i < linearVarsSize; ++i) {
219 std::string separator = i == linearVarsSize - 1 ? "" : ", ";
220 p << linearVars[i];
221 if (linearStepVars.size() > i)
222 p << " = " << linearStepVars[i];
223 p << " : " << linearVars[i].getType() << separator;
224 }
225}
226
227//===----------------------------------------------------------------------===//
228// Verifier for Nontemporal Clause
229//===----------------------------------------------------------------------===//
230
231static LogicalResult verifyNontemporalClause(Operation *op,
232 OperandRange nontemporalVars) {
233
234 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
235 DenseSet<Value> nontemporalItems;
236 for (const auto &it : nontemporalVars)
237 if (!nontemporalItems.insert(V: it).second)
238 return op->emitOpError() << "nontemporal variable used more than once";
239
240 return success();
241}
242
243//===----------------------------------------------------------------------===//
244// Parser, verifier and printer for Aligned Clause
245//===----------------------------------------------------------------------===//
246static LogicalResult verifyAlignedClause(Operation *op,
247 std::optional<ArrayAttr> alignments,
248 OperandRange alignedVars) {
249 // Check if number of alignment values equals to number of aligned variables
250 if (!alignedVars.empty()) {
251 if (!alignments || alignments->size() != alignedVars.size())
252 return op->emitOpError()
253 << "expected as many alignment values as aligned variables";
254 } else {
255 if (alignments)
256 return op->emitOpError() << "unexpected alignment values attribute";
257 return success();
258 }
259
260 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
261 DenseSet<Value> alignedItems;
262 for (auto it : alignedVars)
263 if (!alignedItems.insert(V: it).second)
264 return op->emitOpError() << "aligned variable used more than once";
265
266 if (!alignments)
267 return success();
268
269 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
270 for (unsigned i = 0; i < (*alignments).size(); ++i) {
271 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(Val: (*alignments)[i])) {
272 if (intAttr.getValue().sle(RHS: 0))
273 return op->emitOpError() << "alignment should be greater than 0";
274 } else {
275 return op->emitOpError() << "expected integer alignment";
276 }
277 }
278
279 return success();
280}
281
282/// aligned ::= `aligned` `(` aligned-list `)`
283/// aligned-list := aligned-val | aligned-val aligned-list
284/// aligned-val := ssa-id-and-type `->` alignment
285static ParseResult
286parseAlignedClause(OpAsmParser &parser,
287 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &alignedVars,
288 SmallVectorImpl<Type> &alignedTypes,
289 ArrayAttr &alignmentsAttr) {
290 SmallVector<Attribute> alignmentVec;
291 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
292 if (parser.parseOperand(result&: alignedVars.emplace_back()) ||
293 parser.parseColonType(result&: alignedTypes.emplace_back()) ||
294 parser.parseArrow() ||
295 parser.parseAttribute(result&: alignmentVec.emplace_back())) {
296 return failure();
297 }
298 return success();
299 })))
300 return failure();
301 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
302 alignmentsAttr = ArrayAttr::get(context: parser.getContext(), value: alignments);
303 return success();
304}
305
306/// Print Aligned Clause
307static void printAlignedClause(OpAsmPrinter &p, Operation *op,
308 ValueRange alignedVars, TypeRange alignedTypes,
309 std::optional<ArrayAttr> alignments) {
310 for (unsigned i = 0; i < alignedVars.size(); ++i) {
311 if (i != 0)
312 p << ", ";
313 p << alignedVars[i] << " : " << alignedVars[i].getType();
314 p << " -> " << (*alignments)[i];
315 }
316}
317
318//===----------------------------------------------------------------------===//
319// Parser, printer and verifier for Schedule Clause
320//===----------------------------------------------------------------------===//
321
322static ParseResult
323verifyScheduleModifiers(OpAsmParser &parser,
324 SmallVectorImpl<SmallString<12>> &modifiers) {
325 if (modifiers.size() > 2)
326 return parser.emitError(loc: parser.getNameLoc()) << " unexpected modifier(s)";
327 for (const auto &mod : modifiers) {
328 // Translate the string. If it has no value, then it was not a valid
329 // modifier!
330 auto symbol = symbolizeScheduleModifier(str: mod);
331 if (!symbol)
332 return parser.emitError(loc: parser.getNameLoc())
333 << " unknown modifier type: " << mod;
334 }
335
336 // If we have one modifier that is "simd", then stick a "none" modiifer in
337 // index 0.
338 if (modifiers.size() == 1) {
339 if (symbolizeScheduleModifier(str: modifiers[0]) == ScheduleModifier::simd) {
340 modifiers.push_back(Elt: modifiers[0]);
341 modifiers[0] = stringifyScheduleModifier(val: ScheduleModifier::none);
342 }
343 } else if (modifiers.size() == 2) {
344 // If there are two modifier:
345 // First modifier should not be simd, second one should be simd
346 if (symbolizeScheduleModifier(str: modifiers[0]) == ScheduleModifier::simd ||
347 symbolizeScheduleModifier(str: modifiers[1]) != ScheduleModifier::simd)
348 return parser.emitError(loc: parser.getNameLoc())
349 << " incorrect modifier order";
350 }
351 return success();
352}
353
354/// schedule ::= `schedule` `(` sched-list `)`
355/// sched-list ::= sched-val | sched-val sched-list |
356/// sched-val `,` sched-modifier
357/// sched-val ::= sched-with-chunk | sched-wo-chunk
358/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
359/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
360/// sched-wo-chunk ::= `auto` | `runtime`
361/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
362/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
363static ParseResult
364parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
365 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
366 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
367 Type &chunkType) {
368 StringRef keyword;
369 if (parser.parseKeyword(keyword: &keyword))
370 return failure();
371 std::optional<mlir::omp::ClauseScheduleKind> schedule =
372 symbolizeClauseScheduleKind(str: keyword);
373 if (!schedule)
374 return parser.emitError(loc: parser.getNameLoc()) << " expected schedule kind";
375
376 scheduleAttr = ClauseScheduleKindAttr::get(context: parser.getContext(), value: *schedule);
377 switch (*schedule) {
378 case ClauseScheduleKind::Static:
379 case ClauseScheduleKind::Dynamic:
380 case ClauseScheduleKind::Guided:
381 if (succeeded(Result: parser.parseOptionalEqual())) {
382 chunkSize = OpAsmParser::UnresolvedOperand{};
383 if (parser.parseOperand(result&: *chunkSize) || parser.parseColonType(result&: chunkType))
384 return failure();
385 } else {
386 chunkSize = std::nullopt;
387 }
388 break;
389 case ClauseScheduleKind::Auto:
390 case ClauseScheduleKind::Runtime:
391 chunkSize = std::nullopt;
392 }
393
394 // If there is a comma, we have one or more modifiers..
395 SmallVector<SmallString<12>> modifiers;
396 while (succeeded(Result: parser.parseOptionalComma())) {
397 StringRef mod;
398 if (parser.parseKeyword(keyword: &mod))
399 return failure();
400 modifiers.push_back(Elt: mod);
401 }
402
403 if (verifyScheduleModifiers(parser, modifiers))
404 return failure();
405
406 if (!modifiers.empty()) {
407 SMLoc loc = parser.getCurrentLocation();
408 if (std::optional<ScheduleModifier> mod =
409 symbolizeScheduleModifier(str: modifiers[0])) {
410 scheduleMod = ScheduleModifierAttr::get(context: parser.getContext(), value: *mod);
411 } else {
412 return parser.emitError(loc, message: "invalid schedule modifier");
413 }
414 // Only SIMD attribute is allowed here!
415 if (modifiers.size() > 1) {
416 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
417 scheduleSimd = UnitAttr::get(context: parser.getBuilder().getContext());
418 }
419 }
420
421 return success();
422}
423
424/// Print schedule clause
425static void printScheduleClause(OpAsmPrinter &p, Operation *op,
426 ClauseScheduleKindAttr scheduleKind,
427 ScheduleModifierAttr scheduleMod,
428 UnitAttr scheduleSimd, Value scheduleChunk,
429 Type scheduleChunkType) {
430 p << stringifyClauseScheduleKind(val: scheduleKind.getValue());
431 if (scheduleChunk)
432 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
433 if (scheduleMod)
434 p << ", " << stringifyScheduleModifier(val: scheduleMod.getValue());
435 if (scheduleSimd)
436 p << ", simd";
437}
438
439//===----------------------------------------------------------------------===//
440// Parser and printer for Order Clause
441//===----------------------------------------------------------------------===//
442
443// order ::= `order` `(` [order-modifier ':'] concurrent `)`
444// order-modifier ::= reproducible | unconstrained
445static ParseResult parseOrderClause(OpAsmParser &parser,
446 ClauseOrderKindAttr &order,
447 OrderModifierAttr &orderMod) {
448 StringRef enumStr;
449 SMLoc loc = parser.getCurrentLocation();
450 if (parser.parseKeyword(keyword: &enumStr))
451 return failure();
452 if (std::optional<OrderModifier> enumValue =
453 symbolizeOrderModifier(str: enumStr)) {
454 orderMod = OrderModifierAttr::get(context: parser.getContext(), value: *enumValue);
455 if (parser.parseOptionalColon())
456 return failure();
457 loc = parser.getCurrentLocation();
458 if (parser.parseKeyword(keyword: &enumStr))
459 return failure();
460 }
461 if (std::optional<ClauseOrderKind> enumValue =
462 symbolizeClauseOrderKind(str: enumStr)) {
463 order = ClauseOrderKindAttr::get(context: parser.getContext(), value: *enumValue);
464 return success();
465 }
466 return parser.emitError(loc, message: "invalid clause value: '") << enumStr << "'";
467}
468
469static void printOrderClause(OpAsmPrinter &p, Operation *op,
470 ClauseOrderKindAttr order,
471 OrderModifierAttr orderMod) {
472 if (orderMod)
473 p << stringifyOrderModifier(val: orderMod.getValue()) << ":";
474 if (order)
475 p << stringifyClauseOrderKind(val: order.getValue());
476}
477
478template <typename ClauseTypeAttr, typename ClauseType>
479static ParseResult
480parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
481 std::optional<OpAsmParser::UnresolvedOperand> &operand,
482 Type &operandType,
483 std::optional<ClauseType> (*symbolizeClause)(StringRef),
484 StringRef clauseName) {
485 StringRef enumStr;
486 if (succeeded(Result: parser.parseOptionalKeyword(keyword: &enumStr))) {
487 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
488 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
489 if (parser.parseComma())
490 return failure();
491 } else {
492 return parser.emitError(loc: parser.getCurrentLocation())
493 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
494 ;
495 }
496 }
497
498 OpAsmParser::UnresolvedOperand var;
499 if (succeeded(Result: parser.parseOperand(result&: var))) {
500 operand = var;
501 } else {
502 return parser.emitError(loc: parser.getCurrentLocation())
503 << "expected " << clauseName << " operand";
504 }
505
506 if (operand.has_value()) {
507 if (parser.parseColonType(result&: operandType))
508 return failure();
509 }
510
511 return success();
512}
513
514template <typename ClauseTypeAttr, typename ClauseType>
515static void
516printGranularityClause(OpAsmPrinter &p, Operation *op,
517 ClauseTypeAttr prescriptiveness, Value operand,
518 mlir::Type operandType,
519 StringRef (*stringifyClauseType)(ClauseType)) {
520
521 if (prescriptiveness)
522 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
523
524 if (operand)
525 p << operand << ": " << operandType;
526}
527
528//===----------------------------------------------------------------------===//
529// Parser and printer for grainsize Clause
530//===----------------------------------------------------------------------===//
531
532// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
533static ParseResult
534parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
535 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
536 Type &grainsizeType) {
537 return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
538 parser, prescriptiveness&: grainsizeMod, operand&: grainsize, operandType&: grainsizeType,
539 symbolizeClause: &symbolizeClauseGrainsizeType, clauseName: "grainsize");
540}
541
542static void printGrainsizeClause(OpAsmPrinter &p, Operation *op,
543 ClauseGrainsizeTypeAttr grainsizeMod,
544 Value grainsize, mlir::Type grainsizeType) {
545 printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>(
546 p, op, prescriptiveness: grainsizeMod, operand: grainsize, operandType: grainsizeType,
547 stringifyClauseType: &stringifyClauseGrainsizeType);
548}
549
550//===----------------------------------------------------------------------===//
551// Parser and printer for num_tasks Clause
552//===----------------------------------------------------------------------===//
553
554// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
555static ParseResult
556parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
557 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
558 Type &numTasksType) {
559 return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
560 parser, prescriptiveness&: numTasksMod, operand&: numTasks, operandType&: numTasksType, symbolizeClause: &symbolizeClauseNumTasksType,
561 clauseName: "num_tasks");
562}
563
564static void printNumTasksClause(OpAsmPrinter &p, Operation *op,
565 ClauseNumTasksTypeAttr numTasksMod,
566 Value numTasks, mlir::Type numTasksType) {
567 printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>(
568 p, op, prescriptiveness: numTasksMod, operand: numTasks, operandType: numTasksType, stringifyClauseType: &stringifyClauseNumTasksType);
569}
570
571//===----------------------------------------------------------------------===//
572// Parsers for operations including clauses that define entry block arguments.
573//===----------------------------------------------------------------------===//
574
575namespace {
576struct MapParseArgs {
577 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
578 SmallVectorImpl<Type> &types;
579 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
580 SmallVectorImpl<Type> &types)
581 : vars(vars), types(types) {}
582};
583struct PrivateParseArgs {
584 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
585 llvm::SmallVectorImpl<Type> &types;
586 ArrayAttr &syms;
587 UnitAttr &needsBarrier;
588 DenseI64ArrayAttr *mapIndices;
589 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
590 SmallVectorImpl<Type> &types, ArrayAttr &syms,
591 UnitAttr &needsBarrier,
592 DenseI64ArrayAttr *mapIndices = nullptr)
593 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
594 mapIndices(mapIndices) {}
595};
596
597struct ReductionParseArgs {
598 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
599 SmallVectorImpl<Type> &types;
600 DenseBoolArrayAttr &byref;
601 ArrayAttr &syms;
602 ReductionModifierAttr *modifier;
603 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
604 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
605 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
606 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
607};
608
609struct AllRegionParseArgs {
610 std::optional<MapParseArgs> hasDeviceAddrArgs;
611 std::optional<MapParseArgs> hostEvalArgs;
612 std::optional<ReductionParseArgs> inReductionArgs;
613 std::optional<MapParseArgs> mapArgs;
614 std::optional<PrivateParseArgs> privateArgs;
615 std::optional<ReductionParseArgs> reductionArgs;
616 std::optional<ReductionParseArgs> taskReductionArgs;
617 std::optional<MapParseArgs> useDeviceAddrArgs;
618 std::optional<MapParseArgs> useDevicePtrArgs;
619};
620} // namespace
621
622static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
623 return "private_barrier";
624}
625
626static ParseResult parseClauseWithRegionArgs(
627 OpAsmParser &parser,
628 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
629 SmallVectorImpl<Type> &types,
630 SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
631 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
632 DenseBoolArrayAttr *byref = nullptr,
633 ReductionModifierAttr *modifier = nullptr,
634 UnitAttr *needsBarrier = nullptr) {
635 SmallVector<SymbolRefAttr> symbolVec;
636 SmallVector<int64_t> mapIndicesVec;
637 SmallVector<bool> isByRefVec;
638 unsigned regionArgOffset = regionPrivateArgs.size();
639
640 if (parser.parseLParen())
641 return failure();
642
643 if (modifier && succeeded(Result: parser.parseOptionalKeyword(keyword: "mod"))) {
644 StringRef enumStr;
645 if (parser.parseColon() || parser.parseKeyword(keyword: &enumStr) ||
646 parser.parseComma())
647 return failure();
648 std::optional<ReductionModifier> enumValue =
649 symbolizeReductionModifier(str: enumStr);
650 if (!enumValue.has_value())
651 return failure();
652 *modifier = ReductionModifierAttr::get(context: parser.getContext(), value: *enumValue);
653 if (!*modifier)
654 return failure();
655 }
656
657 if (parser.parseCommaSeparatedList(parseElementFn: [&]() {
658 if (byref)
659 isByRefVec.push_back(
660 Elt: parser.parseOptionalKeyword(keyword: "byref").succeeded());
661
662 if (symbols && parser.parseAttribute(result&: symbolVec.emplace_back()))
663 return failure();
664
665 if (parser.parseOperand(result&: operands.emplace_back()) ||
666 parser.parseArrow() ||
667 parser.parseArgument(result&: regionPrivateArgs.emplace_back()))
668 return failure();
669
670 if (mapIndices) {
671 if (parser.parseOptionalLSquare().succeeded()) {
672 if (parser.parseKeyword(keyword: "map_idx") || parser.parseEqual() ||
673 parser.parseInteger(result&: mapIndicesVec.emplace_back()) ||
674 parser.parseRSquare())
675 return failure();
676 } else {
677 mapIndicesVec.push_back(Elt: -1);
678 }
679 }
680
681 return success();
682 }))
683 return failure();
684
685 if (parser.parseColon())
686 return failure();
687
688 if (parser.parseCommaSeparatedList(parseElementFn: [&]() {
689 if (parser.parseType(result&: types.emplace_back()))
690 return failure();
691
692 return success();
693 }))
694 return failure();
695
696 if (operands.size() != types.size())
697 return failure();
698
699 if (parser.parseRParen())
700 return failure();
701
702 if (needsBarrier) {
703 if (parser.parseOptionalKeyword(keyword: getPrivateNeedsBarrierSpelling())
704 .succeeded())
705 *needsBarrier = mlir::UnitAttr::get(context: parser.getContext());
706 }
707
708 auto *argsBegin = regionPrivateArgs.begin();
709 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
710 argsBegin + regionArgOffset + types.size());
711 for (auto [prv, type] : llvm::zip_equal(t&: argsSubrange, u&: types)) {
712 prv.type = type;
713 }
714
715 if (symbols) {
716 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
717 *symbols = ArrayAttr::get(context: parser.getContext(), value: symbolAttrs);
718 }
719
720 if (!mapIndicesVec.empty())
721 *mapIndices =
722 mlir::DenseI64ArrayAttr::get(context: parser.getContext(), content: mapIndicesVec);
723
724 if (byref)
725 *byref = makeDenseBoolArrayAttr(ctx: parser.getContext(), boolArray: isByRefVec);
726
727 return success();
728}
729
730static ParseResult parseBlockArgClause(
731 OpAsmParser &parser,
732 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
733 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
734 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
735 if (!mapArgs)
736 return failure();
737
738 if (failed(Result: parseClauseWithRegionArgs(parser, operands&: mapArgs->vars, types&: mapArgs->types,
739 regionPrivateArgs&: entryBlockArgs)))
740 return failure();
741 }
742 return success();
743}
744
745static ParseResult parseBlockArgClause(
746 OpAsmParser &parser,
747 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
748 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
749 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
750 if (!privateArgs)
751 return failure();
752
753 if (failed(Result: parseClauseWithRegionArgs(
754 parser, operands&: privateArgs->vars, types&: privateArgs->types, regionPrivateArgs&: entryBlockArgs,
755 symbols: &privateArgs->syms, mapIndices: privateArgs->mapIndices, /*byref=*/nullptr,
756 /*modifier=*/nullptr, needsBarrier: &privateArgs->needsBarrier)))
757 return failure();
758 }
759 return success();
760}
761
762static ParseResult parseBlockArgClause(
763 OpAsmParser &parser,
764 llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
765 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
766 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
767 if (!reductionArgs)
768 return failure();
769 if (failed(Result: parseClauseWithRegionArgs(
770 parser, operands&: reductionArgs->vars, types&: reductionArgs->types, regionPrivateArgs&: entryBlockArgs,
771 symbols: &reductionArgs->syms, /*mapIndices=*/nullptr, byref: &reductionArgs->byref,
772 modifier: reductionArgs->modifier)))
773 return failure();
774 }
775 return success();
776}
777
778static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
779 AllRegionParseArgs args) {
780 llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
781
782 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "has_device_addr",
783 mapArgs: args.hasDeviceAddrArgs)))
784 return parser.emitError(loc: parser.getCurrentLocation())
785 << "invalid `has_device_addr` format";
786
787 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "host_eval",
788 mapArgs: args.hostEvalArgs)))
789 return parser.emitError(loc: parser.getCurrentLocation())
790 << "invalid `host_eval` format";
791
792 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "in_reduction",
793 reductionArgs: args.inReductionArgs)))
794 return parser.emitError(loc: parser.getCurrentLocation())
795 << "invalid `in_reduction` format";
796
797 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "map_entries",
798 mapArgs: args.mapArgs)))
799 return parser.emitError(loc: parser.getCurrentLocation())
800 << "invalid `map_entries` format";
801
802 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "private",
803 privateArgs: args.privateArgs)))
804 return parser.emitError(loc: parser.getCurrentLocation())
805 << "invalid `private` format";
806
807 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "reduction",
808 reductionArgs: args.reductionArgs)))
809 return parser.emitError(loc: parser.getCurrentLocation())
810 << "invalid `reduction` format";
811
812 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "task_reduction",
813 reductionArgs: args.taskReductionArgs)))
814 return parser.emitError(loc: parser.getCurrentLocation())
815 << "invalid `task_reduction` format";
816
817 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "use_device_addr",
818 mapArgs: args.useDeviceAddrArgs)))
819 return parser.emitError(loc: parser.getCurrentLocation())
820 << "invalid `use_device_addr` format";
821
822 if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "use_device_ptr",
823 mapArgs: args.useDevicePtrArgs)))
824 return parser.emitError(loc: parser.getCurrentLocation())
825 << "invalid `use_device_addr` format";
826
827 return parser.parseRegion(region, arguments: entryBlockArgs);
828}
829
830// These parseXyz functions correspond to the custom<Xyz> definitions
831// in the .td file(s).
832static ParseResult parseTargetOpRegion(
833 OpAsmParser &parser, Region &region,
834 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hasDeviceAddrVars,
835 SmallVectorImpl<Type> &hasDeviceAddrTypes,
836 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
837 SmallVectorImpl<Type> &hostEvalTypes,
838 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
839 SmallVectorImpl<Type> &inReductionTypes,
840 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
841 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
842 SmallVectorImpl<Type> &mapTypes,
843 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
844 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
845 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
846 AllRegionParseArgs args;
847 args.hasDeviceAddrArgs.emplace(args&: hasDeviceAddrVars, args&: hasDeviceAddrTypes);
848 args.hostEvalArgs.emplace(args&: hostEvalVars, args&: hostEvalTypes);
849 args.inReductionArgs.emplace(args&: inReductionVars, args&: inReductionTypes,
850 args&: inReductionByref, args&: inReductionSyms);
851 args.mapArgs.emplace(args&: mapVars, args&: mapTypes);
852 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
853 args&: privateNeedsBarrier, args: &privateMaps);
854 return parseBlockArgRegion(parser, region, args);
855}
856
857static ParseResult parseInReductionPrivateRegion(
858 OpAsmParser &parser, Region &region,
859 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
860 SmallVectorImpl<Type> &inReductionTypes,
861 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
862 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
863 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
864 UnitAttr &privateNeedsBarrier) {
865 AllRegionParseArgs args;
866 args.inReductionArgs.emplace(args&: inReductionVars, args&: inReductionTypes,
867 args&: inReductionByref, args&: inReductionSyms);
868 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
869 args&: privateNeedsBarrier);
870 return parseBlockArgRegion(parser, region, args);
871}
872
873static ParseResult parseInReductionPrivateReductionRegion(
874 OpAsmParser &parser, Region &region,
875 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
876 SmallVectorImpl<Type> &inReductionTypes,
877 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
878 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
879 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
880 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
881 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
882 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
883 ArrayAttr &reductionSyms) {
884 AllRegionParseArgs args;
885 args.inReductionArgs.emplace(args&: inReductionVars, args&: inReductionTypes,
886 args&: inReductionByref, args&: inReductionSyms);
887 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
888 args&: privateNeedsBarrier);
889 args.reductionArgs.emplace(args&: reductionVars, args&: reductionTypes, args&: reductionByref,
890 args&: reductionSyms, args: &reductionMod);
891 return parseBlockArgRegion(parser, region, args);
892}
893
894static ParseResult parsePrivateRegion(
895 OpAsmParser &parser, Region &region,
896 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
897 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
898 UnitAttr &privateNeedsBarrier) {
899 AllRegionParseArgs args;
900 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
901 args&: privateNeedsBarrier);
902 return parseBlockArgRegion(parser, region, args);
903}
904
905static ParseResult parsePrivateReductionRegion(
906 OpAsmParser &parser, Region &region,
907 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
908 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
909 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
910 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
911 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
912 ArrayAttr &reductionSyms) {
913 AllRegionParseArgs args;
914 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
915 args&: privateNeedsBarrier);
916 args.reductionArgs.emplace(args&: reductionVars, args&: reductionTypes, args&: reductionByref,
917 args&: reductionSyms, args: &reductionMod);
918 return parseBlockArgRegion(parser, region, args);
919}
920
921static ParseResult parseTaskReductionRegion(
922 OpAsmParser &parser, Region &region,
923 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
924 SmallVectorImpl<Type> &taskReductionTypes,
925 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
926 AllRegionParseArgs args;
927 args.taskReductionArgs.emplace(args&: taskReductionVars, args&: taskReductionTypes,
928 args&: taskReductionByref, args&: taskReductionSyms);
929 return parseBlockArgRegion(parser, region, args);
930}
931
932static ParseResult parseUseDeviceAddrUseDevicePtrRegion(
933 OpAsmParser &parser, Region &region,
934 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars,
935 SmallVectorImpl<Type> &useDeviceAddrTypes,
936 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars,
937 SmallVectorImpl<Type> &useDevicePtrTypes) {
938 AllRegionParseArgs args;
939 args.useDeviceAddrArgs.emplace(args&: useDeviceAddrVars, args&: useDeviceAddrTypes);
940 args.useDevicePtrArgs.emplace(args&: useDevicePtrVars, args&: useDevicePtrTypes);
941 return parseBlockArgRegion(parser, region, args);
942}
943
944//===----------------------------------------------------------------------===//
945// Printers for operations including clauses that define entry block arguments.
946//===----------------------------------------------------------------------===//
947
948namespace {
949struct MapPrintArgs {
950 ValueRange vars;
951 TypeRange types;
952 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
953};
954struct PrivatePrintArgs {
955 ValueRange vars;
956 TypeRange types;
957 ArrayAttr syms;
958 UnitAttr needsBarrier;
959 DenseI64ArrayAttr mapIndices;
960 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
961 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
962 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
963 mapIndices(mapIndices) {}
964};
965struct ReductionPrintArgs {
966 ValueRange vars;
967 TypeRange types;
968 DenseBoolArrayAttr byref;
969 ArrayAttr syms;
970 ReductionModifierAttr modifier;
971 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
972 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
973 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
974};
975struct AllRegionPrintArgs {
976 std::optional<MapPrintArgs> hasDeviceAddrArgs;
977 std::optional<MapPrintArgs> hostEvalArgs;
978 std::optional<ReductionPrintArgs> inReductionArgs;
979 std::optional<MapPrintArgs> mapArgs;
980 std::optional<PrivatePrintArgs> privateArgs;
981 std::optional<ReductionPrintArgs> reductionArgs;
982 std::optional<ReductionPrintArgs> taskReductionArgs;
983 std::optional<MapPrintArgs> useDeviceAddrArgs;
984 std::optional<MapPrintArgs> useDevicePtrArgs;
985};
986} // namespace
987
988static void printClauseWithRegionArgs(
989 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
990 ValueRange argsSubrange, ValueRange operands, TypeRange types,
991 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
992 DenseBoolArrayAttr byref = nullptr,
993 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
994 if (argsSubrange.empty())
995 return;
996
997 p << clauseName << "(";
998
999 if (modifier)
1000 p << "mod: " << stringifyReductionModifier(val: modifier.getValue()) << ", ";
1001
1002 if (!symbols) {
1003 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1004 symbols = ArrayAttr::get(context: ctx, value: values);
1005 }
1006
1007 if (!mapIndices) {
1008 llvm::SmallVector<int64_t> values(operands.size(), -1);
1009 mapIndices = DenseI64ArrayAttr::get(context: ctx, content: values);
1010 }
1011
1012 if (!byref) {
1013 mlir::SmallVector<bool> values(operands.size(), false);
1014 byref = DenseBoolArrayAttr::get(context: ctx, content: values);
1015 }
1016
1017 llvm::interleaveComma(c: llvm::zip_equal(t&: operands, u&: argsSubrange, args&: symbols,
1018 args: mapIndices.asArrayRef(),
1019 args: byref.asArrayRef()),
1020 os&: p, each_fn: [&p](auto t) {
1021 auto [op, arg, sym, map, isByRef] = t;
1022 if (isByRef)
1023 p << "byref ";
1024 if (sym)
1025 p << sym << " ";
1026
1027 p << op << " -> " << arg;
1028
1029 if (map != -1)
1030 p << " [map_idx=" << map << "]";
1031 });
1032 p << " : ";
1033 llvm::interleaveComma(c: types, os&: p);
1034 p << ") ";
1035
1036 if (needsBarrier)
1037 p << getPrivateNeedsBarrierSpelling() << " ";
1038}
1039
1040static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
1041 StringRef clauseName, ValueRange argsSubrange,
1042 std::optional<MapPrintArgs> mapArgs) {
1043 if (mapArgs)
1044 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, operands: mapArgs->vars,
1045 types: mapArgs->types);
1046}
1047
1048static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
1049 StringRef clauseName, ValueRange argsSubrange,
1050 std::optional<PrivatePrintArgs> privateArgs) {
1051 if (privateArgs)
1052 printClauseWithRegionArgs(
1053 p, ctx, clauseName, argsSubrange, operands: privateArgs->vars, types: privateArgs->types,
1054 symbols: privateArgs->syms, mapIndices: privateArgs->mapIndices, /*byref=*/nullptr,
1055 /*modifier=*/nullptr, needsBarrier: privateArgs->needsBarrier);
1056}
1057
1058static void
1059printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1060 ValueRange argsSubrange,
1061 std::optional<ReductionPrintArgs> reductionArgs) {
1062 if (reductionArgs)
1063 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1064 operands: reductionArgs->vars, types: reductionArgs->types,
1065 symbols: reductionArgs->syms, /*mapIndices=*/nullptr,
1066 byref: reductionArgs->byref, modifier: reductionArgs->modifier);
1067}
1068
1069static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
1070 const AllRegionPrintArgs &args) {
1071 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(Val: op);
1072 MLIRContext *ctx = op->getContext();
1073
1074 printBlockArgClause(p, ctx, clauseName: "has_device_addr",
1075 argsSubrange: iface.getHasDeviceAddrBlockArgs(),
1076 mapArgs: args.hasDeviceAddrArgs);
1077 printBlockArgClause(p, ctx, clauseName: "host_eval", argsSubrange: iface.getHostEvalBlockArgs(),
1078 mapArgs: args.hostEvalArgs);
1079 printBlockArgClause(p, ctx, clauseName: "in_reduction", argsSubrange: iface.getInReductionBlockArgs(),
1080 reductionArgs: args.inReductionArgs);
1081 printBlockArgClause(p, ctx, clauseName: "map_entries", argsSubrange: iface.getMapBlockArgs(),
1082 mapArgs: args.mapArgs);
1083 printBlockArgClause(p, ctx, clauseName: "private", argsSubrange: iface.getPrivateBlockArgs(),
1084 privateArgs: args.privateArgs);
1085 printBlockArgClause(p, ctx, clauseName: "reduction", argsSubrange: iface.getReductionBlockArgs(),
1086 reductionArgs: args.reductionArgs);
1087 printBlockArgClause(p, ctx, clauseName: "task_reduction",
1088 argsSubrange: iface.getTaskReductionBlockArgs(),
1089 reductionArgs: args.taskReductionArgs);
1090 printBlockArgClause(p, ctx, clauseName: "use_device_addr",
1091 argsSubrange: iface.getUseDeviceAddrBlockArgs(),
1092 mapArgs: args.useDeviceAddrArgs);
1093 printBlockArgClause(p, ctx, clauseName: "use_device_ptr",
1094 argsSubrange: iface.getUseDevicePtrBlockArgs(), mapArgs: args.useDevicePtrArgs);
1095
1096 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
1097}
1098
1099// These parseXyz functions correspond to the custom<Xyz> definitions
1100// in the .td file(s).
1101static void printTargetOpRegion(
1102 OpAsmPrinter &p, Operation *op, Region &region,
1103 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1104 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1105 ValueRange inReductionVars, TypeRange inReductionTypes,
1106 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1107 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1108 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1109 DenseI64ArrayAttr privateMaps) {
1110 AllRegionPrintArgs args;
1111 args.hasDeviceAddrArgs.emplace(args&: hasDeviceAddrVars, args&: hasDeviceAddrTypes);
1112 args.hostEvalArgs.emplace(args&: hostEvalVars, args&: hostEvalTypes);
1113 args.inReductionArgs.emplace(args&: inReductionVars, args&: inReductionTypes,
1114 args&: inReductionByref, args&: inReductionSyms);
1115 args.mapArgs.emplace(args&: mapVars, args&: mapTypes);
1116 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
1117 args&: privateNeedsBarrier, args&: privateMaps);
1118 printBlockArgRegion(p, op, region, args);
1119}
1120
1121static void printInReductionPrivateRegion(
1122 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1123 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1124 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1125 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1126 AllRegionPrintArgs args;
1127 args.inReductionArgs.emplace(args&: inReductionVars, args&: inReductionTypes,
1128 args&: inReductionByref, args&: inReductionSyms);
1129 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
1130 args&: privateNeedsBarrier,
1131 /*mapIndices=*/args: nullptr);
1132 printBlockArgRegion(p, op, region, args);
1133}
1134
1135static void printInReductionPrivateReductionRegion(
1136 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1137 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1138 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1139 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1140 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1141 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1142 ArrayAttr reductionSyms) {
1143 AllRegionPrintArgs args;
1144 args.inReductionArgs.emplace(args&: inReductionVars, args&: inReductionTypes,
1145 args&: inReductionByref, args&: inReductionSyms);
1146 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
1147 args&: privateNeedsBarrier,
1148 /*mapIndices=*/args: nullptr);
1149 args.reductionArgs.emplace(args&: reductionVars, args&: reductionTypes, args&: reductionByref,
1150 args&: reductionSyms, args&: reductionMod);
1151 printBlockArgRegion(p, op, region, args);
1152}
1153
1154static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
1155 ValueRange privateVars, TypeRange privateTypes,
1156 ArrayAttr privateSyms,
1157 UnitAttr privateNeedsBarrier) {
1158 AllRegionPrintArgs args;
1159 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
1160 args&: privateNeedsBarrier,
1161 /*mapIndices=*/args: nullptr);
1162 printBlockArgRegion(p, op, region, args);
1163}
1164
1165static void printPrivateReductionRegion(
1166 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1167 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1168 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1169 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1170 ArrayAttr reductionSyms) {
1171 AllRegionPrintArgs args;
1172 args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms,
1173 args&: privateNeedsBarrier,
1174 /*mapIndices=*/args: nullptr);
1175 args.reductionArgs.emplace(args&: reductionVars, args&: reductionTypes, args&: reductionByref,
1176 args&: reductionSyms, args&: reductionMod);
1177 printBlockArgRegion(p, op, region, args);
1178}
1179
1180static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
1181 Region &region,
1182 ValueRange taskReductionVars,
1183 TypeRange taskReductionTypes,
1184 DenseBoolArrayAttr taskReductionByref,
1185 ArrayAttr taskReductionSyms) {
1186 AllRegionPrintArgs args;
1187 args.taskReductionArgs.emplace(args&: taskReductionVars, args&: taskReductionTypes,
1188 args&: taskReductionByref, args&: taskReductionSyms);
1189 printBlockArgRegion(p, op, region, args);
1190}
1191
1192static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op,
1193 Region &region,
1194 ValueRange useDeviceAddrVars,
1195 TypeRange useDeviceAddrTypes,
1196 ValueRange useDevicePtrVars,
1197 TypeRange useDevicePtrTypes) {
1198 AllRegionPrintArgs args;
1199 args.useDeviceAddrArgs.emplace(args&: useDeviceAddrVars, args&: useDeviceAddrTypes);
1200 args.useDevicePtrArgs.emplace(args&: useDevicePtrVars, args&: useDevicePtrTypes);
1201 printBlockArgRegion(p, op, region, args);
1202}
1203
1204/// Verifies Reduction Clause
1205static LogicalResult
1206verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1207 OperandRange reductionVars,
1208 std::optional<ArrayRef<bool>> reductionByref) {
1209 if (!reductionVars.empty()) {
1210 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1211 return op->emitOpError()
1212 << "expected as many reduction symbol references "
1213 "as reduction variables";
1214 if (reductionByref && reductionByref->size() != reductionVars.size())
1215 return op->emitError() << "expected as many reduction variable by "
1216 "reference attributes as reduction variables";
1217 } else {
1218 if (reductionSyms)
1219 return op->emitOpError() << "unexpected reduction symbol references";
1220 return success();
1221 }
1222
1223 // TODO: The followings should be done in
1224 // SymbolUserOpInterface::verifySymbolUses.
1225 DenseSet<Value> accumulators;
1226 for (auto args : llvm::zip(t&: reductionVars, u&: *reductionSyms)) {
1227 Value accum = std::get<0>(t&: args);
1228
1229 if (!accumulators.insert(V: accum).second)
1230 return op->emitOpError() << "accumulator variable used more than once";
1231
1232 Type varType = accum.getType();
1233 auto symbolRef = llvm::cast<SymbolRefAttr>(Val: std::get<1>(t&: args));
1234 auto decl =
1235 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(from: op, symbol: symbolRef);
1236 if (!decl)
1237 return op->emitOpError() << "expected symbol reference " << symbolRef
1238 << " to point to a reduction declaration";
1239
1240 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1241 return op->emitOpError()
1242 << "expected accumulator (" << varType
1243 << ") to be the same type as reduction declaration ("
1244 << decl.getAccumulatorType() << ")";
1245 }
1246
1247 return success();
1248}
1249
1250//===----------------------------------------------------------------------===//
1251// Parser, printer and verifier for Copyprivate
1252//===----------------------------------------------------------------------===//
1253
1254/// copyprivate-entry-list ::= copyprivate-entry
1255/// | copyprivate-entry-list `,` copyprivate-entry
1256/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1257static ParseResult parseCopyprivate(
1258 OpAsmParser &parser,
1259 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &copyprivateVars,
1260 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1261 SmallVector<SymbolRefAttr> symsVec;
1262 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1263 if (parser.parseOperand(result&: copyprivateVars.emplace_back()) ||
1264 parser.parseArrow() ||
1265 parser.parseAttribute(result&: symsVec.emplace_back()) ||
1266 parser.parseColonType(result&: copyprivateTypes.emplace_back()))
1267 return failure();
1268 return success();
1269 })))
1270 return failure();
1271 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1272 copyprivateSyms = ArrayAttr::get(context: parser.getContext(), value: syms);
1273 return success();
1274}
1275
1276/// Print Copyprivate clause
1277static void printCopyprivate(OpAsmPrinter &p, Operation *op,
1278 OperandRange copyprivateVars,
1279 TypeRange copyprivateTypes,
1280 std::optional<ArrayAttr> copyprivateSyms) {
1281 if (!copyprivateSyms.has_value())
1282 return;
1283 llvm::interleaveComma(
1284 c: llvm::zip(t&: copyprivateVars, u&: *copyprivateSyms, args&: copyprivateTypes), os&: p,
1285 each_fn: [&](const auto &args) {
1286 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1287 << std::get<2>(args);
1288 });
1289}
1290
1291/// Verifies CopyPrivate Clause
1292static LogicalResult
1293verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars,
1294 std::optional<ArrayAttr> copyprivateSyms) {
1295 size_t copyprivateSymsSize =
1296 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1297 if (copyprivateSymsSize != copyprivateVars.size())
1298 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1299 << copyprivateVars.size()
1300 << ") and functions (= " << copyprivateSymsSize
1301 << "), both must be equal";
1302 if (!copyprivateSyms.has_value())
1303 return success();
1304
1305 for (auto copyprivateVarAndSym :
1306 llvm::zip(t&: copyprivateVars, u&: *copyprivateSyms)) {
1307 auto symbolRef =
1308 llvm::cast<SymbolRefAttr>(Val: std::get<1>(t&: copyprivateVarAndSym));
1309 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1310 funcOp;
1311 if (mlir::func::FuncOp mlirFuncOp =
1312 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(from: op,
1313 symbol: symbolRef))
1314 funcOp = mlirFuncOp;
1315 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1316 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
1317 from: op, symbol: symbolRef))
1318 funcOp = llvmFuncOp;
1319
1320 auto getNumArguments = [&] {
1321 return std::visit(visitor: [](auto &f) { return f.getNumArguments(); }, variants&: *funcOp);
1322 };
1323
1324 auto getArgumentType = [&](unsigned i) {
1325 return std::visit(visitor: [i](auto &f) { return f.getArgumentTypes()[i]; },
1326 variants&: *funcOp);
1327 };
1328
1329 if (!funcOp)
1330 return op->emitOpError() << "expected symbol reference " << symbolRef
1331 << " to point to a copy function";
1332
1333 if (getNumArguments() != 2)
1334 return op->emitOpError()
1335 << "expected copy function " << symbolRef << " to have 2 operands";
1336
1337 Type argTy = getArgumentType(0);
1338 if (argTy != getArgumentType(1))
1339 return op->emitOpError() << "expected copy function " << symbolRef
1340 << " arguments to have the same type";
1341
1342 Type varType = std::get<0>(t&: copyprivateVarAndSym).getType();
1343 if (argTy != varType)
1344 return op->emitOpError()
1345 << "expected copy function arguments' type (" << argTy
1346 << ") to be the same as copyprivate variable's type (" << varType
1347 << ")";
1348 }
1349
1350 return success();
1351}
1352
1353//===----------------------------------------------------------------------===//
1354// Parser, printer and verifier for DependVarList
1355//===----------------------------------------------------------------------===//
1356
1357/// depend-entry-list ::= depend-entry
1358/// | depend-entry-list `,` depend-entry
1359/// depend-entry ::= depend-kind `->` ssa-id `:` type
1360static ParseResult
1361parseDependVarList(OpAsmParser &parser,
1362 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars,
1363 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1364 SmallVector<ClauseTaskDependAttr> kindsVec;
1365 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1366 StringRef keyword;
1367 if (parser.parseKeyword(keyword: &keyword) || parser.parseArrow() ||
1368 parser.parseOperand(result&: dependVars.emplace_back()) ||
1369 parser.parseColonType(result&: dependTypes.emplace_back()))
1370 return failure();
1371 if (std::optional<ClauseTaskDepend> keywordDepend =
1372 (symbolizeClauseTaskDepend(str: keyword)))
1373 kindsVec.emplace_back(
1374 Args: ClauseTaskDependAttr::get(context: parser.getContext(), value: *keywordDepend));
1375 else
1376 return failure();
1377 return success();
1378 })))
1379 return failure();
1380 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1381 dependKinds = ArrayAttr::get(context: parser.getContext(), value: kinds);
1382 return success();
1383}
1384
1385/// Print Depend clause
1386static void printDependVarList(OpAsmPrinter &p, Operation *op,
1387 OperandRange dependVars, TypeRange dependTypes,
1388 std::optional<ArrayAttr> dependKinds) {
1389
1390 for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1391 if (i != 0)
1392 p << ", ";
1393 p << stringifyClauseTaskDepend(
1394 val: llvm::cast<mlir::omp::ClauseTaskDependAttr>(Val: (*dependKinds)[i])
1395 .getValue())
1396 << " -> " << dependVars[i] << " : " << dependTypes[i];
1397 }
1398}
1399
1400/// Verifies Depend clause
1401static LogicalResult verifyDependVarList(Operation *op,
1402 std::optional<ArrayAttr> dependKinds,
1403 OperandRange dependVars) {
1404 if (!dependVars.empty()) {
1405 if (!dependKinds || dependKinds->size() != dependVars.size())
1406 return op->emitOpError() << "expected as many depend values"
1407 " as depend variables";
1408 } else {
1409 if (dependKinds && !dependKinds->empty())
1410 return op->emitOpError() << "unexpected depend values";
1411 return success();
1412 }
1413
1414 return success();
1415}
1416
1417//===----------------------------------------------------------------------===//
1418// Parser, printer and verifier for Synchronization Hint (2.17.12)
1419//===----------------------------------------------------------------------===//
1420
1421/// Parses a Synchronization Hint clause. The value of hint is an integer
1422/// which is a combination of different hints from `omp_sync_hint_t`.
1423///
1424/// hint-clause = `hint` `(` hint-value `)`
1425static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1426 IntegerAttr &hintAttr) {
1427 StringRef hintKeyword;
1428 int64_t hint = 0;
1429 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "none"))) {
1430 hintAttr = IntegerAttr::get(type: parser.getBuilder().getI64Type(), value: 0);
1431 return success();
1432 }
1433 auto parseKeyword = [&]() -> ParseResult {
1434 if (failed(Result: parser.parseKeyword(keyword: &hintKeyword)))
1435 return failure();
1436 if (hintKeyword == "uncontended")
1437 hint |= 1;
1438 else if (hintKeyword == "contended")
1439 hint |= 2;
1440 else if (hintKeyword == "nonspeculative")
1441 hint |= 4;
1442 else if (hintKeyword == "speculative")
1443 hint |= 8;
1444 else
1445 return parser.emitError(loc: parser.getCurrentLocation())
1446 << hintKeyword << " is not a valid hint";
1447 return success();
1448 };
1449 if (parser.parseCommaSeparatedList(parseElementFn: parseKeyword))
1450 return failure();
1451 hintAttr = IntegerAttr::get(type: parser.getBuilder().getI64Type(), value: hint);
1452 return success();
1453}
1454
1455/// Prints a Synchronization Hint clause
1456static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
1457 IntegerAttr hintAttr) {
1458 int64_t hint = hintAttr.getInt();
1459
1460 if (hint == 0) {
1461 p << "none";
1462 return;
1463 }
1464
1465 // Helper function to get n-th bit from the right end of `value`
1466 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1467
1468 bool uncontended = bitn(hint, 0);
1469 bool contended = bitn(hint, 1);
1470 bool nonspeculative = bitn(hint, 2);
1471 bool speculative = bitn(hint, 3);
1472
1473 SmallVector<StringRef> hints;
1474 if (uncontended)
1475 hints.push_back(Elt: "uncontended");
1476 if (contended)
1477 hints.push_back(Elt: "contended");
1478 if (nonspeculative)
1479 hints.push_back(Elt: "nonspeculative");
1480 if (speculative)
1481 hints.push_back(Elt: "speculative");
1482
1483 llvm::interleaveComma(c: hints, os&: p);
1484}
1485
1486/// Verifies a synchronization hint clause
1487static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1488
1489 // Helper function to get n-th bit from the right end of `value`
1490 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1491
1492 bool uncontended = bitn(hint, 0);
1493 bool contended = bitn(hint, 1);
1494 bool nonspeculative = bitn(hint, 2);
1495 bool speculative = bitn(hint, 3);
1496
1497 if (uncontended && contended)
1498 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1499 "omp_sync_hint_contended cannot be combined";
1500 if (nonspeculative && speculative)
1501 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1502 "omp_sync_hint_speculative cannot be combined.";
1503 return success();
1504}
1505
1506//===----------------------------------------------------------------------===//
1507// Parser, printer and verifier for Target
1508//===----------------------------------------------------------------------===//
1509
1510// Helper function to get bitwise AND of `value` and 'flag'
1511uint64_t mapTypeToBitFlag(uint64_t value,
1512 llvm::omp::OpenMPOffloadMappingFlags flag) {
1513 return value & llvm::to_underlying(E: flag);
1514}
1515
1516/// Parses a map_entries map type from a string format back into its numeric
1517/// value.
1518///
1519/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1520/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1521static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
1522 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
1523 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
1524
1525 // This simply verifies the correct keyword is read in, the
1526 // keyword itself is stored inside of the operation
1527 auto parseTypeAndMod = [&]() -> ParseResult {
1528 StringRef mapTypeMod;
1529 if (parser.parseKeyword(keyword: &mapTypeMod))
1530 return failure();
1531
1532 if (mapTypeMod == "always")
1533 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
1534
1535 if (mapTypeMod == "implicit")
1536 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
1537
1538 if (mapTypeMod == "ompx_hold")
1539 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD;
1540
1541 if (mapTypeMod == "close")
1542 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
1543
1544 if (mapTypeMod == "present")
1545 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
1546
1547 if (mapTypeMod == "to")
1548 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
1549
1550 if (mapTypeMod == "from")
1551 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1552
1553 if (mapTypeMod == "tofrom")
1554 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
1555 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
1556
1557 if (mapTypeMod == "delete")
1558 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
1559
1560 if (mapTypeMod == "return_param")
1561 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
1562
1563 return success();
1564 };
1565
1566 if (parser.parseCommaSeparatedList(parseElementFn: parseTypeAndMod))
1567 return failure();
1568
1569 mapType = parser.getBuilder().getIntegerAttr(
1570 type: parser.getBuilder().getIntegerType(width: 64, /*isSigned=*/false),
1571 value: llvm::to_underlying(E: mapTypeBits));
1572
1573 return success();
1574}
1575
1576/// Prints a map_entries map type from its numeric value out into its string
1577/// format.
1578static void printMapClause(OpAsmPrinter &p, Operation *op,
1579 IntegerAttr mapType) {
1580 uint64_t mapTypeBits = mapType.getUInt();
1581
1582 bool emitAllocRelease = true;
1583 llvm::SmallVector<std::string, 4> mapTypeStrs;
1584
1585 // handling of always, close, present placed at the beginning of the string
1586 // to aid readability
1587 if (mapTypeToBitFlag(value: mapTypeBits,
1588 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
1589 mapTypeStrs.push_back(Elt: "always");
1590 if (mapTypeToBitFlag(value: mapTypeBits,
1591 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
1592 mapTypeStrs.push_back(Elt: "implicit");
1593 if (mapTypeToBitFlag(value: mapTypeBits,
1594 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD))
1595 mapTypeStrs.push_back(Elt: "ompx_hold");
1596 if (mapTypeToBitFlag(value: mapTypeBits,
1597 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
1598 mapTypeStrs.push_back(Elt: "close");
1599 if (mapTypeToBitFlag(value: mapTypeBits,
1600 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
1601 mapTypeStrs.push_back(Elt: "present");
1602
1603 // special handling of to/from/tofrom/delete and release/alloc, release +
1604 // alloc are the abscense of one of the other flags, whereas tofrom requires
1605 // both the to and from flag to be set.
1606 bool to = mapTypeToBitFlag(value: mapTypeBits,
1607 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1608 bool from = mapTypeToBitFlag(
1609 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1610 if (to && from) {
1611 emitAllocRelease = false;
1612 mapTypeStrs.push_back(Elt: "tofrom");
1613 } else if (from) {
1614 emitAllocRelease = false;
1615 mapTypeStrs.push_back(Elt: "from");
1616 } else if (to) {
1617 emitAllocRelease = false;
1618 mapTypeStrs.push_back(Elt: "to");
1619 }
1620 if (mapTypeToBitFlag(value: mapTypeBits,
1621 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
1622 emitAllocRelease = false;
1623 mapTypeStrs.push_back(Elt: "delete");
1624 }
1625 if (mapTypeToBitFlag(
1626 value: mapTypeBits,
1627 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) {
1628 emitAllocRelease = false;
1629 mapTypeStrs.push_back(Elt: "return_param");
1630 }
1631 if (emitAllocRelease)
1632 mapTypeStrs.push_back(Elt: "exit_release_or_enter_alloc");
1633
1634 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1635 p << mapTypeStrs[i];
1636 if (i + 1 < mapTypeStrs.size()) {
1637 p << ", ";
1638 }
1639 }
1640}
1641
1642static ParseResult parseMembersIndex(OpAsmParser &parser,
1643 ArrayAttr &membersIdx) {
1644 SmallVector<Attribute> values, memberIdxs;
1645
1646 auto parseIndices = [&]() -> ParseResult {
1647 int64_t value;
1648 if (parser.parseInteger(result&: value))
1649 return failure();
1650 values.push_back(Elt: IntegerAttr::get(type: parser.getBuilder().getIntegerType(width: 64),
1651 value: APInt(64, value, /*isSigned=*/false)));
1652 return success();
1653 };
1654
1655 do {
1656 if (failed(Result: parser.parseLSquare()))
1657 return failure();
1658
1659 if (parser.parseCommaSeparatedList(parseElementFn: parseIndices))
1660 return failure();
1661
1662 if (failed(Result: parser.parseRSquare()))
1663 return failure();
1664
1665 memberIdxs.push_back(Elt: ArrayAttr::get(context: parser.getContext(), value: values));
1666 values.clear();
1667 } while (succeeded(Result: parser.parseOptionalComma()));
1668
1669 if (!memberIdxs.empty())
1670 membersIdx = ArrayAttr::get(context: parser.getContext(), value: memberIdxs);
1671
1672 return success();
1673}
1674
1675static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1676 ArrayAttr membersIdx) {
1677 if (!membersIdx)
1678 return;
1679
1680 llvm::interleaveComma(c: membersIdx, os&: p, each_fn: [&p](Attribute v) {
1681 p << "[";
1682 auto memberIdx = cast<ArrayAttr>(Val&: v);
1683 llvm::interleaveComma(c: memberIdx.getValue(), os&: p, each_fn: [&p](Attribute v2) {
1684 p << cast<IntegerAttr>(Val&: v2).getInt();
1685 });
1686 p << "]";
1687 });
1688}
1689
1690static void printCaptureType(OpAsmPrinter &p, Operation *op,
1691 VariableCaptureKindAttr mapCaptureType) {
1692 std::string typeCapStr;
1693 llvm::raw_string_ostream typeCap(typeCapStr);
1694 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1695 typeCap << "ByRef";
1696 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1697 typeCap << "ByCopy";
1698 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1699 typeCap << "VLAType";
1700 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1701 typeCap << "This";
1702 p << typeCapStr;
1703}
1704
1705static ParseResult parseCaptureType(OpAsmParser &parser,
1706 VariableCaptureKindAttr &mapCaptureType) {
1707 StringRef mapCaptureKey;
1708 if (parser.parseKeyword(keyword: &mapCaptureKey))
1709 return failure();
1710
1711 if (mapCaptureKey == "This")
1712 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1713 context: parser.getContext(), value: mlir::omp::VariableCaptureKind::This);
1714 if (mapCaptureKey == "ByRef")
1715 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1716 context: parser.getContext(), value: mlir::omp::VariableCaptureKind::ByRef);
1717 if (mapCaptureKey == "ByCopy")
1718 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1719 context: parser.getContext(), value: mlir::omp::VariableCaptureKind::ByCopy);
1720 if (mapCaptureKey == "VLAType")
1721 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1722 context: parser.getContext(), value: mlir::omp::VariableCaptureKind::VLAType);
1723
1724 return success();
1725}
1726
1727static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1728 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars;
1729 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars;
1730
1731 for (auto mapOp : mapVars) {
1732 if (!mapOp.getDefiningOp())
1733 return emitError(loc: op->getLoc(), message: "missing map operation");
1734
1735 if (auto mapInfoOp =
1736 mlir::dyn_cast<mlir::omp::MapInfoOp>(Val: mapOp.getDefiningOp())) {
1737 uint64_t mapTypeBits = mapInfoOp.getMapType();
1738
1739 bool to = mapTypeToBitFlag(
1740 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1741 bool from = mapTypeToBitFlag(
1742 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1743 bool del = mapTypeToBitFlag(
1744 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1745
1746 bool always = mapTypeToBitFlag(
1747 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1748 bool close = mapTypeToBitFlag(
1749 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1750 bool implicit = mapTypeToBitFlag(
1751 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1752
1753 if ((isa<TargetDataOp>(Val: op) || isa<TargetOp>(Val: op)) && del)
1754 return emitError(loc: op->getLoc(),
1755 message: "to, from, tofrom and alloc map types are permitted");
1756
1757 if (isa<TargetEnterDataOp>(Val: op) && (from || del))
1758 return emitError(loc: op->getLoc(), message: "to and alloc map types are permitted");
1759
1760 if (isa<TargetExitDataOp>(Val: op) && to)
1761 return emitError(loc: op->getLoc(),
1762 message: "from, release and delete map types are permitted");
1763
1764 if (isa<TargetUpdateOp>(Val: op)) {
1765 if (del) {
1766 return emitError(loc: op->getLoc(),
1767 message: "at least one of to or from map types must be "
1768 "specified, other map types are not permitted");
1769 }
1770
1771 if (!to && !from) {
1772 return emitError(loc: op->getLoc(),
1773 message: "at least one of to or from map types must be "
1774 "specified, other map types are not permitted");
1775 }
1776
1777 auto updateVar = mapInfoOp.getVarPtr();
1778
1779 if ((to && from) || (to && updateFromVars.contains(V: updateVar)) ||
1780 (from && updateToVars.contains(V: updateVar))) {
1781 return emitError(
1782 loc: op->getLoc(),
1783 message: "either to or from map types can be specified, not both");
1784 }
1785
1786 if (always || close || implicit) {
1787 return emitError(
1788 loc: op->getLoc(),
1789 message: "present, mapper and iterator map type modifiers are permitted");
1790 }
1791
1792 to ? updateToVars.insert(V: updateVar) : updateFromVars.insert(V: updateVar);
1793 }
1794 } else if (!isa<DeclareMapperInfoOp>(Val: op)) {
1795 return emitError(loc: op->getLoc(),
1796 message: "map argument is not a map entry operation");
1797 }
1798 }
1799
1800 return success();
1801}
1802
1803static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
1804 std::optional<DenseI64ArrayAttr> privateMapIndices =
1805 targetOp.getPrivateMapsAttr();
1806
1807 // None of the private operands are mapped.
1808 if (!privateMapIndices.has_value() || !privateMapIndices.value())
1809 return success();
1810
1811 OperandRange privateVars = targetOp.getPrivateVars();
1812
1813 if (privateMapIndices.value().size() !=
1814 static_cast<int64_t>(privateVars.size()))
1815 return emitError(loc: targetOp.getLoc(), message: "sizes of `private` operand range and "
1816 "`private_maps` attribute mismatch");
1817
1818 return success();
1819}
1820
1821//===----------------------------------------------------------------------===//
1822// MapInfoOp
1823//===----------------------------------------------------------------------===//
1824
1825static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
1826 StringRef clauseName,
1827 OperandRange vars) {
1828 for (Value var : vars)
1829 if (!llvm::isa_and_present<MapInfoOp>(Val: var.getDefiningOp()))
1830 return op->emitOpError()
1831 << "'" << clauseName
1832 << "' arguments must be defined by 'omp.map.info' ops";
1833 return success();
1834}
1835
1836LogicalResult MapInfoOp::verify() {
1837 if (getMapperId() &&
1838 !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
1839 from: *this, symbol: getMapperIdAttr())) {
1840 return emitError(message: "invalid mapper id");
1841 }
1842
1843 if (failed(Result: verifyMapInfoDefinedArgs(op: *this, clauseName: "members", vars: getMembers())))
1844 return failure();
1845
1846 return success();
1847}
1848
1849//===----------------------------------------------------------------------===//
1850// TargetDataOp
1851//===----------------------------------------------------------------------===//
1852
1853void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1854 const TargetDataOperands &clauses) {
1855 TargetDataOp::build(odsBuilder&: builder, odsState&: state, device: clauses.device, if_expr: clauses.ifExpr,
1856 map_vars: clauses.mapVars, use_device_addr_vars: clauses.useDeviceAddrVars,
1857 use_device_ptr_vars: clauses.useDevicePtrVars);
1858}
1859
1860LogicalResult TargetDataOp::verify() {
1861 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
1862 getUseDeviceAddrVars().empty()) {
1863 return ::emitError(loc: this->getLoc(),
1864 message: "At least one of map, use_device_ptr_vars, or "
1865 "use_device_addr_vars operand must be present");
1866 }
1867
1868 if (failed(Result: verifyMapInfoDefinedArgs(op: *this, clauseName: "use_device_ptr",
1869 vars: getUseDevicePtrVars())))
1870 return failure();
1871
1872 if (failed(Result: verifyMapInfoDefinedArgs(op: *this, clauseName: "use_device_addr",
1873 vars: getUseDeviceAddrVars())))
1874 return failure();
1875
1876 return verifyMapClause(op: *this, mapVars: getMapVars());
1877}
1878
1879//===----------------------------------------------------------------------===//
1880// TargetEnterDataOp
1881//===----------------------------------------------------------------------===//
1882
1883void TargetEnterDataOp::build(
1884 OpBuilder &builder, OperationState &state,
1885 const TargetEnterExitUpdateDataOperands &clauses) {
1886 MLIRContext *ctx = builder.getContext();
1887 TargetEnterDataOp::build(odsBuilder&: builder, odsState&: state,
1888 depend_kinds: makeArrayAttr(context: ctx, attrs: clauses.dependKinds),
1889 depend_vars: clauses.dependVars, device: clauses.device, if_expr: clauses.ifExpr,
1890 map_vars: clauses.mapVars, nowait: clauses.nowait);
1891}
1892
1893LogicalResult TargetEnterDataOp::verify() {
1894 LogicalResult verifyDependVars =
1895 verifyDependVarList(op: *this, dependKinds: getDependKinds(), dependVars: getDependVars());
1896 return failed(Result: verifyDependVars) ? verifyDependVars
1897 : verifyMapClause(op: *this, mapVars: getMapVars());
1898}
1899
1900//===----------------------------------------------------------------------===//
1901// TargetExitDataOp
1902//===----------------------------------------------------------------------===//
1903
1904void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
1905 const TargetEnterExitUpdateDataOperands &clauses) {
1906 MLIRContext *ctx = builder.getContext();
1907 TargetExitDataOp::build(odsBuilder&: builder, odsState&: state,
1908 depend_kinds: makeArrayAttr(context: ctx, attrs: clauses.dependKinds),
1909 depend_vars: clauses.dependVars, device: clauses.device, if_expr: clauses.ifExpr,
1910 map_vars: clauses.mapVars, nowait: clauses.nowait);
1911}
1912
1913LogicalResult TargetExitDataOp::verify() {
1914 LogicalResult verifyDependVars =
1915 verifyDependVarList(op: *this, dependKinds: getDependKinds(), dependVars: getDependVars());
1916 return failed(Result: verifyDependVars) ? verifyDependVars
1917 : verifyMapClause(op: *this, mapVars: getMapVars());
1918}
1919
1920//===----------------------------------------------------------------------===//
1921// TargetUpdateOp
1922//===----------------------------------------------------------------------===//
1923
1924void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1925 const TargetEnterExitUpdateDataOperands &clauses) {
1926 MLIRContext *ctx = builder.getContext();
1927 TargetUpdateOp::build(odsBuilder&: builder, odsState&: state, depend_kinds: makeArrayAttr(context: ctx, attrs: clauses.dependKinds),
1928 depend_vars: clauses.dependVars, device: clauses.device, if_expr: clauses.ifExpr,
1929 map_vars: clauses.mapVars, nowait: clauses.nowait);
1930}
1931
1932LogicalResult TargetUpdateOp::verify() {
1933 LogicalResult verifyDependVars =
1934 verifyDependVarList(op: *this, dependKinds: getDependKinds(), dependVars: getDependVars());
1935 return failed(Result: verifyDependVars) ? verifyDependVars
1936 : verifyMapClause(op: *this, mapVars: getMapVars());
1937}
1938
1939//===----------------------------------------------------------------------===//
1940// TargetOp
1941//===----------------------------------------------------------------------===//
1942
1943void TargetOp::build(OpBuilder &builder, OperationState &state,
1944 const TargetOperands &clauses) {
1945 MLIRContext *ctx = builder.getContext();
1946 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1947 // inReductionByref, inReductionSyms.
1948 TargetOp::build(odsBuilder&: builder, odsState&: state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
1949 bare: clauses.bare, depend_kinds: makeArrayAttr(context: ctx, attrs: clauses.dependKinds),
1950 depend_vars: clauses.dependVars, device: clauses.device, has_device_addr_vars: clauses.hasDeviceAddrVars,
1951 host_eval_vars: clauses.hostEvalVars, if_expr: clauses.ifExpr,
1952 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1953 /*in_reduction_syms=*/nullptr, is_device_ptr_vars: clauses.isDevicePtrVars,
1954 map_vars: clauses.mapVars, nowait: clauses.nowait, private_vars: clauses.privateVars,
1955 private_syms: makeArrayAttr(context: ctx, attrs: clauses.privateSyms),
1956 private_needs_barrier: clauses.privateNeedsBarrier, thread_limit: clauses.threadLimit,
1957 /*private_maps=*/nullptr);
1958}
1959
1960LogicalResult TargetOp::verify() {
1961 if (failed(Result: verifyDependVarList(op: *this, dependKinds: getDependKinds(), dependVars: getDependVars())))
1962 return failure();
1963
1964 if (failed(Result: verifyMapInfoDefinedArgs(op: *this, clauseName: "has_device_addr",
1965 vars: getHasDeviceAddrVars())))
1966 return failure();
1967
1968 if (failed(Result: verifyMapClause(op: *this, mapVars: getMapVars())))
1969 return failure();
1970
1971 return verifyPrivateVarsMapping(targetOp: *this);
1972}
1973
1974LogicalResult TargetOp::verifyRegions() {
1975 auto teamsOps = getOps<TeamsOp>();
1976 if (std::distance(first: teamsOps.begin(), last: teamsOps.end()) > 1)
1977 return emitError(message: "target containing multiple 'omp.teams' nested ops");
1978
1979 // Check that host_eval values are only used in legal ways.
1980 Operation *capturedOp = getInnermostCapturedOmpOp();
1981 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
1982 for (Value hostEvalArg :
1983 cast<BlockArgOpenMPOpInterface>(Val: getOperation()).getHostEvalBlockArgs()) {
1984 for (Operation *user : hostEvalArg.getUsers()) {
1985 if (auto teamsOp = dyn_cast<TeamsOp>(Val: user)) {
1986 if (llvm::is_contained(Set: {teamsOp.getNumTeamsLower(),
1987 teamsOp.getNumTeamsUpper(),
1988 teamsOp.getThreadLimit()},
1989 Element: hostEvalArg))
1990 continue;
1991
1992 return emitOpError() << "host_eval argument only legal as 'num_teams' "
1993 "and 'thread_limit' in 'omp.teams'";
1994 }
1995 if (auto parallelOp = dyn_cast<ParallelOp>(Val: user)) {
1996 if (bitEnumContainsAny(bits: execFlags, bit: TargetRegionFlags::spmd) &&
1997 parallelOp->isAncestor(other: capturedOp) &&
1998 hostEvalArg == parallelOp.getNumThreads())
1999 continue;
2000
2001 return emitOpError()
2002 << "host_eval argument only legal as 'num_threads' in "
2003 "'omp.parallel' when representing target SPMD";
2004 }
2005 if (auto loopNestOp = dyn_cast<LoopNestOp>(Val: user)) {
2006 if (bitEnumContainsAny(bits: execFlags, bit: TargetRegionFlags::trip_count) &&
2007 loopNestOp.getOperation() == capturedOp &&
2008 (llvm::is_contained(Range: loopNestOp.getLoopLowerBounds(), Element: hostEvalArg) ||
2009 llvm::is_contained(Range: loopNestOp.getLoopUpperBounds(), Element: hostEvalArg) ||
2010 llvm::is_contained(Range: loopNestOp.getLoopSteps(), Element: hostEvalArg)))
2011 continue;
2012
2013 return emitOpError() << "host_eval argument only legal as loop bounds "
2014 "and steps in 'omp.loop_nest' when trip count "
2015 "must be evaluated in the host";
2016 }
2017
2018 return emitOpError() << "host_eval argument illegal use in '"
2019 << user->getName() << "' operation";
2020 }
2021 }
2022 return success();
2023}
2024
2025static Operation *
2026findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2027 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2028 assert(rootOp && "expected valid operation");
2029
2030 Dialect *ompDialect = rootOp->getDialect();
2031 Operation *capturedOp = nullptr;
2032 DominanceInfo domInfo;
2033
2034 // Process in pre-order to check operations from outermost to innermost,
2035 // ensuring we only enter the region of an operation if it meets the criteria
2036 // for being captured. We stop the exploration of nested operations as soon as
2037 // we process a region holding no operations to be captured.
2038 rootOp->walk<WalkOrder::PreOrder>(callback: [&](Operation *op) {
2039 if (op == rootOp)
2040 return WalkResult::advance();
2041
2042 // Ignore operations of other dialects or omp operations with no regions,
2043 // because these will only be checked if they are siblings of an omp
2044 // operation that can potentially be captured.
2045 bool isOmpDialect = op->getDialect() == ompDialect;
2046 bool hasRegions = op->getNumRegions() > 0;
2047 if (!isOmpDialect || !hasRegions)
2048 return WalkResult::skip();
2049
2050 // This operation cannot be captured if it can be executed more than once
2051 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2052 // be executed before all exits of the region (i.e. it doesn't dominate all
2053 // blocks with no successors reachable from the entry block).
2054 if (checkSingleMandatoryExec) {
2055 Region *parentRegion = op->getParentRegion();
2056 Block *parentBlock = op->getBlock();
2057
2058 for (Block *successor : parentBlock->getSuccessors())
2059 if (successor->isReachable(other: parentBlock))
2060 return WalkResult::interrupt();
2061
2062 for (Block &block : *parentRegion)
2063 if (domInfo.isReachableFromEntry(a: &block) && block.hasNoSuccessors() &&
2064 !domInfo.dominates(a: parentBlock, b: &block))
2065 return WalkResult::interrupt();
2066 }
2067
2068 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2069 // into nested operations.
2070 for (Operation &sibling : op->getParentRegion()->getOps())
2071 if (&sibling != op && !siblingAllowedFn(&sibling))
2072 return WalkResult::interrupt();
2073
2074 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2075 // Otherwise, process the contents of this operation.
2076 capturedOp = op;
2077 return llvm::isa<LoopNestOp>(Val: op) ? WalkResult::interrupt()
2078 : WalkResult::advance();
2079 });
2080
2081 return capturedOp;
2082}
2083
2084Operation *TargetOp::getInnermostCapturedOmpOp() {
2085 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2086
2087 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2088 // effects, but don't include a memory write effect.
2089 return findCapturedOmpOp(
2090 rootOp: *this, /*checkSingleMandatoryExec=*/true, siblingAllowedFn: [&](Operation *sibling) {
2091 if (!sibling)
2092 return false;
2093
2094 if (ompDialect == sibling->getDialect())
2095 return sibling->hasTrait<OpTrait::IsTerminator>();
2096
2097 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(Val: sibling)) {
2098 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2099 effects;
2100 memOp.getEffects(effects);
2101 return !llvm::any_of(
2102 Range&: effects, P: [&](MemoryEffects::EffectInstance &effect) {
2103 return isa<MemoryEffects::Write>(Val: effect.getEffect()) &&
2104 isa<SideEffects::AutomaticAllocationScopeResource>(
2105 Val: effect.getResource());
2106 });
2107 }
2108 return true;
2109 });
2110}
2111
2112TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2113 // A non-null captured op is only valid if it resides inside of a TargetOp
2114 // and is the result of calling getInnermostCapturedOmpOp() on it.
2115 TargetOp targetOp =
2116 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2117 assert((!capturedOp ||
2118 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2119 "unexpected captured op");
2120
2121 // If it's not capturing a loop, it's a default target region.
2122 if (!isa_and_present<LoopNestOp>(Val: capturedOp))
2123 return TargetRegionFlags::generic;
2124
2125 // Get the innermost non-simd loop wrapper.
2126 SmallVector<LoopWrapperInterface> loopWrappers;
2127 cast<LoopNestOp>(Val: capturedOp).gatherWrappers(wrappers&: loopWrappers);
2128 assert(!loopWrappers.empty());
2129
2130 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2131 if (isa<SimdOp>(Val: innermostWrapper))
2132 innermostWrapper = std::next(x: innermostWrapper);
2133
2134 auto numWrappers = std::distance(first: innermostWrapper, last: loopWrappers.end());
2135 if (numWrappers != 1 && numWrappers != 2)
2136 return TargetRegionFlags::generic;
2137
2138 // Detect target-teams-distribute-parallel-wsloop[-simd].
2139 if (numWrappers == 2) {
2140 if (!isa<WsloopOp>(Val: innermostWrapper))
2141 return TargetRegionFlags::generic;
2142
2143 innermostWrapper = std::next(x: innermostWrapper);
2144 if (!isa<DistributeOp>(Val: innermostWrapper))
2145 return TargetRegionFlags::generic;
2146
2147 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2148 if (!isa_and_present<ParallelOp>(Val: parallelOp))
2149 return TargetRegionFlags::generic;
2150
2151 Operation *teamsOp = parallelOp->getParentOp();
2152 if (!isa_and_present<TeamsOp>(Val: teamsOp))
2153 return TargetRegionFlags::generic;
2154
2155 if (teamsOp->getParentOp() == targetOp.getOperation())
2156 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2157 }
2158 // Detect target-teams-distribute[-simd] and target-teams-loop.
2159 else if (isa<DistributeOp, LoopOp>(Val: innermostWrapper)) {
2160 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2161 if (!isa_and_present<TeamsOp>(Val: teamsOp))
2162 return TargetRegionFlags::generic;
2163
2164 if (teamsOp->getParentOp() != targetOp.getOperation())
2165 return TargetRegionFlags::generic;
2166
2167 if (isa<LoopOp>(Val: innermostWrapper))
2168 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2169
2170 // Find single immediately nested captured omp.parallel and add spmd flag
2171 // (generic-spmd case).
2172 //
2173 // TODO: This shouldn't have to be done here, as it is too easy to break.
2174 // The openmp-opt pass should be updated to be able to promote kernels like
2175 // this from "Generic" to "Generic-SPMD". However, the use of the
2176 // `kmpc_distribute_static_loop` family of functions produced by the
2177 // OMPIRBuilder for these kernels prevents that from working.
2178 Dialect *ompDialect = targetOp->getDialect();
2179 Operation *nestedCapture = findCapturedOmpOp(
2180 rootOp: capturedOp, /*checkSingleMandatoryExec=*/false,
2181 siblingAllowedFn: [&](Operation *sibling) {
2182 return sibling && (ompDialect != sibling->getDialect() ||
2183 sibling->hasTrait<OpTrait::IsTerminator>());
2184 });
2185
2186 TargetRegionFlags result =
2187 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2188
2189 if (!nestedCapture)
2190 return result;
2191
2192 while (nestedCapture->getParentOp() != capturedOp)
2193 nestedCapture = nestedCapture->getParentOp();
2194
2195 return isa<ParallelOp>(Val: nestedCapture) ? result | TargetRegionFlags::spmd
2196 : result;
2197 }
2198 // Detect target-parallel-wsloop[-simd].
2199 else if (isa<WsloopOp>(Val: innermostWrapper)) {
2200 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2201 if (!isa_and_present<ParallelOp>(Val: parallelOp))
2202 return TargetRegionFlags::generic;
2203
2204 if (parallelOp->getParentOp() == targetOp.getOperation())
2205 return TargetRegionFlags::spmd;
2206 }
2207
2208 return TargetRegionFlags::generic;
2209}
2210
2211//===----------------------------------------------------------------------===//
2212// ParallelOp
2213//===----------------------------------------------------------------------===//
2214
2215void ParallelOp::build(OpBuilder &builder, OperationState &state,
2216 ArrayRef<NamedAttribute> attributes) {
2217 ParallelOp::build(odsBuilder&: builder, odsState&: state, /*allocate_vars=*/ValueRange(),
2218 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2219 /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2220 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2221 /*proc_bind_kind=*/nullptr,
2222 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2223 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2224 state.addAttributes(newAttributes: attributes);
2225}
2226
2227void ParallelOp::build(OpBuilder &builder, OperationState &state,
2228 const ParallelOperands &clauses) {
2229 MLIRContext *ctx = builder.getContext();
2230 ParallelOp::build(odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars, allocator_vars: clauses.allocatorVars,
2231 if_expr: clauses.ifExpr, num_threads: clauses.numThreads, private_vars: clauses.privateVars,
2232 private_syms: makeArrayAttr(context: ctx, attrs: clauses.privateSyms),
2233 private_needs_barrier: clauses.privateNeedsBarrier, proc_bind_kind: clauses.procBindKind,
2234 reduction_mod: clauses.reductionMod, reduction_vars: clauses.reductionVars,
2235 reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.reductionByref),
2236 reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.reductionSyms));
2237}
2238
2239template <typename OpType>
2240static LogicalResult verifyPrivateVarList(OpType &op) {
2241 auto privateVars = op.getPrivateVars();
2242 auto privateSyms = op.getPrivateSymsAttr();
2243
2244 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2245 return success();
2246
2247 auto numPrivateVars = privateVars.size();
2248 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2249
2250 if (numPrivateVars != numPrivateSyms)
2251 return op.emitError() << "inconsistent number of private variables and "
2252 "privatizer op symbols, private vars: "
2253 << numPrivateVars
2254 << " vs. privatizer op symbols: " << numPrivateSyms;
2255
2256 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2257 Type varType = std::get<0>(privateVarInfo).getType();
2258 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2259 PrivateClauseOp privatizerOp =
2260 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym);
2261
2262 if (privatizerOp == nullptr)
2263 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2264 << privateSym << "'";
2265
2266 Type privatizerType = privatizerOp.getArgType();
2267
2268 if (privatizerType && (varType != privatizerType))
2269 return op.emitError()
2270 << "type mismatch between a "
2271 << (privatizerOp.getDataSharingType() ==
2272 DataSharingClauseType::Private
2273 ? "private"
2274 : "firstprivate")
2275 << " variable and its privatizer op, var type: " << varType
2276 << " vs. privatizer op type: " << privatizerType;
2277 }
2278
2279 return success();
2280}
2281
2282LogicalResult ParallelOp::verify() {
2283 if (getAllocateVars().size() != getAllocatorVars().size())
2284 return emitError(
2285 message: "expected equal sizes for allocate and allocator variables");
2286
2287 if (failed(Result: verifyPrivateVarList(op&: *this)))
2288 return failure();
2289
2290 return verifyReductionVarList(op: *this, reductionSyms: getReductionSyms(), reductionVars: getReductionVars(),
2291 reductionByref: getReductionByref());
2292}
2293
2294LogicalResult ParallelOp::verifyRegions() {
2295 auto distChildOps = getOps<DistributeOp>();
2296 int numDistChildOps = std::distance(first: distChildOps.begin(), last: distChildOps.end());
2297 if (numDistChildOps > 1)
2298 return emitError()
2299 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2300
2301 if (numDistChildOps == 1) {
2302 if (!isComposite())
2303 return emitError()
2304 << "'omp.composite' attribute missing from composite operation";
2305
2306 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2307 Operation &distributeOp = **distChildOps.begin();
2308 for (Operation &childOp : getOps()) {
2309 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2310 continue;
2311
2312 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2313 return emitError() << "unexpected OpenMP operation inside of composite "
2314 "'omp.parallel': "
2315 << childOp.getName();
2316 }
2317 } else if (isComposite()) {
2318 return emitError()
2319 << "'omp.composite' attribute present in non-composite operation";
2320 }
2321 return success();
2322}
2323
2324//===----------------------------------------------------------------------===//
2325// TeamsOp
2326//===----------------------------------------------------------------------===//
2327
2328static bool opInGlobalImplicitParallelRegion(Operation *op) {
2329 while ((op = op->getParentOp()))
2330 if (isa<OpenMPDialect>(Val: op->getDialect()))
2331 return false;
2332 return true;
2333}
2334
2335void TeamsOp::build(OpBuilder &builder, OperationState &state,
2336 const TeamsOperands &clauses) {
2337 MLIRContext *ctx = builder.getContext();
2338 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2339 TeamsOp::build(odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars, allocator_vars: clauses.allocatorVars,
2340 if_expr: clauses.ifExpr, num_teams_lower: clauses.numTeamsLower, num_teams_upper: clauses.numTeamsUpper,
2341 /*private_vars=*/{}, /*private_syms=*/nullptr,
2342 /*private_needs_barrier=*/nullptr, reduction_mod: clauses.reductionMod,
2343 reduction_vars: clauses.reductionVars,
2344 reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.reductionByref),
2345 reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.reductionSyms),
2346 thread_limit: clauses.threadLimit);
2347}
2348
2349LogicalResult TeamsOp::verify() {
2350 // Check parent region
2351 // TODO If nested inside of a target region, also check that it does not
2352 // contain any statements, declarations or directives other than this
2353 // omp.teams construct. The issue is how to support the initialization of
2354 // this operation's own arguments (allow SSA values across omp.target?).
2355 Operation *op = getOperation();
2356 if (!isa<TargetOp>(Val: op->getParentOp()) &&
2357 !opInGlobalImplicitParallelRegion(op))
2358 return emitError(message: "expected to be nested inside of omp.target or not nested "
2359 "in any OpenMP dialect operations");
2360
2361 // Check for num_teams clause restrictions
2362 if (auto numTeamsLowerBound = getNumTeamsLower()) {
2363 auto numTeamsUpperBound = getNumTeamsUpper();
2364 if (!numTeamsUpperBound)
2365 return emitError(message: "expected num_teams upper bound to be defined if the "
2366 "lower bound is defined");
2367 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2368 return emitError(
2369 message: "expected num_teams upper bound and lower bound to be the same type");
2370 }
2371
2372 // Check for allocate clause restrictions
2373 if (getAllocateVars().size() != getAllocatorVars().size())
2374 return emitError(
2375 message: "expected equal sizes for allocate and allocator variables");
2376
2377 return verifyReductionVarList(op: *this, reductionSyms: getReductionSyms(), reductionVars: getReductionVars(),
2378 reductionByref: getReductionByref());
2379}
2380
2381//===----------------------------------------------------------------------===//
2382// SectionOp
2383//===----------------------------------------------------------------------===//
2384
2385OperandRange SectionOp::getPrivateVars() {
2386 return getParentOp().getPrivateVars();
2387}
2388
2389OperandRange SectionOp::getReductionVars() {
2390 return getParentOp().getReductionVars();
2391}
2392
2393//===----------------------------------------------------------------------===//
2394// SectionsOp
2395//===----------------------------------------------------------------------===//
2396
2397void SectionsOp::build(OpBuilder &builder, OperationState &state,
2398 const SectionsOperands &clauses) {
2399 MLIRContext *ctx = builder.getContext();
2400 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2401 SectionsOp::build(odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars, allocator_vars: clauses.allocatorVars,
2402 nowait: clauses.nowait, /*private_vars=*/{},
2403 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2404 reduction_mod: clauses.reductionMod, reduction_vars: clauses.reductionVars,
2405 reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.reductionByref),
2406 reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.reductionSyms));
2407}
2408
2409LogicalResult SectionsOp::verify() {
2410 if (getAllocateVars().size() != getAllocatorVars().size())
2411 return emitError(
2412 message: "expected equal sizes for allocate and allocator variables");
2413
2414 return verifyReductionVarList(op: *this, reductionSyms: getReductionSyms(), reductionVars: getReductionVars(),
2415 reductionByref: getReductionByref());
2416}
2417
2418LogicalResult SectionsOp::verifyRegions() {
2419 for (auto &inst : *getRegion().begin()) {
2420 if (!(isa<SectionOp>(Val: inst) || isa<TerminatorOp>(Val: inst))) {
2421 return emitOpError()
2422 << "expected omp.section op or terminator op inside region";
2423 }
2424 }
2425
2426 return success();
2427}
2428
2429//===----------------------------------------------------------------------===//
2430// SingleOp
2431//===----------------------------------------------------------------------===//
2432
2433void SingleOp::build(OpBuilder &builder, OperationState &state,
2434 const SingleOperands &clauses) {
2435 MLIRContext *ctx = builder.getContext();
2436 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2437 SingleOp::build(odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars, allocator_vars: clauses.allocatorVars,
2438 copyprivate_vars: clauses.copyprivateVars,
2439 copyprivate_syms: makeArrayAttr(context: ctx, attrs: clauses.copyprivateSyms), nowait: clauses.nowait,
2440 /*private_vars=*/{}, /*private_syms=*/nullptr,
2441 /*private_needs_barrier=*/nullptr);
2442}
2443
2444LogicalResult SingleOp::verify() {
2445 // Check for allocate clause restrictions
2446 if (getAllocateVars().size() != getAllocatorVars().size())
2447 return emitError(
2448 message: "expected equal sizes for allocate and allocator variables");
2449
2450 return verifyCopyprivateVarList(op: *this, copyprivateVars: getCopyprivateVars(),
2451 copyprivateSyms: getCopyprivateSyms());
2452}
2453
2454//===----------------------------------------------------------------------===//
2455// WorkshareOp
2456//===----------------------------------------------------------------------===//
2457
2458void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2459 const WorkshareOperands &clauses) {
2460 WorkshareOp::build(odsBuilder&: builder, odsState&: state, nowait: clauses.nowait);
2461}
2462
2463//===----------------------------------------------------------------------===//
2464// WorkshareLoopWrapperOp
2465//===----------------------------------------------------------------------===//
2466
2467LogicalResult WorkshareLoopWrapperOp::verify() {
2468 if (!(*this)->getParentOfType<WorkshareOp>())
2469 return emitOpError() << "must be nested in an omp.workshare";
2470 return success();
2471}
2472
2473LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2474 if (isa_and_nonnull<LoopWrapperInterface>(Val: (*this)->getParentOp()) ||
2475 getNestedWrapper())
2476 return emitOpError() << "expected to be a standalone loop wrapper";
2477
2478 return success();
2479}
2480
2481//===----------------------------------------------------------------------===//
2482// LoopWrapperInterface
2483//===----------------------------------------------------------------------===//
2484
2485LogicalResult LoopWrapperInterface::verifyImpl() {
2486 Operation *op = this->getOperation();
2487 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2488 !op->hasTrait<OpTrait::SingleBlock>())
2489 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2490 "and `SingleBlock` traits";
2491
2492 if (op->getNumRegions() != 1)
2493 return emitOpError() << "loop wrapper does not contain exactly one region";
2494
2495 Region &region = op->getRegion(index: 0);
2496 if (range_size(Range: region.getOps()) != 1)
2497 return emitOpError()
2498 << "loop wrapper does not contain exactly one nested op";
2499
2500 Operation &firstOp = *region.op_begin();
2501 if (!isa<LoopNestOp, LoopWrapperInterface>(Val: firstOp))
2502 return emitOpError() << "nested in loop wrapper is not another loop "
2503 "wrapper or `omp.loop_nest`";
2504
2505 return success();
2506}
2507
2508//===----------------------------------------------------------------------===//
2509// LoopOp
2510//===----------------------------------------------------------------------===//
2511
2512void LoopOp::build(OpBuilder &builder, OperationState &state,
2513 const LoopOperands &clauses) {
2514 MLIRContext *ctx = builder.getContext();
2515
2516 LoopOp::build(odsBuilder&: builder, odsState&: state, bind_kind: clauses.bindKind, private_vars: clauses.privateVars,
2517 private_syms: makeArrayAttr(context: ctx, attrs: clauses.privateSyms),
2518 private_needs_barrier: clauses.privateNeedsBarrier, order: clauses.order, order_mod: clauses.orderMod,
2519 reduction_mod: clauses.reductionMod, reduction_vars: clauses.reductionVars,
2520 reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.reductionByref),
2521 reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.reductionSyms));
2522}
2523
2524LogicalResult LoopOp::verify() {
2525 return verifyReductionVarList(op: *this, reductionSyms: getReductionSyms(), reductionVars: getReductionVars(),
2526 reductionByref: getReductionByref());
2527}
2528
2529LogicalResult LoopOp::verifyRegions() {
2530 if (llvm::isa_and_nonnull<LoopWrapperInterface>(Val: (*this)->getParentOp()) ||
2531 getNestedWrapper())
2532 return emitOpError() << "expected to be a standalone loop wrapper";
2533
2534 return success();
2535}
2536
2537//===----------------------------------------------------------------------===//
2538// WsloopOp
2539//===----------------------------------------------------------------------===//
2540
2541void WsloopOp::build(OpBuilder &builder, OperationState &state,
2542 ArrayRef<NamedAttribute> attributes) {
2543 build(odsBuilder&: builder, odsState&: state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2544 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2545 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2546 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2547 /*private_needs_barrier=*/false,
2548 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2549 /*reduction_byref=*/nullptr,
2550 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2551 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2552 /*schedule_simd=*/false);
2553 state.addAttributes(newAttributes: attributes);
2554}
2555
2556void WsloopOp::build(OpBuilder &builder, OperationState &state,
2557 const WsloopOperands &clauses) {
2558 MLIRContext *ctx = builder.getContext();
2559 // TODO: Store clauses in op: allocateVars, allocatorVars
2560 WsloopOp::build(
2561 odsBuilder&: builder, odsState&: state,
2562 /*allocate_vars=*/{}, /*allocator_vars=*/{}, linear_vars: clauses.linearVars,
2563 linear_step_vars: clauses.linearStepVars, nowait: clauses.nowait, order: clauses.order, order_mod: clauses.orderMod,
2564 ordered: clauses.ordered, private_vars: clauses.privateVars,
2565 private_syms: makeArrayAttr(context: ctx, attrs: clauses.privateSyms), private_needs_barrier: clauses.privateNeedsBarrier,
2566 reduction_mod: clauses.reductionMod, reduction_vars: clauses.reductionVars,
2567 reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.reductionByref),
2568 reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.reductionSyms), schedule_kind: clauses.scheduleKind,
2569 schedule_chunk: clauses.scheduleChunk, schedule_mod: clauses.scheduleMod, schedule_simd: clauses.scheduleSimd);
2570}
2571
2572LogicalResult WsloopOp::verify() {
2573 return verifyReductionVarList(op: *this, reductionSyms: getReductionSyms(), reductionVars: getReductionVars(),
2574 reductionByref: getReductionByref());
2575}
2576
2577LogicalResult WsloopOp::verifyRegions() {
2578 bool isCompositeChildLeaf =
2579 llvm::dyn_cast_if_present<LoopWrapperInterface>(Val: (*this)->getParentOp());
2580
2581 if (LoopWrapperInterface nested = getNestedWrapper()) {
2582 if (!isComposite())
2583 return emitError()
2584 << "'omp.composite' attribute missing from composite wrapper";
2585
2586 // Check for the allowed leaf constructs that may appear in a composite
2587 // construct directly after DO/FOR.
2588 if (!isa<SimdOp>(Val: nested))
2589 return emitError() << "only supported nested wrapper is 'omp.simd'";
2590
2591 } else if (isComposite() && !isCompositeChildLeaf) {
2592 return emitError()
2593 << "'omp.composite' attribute present in non-composite wrapper";
2594 } else if (!isComposite() && isCompositeChildLeaf) {
2595 return emitError()
2596 << "'omp.composite' attribute missing from composite wrapper";
2597 }
2598
2599 return success();
2600}
2601
2602//===----------------------------------------------------------------------===//
2603// Simd construct [2.9.3.1]
2604//===----------------------------------------------------------------------===//
2605
2606void SimdOp::build(OpBuilder &builder, OperationState &state,
2607 const SimdOperands &clauses) {
2608 MLIRContext *ctx = builder.getContext();
2609 // TODO Store clauses in op: linearVars, linearStepVars
2610 SimdOp::build(odsBuilder&: builder, odsState&: state, aligned_vars: clauses.alignedVars,
2611 alignments: makeArrayAttr(context: ctx, attrs: clauses.alignments), if_expr: clauses.ifExpr,
2612 /*linear_vars=*/{}, /*linear_step_vars=*/{},
2613 nontemporal_vars: clauses.nontemporalVars, order: clauses.order, order_mod: clauses.orderMod,
2614 private_vars: clauses.privateVars, private_syms: makeArrayAttr(context: ctx, attrs: clauses.privateSyms),
2615 private_needs_barrier: clauses.privateNeedsBarrier, reduction_mod: clauses.reductionMod,
2616 reduction_vars: clauses.reductionVars,
2617 reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.reductionByref),
2618 reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.reductionSyms), safelen: clauses.safelen,
2619 simdlen: clauses.simdlen);
2620}
2621
2622LogicalResult SimdOp::verify() {
2623 if (getSimdlen().has_value() && getSafelen().has_value() &&
2624 getSimdlen().value() > getSafelen().value())
2625 return emitOpError()
2626 << "simdlen clause and safelen clause are both present, but the "
2627 "simdlen value is not less than or equal to safelen value";
2628
2629 if (verifyAlignedClause(op: *this, alignments: getAlignments(), alignedVars: getAlignedVars()).failed())
2630 return failure();
2631
2632 if (verifyNontemporalClause(op: *this, nontemporalVars: getNontemporalVars()).failed())
2633 return failure();
2634
2635 bool isCompositeChildLeaf =
2636 llvm::dyn_cast_if_present<LoopWrapperInterface>(Val: (*this)->getParentOp());
2637
2638 if (!isComposite() && isCompositeChildLeaf)
2639 return emitError()
2640 << "'omp.composite' attribute missing from composite wrapper";
2641
2642 if (isComposite() && !isCompositeChildLeaf)
2643 return emitError()
2644 << "'omp.composite' attribute present in non-composite wrapper";
2645
2646 // Firstprivate is not allowed for SIMD in the standard. Check that none of
2647 // the private decls are for firstprivate.
2648 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2649 if (privateSyms) {
2650 for (const Attribute &sym : *privateSyms) {
2651 auto symRef = cast<SymbolRefAttr>(Val: sym);
2652 omp::PrivateClauseOp privatizer =
2653 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
2654 from: getOperation(), symbol: symRef);
2655 if (!privatizer)
2656 return emitError() << "Cannot find privatizer '" << symRef << "'";
2657 if (privatizer.getDataSharingType() ==
2658 DataSharingClauseType::FirstPrivate)
2659 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
2660 }
2661 }
2662
2663 return success();
2664}
2665
2666LogicalResult SimdOp::verifyRegions() {
2667 if (getNestedWrapper())
2668 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2669
2670 return success();
2671}
2672
2673//===----------------------------------------------------------------------===//
2674// Distribute construct [2.9.4.1]
2675//===----------------------------------------------------------------------===//
2676
2677void DistributeOp::build(OpBuilder &builder, OperationState &state,
2678 const DistributeOperands &clauses) {
2679 DistributeOp::build(odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars,
2680 allocator_vars: clauses.allocatorVars, dist_schedule_static: clauses.distScheduleStatic,
2681 dist_schedule_chunk_size: clauses.distScheduleChunkSize, order: clauses.order,
2682 order_mod: clauses.orderMod, private_vars: clauses.privateVars,
2683 private_syms: makeArrayAttr(context: builder.getContext(), attrs: clauses.privateSyms),
2684 private_needs_barrier: clauses.privateNeedsBarrier);
2685}
2686
2687LogicalResult DistributeOp::verify() {
2688 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2689 return emitOpError() << "chunk size set without "
2690 "dist_schedule_static being present";
2691
2692 if (getAllocateVars().size() != getAllocatorVars().size())
2693 return emitError(
2694 message: "expected equal sizes for allocate and allocator variables");
2695
2696 return success();
2697}
2698
2699LogicalResult DistributeOp::verifyRegions() {
2700 if (LoopWrapperInterface nested = getNestedWrapper()) {
2701 if (!isComposite())
2702 return emitError()
2703 << "'omp.composite' attribute missing from composite wrapper";
2704 // Check for the allowed leaf constructs that may appear in a composite
2705 // construct directly after DISTRIBUTE.
2706 if (isa<WsloopOp>(Val: nested)) {
2707 Operation *parentOp = (*this)->getParentOp();
2708 if (!llvm::dyn_cast_if_present<ParallelOp>(Val: parentOp) ||
2709 !cast<ComposableOpInterface>(Val: parentOp).isComposite()) {
2710 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2711 "when a composite 'omp.parallel' is the direct "
2712 "parent";
2713 }
2714 } else if (!isa<SimdOp>(Val: nested))
2715 return emitError() << "only supported nested wrappers are 'omp.simd' and "
2716 "'omp.wsloop'";
2717 } else if (isComposite()) {
2718 return emitError()
2719 << "'omp.composite' attribute present in non-composite wrapper";
2720 }
2721
2722 return success();
2723}
2724
2725//===----------------------------------------------------------------------===//
2726// DeclareMapperOp / DeclareMapperInfoOp
2727//===----------------------------------------------------------------------===//
2728
2729LogicalResult DeclareMapperInfoOp::verify() {
2730 return verifyMapClause(op: *this, mapVars: getMapVars());
2731}
2732
2733LogicalResult DeclareMapperOp::verifyRegions() {
2734 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
2735 Val: getRegion().getBlocks().front().getTerminator()))
2736 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
2737
2738 return success();
2739}
2740
2741//===----------------------------------------------------------------------===//
2742// DeclareReductionOp
2743//===----------------------------------------------------------------------===//
2744
2745LogicalResult DeclareReductionOp::verifyRegions() {
2746 if (!getAllocRegion().empty()) {
2747 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
2748 if (yieldOp.getResults().size() != 1 ||
2749 yieldOp.getResults().getTypes()[0] != getType())
2750 return emitOpError() << "expects alloc region to yield a value "
2751 "of the reduction type";
2752 }
2753 }
2754
2755 if (getInitializerRegion().empty())
2756 return emitOpError() << "expects non-empty initializer region";
2757 Block &initializerEntryBlock = getInitializerRegion().front();
2758
2759 if (initializerEntryBlock.getNumArguments() == 1) {
2760 if (!getAllocRegion().empty())
2761 return emitOpError() << "expects two arguments to the initializer region "
2762 "when an allocation region is used";
2763 } else if (initializerEntryBlock.getNumArguments() == 2) {
2764 if (getAllocRegion().empty())
2765 return emitOpError() << "expects one argument to the initializer region "
2766 "when no allocation region is used";
2767 } else {
2768 return emitOpError()
2769 << "expects one or two arguments to the initializer region";
2770 }
2771
2772 for (mlir::Value arg : initializerEntryBlock.getArguments())
2773 if (arg.getType() != getType())
2774 return emitOpError() << "expects initializer region argument to match "
2775 "the reduction type";
2776
2777 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
2778 if (yieldOp.getResults().size() != 1 ||
2779 yieldOp.getResults().getTypes()[0] != getType())
2780 return emitOpError() << "expects initializer region to yield a value "
2781 "of the reduction type";
2782 }
2783
2784 if (getReductionRegion().empty())
2785 return emitOpError() << "expects non-empty reduction region";
2786 Block &reductionEntryBlock = getReductionRegion().front();
2787 if (reductionEntryBlock.getNumArguments() != 2 ||
2788 reductionEntryBlock.getArgumentTypes()[0] !=
2789 reductionEntryBlock.getArgumentTypes()[1] ||
2790 reductionEntryBlock.getArgumentTypes()[0] != getType())
2791 return emitOpError() << "expects reduction region with two arguments of "
2792 "the reduction type";
2793 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
2794 if (yieldOp.getResults().size() != 1 ||
2795 yieldOp.getResults().getTypes()[0] != getType())
2796 return emitOpError() << "expects reduction region to yield a value "
2797 "of the reduction type";
2798 }
2799
2800 if (!getAtomicReductionRegion().empty()) {
2801 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
2802 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
2803 atomicReductionEntryBlock.getArgumentTypes()[0] !=
2804 atomicReductionEntryBlock.getArgumentTypes()[1])
2805 return emitOpError() << "expects atomic reduction region with two "
2806 "arguments of the same type";
2807 auto ptrType = llvm::dyn_cast<PointerLikeType>(
2808 Val: atomicReductionEntryBlock.getArgumentTypes()[0]);
2809 if (!ptrType ||
2810 (ptrType.getElementType() && ptrType.getElementType() != getType()))
2811 return emitOpError() << "expects atomic reduction region arguments to "
2812 "be accumulators containing the reduction type";
2813 }
2814
2815 if (getCleanupRegion().empty())
2816 return success();
2817 Block &cleanupEntryBlock = getCleanupRegion().front();
2818 if (cleanupEntryBlock.getNumArguments() != 1 ||
2819 cleanupEntryBlock.getArgument(i: 0).getType() != getType())
2820 return emitOpError() << "expects cleanup region with one argument "
2821 "of the reduction type";
2822
2823 return success();
2824}
2825
2826//===----------------------------------------------------------------------===//
2827// TaskOp
2828//===----------------------------------------------------------------------===//
2829
2830void TaskOp::build(OpBuilder &builder, OperationState &state,
2831 const TaskOperands &clauses) {
2832 MLIRContext *ctx = builder.getContext();
2833 TaskOp::build(odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars, allocator_vars: clauses.allocatorVars,
2834 depend_kinds: makeArrayAttr(context: ctx, attrs: clauses.dependKinds), depend_vars: clauses.dependVars,
2835 final: clauses.final, if_expr: clauses.ifExpr, in_reduction_vars: clauses.inReductionVars,
2836 in_reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.inReductionByref),
2837 in_reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.inReductionSyms), mergeable: clauses.mergeable,
2838 priority: clauses.priority, /*private_vars=*/clauses.privateVars,
2839 /*private_syms=*/makeArrayAttr(context: ctx, attrs: clauses.privateSyms),
2840 private_needs_barrier: clauses.privateNeedsBarrier, untied: clauses.untied,
2841 event_handle: clauses.eventHandle);
2842}
2843
2844LogicalResult TaskOp::verify() {
2845 LogicalResult verifyDependVars =
2846 verifyDependVarList(op: *this, dependKinds: getDependKinds(), dependVars: getDependVars());
2847 return failed(Result: verifyDependVars)
2848 ? verifyDependVars
2849 : verifyReductionVarList(op: *this, reductionSyms: getInReductionSyms(),
2850 reductionVars: getInReductionVars(),
2851 reductionByref: getInReductionByref());
2852}
2853
2854//===----------------------------------------------------------------------===//
2855// TaskgroupOp
2856//===----------------------------------------------------------------------===//
2857
2858void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
2859 const TaskgroupOperands &clauses) {
2860 MLIRContext *ctx = builder.getContext();
2861 TaskgroupOp::build(odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars,
2862 allocator_vars: clauses.allocatorVars, task_reduction_vars: clauses.taskReductionVars,
2863 task_reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.taskReductionByref),
2864 task_reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.taskReductionSyms));
2865}
2866
2867LogicalResult TaskgroupOp::verify() {
2868 return verifyReductionVarList(op: *this, reductionSyms: getTaskReductionSyms(),
2869 reductionVars: getTaskReductionVars(),
2870 reductionByref: getTaskReductionByref());
2871}
2872
2873//===----------------------------------------------------------------------===//
2874// TaskloopOp
2875//===----------------------------------------------------------------------===//
2876
2877void TaskloopOp::build(OpBuilder &builder, OperationState &state,
2878 const TaskloopOperands &clauses) {
2879 MLIRContext *ctx = builder.getContext();
2880 TaskloopOp::build(
2881 odsBuilder&: builder, odsState&: state, allocate_vars: clauses.allocateVars, allocator_vars: clauses.allocatorVars,
2882 final: clauses.final, grainsize_mod: clauses.grainsizeMod, grainsize: clauses.grainsize, if_expr: clauses.ifExpr,
2883 in_reduction_vars: clauses.inReductionVars,
2884 in_reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.inReductionByref),
2885 in_reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.inReductionSyms), mergeable: clauses.mergeable,
2886 nogroup: clauses.nogroup, num_tasks_mod: clauses.numTasksMod, num_tasks: clauses.numTasks, priority: clauses.priority,
2887 /*private_vars=*/clauses.privateVars,
2888 /*private_syms=*/makeArrayAttr(context: ctx, attrs: clauses.privateSyms),
2889 private_needs_barrier: clauses.privateNeedsBarrier, reduction_mod: clauses.reductionMod, reduction_vars: clauses.reductionVars,
2890 reduction_byref: makeDenseBoolArrayAttr(ctx, boolArray: clauses.reductionByref),
2891 reduction_syms: makeArrayAttr(context: ctx, attrs: clauses.reductionSyms), untied: clauses.untied);
2892}
2893
2894LogicalResult TaskloopOp::verify() {
2895 if (getAllocateVars().size() != getAllocatorVars().size())
2896 return emitError(
2897 message: "expected equal sizes for allocate and allocator variables");
2898 if (failed(Result: verifyReductionVarList(op: *this, reductionSyms: getReductionSyms(),
2899 reductionVars: getReductionVars(), reductionByref: getReductionByref())) ||
2900 failed(Result: verifyReductionVarList(op: *this, reductionSyms: getInReductionSyms(),
2901 reductionVars: getInReductionVars(),
2902 reductionByref: getInReductionByref())))
2903 return failure();
2904
2905 if (!getReductionVars().empty() && getNogroup())
2906 return emitError(message: "if a reduction clause is present on the taskloop "
2907 "directive, the nogroup clause must not be specified");
2908 for (auto var : getReductionVars()) {
2909 if (llvm::is_contained(Range: getInReductionVars(), Element: var))
2910 return emitError(message: "the same list item cannot appear in both a reduction "
2911 "and an in_reduction clause");
2912 }
2913
2914 if (getGrainsize() && getNumTasks()) {
2915 return emitError(
2916 message: "the grainsize clause and num_tasks clause are mutually exclusive and "
2917 "may not appear on the same taskloop directive");
2918 }
2919
2920 return success();
2921}
2922
2923LogicalResult TaskloopOp::verifyRegions() {
2924 if (LoopWrapperInterface nested = getNestedWrapper()) {
2925 if (!isComposite())
2926 return emitError()
2927 << "'omp.composite' attribute missing from composite wrapper";
2928
2929 // Check for the allowed leaf constructs that may appear in a composite
2930 // construct directly after TASKLOOP.
2931 if (!isa<SimdOp>(Val: nested))
2932 return emitError() << "only supported nested wrapper is 'omp.simd'";
2933 } else if (isComposite()) {
2934 return emitError()
2935 << "'omp.composite' attribute present in non-composite wrapper";
2936 }
2937
2938 return success();
2939}
2940
2941//===----------------------------------------------------------------------===//
2942// LoopNestOp
2943//===----------------------------------------------------------------------===//
2944
2945ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
2946 // Parse an opening `(` followed by induction variables followed by `)`
2947 SmallVector<OpAsmParser::Argument> ivs;
2948 SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs;
2949 Type loopVarType;
2950 if (parser.parseArgumentList(result&: ivs, delimiter: OpAsmParser::Delimiter::Paren) ||
2951 parser.parseColonType(result&: loopVarType) ||
2952 // Parse loop bounds.
2953 parser.parseEqual() ||
2954 parser.parseOperandList(result&: lbs, requiredOperandCount: ivs.size(), delimiter: OpAsmParser::Delimiter::Paren) ||
2955 parser.parseKeyword(keyword: "to") ||
2956 parser.parseOperandList(result&: ubs, requiredOperandCount: ivs.size(), delimiter: OpAsmParser::Delimiter::Paren))
2957 return failure();
2958
2959 for (auto &iv : ivs)
2960 iv.type = loopVarType;
2961
2962 // Parse "inclusive" flag.
2963 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "inclusive")))
2964 result.addAttribute(name: "loop_inclusive",
2965 attr: UnitAttr::get(context: parser.getBuilder().getContext()));
2966
2967 // Parse step values.
2968 SmallVector<OpAsmParser::UnresolvedOperand> steps;
2969 if (parser.parseKeyword(keyword: "step") ||
2970 parser.parseOperandList(result&: steps, requiredOperandCount: ivs.size(), delimiter: OpAsmParser::Delimiter::Paren))
2971 return failure();
2972
2973 // Parse the body.
2974 Region *region = result.addRegion();
2975 if (parser.parseRegion(region&: *region, arguments: ivs))
2976 return failure();
2977
2978 // Resolve operands.
2979 if (parser.resolveOperands(operands&: lbs, type: loopVarType, result&: result.operands) ||
2980 parser.resolveOperands(operands&: ubs, type: loopVarType, result&: result.operands) ||
2981 parser.resolveOperands(operands&: steps, type: loopVarType, result&: result.operands))
2982 return failure();
2983
2984 // Parse the optional attribute list.
2985 return parser.parseOptionalAttrDict(result&: result.attributes);
2986}
2987
2988void LoopNestOp::print(OpAsmPrinter &p) {
2989 Region &region = getRegion();
2990 auto args = region.getArguments();
2991 p << " (" << args << ") : " << args[0].getType() << " = ("
2992 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
2993 if (getLoopInclusive())
2994 p << "inclusive ";
2995 p << "step (" << getLoopSteps() << ") ";
2996 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
2997}
2998
2999void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3000 const LoopNestOperands &clauses) {
3001 LoopNestOp::build(odsBuilder&: builder, odsState&: state, loop_lower_bounds: clauses.loopLowerBounds,
3002 loop_upper_bounds: clauses.loopUpperBounds, loop_steps: clauses.loopSteps,
3003 loop_inclusive: clauses.loopInclusive);
3004}
3005
3006LogicalResult LoopNestOp::verify() {
3007 if (getLoopLowerBounds().empty())
3008 return emitOpError() << "must represent at least one loop";
3009
3010 if (getLoopLowerBounds().size() != getIVs().size())
3011 return emitOpError() << "number of range arguments and IVs do not match";
3012
3013 for (auto [lb, iv] : llvm::zip_equal(t: getLoopLowerBounds(), u: getIVs())) {
3014 if (lb.getType() != iv.getType())
3015 return emitOpError()
3016 << "range argument type does not match corresponding IV type";
3017 }
3018
3019 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>(Val: (*this)->getParentOp()))
3020 return emitOpError() << "expects parent op to be a loop wrapper";
3021
3022 return success();
3023}
3024
3025void LoopNestOp::gatherWrappers(
3026 SmallVectorImpl<LoopWrapperInterface> &wrappers) {
3027 Operation *parent = (*this)->getParentOp();
3028 while (auto wrapper =
3029 llvm::dyn_cast_if_present<LoopWrapperInterface>(Val: parent)) {
3030 wrappers.push_back(Elt: wrapper);
3031 parent = parent->getParentOp();
3032 }
3033}
3034
3035//===----------------------------------------------------------------------===//
3036// OpenMP canonical loop handling
3037//===----------------------------------------------------------------------===//
3038
3039std::tuple<NewCliOp, OpOperand *, OpOperand *>
3040mlir::omp ::decodeCli(Value cli) {
3041
3042 // Defining a CLI for a generated loop is optional; if there is none then
3043 // there is no followup-tranformation
3044 if (!cli)
3045 return {{}, nullptr, nullptr};
3046
3047 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3048 "Unexpected type of cli");
3049
3050 NewCliOp create = cast<NewCliOp>(Val: cli.getDefiningOp());
3051 OpOperand *gen = nullptr;
3052 OpOperand *cons = nullptr;
3053 for (OpOperand &use : cli.getUses()) {
3054 auto op = cast<LoopTransformationInterface>(Val: use.getOwner());
3055
3056 unsigned opnum = use.getOperandNumber();
3057 if (op.isGeneratee(opnum)) {
3058 assert(!gen && "Each CLI may have at most one def");
3059 gen = &use;
3060 } else if (op.isApplyee(opnum)) {
3061 assert(!cons && "Each CLI may have at most one consumer");
3062 cons = &use;
3063 } else {
3064 llvm_unreachable("Unexpected operand for a CLI");
3065 }
3066 }
3067
3068 return {create, gen, cons};
3069}
3070
3071void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3072 ::mlir::OperationState &odsState) {
3073 odsState.addTypes(newTypes: CanonicalLoopInfoType::get(ctx: odsBuilder.getContext()));
3074}
3075
3076void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3077 Value result = getResult();
3078 auto [newCli, gen, cons] = decodeCli(cli: result);
3079
3080 // Derive the CLI variable name from its generator:
3081 // * "canonloop" for omp.canonical_loop
3082 // * custom name for loop transformation generatees
3083 // * "cli" as fallback if no generator
3084 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3085 // at that level
3086 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3087 // the index of that region
3088 std::string cliName{"cli"};
3089 if (gen) {
3090 cliName =
3091 TypeSwitch<Operation *, std::string>(gen->getOwner())
3092 .Case(caseFn: [&](CanonicalLoopOp op) {
3093 // Find the canonical loop nesting: For each ancestor add a
3094 // "+_r<idx>" suffix (in reverse order)
3095 SmallVector<std::string> components;
3096 Operation *o = op.getOperation();
3097 while (o) {
3098 if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
3099 break;
3100
3101 Region *r = o->getParentRegion();
3102 if (!r)
3103 break;
3104
3105 auto getSequentialIndex = [](Region *r, Operation *o) {
3106 llvm::ReversePostOrderTraversal<Block *> traversal(
3107 &r->getBlocks().front());
3108 size_t idx = 0;
3109 for (Block *b : traversal) {
3110 for (Operation &op : *b) {
3111 if (&op == o)
3112 return idx;
3113 // Only consider operations that are containers as
3114 // possible children
3115 if (!op.getRegions().empty())
3116 idx += 1;
3117 }
3118 }
3119 llvm_unreachable("Operation not part of the region");
3120 };
3121 size_t sequentialIdx = getSequentialIndex(r, o);
3122 components.push_back(Elt: ("s" + Twine(sequentialIdx)).str());
3123
3124 Operation *parent = r->getParentOp();
3125 if (!parent)
3126 break;
3127
3128 // If the operation has more than one region, also count in
3129 // which of the regions
3130 if (parent->getRegions().size() > 1) {
3131 auto getRegionIndex = [](Operation *o, Region *r) {
3132 for (auto [idx, region] :
3133 llvm::enumerate(First: o->getRegions())) {
3134 if (&region == r)
3135 return idx;
3136 }
3137 llvm_unreachable("Region not child its parent operation");
3138 };
3139 size_t regionIdx = getRegionIndex(parent, r);
3140 components.push_back(Elt: ("r" + Twine(regionIdx)).str());
3141 }
3142
3143 // next parent
3144 o = parent;
3145 }
3146
3147 SmallString<64> Name("canonloop");
3148 for (std::string s : reverse(C&: components)) {
3149 Name += '_';
3150 Name += s;
3151 }
3152
3153 return Name;
3154 })
3155 .Case(caseFn: [&](UnrollHeuristicOp op) -> std::string {
3156 llvm_unreachable("heuristic unrolling does not generate a loop");
3157 })
3158 .Default(defaultFn: [&](Operation *op) {
3159 assert(false && "TODO: Custom name for this operation");
3160 return "transformed";
3161 });
3162 }
3163
3164 setNameFn(result, cliName);
3165}
3166
3167LogicalResult NewCliOp::verify() {
3168 Value cli = getResult();
3169
3170 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3171 "Unexpected type of cli");
3172
3173 // Check that the CLI is used in at most generator and one consumer
3174 OpOperand *gen = nullptr;
3175 OpOperand *cons = nullptr;
3176 for (mlir::OpOperand &use : cli.getUses()) {
3177 auto op = cast<mlir::omp::LoopTransformationInterface>(Val: use.getOwner());
3178
3179 unsigned opnum = use.getOperandNumber();
3180 if (op.isGeneratee(opnum)) {
3181 if (gen) {
3182 InFlightDiagnostic error =
3183 emitOpError(message: "CLI must have at most one generator");
3184 error.attachNote(noteLoc: gen->getOwner()->getLoc())
3185 .append(arg: "first generator here:");
3186 error.attachNote(noteLoc: use.getOwner()->getLoc())
3187 .append(arg: "second generator here:");
3188 return error;
3189 }
3190
3191 gen = &use;
3192 } else if (op.isApplyee(opnum)) {
3193 if (cons) {
3194 InFlightDiagnostic error =
3195 emitOpError(message: "CLI must have at most one consumer");
3196 error.attachNote(noteLoc: cons->getOwner()->getLoc())
3197 .append(arg: "first consumer here:")
3198 .appendOp(op&: *cons->getOwner(),
3199 flags: OpPrintingFlags().printGenericOpForm());
3200 error.attachNote(noteLoc: use.getOwner()->getLoc())
3201 .append(arg: "second consumer here:")
3202 .appendOp(op&: *use.getOwner(), flags: OpPrintingFlags().printGenericOpForm());
3203 return error;
3204 }
3205
3206 cons = &use;
3207 } else {
3208 llvm_unreachable("Unexpected operand for a CLI");
3209 }
3210 }
3211
3212 // If the CLI is source of a transformation, it must have a generator
3213 if (cons && !gen) {
3214 InFlightDiagnostic error = emitOpError(message: "CLI has no generator");
3215 error.attachNote(noteLoc: cons->getOwner()->getLoc())
3216 .append(arg: "see consumer here: ")
3217 .appendOp(op&: *cons->getOwner(), flags: OpPrintingFlags().printGenericOpForm());
3218 return error;
3219 }
3220
3221 return success();
3222}
3223
3224void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3225 Value tripCount) {
3226 odsState.addOperands(newOperands: tripCount);
3227 odsState.addOperands(newOperands: Value());
3228 (void)odsState.addRegion();
3229}
3230
3231void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3232 Value tripCount, ::mlir::Value cli) {
3233 odsState.addOperands(newOperands: tripCount);
3234 odsState.addOperands(newOperands: cli);
3235 (void)odsState.addRegion();
3236}
3237
3238void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3239 setNameFn(&getRegion().front(), "body_entry");
3240}
3241
3242void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3243 OpAsmSetValueNameFn setNameFn) {
3244 setNameFn(region.getArgument(i: 0), "iv");
3245}
3246
3247void CanonicalLoopOp::print(OpAsmPrinter &p) {
3248 if (getCli())
3249 p << '(' << getCli() << ')';
3250 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3251 << " in range(" << getTripCount() << ") ";
3252
3253 p.printRegion(blocks&: getRegion(), /*printEntryBlockArgs=*/false,
3254 /*printBlockTerminators=*/true);
3255
3256 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
3257}
3258
3259mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3260 ::mlir::OperationState &result) {
3261 CanonicalLoopInfoType cliType =
3262 CanonicalLoopInfoType::get(ctx: parser.getContext());
3263
3264 // Parse (optional) omp.cli identifier
3265 OpAsmParser::UnresolvedOperand cli;
3266 SmallVector<mlir::Value, 1> cliOperand;
3267 if (!parser.parseOptionalLParen()) {
3268 if (parser.parseOperand(result&: cli) ||
3269 parser.resolveOperand(operand: cli, type: cliType, result&: cliOperand) || parser.parseRParen())
3270 return failure();
3271 }
3272
3273 // We derive the type of tripCount from inductionVariable. MLIR requires the
3274 // type of tripCount to be known when calling resolveOperand so we have parse
3275 // the type before processing the inductionVariable.
3276 OpAsmParser::Argument inductionVariable;
3277 OpAsmParser::UnresolvedOperand tripcount;
3278 if (parser.parseArgument(result&: inductionVariable, /*allowType*/ true) ||
3279 parser.parseKeyword(keyword: "in") || parser.parseKeyword(keyword: "range") ||
3280 parser.parseLParen() || parser.parseOperand(result&: tripcount) ||
3281 parser.parseRParen() ||
3282 parser.resolveOperand(operand: tripcount, type: inductionVariable.type, result&: result.operands))
3283 return failure();
3284
3285 // Parse the loop body.
3286 Region *region = result.addRegion();
3287 if (parser.parseRegion(region&: *region, arguments: {inductionVariable}))
3288 return failure();
3289
3290 // We parsed the cli operand forst, but because it is optional, it must be
3291 // last in the operand list.
3292 result.operands.append(RHS: cliOperand);
3293
3294 // Parse the optional attribute list.
3295 if (parser.parseOptionalAttrDict(result&: result.attributes))
3296 return failure();
3297
3298 return mlir::success();
3299}
3300
3301LogicalResult CanonicalLoopOp::verify() {
3302 // The region's entry must accept the induction variable
3303 // It can also be empty if just created
3304 if (!getRegion().empty()) {
3305 Region &region = getRegion();
3306 if (region.getNumArguments() != 1)
3307 return emitOpError(
3308 message: "Canonical loop region must have exactly one argument");
3309
3310 if (getInductionVar().getType() != getTripCount().getType())
3311 return emitOpError(
3312 message: "Region argument must be the same type as the trip count");
3313 }
3314
3315 return success();
3316}
3317
3318Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(i: 0); }
3319
3320std::pair<unsigned, unsigned>
3321CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3322 // No applyees
3323 return {0, 0};
3324}
3325
3326std::pair<unsigned, unsigned>
3327CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3328 return getODSOperandIndexAndLength(index: odsIndex_cli);
3329}
3330
3331//===----------------------------------------------------------------------===//
3332// UnrollHeuristicOp
3333//===----------------------------------------------------------------------===//
3334
3335void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3336 ::mlir::OperationState &odsState,
3337 ::mlir::Value cli) {
3338 odsState.addOperands(newOperands: cli);
3339}
3340
3341void UnrollHeuristicOp::print(OpAsmPrinter &p) {
3342 p << '(' << getApplyee() << ')';
3343
3344 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
3345}
3346
3347mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3348 ::mlir::OperationState &result) {
3349 auto cliType = CanonicalLoopInfoType::get(ctx: parser.getContext());
3350
3351 if (parser.parseLParen())
3352 return failure();
3353
3354 OpAsmParser::UnresolvedOperand applyee;
3355 if (parser.parseOperand(result&: applyee) ||
3356 parser.resolveOperand(operand: applyee, type: cliType, result&: result.operands))
3357 return failure();
3358
3359 if (parser.parseRParen())
3360 return failure();
3361
3362 // Optional output loop (full unrolling has none)
3363 if (!parser.parseOptionalArrow()) {
3364 if (parser.parseLParen() || parser.parseRParen())
3365 return failure();
3366 }
3367
3368 // Parse the optional attribute list.
3369 if (parser.parseOptionalAttrDict(result&: result.attributes))
3370 return failure();
3371
3372 return mlir::success();
3373}
3374
3375std::pair<unsigned, unsigned>
3376UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3377 return getODSOperandIndexAndLength(index: odsIndex_applyee);
3378}
3379
3380std::pair<unsigned, unsigned>
3381UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3382 return {0, 0};
3383}
3384
3385//===----------------------------------------------------------------------===//
3386// Critical construct (2.17.1)
3387//===----------------------------------------------------------------------===//
3388
3389void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3390 const CriticalDeclareOperands &clauses) {
3391 CriticalDeclareOp::build(odsBuilder&: builder, odsState&: state, sym_name: clauses.symName, hint: clauses.hint);
3392}
3393
3394LogicalResult CriticalDeclareOp::verify() {
3395 return verifySynchronizationHint(op: *this, hint: getHint());
3396}
3397
3398LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3399 if (getNameAttr()) {
3400 SymbolRefAttr symbolRef = getNameAttr();
3401 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3402 from: *this, symbol: symbolRef);
3403 if (!decl) {
3404 return emitOpError() << "expected symbol reference " << symbolRef
3405 << " to point to a critical declaration";
3406 }
3407 }
3408
3409 return success();
3410}
3411
3412//===----------------------------------------------------------------------===//
3413// Ordered construct
3414//===----------------------------------------------------------------------===//
3415
3416static LogicalResult verifyOrderedParent(Operation &op) {
3417 bool hasRegion = op.getNumRegions() > 0;
3418 auto loopOp = op.getParentOfType<LoopNestOp>();
3419 if (!loopOp) {
3420 if (hasRegion)
3421 return success();
3422
3423 // TODO: Consider if this needs to be the case only for the standalone
3424 // variant of the ordered construct.
3425 return op.emitOpError() << "must be nested inside of a loop";
3426 }
3427
3428 Operation *wrapper = loopOp->getParentOp();
3429 if (auto wsloopOp = dyn_cast<WsloopOp>(Val: wrapper)) {
3430 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3431 if (!orderedAttr)
3432 return op.emitOpError() << "the enclosing worksharing-loop region must "
3433 "have an ordered clause";
3434
3435 if (hasRegion && orderedAttr.getInt() != 0)
3436 return op.emitOpError() << "the enclosing loop's ordered clause must not "
3437 "have a parameter present";
3438
3439 if (!hasRegion && orderedAttr.getInt() == 0)
3440 return op.emitOpError() << "the enclosing loop's ordered clause must "
3441 "have a parameter present";
3442 } else if (!isa<SimdOp>(Val: wrapper)) {
3443 return op.emitOpError() << "must be nested inside of a worksharing, simd "
3444 "or worksharing simd loop";
3445 }
3446 return success();
3447}
3448
3449void OrderedOp::build(OpBuilder &builder, OperationState &state,
3450 const OrderedOperands &clauses) {
3451 OrderedOp::build(odsBuilder&: builder, odsState&: state, doacross_depend_type: clauses.doacrossDependType,
3452 doacross_num_loops: clauses.doacrossNumLoops, doacross_depend_vars: clauses.doacrossDependVars);
3453}
3454
3455LogicalResult OrderedOp::verify() {
3456 if (failed(Result: verifyOrderedParent(op&: **this)))
3457 return failure();
3458
3459 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3460 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3461 return emitOpError() << "number of variables in depend clause does not "
3462 << "match number of iteration variables in the "
3463 << "doacross loop";
3464
3465 return success();
3466}
3467
3468void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3469 const OrderedRegionOperands &clauses) {
3470 OrderedRegionOp::build(odsBuilder&: builder, odsState&: state, par_level_simd: clauses.parLevelSimd);
3471}
3472
3473LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(op&: **this); }
3474
3475//===----------------------------------------------------------------------===//
3476// TaskwaitOp
3477//===----------------------------------------------------------------------===//
3478
3479void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3480 const TaskwaitOperands &clauses) {
3481 // TODO Store clauses in op: dependKinds, dependVars, nowait.
3482 TaskwaitOp::build(odsBuilder&: builder, odsState&: state, /*depend_kinds=*/nullptr,
3483 /*depend_vars=*/{}, /*nowait=*/nullptr);
3484}
3485
3486//===----------------------------------------------------------------------===//
3487// Verifier for AtomicReadOp
3488//===----------------------------------------------------------------------===//
3489
3490LogicalResult AtomicReadOp::verify() {
3491 if (verifyCommon().failed())
3492 return mlir::failure();
3493
3494 if (auto mo = getMemoryOrder()) {
3495 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3496 *mo == ClauseMemoryOrderKind::Release) {
3497 return emitError(
3498 message: "memory-order must not be acq_rel or release for atomic reads");
3499 }
3500 }
3501 return verifySynchronizationHint(op: *this, hint: getHint());
3502}
3503
3504//===----------------------------------------------------------------------===//
3505// Verifier for AtomicWriteOp
3506//===----------------------------------------------------------------------===//
3507
3508LogicalResult AtomicWriteOp::verify() {
3509 if (verifyCommon().failed())
3510 return mlir::failure();
3511
3512 if (auto mo = getMemoryOrder()) {
3513 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3514 *mo == ClauseMemoryOrderKind::Acquire) {
3515 return emitError(
3516 message: "memory-order must not be acq_rel or acquire for atomic writes");
3517 }
3518 }
3519 return verifySynchronizationHint(op: *this, hint: getHint());
3520}
3521
3522//===----------------------------------------------------------------------===//
3523// Verifier for AtomicUpdateOp
3524//===----------------------------------------------------------------------===//
3525
3526LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3527 PatternRewriter &rewriter) {
3528 if (op.isNoOp()) {
3529 rewriter.eraseOp(op);
3530 return success();
3531 }
3532 if (Value writeVal = op.getWriteOpVal()) {
3533 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3534 op, args: op.getX(), args&: writeVal, args: op.getHintAttr(), args: op.getMemoryOrderAttr());
3535 return success();
3536 }
3537 return failure();
3538}
3539
3540LogicalResult AtomicUpdateOp::verify() {
3541 if (verifyCommon().failed())
3542 return mlir::failure();
3543
3544 if (auto mo = getMemoryOrder()) {
3545 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3546 *mo == ClauseMemoryOrderKind::Acquire) {
3547 return emitError(
3548 message: "memory-order must not be acq_rel or acquire for atomic updates");
3549 }
3550 }
3551
3552 return verifySynchronizationHint(op: *this, hint: getHint());
3553}
3554
3555LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3556
3557//===----------------------------------------------------------------------===//
3558// Verifier for AtomicCaptureOp
3559//===----------------------------------------------------------------------===//
3560
3561AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3562 if (auto op = dyn_cast<AtomicReadOp>(Val: getFirstOp()))
3563 return op;
3564 return dyn_cast<AtomicReadOp>(Val: getSecondOp());
3565}
3566
3567AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3568 if (auto op = dyn_cast<AtomicWriteOp>(Val: getFirstOp()))
3569 return op;
3570 return dyn_cast<AtomicWriteOp>(Val: getSecondOp());
3571}
3572
3573AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3574 if (auto op = dyn_cast<AtomicUpdateOp>(Val: getFirstOp()))
3575 return op;
3576 return dyn_cast<AtomicUpdateOp>(Val: getSecondOp());
3577}
3578
3579LogicalResult AtomicCaptureOp::verify() {
3580 return verifySynchronizationHint(op: *this, hint: getHint());
3581}
3582
3583LogicalResult AtomicCaptureOp::verifyRegions() {
3584 if (verifyRegionsCommon().failed())
3585 return mlir::failure();
3586
3587 if (getFirstOp()->getAttr(name: "hint") || getSecondOp()->getAttr(name: "hint"))
3588 return emitOpError(
3589 message: "operations inside capture region must not have hint clause");
3590
3591 if (getFirstOp()->getAttr(name: "memory_order") ||
3592 getSecondOp()->getAttr(name: "memory_order"))
3593 return emitOpError(
3594 message: "operations inside capture region must not have memory_order clause");
3595 return success();
3596}
3597
3598//===----------------------------------------------------------------------===//
3599// CancelOp
3600//===----------------------------------------------------------------------===//
3601
3602void CancelOp::build(OpBuilder &builder, OperationState &state,
3603 const CancelOperands &clauses) {
3604 CancelOp::build(odsBuilder&: builder, odsState&: state, cancel_directive: clauses.cancelDirective, if_expr: clauses.ifExpr);
3605}
3606
3607static Operation *getParentInSameDialect(Operation *thisOp) {
3608 Operation *parent = thisOp->getParentOp();
3609 while (parent) {
3610 if (parent->getDialect() == thisOp->getDialect())
3611 return parent;
3612 parent = parent->getParentOp();
3613 }
3614 return nullptr;
3615}
3616
3617LogicalResult CancelOp::verify() {
3618 ClauseCancellationConstructType cct = getCancelDirective();
3619 // The next OpenMP operation in the chain of parents
3620 Operation *structuralParent = getParentInSameDialect(thisOp: (*this).getOperation());
3621 if (!structuralParent)
3622 return emitOpError() << "Orphaned cancel construct";
3623
3624 if ((cct == ClauseCancellationConstructType::Parallel) &&
3625 !mlir::isa<ParallelOp>(Val: structuralParent)) {
3626 return emitOpError() << "cancel parallel must appear "
3627 << "inside a parallel region";
3628 }
3629 if (cct == ClauseCancellationConstructType::Loop) {
3630 // structural parent will be omp.loop_nest, directly nested inside
3631 // omp.wsloop
3632 auto wsloopOp = mlir::dyn_cast<WsloopOp>(Val: structuralParent->getParentOp());
3633
3634 if (!wsloopOp) {
3635 return emitOpError()
3636 << "cancel loop must appear inside a worksharing-loop region";
3637 }
3638 if (wsloopOp.getNowaitAttr()) {
3639 return emitError() << "A worksharing construct that is canceled "
3640 << "must not have a nowait clause";
3641 }
3642 if (wsloopOp.getOrderedAttr()) {
3643 return emitError() << "A worksharing construct that is canceled "
3644 << "must not have an ordered clause";
3645 }
3646
3647 } else if (cct == ClauseCancellationConstructType::Sections) {
3648 // structural parent will be an omp.section, directly nested inside
3649 // omp.sections
3650 auto sectionsOp =
3651 mlir::dyn_cast<SectionsOp>(Val: structuralParent->getParentOp());
3652 if (!sectionsOp) {
3653 return emitOpError() << "cancel sections must appear "
3654 << "inside a sections region";
3655 }
3656 if (sectionsOp.getNowait()) {
3657 return emitError() << "A sections construct that is canceled "
3658 << "must not have a nowait clause";
3659 }
3660 }
3661 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3662 (!mlir::isa<omp::TaskOp>(Val: structuralParent) &&
3663 !mlir::isa<omp::TaskloopOp>(Val: structuralParent->getParentOp()))) {
3664 return emitOpError() << "cancel taskgroup must appear "
3665 << "inside a task region";
3666 }
3667 return success();
3668}
3669
3670//===----------------------------------------------------------------------===//
3671// CancellationPointOp
3672//===----------------------------------------------------------------------===//
3673
3674void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
3675 const CancellationPointOperands &clauses) {
3676 CancellationPointOp::build(odsBuilder&: builder, odsState&: state, cancel_directive: clauses.cancelDirective);
3677}
3678
3679LogicalResult CancellationPointOp::verify() {
3680 ClauseCancellationConstructType cct = getCancelDirective();
3681 // The next OpenMP operation in the chain of parents
3682 Operation *structuralParent = getParentInSameDialect(thisOp: (*this).getOperation());
3683 if (!structuralParent)
3684 return emitOpError() << "Orphaned cancellation point";
3685
3686 if ((cct == ClauseCancellationConstructType::Parallel) &&
3687 !mlir::isa<ParallelOp>(Val: structuralParent)) {
3688 return emitOpError() << "cancellation point parallel must appear "
3689 << "inside a parallel region";
3690 }
3691 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3692 // find the wsloop
3693 if ((cct == ClauseCancellationConstructType::Loop) &&
3694 !mlir::isa<WsloopOp>(Val: structuralParent->getParentOp())) {
3695 return emitOpError() << "cancellation point loop must appear "
3696 << "inside a worksharing-loop region";
3697 }
3698 if ((cct == ClauseCancellationConstructType::Sections) &&
3699 !mlir::isa<omp::SectionOp>(Val: structuralParent)) {
3700 return emitOpError() << "cancellation point sections must appear "
3701 << "inside a sections region";
3702 }
3703 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
3704 !mlir::isa<omp::TaskOp>(Val: structuralParent)) {
3705 return emitOpError() << "cancellation point taskgroup must appear "
3706 << "inside a task region";
3707 }
3708 return success();
3709}
3710
3711//===----------------------------------------------------------------------===//
3712// MapBoundsOp
3713//===----------------------------------------------------------------------===//
3714
3715LogicalResult MapBoundsOp::verify() {
3716 auto extent = getExtent();
3717 auto upperbound = getUpperBound();
3718 if (!extent && !upperbound)
3719 return emitError(message: "expected extent or upperbound.");
3720 return success();
3721}
3722
3723void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3724 TypeRange /*result_types*/, StringAttr symName,
3725 TypeAttr type) {
3726 PrivateClauseOp::build(
3727 odsBuilder, odsState, sym_name: symName, type,
3728 data_sharing_type: DataSharingClauseTypeAttr::get(context: odsBuilder.getContext(),
3729 value: DataSharingClauseType::Private));
3730}
3731
3732LogicalResult PrivateClauseOp::verifyRegions() {
3733 Type argType = getArgType();
3734 auto verifyTerminator = [&](Operation *terminator,
3735 bool yieldsValue) -> LogicalResult {
3736 if (!terminator->getBlock()->getSuccessors().empty())
3737 return success();
3738
3739 if (!llvm::isa<YieldOp>(Val: terminator))
3740 return mlir::emitError(loc: terminator->getLoc())
3741 << "expected exit block terminator to be an `omp.yield` op.";
3742
3743 YieldOp yieldOp = llvm::cast<YieldOp>(Val: terminator);
3744 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
3745
3746 if (!yieldsValue) {
3747 if (yieldedTypes.empty())
3748 return success();
3749
3750 return mlir::emitError(loc: terminator->getLoc())
3751 << "Did not expect any values to be yielded.";
3752 }
3753
3754 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
3755 return success();
3756
3757 auto error = mlir::emitError(loc: yieldOp.getLoc())
3758 << "Invalid yielded value. Expected type: " << argType
3759 << ", got: ";
3760
3761 if (yieldedTypes.empty())
3762 error << "None";
3763 else
3764 error << yieldedTypes;
3765
3766 return error;
3767 };
3768
3769 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
3770 StringRef regionName,
3771 bool yieldsValue) -> LogicalResult {
3772 assert(!region.empty());
3773
3774 if (region.getNumArguments() != expectedNumArgs)
3775 return mlir::emitError(loc: region.getLoc())
3776 << "`" << regionName << "`: "
3777 << "expected " << expectedNumArgs
3778 << " region arguments, got: " << region.getNumArguments();
3779
3780 for (Block &block : region) {
3781 // MLIR will verify the absence of the terminator for us.
3782 if (!block.mightHaveTerminator())
3783 continue;
3784
3785 if (failed(Result: verifyTerminator(block.getTerminator(), yieldsValue)))
3786 return failure();
3787 }
3788
3789 return success();
3790 };
3791
3792 // Ensure all of the region arguments have the same type
3793 for (Region *region : getRegions())
3794 for (Type ty : region->getArgumentTypes())
3795 if (ty != argType)
3796 return emitError() << "Region argument type mismatch: got " << ty
3797 << " expected " << argType << ".";
3798
3799 mlir::Region &initRegion = getInitRegion();
3800 if (!initRegion.empty() &&
3801 failed(Result: verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
3802 /*yieldsValue=*/true)))
3803 return failure();
3804
3805 DataSharingClauseType dsType = getDataSharingType();
3806
3807 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
3808 return emitError(message: "`private` clauses do not require a `copy` region.");
3809
3810 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
3811 return emitError(
3812 message: "`firstprivate` clauses require at least a `copy` region.");
3813
3814 if (dsType == DataSharingClauseType::FirstPrivate &&
3815 failed(Result: verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
3816 /*yieldsValue=*/true)))
3817 return failure();
3818
3819 if (!getDeallocRegion().empty() &&
3820 failed(Result: verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
3821 /*yieldsValue=*/false)))
3822 return failure();
3823
3824 return success();
3825}
3826
3827//===----------------------------------------------------------------------===//
3828// Spec 5.2: Masked construct (10.5)
3829//===----------------------------------------------------------------------===//
3830
3831void MaskedOp::build(OpBuilder &builder, OperationState &state,
3832 const MaskedOperands &clauses) {
3833 MaskedOp::build(odsBuilder&: builder, odsState&: state, filtered_thread_id: clauses.filteredThreadId);
3834}
3835
3836//===----------------------------------------------------------------------===//
3837// Spec 5.2: Scan construct (5.6)
3838//===----------------------------------------------------------------------===//
3839
3840void ScanOp::build(OpBuilder &builder, OperationState &state,
3841 const ScanOperands &clauses) {
3842 ScanOp::build(odsBuilder&: builder, odsState&: state, inclusive_vars: clauses.inclusiveVars, exclusive_vars: clauses.exclusiveVars);
3843}
3844
3845LogicalResult ScanOp::verify() {
3846 if (hasExclusiveVars() == hasInclusiveVars())
3847 return emitError(
3848 message: "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
3849 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
3850 if (parentWsLoopOp.getReductionModAttr() &&
3851 parentWsLoopOp.getReductionModAttr().getValue() ==
3852 ReductionModifier::inscan)
3853 return success();
3854 }
3855 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
3856 if (parentSimdOp.getReductionModAttr() &&
3857 parentSimdOp.getReductionModAttr().getValue() ==
3858 ReductionModifier::inscan)
3859 return success();
3860 }
3861 return emitError(message: "SCAN directive needs to be enclosed within a parent "
3862 "worksharing loop construct or SIMD construct with INSCAN "
3863 "reduction modifier");
3864}
3865
3866#define GET_ATTRDEF_CLASSES
3867#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
3868
3869#define GET_OP_CLASSES
3870#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
3871
3872#define GET_TYPEDEF_CLASSES
3873#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
3874

source code of mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp