1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/DialectImplementation.h"
19#include "mlir/IR/OpImplementation.h"
20#include "mlir/IR/OperationSupport.h"
21#include "mlir/Interfaces/FoldInterfaces.h"
22
23#include "llvm/ADT/BitVector.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/STLForwardCompat.h"
26#include "llvm/ADT/SmallString.h"
27#include "llvm/ADT/StringExtras.h"
28#include "llvm/ADT/StringRef.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Frontend/OpenMP/OMPConstants.h"
31#include <cstddef>
32#include <iterator>
33#include <optional>
34
35#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
36#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
37#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
38#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
39#include "mlir/Support/LogicalResult.h"
40
41using namespace mlir;
42using namespace mlir::omp;
43
44static ArrayAttr makeArrayAttr(MLIRContext *context,
45 llvm::ArrayRef<Attribute> attrs) {
46 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
47}
48
49namespace {
50struct MemRefPointerLikeModel
51 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
52 MemRefType> {
53 Type getElementType(Type pointer) const {
54 return llvm::cast<MemRefType>(pointer).getElementType();
55 }
56};
57
58struct LLVMPointerPointerLikeModel
59 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
60 LLVM::LLVMPointerType> {
61 Type getElementType(Type pointer) const { return Type(); }
62};
63
64struct OpenMPDialectFoldInterface : public DialectFoldInterface {
65 using DialectFoldInterface::DialectFoldInterface;
66
67 bool shouldMaterializeInto(Region *region) const final {
68 // Avoid folding constants across target regions
69 return isa<TargetOp>(region->getParentOp());
70 }
71};
72} // namespace
73
74void OpenMPDialect::initialize() {
75 addOperations<
76#define GET_OP_LIST
77#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
78 >();
79 addAttributes<
80#define GET_ATTRDEF_LIST
81#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
82 >();
83 addTypes<
84#define GET_TYPEDEF_LIST
85#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
86 >();
87
88 addInterface<OpenMPDialectFoldInterface>();
89 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
90 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
91 *getContext());
92
93 // Attach default offload module interface to module op to access
94 // offload functionality through
95 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
96 *getContext());
97
98 // Attach default declare target interfaces to operations which can be marked
99 // as declare target (Global Operations and Functions/Subroutines in dialects
100 // that Fortran (or other languages that lower to MLIR) translates too
101 mlir::LLVM::GlobalOp::attachInterface<
102 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::GlobalOp>>(
103 *getContext());
104 mlir::LLVM::LLVMFuncOp::attachInterface<
105 mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::LLVMFuncOp>>(
106 *getContext());
107 mlir::func::FuncOp::attachInterface<
108 mlir::omp::DeclareTargetDefaultModel<mlir::func::FuncOp>>(*getContext());
109}
110
111//===----------------------------------------------------------------------===//
112// Parser and printer for Allocate Clause
113//===----------------------------------------------------------------------===//
114
115/// Parse an allocate clause with allocators and a list of operands with types.
116///
117/// allocate-operand-list :: = allocate-operand |
118/// allocator-operand `,` allocate-operand-list
119/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
120/// ssa-id-and-type ::= ssa-id `:` type
121static ParseResult parseAllocateAndAllocator(
122 OpAsmParser &parser,
123 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocate,
124 SmallVectorImpl<Type> &typesAllocate,
125 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operandsAllocator,
126 SmallVectorImpl<Type> &typesAllocator) {
127
128 return parser.parseCommaSeparatedList([&]() {
129 OpAsmParser::UnresolvedOperand operand;
130 Type type;
131 if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type))
132 return failure();
133 operandsAllocator.push_back(operand);
134 typesAllocator.push_back(Elt: type);
135 if (parser.parseArrow())
136 return failure();
137 if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type))
138 return failure();
139
140 operandsAllocate.push_back(operand);
141 typesAllocate.push_back(Elt: type);
142 return success();
143 });
144}
145
146/// Print allocate clause
147static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op,
148 OperandRange varsAllocate,
149 TypeRange typesAllocate,
150 OperandRange varsAllocator,
151 TypeRange typesAllocator) {
152 for (unsigned i = 0; i < varsAllocate.size(); ++i) {
153 std::string separator = i == varsAllocate.size() - 1 ? "" : ", ";
154 p << varsAllocator[i] << " : " << typesAllocator[i] << " -> ";
155 p << varsAllocate[i] << " : " << typesAllocate[i] << separator;
156 }
157}
158
159//===----------------------------------------------------------------------===//
160// Parser and printer for a clause attribute (StringEnumAttr)
161//===----------------------------------------------------------------------===//
162
163template <typename ClauseAttr>
164static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
165 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
166 StringRef enumStr;
167 SMLoc loc = parser.getCurrentLocation();
168 if (parser.parseKeyword(keyword: &enumStr))
169 return failure();
170 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
171 attr = ClauseAttr::get(parser.getContext(), *enumValue);
172 return success();
173 }
174 return parser.emitError(loc, message: "invalid clause value: '") << enumStr << "'";
175}
176
177template <typename ClauseAttr>
178void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
179 p << stringifyEnum(attr.getValue());
180}
181
182//===----------------------------------------------------------------------===//
183// Parser and printer for Linear Clause
184//===----------------------------------------------------------------------===//
185
186/// linear ::= `linear` `(` linear-list `)`
187/// linear-list := linear-val | linear-val linear-list
188/// linear-val := ssa-id-and-type `=` ssa-id-and-type
189static ParseResult
190parseLinearClause(OpAsmParser &parser,
191 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
192 SmallVectorImpl<Type> &types,
193 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &stepVars) {
194 return parser.parseCommaSeparatedList(parseElementFn: [&]() {
195 OpAsmParser::UnresolvedOperand var;
196 Type type;
197 OpAsmParser::UnresolvedOperand stepVar;
198 if (parser.parseOperand(result&: var) || parser.parseEqual() ||
199 parser.parseOperand(result&: stepVar) || parser.parseColonType(result&: type))
200 return failure();
201
202 vars.push_back(Elt: var);
203 types.push_back(Elt: type);
204 stepVars.push_back(Elt: stepVar);
205 return success();
206 });
207}
208
209/// Print Linear Clause
210static void printLinearClause(OpAsmPrinter &p, Operation *op,
211 ValueRange linearVars, TypeRange linearVarTypes,
212 ValueRange linearStepVars) {
213 size_t linearVarsSize = linearVars.size();
214 for (unsigned i = 0; i < linearVarsSize; ++i) {
215 std::string separator = i == linearVarsSize - 1 ? "" : ", ";
216 p << linearVars[i];
217 if (linearStepVars.size() > i)
218 p << " = " << linearStepVars[i];
219 p << " : " << linearVars[i].getType() << separator;
220 }
221}
222
223//===----------------------------------------------------------------------===//
224// Verifier for Nontemporal Clause
225//===----------------------------------------------------------------------===//
226
227static LogicalResult
228verifyNontemporalClause(Operation *op, OperandRange nontemporalVariables) {
229
230 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
231 DenseSet<Value> nontemporalItems;
232 for (const auto &it : nontemporalVariables)
233 if (!nontemporalItems.insert(V: it).second)
234 return op->emitOpError() << "nontemporal variable used more than once";
235
236 return success();
237}
238
239//===----------------------------------------------------------------------===//
240// Parser, verifier and printer for Aligned Clause
241//===----------------------------------------------------------------------===//
242static LogicalResult
243verifyAlignedClause(Operation *op, std::optional<ArrayAttr> alignmentValues,
244 OperandRange alignedVariables) {
245 // Check if number of alignment values equals to number of aligned variables
246 if (!alignedVariables.empty()) {
247 if (!alignmentValues || alignmentValues->size() != alignedVariables.size())
248 return op->emitOpError()
249 << "expected as many alignment values as aligned variables";
250 } else {
251 if (alignmentValues)
252 return op->emitOpError() << "unexpected alignment values attribute";
253 return success();
254 }
255
256 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
257 DenseSet<Value> alignedItems;
258 for (auto it : alignedVariables)
259 if (!alignedItems.insert(V: it).second)
260 return op->emitOpError() << "aligned variable used more than once";
261
262 if (!alignmentValues)
263 return success();
264
265 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
266 for (unsigned i = 0; i < (*alignmentValues).size(); ++i) {
267 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
268 if (intAttr.getValue().sle(0))
269 return op->emitOpError() << "alignment should be greater than 0";
270 } else {
271 return op->emitOpError() << "expected integer alignment";
272 }
273 }
274
275 return success();
276}
277
278/// aligned ::= `aligned` `(` aligned-list `)`
279/// aligned-list := aligned-val | aligned-val aligned-list
280/// aligned-val := ssa-id-and-type `->` alignment
281static ParseResult parseAlignedClause(
282 OpAsmParser &parser,
283 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &alignedItems,
284 SmallVectorImpl<Type> &types, ArrayAttr &alignmentValues) {
285 SmallVector<Attribute> alignmentVec;
286 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
287 if (parser.parseOperand(result&: alignedItems.emplace_back()) ||
288 parser.parseColonType(result&: types.emplace_back()) ||
289 parser.parseArrow() ||
290 parser.parseAttribute(result&: alignmentVec.emplace_back())) {
291 return failure();
292 }
293 return success();
294 })))
295 return failure();
296 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
297 alignmentValues = ArrayAttr::get(parser.getContext(), alignments);
298 return success();
299}
300
301/// Print Aligned Clause
302static void printAlignedClause(OpAsmPrinter &p, Operation *op,
303 ValueRange alignedVars,
304 TypeRange alignedVarTypes,
305 std::optional<ArrayAttr> alignmentValues) {
306 for (unsigned i = 0; i < alignedVars.size(); ++i) {
307 if (i != 0)
308 p << ", ";
309 p << alignedVars[i] << " : " << alignedVars[i].getType();
310 p << " -> " << (*alignmentValues)[i];
311 }
312}
313
314//===----------------------------------------------------------------------===//
315// Parser, printer and verifier for Schedule Clause
316//===----------------------------------------------------------------------===//
317
318static ParseResult
319verifyScheduleModifiers(OpAsmParser &parser,
320 SmallVectorImpl<SmallString<12>> &modifiers) {
321 if (modifiers.size() > 2)
322 return parser.emitError(loc: parser.getNameLoc()) << " unexpected modifier(s)";
323 for (const auto &mod : modifiers) {
324 // Translate the string. If it has no value, then it was not a valid
325 // modifier!
326 auto symbol = symbolizeScheduleModifier(mod);
327 if (!symbol)
328 return parser.emitError(loc: parser.getNameLoc())
329 << " unknown modifier type: " << mod;
330 }
331
332 // If we have one modifier that is "simd", then stick a "none" modiifer in
333 // index 0.
334 if (modifiers.size() == 1) {
335 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
336 modifiers.push_back(Elt: modifiers[0]);
337 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
338 }
339 } else if (modifiers.size() == 2) {
340 // If there are two modifier:
341 // First modifier should not be simd, second one should be simd
342 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
343 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
344 return parser.emitError(loc: parser.getNameLoc())
345 << " incorrect modifier order";
346 }
347 return success();
348}
349
350/// schedule ::= `schedule` `(` sched-list `)`
351/// sched-list ::= sched-val | sched-val sched-list |
352/// sched-val `,` sched-modifier
353/// sched-val ::= sched-with-chunk | sched-wo-chunk
354/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
355/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
356/// sched-wo-chunk ::= `auto` | `runtime`
357/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
358/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
359static ParseResult parseScheduleClause(
360 OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
361 ScheduleModifierAttr &scheduleModifier, UnitAttr &simdModifier,
362 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize, Type &chunkType) {
363 StringRef keyword;
364 if (parser.parseKeyword(keyword: &keyword))
365 return failure();
366 std::optional<mlir::omp::ClauseScheduleKind> schedule =
367 symbolizeClauseScheduleKind(keyword);
368 if (!schedule)
369 return parser.emitError(loc: parser.getNameLoc()) << " expected schedule kind";
370
371 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
372 switch (*schedule) {
373 case ClauseScheduleKind::Static:
374 case ClauseScheduleKind::Dynamic:
375 case ClauseScheduleKind::Guided:
376 if (succeeded(result: parser.parseOptionalEqual())) {
377 chunkSize = OpAsmParser::UnresolvedOperand{};
378 if (parser.parseOperand(result&: *chunkSize) || parser.parseColonType(result&: chunkType))
379 return failure();
380 } else {
381 chunkSize = std::nullopt;
382 }
383 break;
384 case ClauseScheduleKind::Auto:
385 case ClauseScheduleKind::Runtime:
386 chunkSize = std::nullopt;
387 }
388
389 // If there is a comma, we have one or more modifiers..
390 SmallVector<SmallString<12>> modifiers;
391 while (succeeded(result: parser.parseOptionalComma())) {
392 StringRef mod;
393 if (parser.parseKeyword(keyword: &mod))
394 return failure();
395 modifiers.push_back(Elt: mod);
396 }
397
398 if (verifyScheduleModifiers(parser, modifiers))
399 return failure();
400
401 if (!modifiers.empty()) {
402 SMLoc loc = parser.getCurrentLocation();
403 if (std::optional<ScheduleModifier> mod =
404 symbolizeScheduleModifier(modifiers[0])) {
405 scheduleModifier = ScheduleModifierAttr::get(parser.getContext(), *mod);
406 } else {
407 return parser.emitError(loc, message: "invalid schedule modifier");
408 }
409 // Only SIMD attribute is allowed here!
410 if (modifiers.size() > 1) {
411 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
412 simdModifier = UnitAttr::get(parser.getBuilder().getContext());
413 }
414 }
415
416 return success();
417}
418
419/// Print schedule clause
420static void printScheduleClause(OpAsmPrinter &p, Operation *op,
421 ClauseScheduleKindAttr schedAttr,
422 ScheduleModifierAttr modifier, UnitAttr simd,
423 Value scheduleChunkVar,
424 Type scheduleChunkType) {
425 p << stringifyClauseScheduleKind(schedAttr.getValue());
426 if (scheduleChunkVar)
427 p << " = " << scheduleChunkVar << " : " << scheduleChunkVar.getType();
428 if (modifier)
429 p << ", " << stringifyScheduleModifier(modifier.getValue());
430 if (simd)
431 p << ", simd";
432}
433
434//===----------------------------------------------------------------------===//
435// Parser, printer and verifier for ReductionVarList
436//===----------------------------------------------------------------------===//
437
438ParseResult parseClauseWithRegionArgs(
439 OpAsmParser &parser, Region &region,
440 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
441 SmallVectorImpl<Type> &types, ArrayAttr &symbols,
442 SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs) {
443 SmallVector<SymbolRefAttr> reductionVec;
444 unsigned regionArgOffset = regionPrivateArgs.size();
445
446 if (failed(
447 result: parser.parseCommaSeparatedList(delimiter: OpAsmParser::Delimiter::Paren, parseElementFn: [&]() {
448 if (parser.parseAttribute(reductionVec.emplace_back()) ||
449 parser.parseOperand(result&: operands.emplace_back()) ||
450 parser.parseArrow() ||
451 parser.parseArgument(result&: regionPrivateArgs.emplace_back()) ||
452 parser.parseColonType(result&: types.emplace_back()))
453 return failure();
454 return success();
455 })))
456 return failure();
457
458 auto *argsBegin = regionPrivateArgs.begin();
459 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
460 argsBegin + regionArgOffset + types.size());
461 for (auto [prv, type] : llvm::zip_equal(t&: argsSubrange, u&: types)) {
462 prv.type = type;
463 }
464 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
465 symbols = ArrayAttr::get(parser.getContext(), reductions);
466 return success();
467}
468
469static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
470 ValueRange argsSubrange,
471 StringRef clauseName, ValueRange operands,
472 TypeRange types, ArrayAttr symbols) {
473 p << clauseName << "(";
474 llvm::interleaveComma(
475 llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
476 auto [sym, op, arg, type] = t;
477 p << sym << " " << op << " -> " << arg << " : " << type;
478 });
479 p << ") ";
480}
481
482static ParseResult parseParallelRegion(
483 OpAsmParser &parser, Region &region,
484 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVarOperands,
485 SmallVectorImpl<Type> &reductionVarTypes, ArrayAttr &reductionSymbols,
486 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVarOperands,
487 llvm::SmallVectorImpl<Type> &privateVarsTypes,
488 ArrayAttr &privatizerSymbols) {
489 llvm::SmallVector<OpAsmParser::Argument> regionPrivateArgs;
490
491 if (succeeded(result: parser.parseOptionalKeyword(keyword: "reduction"))) {
492 if (failed(result: parseClauseWithRegionArgs(parser, region, operands&: reductionVarOperands,
493 types&: reductionVarTypes, symbols&: reductionSymbols,
494 regionPrivateArgs)))
495 return failure();
496 }
497
498 if (succeeded(result: parser.parseOptionalKeyword(keyword: "private"))) {
499 if (failed(result: parseClauseWithRegionArgs(parser, region, operands&: privateVarOperands,
500 types&: privateVarsTypes, symbols&: privatizerSymbols,
501 regionPrivateArgs)))
502 return failure();
503 }
504
505 return parser.parseRegion(region, arguments: regionPrivateArgs);
506}
507
508static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
509 ValueRange reductionVarOperands,
510 TypeRange reductionVarTypes,
511 ArrayAttr reductionSymbols,
512 ValueRange privateVarOperands,
513 TypeRange privateVarTypes,
514 ArrayAttr privatizerSymbols) {
515 if (reductionSymbols) {
516 auto *argsBegin = region.front().getArguments().begin();
517 MutableArrayRef argsSubrange(argsBegin,
518 argsBegin + reductionVarTypes.size());
519 printClauseWithRegionArgs(p, op, argsSubrange, "reduction",
520 reductionVarOperands, reductionVarTypes,
521 reductionSymbols);
522 }
523
524 if (privatizerSymbols) {
525 auto *argsBegin = region.front().getArguments().begin();
526 MutableArrayRef argsSubrange(argsBegin + reductionVarOperands.size(),
527 argsBegin + reductionVarOperands.size() +
528 privateVarTypes.size());
529 printClauseWithRegionArgs(p, op, argsSubrange, "private",
530 privateVarOperands, privateVarTypes,
531 privatizerSymbols);
532 }
533
534 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
535}
536
537/// reduction-entry-list ::= reduction-entry
538/// | reduction-entry-list `,` reduction-entry
539/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
540static ParseResult
541parseReductionVarList(OpAsmParser &parser,
542 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
543 SmallVectorImpl<Type> &types,
544 ArrayAttr &redcuctionSymbols) {
545 SmallVector<SymbolRefAttr> reductionVec;
546 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
547 if (parser.parseAttribute(reductionVec.emplace_back()) ||
548 parser.parseArrow() ||
549 parser.parseOperand(result&: operands.emplace_back()) ||
550 parser.parseColonType(result&: types.emplace_back()))
551 return failure();
552 return success();
553 })))
554 return failure();
555 SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
556 redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
557 return success();
558}
559
560/// Print Reduction clause
561static void printReductionVarList(OpAsmPrinter &p, Operation *op,
562 OperandRange reductionVars,
563 TypeRange reductionTypes,
564 std::optional<ArrayAttr> reductions) {
565 for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
566 if (i != 0)
567 p << ", ";
568 p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
569 << reductionVars[i].getType();
570 }
571}
572
573/// Verifies Reduction Clause
574static LogicalResult verifyReductionVarList(Operation *op,
575 std::optional<ArrayAttr> reductions,
576 OperandRange reductionVars) {
577 if (!reductionVars.empty()) {
578 if (!reductions || reductions->size() != reductionVars.size())
579 return op->emitOpError()
580 << "expected as many reduction symbol references "
581 "as reduction variables";
582 } else {
583 if (reductions)
584 return op->emitOpError() << "unexpected reduction symbol references";
585 return success();
586 }
587
588 // TODO: The followings should be done in
589 // SymbolUserOpInterface::verifySymbolUses.
590 DenseSet<Value> accumulators;
591 for (auto args : llvm::zip(reductionVars, *reductions)) {
592 Value accum = std::get<0>(args);
593
594 if (!accumulators.insert(accum).second)
595 return op->emitOpError() << "accumulator variable used more than once";
596
597 Type varType = accum.getType();
598 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
599 auto decl =
600 SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef);
601 if (!decl)
602 return op->emitOpError() << "expected symbol reference " << symbolRef
603 << " to point to a reduction declaration";
604
605 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
606 return op->emitOpError()
607 << "expected accumulator (" << varType
608 << ") to be the same type as reduction declaration ("
609 << decl.getAccumulatorType() << ")";
610 }
611
612 return success();
613}
614
615//===----------------------------------------------------------------------===//
616// Parser, printer and verifier for CopyPrivateVarList
617//===----------------------------------------------------------------------===//
618
619/// copyprivate-entry-list ::= copyprivate-entry
620/// | copyprivate-entry-list `,` copyprivate-entry
621/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
622static ParseResult parseCopyPrivateVarList(
623 OpAsmParser &parser,
624 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
625 SmallVectorImpl<Type> &types, ArrayAttr &copyPrivateSymbols) {
626 SmallVector<SymbolRefAttr> copyPrivateFuncsVec;
627 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
628 if (parser.parseOperand(result&: operands.emplace_back()) ||
629 parser.parseArrow() ||
630 parser.parseAttribute(copyPrivateFuncsVec.emplace_back()) ||
631 parser.parseColonType(result&: types.emplace_back()))
632 return failure();
633 return success();
634 })))
635 return failure();
636 SmallVector<Attribute> copyPrivateFuncs(copyPrivateFuncsVec.begin(),
637 copyPrivateFuncsVec.end());
638 copyPrivateSymbols = ArrayAttr::get(parser.getContext(), copyPrivateFuncs);
639 return success();
640}
641
642/// Print CopyPrivate clause
643static void printCopyPrivateVarList(OpAsmPrinter &p, Operation *op,
644 OperandRange copyPrivateVars,
645 TypeRange copyPrivateTypes,
646 std::optional<ArrayAttr> copyPrivateFuncs) {
647 if (!copyPrivateFuncs.has_value())
648 return;
649 llvm::interleaveComma(
650 llvm::zip(copyPrivateVars, *copyPrivateFuncs, copyPrivateTypes), p,
651 [&](const auto &args) {
652 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
653 << std::get<2>(args);
654 });
655}
656
657/// Verifies CopyPrivate Clause
658static LogicalResult
659verifyCopyPrivateVarList(Operation *op, OperandRange copyPrivateVars,
660 std::optional<ArrayAttr> copyPrivateFuncs) {
661 size_t copyPrivateFuncsSize =
662 copyPrivateFuncs.has_value() ? copyPrivateFuncs->size() : 0;
663 if (copyPrivateFuncsSize != copyPrivateVars.size())
664 return op->emitOpError() << "inconsistent number of copyPrivate vars (= "
665 << copyPrivateVars.size()
666 << ") and functions (= " << copyPrivateFuncsSize
667 << "), both must be equal";
668 if (!copyPrivateFuncs.has_value())
669 return success();
670
671 for (auto copyPrivateVarAndFunc :
672 llvm::zip(copyPrivateVars, *copyPrivateFuncs)) {
673 auto symbolRef =
674 llvm::cast<SymbolRefAttr>(std::get<1>(copyPrivateVarAndFunc));
675 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
676 funcOp;
677 if (mlir::func::FuncOp mlirFuncOp =
678 SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op,
679 symbolRef))
680 funcOp = mlirFuncOp;
681 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
682 SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>(
683 op, symbolRef))
684 funcOp = llvmFuncOp;
685
686 auto getNumArguments = [&] {
687 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
688 };
689
690 auto getArgumentType = [&](unsigned i) {
691 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
692 *funcOp);
693 };
694
695 if (!funcOp)
696 return op->emitOpError() << "expected symbol reference " << symbolRef
697 << " to point to a copy function";
698
699 if (getNumArguments() != 2)
700 return op->emitOpError()
701 << "expected copy function " << symbolRef << " to have 2 operands";
702
703 Type argTy = getArgumentType(0);
704 if (argTy != getArgumentType(1))
705 return op->emitOpError() << "expected copy function " << symbolRef
706 << " arguments to have the same type";
707
708 Type varType = std::get<0>(copyPrivateVarAndFunc).getType();
709 if (argTy != varType)
710 return op->emitOpError()
711 << "expected copy function arguments' type (" << argTy
712 << ") to be the same as copyprivate variable's type (" << varType
713 << ")";
714 }
715
716 return success();
717}
718
719//===----------------------------------------------------------------------===//
720// Parser, printer and verifier for DependVarList
721//===----------------------------------------------------------------------===//
722
723/// depend-entry-list ::= depend-entry
724/// | depend-entry-list `,` depend-entry
725/// depend-entry ::= depend-kind `->` ssa-id `:` type
726static ParseResult
727parseDependVarList(OpAsmParser &parser,
728 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
729 SmallVectorImpl<Type> &types, ArrayAttr &dependsArray) {
730 SmallVector<ClauseTaskDependAttr> dependVec;
731 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
732 StringRef keyword;
733 if (parser.parseKeyword(keyword: &keyword) || parser.parseArrow() ||
734 parser.parseOperand(result&: operands.emplace_back()) ||
735 parser.parseColonType(result&: types.emplace_back()))
736 return failure();
737 if (std::optional<ClauseTaskDepend> keywordDepend =
738 (symbolizeClauseTaskDepend(keyword)))
739 dependVec.emplace_back(
740 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
741 else
742 return failure();
743 return success();
744 })))
745 return failure();
746 SmallVector<Attribute> depends(dependVec.begin(), dependVec.end());
747 dependsArray = ArrayAttr::get(parser.getContext(), depends);
748 return success();
749}
750
751/// Print Depend clause
752static void printDependVarList(OpAsmPrinter &p, Operation *op,
753 OperandRange dependVars, TypeRange dependTypes,
754 std::optional<ArrayAttr> depends) {
755
756 for (unsigned i = 0, e = depends->size(); i < e; ++i) {
757 if (i != 0)
758 p << ", ";
759 p << stringifyClauseTaskDepend(
760 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*depends)[i])
761 .getValue())
762 << " -> " << dependVars[i] << " : " << dependTypes[i];
763 }
764}
765
766/// Verifies Depend clause
767static LogicalResult verifyDependVarList(Operation *op,
768 std::optional<ArrayAttr> depends,
769 OperandRange dependVars) {
770 if (!dependVars.empty()) {
771 if (!depends || depends->size() != dependVars.size())
772 return op->emitOpError() << "expected as many depend values"
773 " as depend variables";
774 } else {
775 if (depends && !depends->empty())
776 return op->emitOpError() << "unexpected depend values";
777 return success();
778 }
779
780 return success();
781}
782
783//===----------------------------------------------------------------------===//
784// Parser, printer and verifier for Synchronization Hint (2.17.12)
785//===----------------------------------------------------------------------===//
786
787/// Parses a Synchronization Hint clause. The value of hint is an integer
788/// which is a combination of different hints from `omp_sync_hint_t`.
789///
790/// hint-clause = `hint` `(` hint-value `)`
791static ParseResult parseSynchronizationHint(OpAsmParser &parser,
792 IntegerAttr &hintAttr) {
793 StringRef hintKeyword;
794 int64_t hint = 0;
795 if (succeeded(result: parser.parseOptionalKeyword(keyword: "none"))) {
796 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
797 return success();
798 }
799 auto parseKeyword = [&]() -> ParseResult {
800 if (failed(result: parser.parseKeyword(keyword: &hintKeyword)))
801 return failure();
802 if (hintKeyword == "uncontended")
803 hint |= 1;
804 else if (hintKeyword == "contended")
805 hint |= 2;
806 else if (hintKeyword == "nonspeculative")
807 hint |= 4;
808 else if (hintKeyword == "speculative")
809 hint |= 8;
810 else
811 return parser.emitError(loc: parser.getCurrentLocation())
812 << hintKeyword << " is not a valid hint";
813 return success();
814 };
815 if (parser.parseCommaSeparatedList(parseElementFn: parseKeyword))
816 return failure();
817 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
818 return success();
819}
820
821/// Prints a Synchronization Hint clause
822static void printSynchronizationHint(OpAsmPrinter &p, Operation *op,
823 IntegerAttr hintAttr) {
824 int64_t hint = hintAttr.getInt();
825
826 if (hint == 0) {
827 p << "none";
828 return;
829 }
830
831 // Helper function to get n-th bit from the right end of `value`
832 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
833
834 bool uncontended = bitn(hint, 0);
835 bool contended = bitn(hint, 1);
836 bool nonspeculative = bitn(hint, 2);
837 bool speculative = bitn(hint, 3);
838
839 SmallVector<StringRef> hints;
840 if (uncontended)
841 hints.push_back(Elt: "uncontended");
842 if (contended)
843 hints.push_back(Elt: "contended");
844 if (nonspeculative)
845 hints.push_back(Elt: "nonspeculative");
846 if (speculative)
847 hints.push_back(Elt: "speculative");
848
849 llvm::interleaveComma(c: hints, os&: p);
850}
851
852/// Verifies a synchronization hint clause
853static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
854
855 // Helper function to get n-th bit from the right end of `value`
856 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
857
858 bool uncontended = bitn(hint, 0);
859 bool contended = bitn(hint, 1);
860 bool nonspeculative = bitn(hint, 2);
861 bool speculative = bitn(hint, 3);
862
863 if (uncontended && contended)
864 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
865 "omp_sync_hint_contended cannot be combined";
866 if (nonspeculative && speculative)
867 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
868 "omp_sync_hint_speculative cannot be combined.";
869 return success();
870}
871
872//===----------------------------------------------------------------------===//
873// Parser, printer and verifier for Target
874//===----------------------------------------------------------------------===//
875
876// Helper function to get bitwise AND of `value` and 'flag'
877uint64_t mapTypeToBitFlag(uint64_t value,
878 llvm::omp::OpenMPOffloadMappingFlags flag) {
879 return value & llvm::to_underlying(E: flag);
880}
881
882/// Parses a map_entries map type from a string format back into its numeric
883/// value.
884///
885/// map-clause = `map_clauses ( ( `(` `always, `? `close, `? `present, `? (
886/// `to` | `from` | `delete` `)` )+ `)` )
887static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) {
888 llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
889 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
890
891 // This simply verifies the correct keyword is read in, the
892 // keyword itself is stored inside of the operation
893 auto parseTypeAndMod = [&]() -> ParseResult {
894 StringRef mapTypeMod;
895 if (parser.parseKeyword(keyword: &mapTypeMod))
896 return failure();
897
898 if (mapTypeMod == "always")
899 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
900
901 if (mapTypeMod == "implicit")
902 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
903
904 if (mapTypeMod == "close")
905 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE;
906
907 if (mapTypeMod == "present")
908 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT;
909
910 if (mapTypeMod == "to")
911 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
912
913 if (mapTypeMod == "from")
914 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
915
916 if (mapTypeMod == "tofrom")
917 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
918 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
919
920 if (mapTypeMod == "delete")
921 mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
922
923 return success();
924 };
925
926 if (parser.parseCommaSeparatedList(parseElementFn: parseTypeAndMod))
927 return failure();
928
929 mapType = parser.getBuilder().getIntegerAttr(
930 parser.getBuilder().getIntegerType(64, /*isSigned=*/false),
931 llvm::to_underlying(E: mapTypeBits));
932
933 return success();
934}
935
936/// Prints a map_entries map type from its numeric value out into its string
937/// format.
938static void printMapClause(OpAsmPrinter &p, Operation *op,
939 IntegerAttr mapType) {
940 uint64_t mapTypeBits = mapType.getUInt();
941
942 bool emitAllocRelease = true;
943 llvm::SmallVector<std::string, 4> mapTypeStrs;
944
945 // handling of always, close, present placed at the beginning of the string
946 // to aid readability
947 if (mapTypeToBitFlag(value: mapTypeBits,
948 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS))
949 mapTypeStrs.push_back(Elt: "always");
950 if (mapTypeToBitFlag(value: mapTypeBits,
951 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
952 mapTypeStrs.push_back(Elt: "implicit");
953 if (mapTypeToBitFlag(value: mapTypeBits,
954 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE))
955 mapTypeStrs.push_back(Elt: "close");
956 if (mapTypeToBitFlag(value: mapTypeBits,
957 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT))
958 mapTypeStrs.push_back(Elt: "present");
959
960 // special handling of to/from/tofrom/delete and release/alloc, release +
961 // alloc are the abscense of one of the other flags, whereas tofrom requires
962 // both the to and from flag to be set.
963 bool to = mapTypeToBitFlag(value: mapTypeBits,
964 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
965 bool from = mapTypeToBitFlag(
966 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
967 if (to && from) {
968 emitAllocRelease = false;
969 mapTypeStrs.push_back(Elt: "tofrom");
970 } else if (from) {
971 emitAllocRelease = false;
972 mapTypeStrs.push_back(Elt: "from");
973 } else if (to) {
974 emitAllocRelease = false;
975 mapTypeStrs.push_back(Elt: "to");
976 }
977 if (mapTypeToBitFlag(value: mapTypeBits,
978 flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) {
979 emitAllocRelease = false;
980 mapTypeStrs.push_back(Elt: "delete");
981 }
982 if (emitAllocRelease)
983 mapTypeStrs.push_back(Elt: "exit_release_or_enter_alloc");
984
985 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
986 p << mapTypeStrs[i];
987 if (i + 1 < mapTypeStrs.size()) {
988 p << ", ";
989 }
990 }
991}
992
993static ParseResult
994parseMapEntries(OpAsmParser &parser,
995 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapOperands,
996 SmallVectorImpl<Type> &mapOperandTypes) {
997 OpAsmParser::UnresolvedOperand arg;
998 OpAsmParser::UnresolvedOperand blockArg;
999 Type argType;
1000 auto parseEntries = [&]() -> ParseResult {
1001 if (parser.parseOperand(result&: arg) || parser.parseArrow() ||
1002 parser.parseOperand(result&: blockArg))
1003 return failure();
1004 mapOperands.push_back(Elt: arg);
1005 return success();
1006 };
1007
1008 auto parseTypes = [&]() -> ParseResult {
1009 if (parser.parseType(result&: argType))
1010 return failure();
1011 mapOperandTypes.push_back(Elt: argType);
1012 return success();
1013 };
1014
1015 if (parser.parseCommaSeparatedList(parseElementFn: parseEntries))
1016 return failure();
1017
1018 if (parser.parseColon())
1019 return failure();
1020
1021 if (parser.parseCommaSeparatedList(parseElementFn: parseTypes))
1022 return failure();
1023
1024 return success();
1025}
1026
1027static void printMapEntries(OpAsmPrinter &p, Operation *op,
1028 OperandRange mapOperands,
1029 TypeRange mapOperandTypes) {
1030 auto &region = op->getRegion(index: 0);
1031 unsigned argIndex = 0;
1032
1033 for (const auto &mapOp : mapOperands) {
1034 const auto &blockArg = region.front().getArgument(i: argIndex);
1035 p << mapOp << " -> " << blockArg;
1036 argIndex++;
1037 if (argIndex < mapOperands.size())
1038 p << ", ";
1039 }
1040 p << " : ";
1041
1042 argIndex = 0;
1043 for (const auto &mapType : mapOperandTypes) {
1044 p << mapType;
1045 argIndex++;
1046 if (argIndex < mapOperands.size())
1047 p << ", ";
1048 }
1049}
1050
1051static void printCaptureType(OpAsmPrinter &p, Operation *op,
1052 VariableCaptureKindAttr mapCaptureType) {
1053 std::string typeCapStr;
1054 llvm::raw_string_ostream typeCap(typeCapStr);
1055 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1056 typeCap << "ByRef";
1057 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1058 typeCap << "ByCopy";
1059 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1060 typeCap << "VLAType";
1061 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1062 typeCap << "This";
1063 p << typeCap.str();
1064}
1065
1066static ParseResult parseCaptureType(OpAsmParser &parser,
1067 VariableCaptureKindAttr &mapCapture) {
1068 StringRef mapCaptureKey;
1069 if (parser.parseKeyword(keyword: &mapCaptureKey))
1070 return failure();
1071
1072 if (mapCaptureKey == "This")
1073 mapCapture = mlir::omp::VariableCaptureKindAttr::get(
1074 parser.getContext(), mlir::omp::VariableCaptureKind::This);
1075 if (mapCaptureKey == "ByRef")
1076 mapCapture = mlir::omp::VariableCaptureKindAttr::get(
1077 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1078 if (mapCaptureKey == "ByCopy")
1079 mapCapture = mlir::omp::VariableCaptureKindAttr::get(
1080 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1081 if (mapCaptureKey == "VLAType")
1082 mapCapture = mlir::omp::VariableCaptureKindAttr::get(
1083 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1084
1085 return success();
1086}
1087
1088static LogicalResult verifyMapClause(Operation *op, OperandRange mapOperands) {
1089 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars;
1090 llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars;
1091
1092 for (auto mapOp : mapOperands) {
1093 if (!mapOp.getDefiningOp())
1094 emitError(loc: op->getLoc(), message: "missing map operation");
1095
1096 if (auto mapInfoOp =
1097 mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) {
1098 if (!mapInfoOp.getMapType().has_value())
1099 emitError(loc: op->getLoc(), message: "missing map type for map operand");
1100
1101 if (!mapInfoOp.getMapCaptureType().has_value())
1102 emitError(loc: op->getLoc(), message: "missing map capture type for map operand");
1103
1104 uint64_t mapTypeBits = mapInfoOp.getMapType().value();
1105
1106 bool to = mapTypeToBitFlag(
1107 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO);
1108 bool from = mapTypeToBitFlag(
1109 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM);
1110 bool del = mapTypeToBitFlag(
1111 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE);
1112
1113 bool always = mapTypeToBitFlag(
1114 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS);
1115 bool close = mapTypeToBitFlag(
1116 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE);
1117 bool implicit = mapTypeToBitFlag(
1118 value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT);
1119
1120 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
1121 return emitError(loc: op->getLoc(),
1122 message: "to, from, tofrom and alloc map types are permitted");
1123
1124 if (isa<TargetEnterDataOp>(op) && (from || del))
1125 return emitError(loc: op->getLoc(), message: "to and alloc map types are permitted");
1126
1127 if (isa<TargetExitDataOp>(op) && to)
1128 return emitError(loc: op->getLoc(),
1129 message: "from, release and delete map types are permitted");
1130
1131 if (isa<TargetUpdateOp>(op)) {
1132 if (del) {
1133 return emitError(loc: op->getLoc(),
1134 message: "at least one of to or from map types must be "
1135 "specified, other map types are not permitted");
1136 }
1137
1138 if (!to && !from) {
1139 return emitError(loc: op->getLoc(),
1140 message: "at least one of to or from map types must be "
1141 "specified, other map types are not permitted");
1142 }
1143
1144 auto updateVar = mapInfoOp.getVarPtr();
1145
1146 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
1147 (from && updateToVars.contains(updateVar))) {
1148 return emitError(
1149 loc: op->getLoc(),
1150 message: "either to or from map types can be specified, not both");
1151 }
1152
1153 if (always || close || implicit) {
1154 return emitError(
1155 loc: op->getLoc(),
1156 message: "present, mapper and iterator map type modifiers are permitted");
1157 }
1158
1159 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
1160 }
1161 } else {
1162 emitError(loc: op->getLoc(), message: "map argument is not a map entry operation");
1163 }
1164 }
1165
1166 return success();
1167}
1168
1169//===----------------------------------------------------------------------===//
1170// TargetDataOp
1171//===----------------------------------------------------------------------===//
1172
1173void TargetDataOp::build(OpBuilder &builder, OperationState &state,
1174 const TargetDataClauseOps &clauses) {
1175 TargetDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1176 clauses.useDevicePtrVars, clauses.useDeviceAddrVars,
1177 clauses.mapVars);
1178}
1179
1180LogicalResult TargetDataOp::verify() {
1181 if (getMapOperands().empty() && getUseDevicePtr().empty() &&
1182 getUseDeviceAddr().empty()) {
1183 return ::emitError(this->getLoc(), "At least one of map, useDevicePtr, or "
1184 "useDeviceAddr operand must be present");
1185 }
1186 return verifyMapClause(*this, getMapOperands());
1187}
1188
1189//===----------------------------------------------------------------------===//
1190// TargetEnterDataOp
1191//===----------------------------------------------------------------------===//
1192
1193void TargetEnterDataOp::build(
1194 OpBuilder &builder, OperationState &state,
1195 const TargetEnterExitUpdateDataClauseOps &clauses) {
1196 MLIRContext *ctx = builder.getContext();
1197 TargetEnterDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1198 makeArrayAttr(ctx, clauses.dependTypeAttrs),
1199 clauses.dependVars, clauses.nowaitAttr,
1200 clauses.mapVars);
1201}
1202
1203LogicalResult TargetEnterDataOp::verify() {
1204 LogicalResult verifyDependVars =
1205 verifyDependVarList(*this, getDepends(), getDependVars());
1206 return failed(verifyDependVars) ? verifyDependVars
1207 : verifyMapClause(*this, getMapOperands());
1208}
1209
1210//===----------------------------------------------------------------------===//
1211// TargetExitDataOp
1212//===----------------------------------------------------------------------===//
1213
1214void TargetExitDataOp::build(
1215 OpBuilder &builder, OperationState &state,
1216 const TargetEnterExitUpdateDataClauseOps &clauses) {
1217 MLIRContext *ctx = builder.getContext();
1218 TargetExitDataOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1219 makeArrayAttr(ctx, clauses.dependTypeAttrs),
1220 clauses.dependVars, clauses.nowaitAttr,
1221 clauses.mapVars);
1222}
1223
1224LogicalResult TargetExitDataOp::verify() {
1225 LogicalResult verifyDependVars =
1226 verifyDependVarList(*this, getDepends(), getDependVars());
1227 return failed(verifyDependVars) ? verifyDependVars
1228 : verifyMapClause(*this, getMapOperands());
1229}
1230
1231//===----------------------------------------------------------------------===//
1232// TargetUpdateOp
1233//===----------------------------------------------------------------------===//
1234
1235void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
1236 const TargetEnterExitUpdateDataClauseOps &clauses) {
1237 MLIRContext *ctx = builder.getContext();
1238 TargetUpdateOp::build(builder, state, clauses.ifVar, clauses.deviceVar,
1239 makeArrayAttr(ctx, clauses.dependTypeAttrs),
1240 clauses.dependVars, clauses.nowaitAttr,
1241 clauses.mapVars);
1242}
1243
1244LogicalResult TargetUpdateOp::verify() {
1245 LogicalResult verifyDependVars =
1246 verifyDependVarList(*this, getDepends(), getDependVars());
1247 return failed(verifyDependVars) ? verifyDependVars
1248 : verifyMapClause(*this, getMapOperands());
1249}
1250
1251//===----------------------------------------------------------------------===//
1252// TargetOp
1253//===----------------------------------------------------------------------===//
1254
1255void TargetOp::build(OpBuilder &builder, OperationState &state,
1256 const TargetClauseOps &clauses) {
1257 MLIRContext *ctx = builder.getContext();
1258 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
1259 // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
1260 // reductionByRefAttr, reductionDeclSymbols.
1261 TargetOp::build(
1262 builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
1263 makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
1264 clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
1265 clauses.mapVars);
1266}
1267
1268LogicalResult TargetOp::verify() {
1269 LogicalResult verifyDependVars =
1270 verifyDependVarList(*this, getDepends(), getDependVars());
1271 return failed(verifyDependVars) ? verifyDependVars
1272 : verifyMapClause(*this, getMapOperands());
1273}
1274
1275//===----------------------------------------------------------------------===//
1276// ParallelOp
1277//===----------------------------------------------------------------------===//
1278
1279void ParallelOp::build(OpBuilder &builder, OperationState &state,
1280 ArrayRef<NamedAttribute> attributes) {
1281 ParallelOp::build(
1282 builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr,
1283 /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(),
1284 /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr,
1285 /*proc_bind_val=*/nullptr, /*private_vars=*/ValueRange(),
1286 /*privatizers=*/nullptr, /*byref=*/false);
1287 state.addAttributes(attributes);
1288}
1289
1290void ParallelOp::build(OpBuilder &builder, OperationState &state,
1291 const ParallelClauseOps &clauses) {
1292 MLIRContext *ctx = builder.getContext();
1293 ParallelOp::build(
1294 builder, state, clauses.ifVar, clauses.numThreadsVar,
1295 clauses.allocateVars, clauses.allocatorVars, clauses.reductionVars,
1296 makeArrayAttr(ctx, clauses.reductionDeclSymbols),
1297 clauses.procBindKindAttr, clauses.privateVars,
1298 makeArrayAttr(ctx, clauses.privatizers), clauses.reductionByRefAttr);
1299}
1300
1301template <typename OpType>
1302static LogicalResult verifyPrivateVarList(OpType &op) {
1303 auto privateVars = op.getPrivateVars();
1304 auto privatizers = op.getPrivatizersAttr();
1305
1306 if (privateVars.empty() && (privatizers == nullptr || privatizers.empty()))
1307 return success();
1308
1309 auto numPrivateVars = privateVars.size();
1310 auto numPrivatizers = (privatizers == nullptr) ? 0 : privatizers.size();
1311
1312 if (numPrivateVars != numPrivatizers)
1313 return op.emitError() << "inconsistent number of private variables and "
1314 "privatizer op symbols, private vars: "
1315 << numPrivateVars
1316 << " vs. privatizer op symbols: " << numPrivatizers;
1317
1318 for (auto privateVarInfo : llvm::zip_equal(privateVars, privatizers)) {
1319 Type varType = std::get<0>(privateVarInfo).getType();
1320 SymbolRefAttr privatizerSym =
1321 cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
1322 PrivateClauseOp privatizerOp =
1323 SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op,
1324 privatizerSym);
1325
1326 if (privatizerOp == nullptr)
1327 return op.emitError() << "failed to lookup privatizer op with symbol: '"
1328 << privatizerSym << "'";
1329
1330 Type privatizerType = privatizerOp.getType();
1331
1332 if (varType != privatizerType)
1333 return op.emitError()
1334 << "type mismatch between a "
1335 << (privatizerOp.getDataSharingType() ==
1336 DataSharingClauseType::Private
1337 ? "private"
1338 : "firstprivate")
1339 << " variable and its privatizer op, var type: " << varType
1340 << " vs. privatizer op type: " << privatizerType;
1341 }
1342
1343 return success();
1344}
1345
1346LogicalResult ParallelOp::verify() {
1347 // Check that it is a valid loop wrapper if it's taking that role.
1348 if (isa<DistributeOp>((*this)->getParentOp())) {
1349 if (!isWrapper())
1350 return emitOpError() << "must take a loop wrapper role if nested inside "
1351 "of 'omp.distribute'";
1352
1353 if (LoopWrapperInterface nested = getNestedWrapper()) {
1354 // Check for the allowed leaf constructs that may appear in a composite
1355 // construct directly after PARALLEL.
1356 if (!isa<WsloopOp>(nested))
1357 return emitError() << "only supported nested wrapper is 'omp.wsloop'";
1358 } else {
1359 return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
1360 }
1361 }
1362
1363 if (getAllocateVars().size() != getAllocatorsVars().size())
1364 return emitError(
1365 "expected equal sizes for allocate and allocator variables");
1366
1367 if (failed(verifyPrivateVarList(*this)))
1368 return failure();
1369
1370 return verifyReductionVarList(*this, getReductions(), getReductionVars());
1371}
1372
1373//===----------------------------------------------------------------------===//
1374// TeamsOp
1375//===----------------------------------------------------------------------===//
1376
1377static bool opInGlobalImplicitParallelRegion(Operation *op) {
1378 while ((op = op->getParentOp()))
1379 if (isa<OpenMPDialect>(op->getDialect()))
1380 return false;
1381 return true;
1382}
1383
1384void TeamsOp::build(OpBuilder &builder, OperationState &state,
1385 const TeamsClauseOps &clauses) {
1386 MLIRContext *ctx = builder.getContext();
1387 // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
1388 TeamsOp::build(builder, state, clauses.numTeamsLowerVar,
1389 clauses.numTeamsUpperVar, clauses.ifVar,
1390 clauses.threadLimitVar, clauses.allocateVars,
1391 clauses.allocatorVars, clauses.reductionVars,
1392 makeArrayAttr(ctx, clauses.reductionDeclSymbols));
1393}
1394
1395LogicalResult TeamsOp::verify() {
1396 // Check parent region
1397 // TODO If nested inside of a target region, also check that it does not
1398 // contain any statements, declarations or directives other than this
1399 // omp.teams construct. The issue is how to support the initialization of
1400 // this operation's own arguments (allow SSA values across omp.target?).
1401 Operation *op = getOperation();
1402 if (!isa<TargetOp>(op->getParentOp()) &&
1403 !opInGlobalImplicitParallelRegion(op))
1404 return emitError("expected to be nested inside of omp.target or not nested "
1405 "in any OpenMP dialect operations");
1406
1407 // Check for num_teams clause restrictions
1408 if (auto numTeamsLowerBound = getNumTeamsLower()) {
1409 auto numTeamsUpperBound = getNumTeamsUpper();
1410 if (!numTeamsUpperBound)
1411 return emitError("expected num_teams upper bound to be defined if the "
1412 "lower bound is defined");
1413 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
1414 return emitError(
1415 "expected num_teams upper bound and lower bound to be the same type");
1416 }
1417
1418 // Check for allocate clause restrictions
1419 if (getAllocateVars().size() != getAllocatorsVars().size())
1420 return emitError(
1421 "expected equal sizes for allocate and allocator variables");
1422
1423 return verifyReductionVarList(*this, getReductions(), getReductionVars());
1424}
1425
1426//===----------------------------------------------------------------------===//
1427// SectionsOp
1428//===----------------------------------------------------------------------===//
1429
1430void SectionsOp::build(OpBuilder &builder, OperationState &state,
1431 const SectionsClauseOps &clauses) {
1432 MLIRContext *ctx = builder.getContext();
1433 // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
1434 SectionsOp::build(builder, state, clauses.reductionVars,
1435 makeArrayAttr(ctx, clauses.reductionDeclSymbols),
1436 clauses.allocateVars, clauses.allocatorVars,
1437 clauses.nowaitAttr);
1438}
1439
1440LogicalResult SectionsOp::verify() {
1441 if (getAllocateVars().size() != getAllocatorsVars().size())
1442 return emitError(
1443 "expected equal sizes for allocate and allocator variables");
1444
1445 return verifyReductionVarList(*this, getReductions(), getReductionVars());
1446}
1447
1448LogicalResult SectionsOp::verifyRegions() {
1449 for (auto &inst : *getRegion().begin()) {
1450 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
1451 return emitOpError()
1452 << "expected omp.section op or terminator op inside region";
1453 }
1454 }
1455
1456 return success();
1457}
1458
1459//===----------------------------------------------------------------------===//
1460// SingleOp
1461//===----------------------------------------------------------------------===//
1462
1463void SingleOp::build(OpBuilder &builder, OperationState &state,
1464 const SingleClauseOps &clauses) {
1465 MLIRContext *ctx = builder.getContext();
1466 // TODO Store clauses in op: privateVars, privatizers.
1467 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
1468 clauses.copyprivateVars,
1469 makeArrayAttr(ctx, clauses.copyprivateFuncs),
1470 clauses.nowaitAttr);
1471}
1472
1473LogicalResult SingleOp::verify() {
1474 // Check for allocate clause restrictions
1475 if (getAllocateVars().size() != getAllocatorsVars().size())
1476 return emitError(
1477 "expected equal sizes for allocate and allocator variables");
1478
1479 return verifyCopyPrivateVarList(*this, getCopyprivateVars(),
1480 getCopyprivateFuncs());
1481}
1482
1483//===----------------------------------------------------------------------===//
1484// WsloopOp
1485//===----------------------------------------------------------------------===//
1486
1487ParseResult
1488parseWsloop(OpAsmParser &parser, Region &region,
1489 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionOperands,
1490 SmallVectorImpl<Type> &reductionTypes,
1491 ArrayAttr &reductionSymbols) {
1492 // Parse an optional reduction clause
1493 llvm::SmallVector<OpAsmParser::Argument> privates;
1494 if (succeeded(result: parser.parseOptionalKeyword(keyword: "reduction"))) {
1495 if (failed(result: parseClauseWithRegionArgs(parser, region, operands&: reductionOperands,
1496 types&: reductionTypes, symbols&: reductionSymbols,
1497 regionPrivateArgs&: privates)))
1498 return failure();
1499 }
1500 return parser.parseRegion(region, arguments: privates);
1501}
1502
1503void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
1504 ValueRange reductionOperands, TypeRange reductionTypes,
1505 ArrayAttr reductionSymbols) {
1506 if (reductionSymbols) {
1507 auto reductionArgs = region.front().getArguments();
1508 printClauseWithRegionArgs(p, op, reductionArgs, "reduction",
1509 reductionOperands, reductionTypes,
1510 reductionSymbols);
1511 }
1512 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
1513}
1514
1515void WsloopOp::build(OpBuilder &builder, OperationState &state,
1516 ArrayRef<NamedAttribute> attributes) {
1517 build(builder, state, /*linear_vars=*/ValueRange(),
1518 /*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
1519 /*reductions=*/nullptr, /*schedule_val=*/nullptr,
1520 /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
1521 /*simd_modifier=*/false, /*nowait=*/false, /*byref=*/false,
1522 /*ordered_val=*/nullptr, /*order_val=*/nullptr);
1523 state.addAttributes(attributes);
1524}
1525
1526void WsloopOp::build(OpBuilder &builder, OperationState &state,
1527 const WsloopClauseOps &clauses) {
1528 MLIRContext *ctx = builder.getContext();
1529 // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
1530 // privatizers.
1531 WsloopOp::build(
1532 builder, state, clauses.linearVars, clauses.linearStepVars,
1533 clauses.reductionVars, makeArrayAttr(ctx, clauses.reductionDeclSymbols),
1534 clauses.scheduleValAttr, clauses.scheduleChunkVar,
1535 clauses.scheduleModAttr, clauses.scheduleSimdAttr, clauses.nowaitAttr,
1536 clauses.reductionByRefAttr, clauses.orderedAttr, clauses.orderAttr);
1537}
1538
1539LogicalResult WsloopOp::verify() {
1540 if (!isWrapper())
1541 return emitOpError() << "must be a loop wrapper";
1542
1543 if (LoopWrapperInterface nested = getNestedWrapper()) {
1544 // Check for the allowed leaf constructs that may appear in a composite
1545 // construct directly after DO/FOR.
1546 if (!isa<SimdOp>(nested))
1547 return emitError() << "only supported nested wrapper is 'omp.simd'";
1548 }
1549
1550 return verifyReductionVarList(*this, getReductions(), getReductionVars());
1551}
1552
1553//===----------------------------------------------------------------------===//
1554// Simd construct [2.9.3.1]
1555//===----------------------------------------------------------------------===//
1556
1557void SimdOp::build(OpBuilder &builder, OperationState &state,
1558 const SimdClauseOps &clauses) {
1559 MLIRContext *ctx = builder.getContext();
1560 // TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars,
1561 // privatizers, reductionDeclSymbols.
1562 SimdOp::build(builder, state, clauses.alignedVars,
1563 makeArrayAttr(ctx, clauses.alignmentAttrs), clauses.ifVar,
1564 clauses.nontemporalVars, clauses.orderAttr, clauses.simdlenAttr,
1565 clauses.safelenAttr);
1566}
1567
1568LogicalResult SimdOp::verify() {
1569 if (getSimdlen().has_value() && getSafelen().has_value() &&
1570 getSimdlen().value() > getSafelen().value())
1571 return emitOpError()
1572 << "simdlen clause and safelen clause are both present, but the "
1573 "simdlen value is not less than or equal to safelen value";
1574
1575 if (verifyAlignedClause(*this, getAlignmentValues(), getAlignedVars())
1576 .failed())
1577 return failure();
1578
1579 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
1580 return failure();
1581
1582 if (!isWrapper())
1583 return emitOpError() << "must be a loop wrapper";
1584
1585 if (getNestedWrapper())
1586 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
1587
1588 return success();
1589}
1590
1591//===----------------------------------------------------------------------===//
1592// Distribute construct [2.9.4.1]
1593//===----------------------------------------------------------------------===//
1594
1595void DistributeOp::build(OpBuilder &builder, OperationState &state,
1596 const DistributeClauseOps &clauses) {
1597 // TODO Store clauses in op: privateVars, privatizers.
1598 DistributeOp::build(builder, state, clauses.distScheduleStaticAttr,
1599 clauses.distScheduleChunkSizeVar, clauses.allocateVars,
1600 clauses.allocatorVars, clauses.orderAttr);
1601}
1602
1603LogicalResult DistributeOp::verify() {
1604 if (this->getChunkSize() && !this->getDistScheduleStatic())
1605 return emitOpError() << "chunk size set without "
1606 "dist_schedule_static being present";
1607
1608 if (getAllocateVars().size() != getAllocatorsVars().size())
1609 return emitError(
1610 "expected equal sizes for allocate and allocator variables");
1611
1612 if (!isWrapper())
1613 return emitOpError() << "must be a loop wrapper";
1614
1615 if (LoopWrapperInterface nested = getNestedWrapper()) {
1616 // Check for the allowed leaf constructs that may appear in a composite
1617 // construct directly after DISTRIBUTE.
1618 if (!isa<ParallelOp, SimdOp>(nested))
1619 return emitError() << "only supported nested wrappers are 'omp.parallel' "
1620 "and 'omp.simd'";
1621 }
1622
1623 return success();
1624}
1625
1626//===----------------------------------------------------------------------===//
1627// ReductionOp
1628//===----------------------------------------------------------------------===//
1629
1630static ParseResult parseAtomicReductionRegion(OpAsmParser &parser,
1631 Region &region) {
1632 if (parser.parseOptionalKeyword(keyword: "atomic"))
1633 return success();
1634 return parser.parseRegion(region);
1635}
1636
1637static void printAtomicReductionRegion(OpAsmPrinter &printer,
1638 DeclareReductionOp op, Region &region) {
1639 if (region.empty())
1640 return;
1641 printer << "atomic ";
1642 printer.printRegion(blocks&: region);
1643}
1644
1645static ParseResult parseCleanupReductionRegion(OpAsmParser &parser,
1646 Region &region) {
1647 if (parser.parseOptionalKeyword(keyword: "cleanup"))
1648 return success();
1649 return parser.parseRegion(region);
1650}
1651
1652static void printCleanupReductionRegion(OpAsmPrinter &printer,
1653 DeclareReductionOp op, Region &region) {
1654 if (region.empty())
1655 return;
1656 printer << "cleanup ";
1657 printer.printRegion(blocks&: region);
1658}
1659
1660LogicalResult DeclareReductionOp::verifyRegions() {
1661 if (getInitializerRegion().empty())
1662 return emitOpError() << "expects non-empty initializer region";
1663 Block &initializerEntryBlock = getInitializerRegion().front();
1664 if (initializerEntryBlock.getNumArguments() != 1 ||
1665 initializerEntryBlock.getArgument(0).getType() != getType()) {
1666 return emitOpError() << "expects initializer region with one argument "
1667 "of the reduction type";
1668 }
1669
1670 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
1671 if (yieldOp.getResults().size() != 1 ||
1672 yieldOp.getResults().getTypes()[0] != getType())
1673 return emitOpError() << "expects initializer region to yield a value "
1674 "of the reduction type";
1675 }
1676
1677 if (getReductionRegion().empty())
1678 return emitOpError() << "expects non-empty reduction region";
1679 Block &reductionEntryBlock = getReductionRegion().front();
1680 if (reductionEntryBlock.getNumArguments() != 2 ||
1681 reductionEntryBlock.getArgumentTypes()[0] !=
1682 reductionEntryBlock.getArgumentTypes()[1] ||
1683 reductionEntryBlock.getArgumentTypes()[0] != getType())
1684 return emitOpError() << "expects reduction region with two arguments of "
1685 "the reduction type";
1686 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
1687 if (yieldOp.getResults().size() != 1 ||
1688 yieldOp.getResults().getTypes()[0] != getType())
1689 return emitOpError() << "expects reduction region to yield a value "
1690 "of the reduction type";
1691 }
1692
1693 if (!getAtomicReductionRegion().empty()) {
1694 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
1695 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
1696 atomicReductionEntryBlock.getArgumentTypes()[0] !=
1697 atomicReductionEntryBlock.getArgumentTypes()[1])
1698 return emitOpError() << "expects atomic reduction region with two "
1699 "arguments of the same type";
1700 auto ptrType = llvm::dyn_cast<PointerLikeType>(
1701 atomicReductionEntryBlock.getArgumentTypes()[0]);
1702 if (!ptrType ||
1703 (ptrType.getElementType() && ptrType.getElementType() != getType()))
1704 return emitOpError() << "expects atomic reduction region arguments to "
1705 "be accumulators containing the reduction type";
1706 }
1707
1708 if (getCleanupRegion().empty())
1709 return success();
1710 Block &cleanupEntryBlock = getCleanupRegion().front();
1711 if (cleanupEntryBlock.getNumArguments() != 1 ||
1712 cleanupEntryBlock.getArgument(0).getType() != getType())
1713 return emitOpError() << "expects cleanup region with one argument "
1714 "of the reduction type";
1715
1716 return success();
1717}
1718
1719LogicalResult ReductionOp::verify() {
1720 auto *op = (*this)->getParentWithTrait<ReductionClauseInterface::Trait>();
1721 if (!op)
1722 return emitOpError() << "must be used within an operation supporting "
1723 "reduction clause interface";
1724 while (op) {
1725 for (const auto &var :
1726 cast<ReductionClauseInterface>(op).getAllReductionVars())
1727 if (var == getAccumulator())
1728 return success();
1729 op = op->getParentWithTrait<ReductionClauseInterface::Trait>();
1730 }
1731 return emitOpError() << "the accumulator is not used by the parent";
1732}
1733
1734//===----------------------------------------------------------------------===//
1735// TaskOp
1736//===----------------------------------------------------------------------===//
1737
1738void TaskOp::build(OpBuilder &builder, OperationState &state,
1739 const TaskClauseOps &clauses) {
1740 MLIRContext *ctx = builder.getContext();
1741 // TODO Store clauses in op: privateVars, privatizers.
1742 TaskOp::build(
1743 builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
1744 clauses.mergeableAttr, clauses.inReductionVars,
1745 makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
1746 makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
1747 clauses.allocateVars, clauses.allocatorVars);
1748}
1749
1750LogicalResult TaskOp::verify() {
1751 LogicalResult verifyDependVars =
1752 verifyDependVarList(*this, getDepends(), getDependVars());
1753 return failed(verifyDependVars)
1754 ? verifyDependVars
1755 : verifyReductionVarList(*this, getInReductions(),
1756 getInReductionVars());
1757}
1758
1759//===----------------------------------------------------------------------===//
1760// TaskgroupOp
1761//===----------------------------------------------------------------------===//
1762
1763void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
1764 const TaskgroupClauseOps &clauses) {
1765 MLIRContext *ctx = builder.getContext();
1766 TaskgroupOp::build(builder, state, clauses.taskReductionVars,
1767 makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
1768 clauses.allocateVars, clauses.allocatorVars);
1769}
1770
1771LogicalResult TaskgroupOp::verify() {
1772 return verifyReductionVarList(*this, getTaskReductions(),
1773 getTaskReductionVars());
1774}
1775
1776//===----------------------------------------------------------------------===//
1777// TaskloopOp
1778//===----------------------------------------------------------------------===//
1779
1780void TaskloopOp::build(OpBuilder &builder, OperationState &state,
1781 const TaskloopClauseOps &clauses) {
1782 MLIRContext *ctx = builder.getContext();
1783 // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
1784 TaskloopOp::build(
1785 builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
1786 clauses.mergeableAttr, clauses.inReductionVars,
1787 makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
1788 makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
1789 clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
1790 clauses.numTasksVar, clauses.nogroupAttr);
1791}
1792
1793SmallVector<Value> TaskloopOp::getAllReductionVars() {
1794 SmallVector<Value> allReductionNvars(getInReductionVars().begin(),
1795 getInReductionVars().end());
1796 allReductionNvars.insert(allReductionNvars.end(), getReductionVars().begin(),
1797 getReductionVars().end());
1798 return allReductionNvars;
1799}
1800
1801LogicalResult TaskloopOp::verify() {
1802 if (getAllocateVars().size() != getAllocatorsVars().size())
1803 return emitError(
1804 "expected equal sizes for allocate and allocator variables");
1805 if (failed(
1806 verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
1807 failed(verifyReductionVarList(*this, getInReductions(),
1808 getInReductionVars())))
1809 return failure();
1810
1811 if (!getReductionVars().empty() && getNogroup())
1812 return emitError("if a reduction clause is present on the taskloop "
1813 "directive, the nogroup clause must not be specified");
1814 for (auto var : getReductionVars()) {
1815 if (llvm::is_contained(getInReductionVars(), var))
1816 return emitError("the same list item cannot appear in both a reduction "
1817 "and an in_reduction clause");
1818 }
1819
1820 if (getGrainSize() && getNumTasks()) {
1821 return emitError(
1822 "the grainsize clause and num_tasks clause are mutually exclusive and "
1823 "may not appear on the same taskloop directive");
1824 }
1825
1826 if (!isWrapper())
1827 return emitOpError() << "must be a loop wrapper";
1828
1829 if (LoopWrapperInterface nested = getNestedWrapper()) {
1830 // Check for the allowed leaf constructs that may appear in a composite
1831 // construct directly after TASKLOOP.
1832 if (!isa<SimdOp>(nested))
1833 return emitError() << "only supported nested wrapper is 'omp.simd'";
1834 }
1835 return success();
1836}
1837
1838//===----------------------------------------------------------------------===//
1839// LoopNestOp
1840//===----------------------------------------------------------------------===//
1841
1842ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
1843 // Parse an opening `(` followed by induction variables followed by `)`
1844 SmallVector<OpAsmParser::Argument> ivs;
1845 SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs;
1846 Type loopVarType;
1847 if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
1848 parser.parseColonType(loopVarType) ||
1849 // Parse loop bounds.
1850 parser.parseEqual() ||
1851 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
1852 parser.parseKeyword("to") ||
1853 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
1854 return failure();
1855
1856 for (auto &iv : ivs)
1857 iv.type = loopVarType;
1858
1859 // Parse "inclusive" flag.
1860 if (succeeded(parser.parseOptionalKeyword("inclusive")))
1861 result.addAttribute("inclusive",
1862 UnitAttr::get(parser.getBuilder().getContext()));
1863
1864 // Parse step values.
1865 SmallVector<OpAsmParser::UnresolvedOperand> steps;
1866 if (parser.parseKeyword("step") ||
1867 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
1868 return failure();
1869
1870 // Parse the body.
1871 Region *region = result.addRegion();
1872 if (parser.parseRegion(*region, ivs))
1873 return failure();
1874
1875 // Resolve operands.
1876 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
1877 parser.resolveOperands(ubs, loopVarType, result.operands) ||
1878 parser.resolveOperands(steps, loopVarType, result.operands))
1879 return failure();
1880
1881 // Parse the optional attribute list.
1882 return parser.parseOptionalAttrDict(result.attributes);
1883}
1884
1885void LoopNestOp::print(OpAsmPrinter &p) {
1886 Region &region = getRegion();
1887 auto args = region.getArguments();
1888 p << " (" << args << ") : " << args[0].getType() << " = (" << getLowerBound()
1889 << ") to (" << getUpperBound() << ") ";
1890 if (getInclusive())
1891 p << "inclusive ";
1892 p << "step (" << getStep() << ") ";
1893 p.printRegion(region, /*printEntryBlockArgs=*/false);
1894}
1895
1896void LoopNestOp::build(OpBuilder &builder, OperationState &state,
1897 const LoopNestClauseOps &clauses) {
1898 LoopNestOp::build(builder, state, clauses.loopLBVar, clauses.loopUBVar,
1899 clauses.loopStepVar, clauses.loopInclusiveAttr);
1900}
1901
1902LogicalResult LoopNestOp::verify() {
1903 if (getLowerBound().empty())
1904 return emitOpError() << "must represent at least one loop";
1905
1906 if (getLowerBound().size() != getIVs().size())
1907 return emitOpError() << "number of range arguments and IVs do not match";
1908
1909 for (auto [lb, iv] : llvm::zip_equal(getLowerBound(), getIVs())) {
1910 if (lb.getType() != iv.getType())
1911 return emitOpError()
1912 << "range argument type does not match corresponding IV type";
1913 }
1914
1915 auto wrapper =
1916 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
1917
1918 if (!wrapper || !wrapper.isWrapper())
1919 return emitOpError() << "expects parent op to be a valid loop wrapper";
1920
1921 return success();
1922}
1923
1924void LoopNestOp::gatherWrappers(
1925 SmallVectorImpl<LoopWrapperInterface> &wrappers) {
1926 Operation *parent = (*this)->getParentOp();
1927 while (auto wrapper =
1928 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
1929 if (!wrapper.isWrapper())
1930 break;
1931 wrappers.push_back(wrapper);
1932 parent = parent->getParentOp();
1933 }
1934}
1935
1936//===----------------------------------------------------------------------===//
1937// Critical construct (2.17.1)
1938//===----------------------------------------------------------------------===//
1939
1940void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
1941 const CriticalClauseOps &clauses) {
1942 CriticalDeclareOp::build(builder, state, clauses.nameAttr, clauses.hintAttr);
1943}
1944
1945LogicalResult CriticalDeclareOp::verify() {
1946 return verifySynchronizationHint(*this, getHintVal());
1947}
1948
1949LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1950 if (getNameAttr()) {
1951 SymbolRefAttr symbolRef = getNameAttr();
1952 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
1953 *this, symbolRef);
1954 if (!decl) {
1955 return emitOpError() << "expected symbol reference " << symbolRef
1956 << " to point to a critical declaration";
1957 }
1958 }
1959
1960 return success();
1961}
1962
1963//===----------------------------------------------------------------------===//
1964// Ordered construct
1965//===----------------------------------------------------------------------===//
1966
1967static LogicalResult verifyOrderedParent(Operation &op) {
1968 bool hasRegion = op.getNumRegions() > 0;
1969 auto loopOp = op.getParentOfType<LoopNestOp>();
1970 if (!loopOp) {
1971 if (hasRegion)
1972 return success();
1973
1974 // TODO: Consider if this needs to be the case only for the standalone
1975 // variant of the ordered construct.
1976 return op.emitOpError() << "must be nested inside of a loop";
1977 }
1978
1979 Operation *wrapper = loopOp->getParentOp();
1980 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
1981 IntegerAttr orderedAttr = wsloopOp.getOrderedValAttr();
1982 if (!orderedAttr)
1983 return op.emitOpError() << "the enclosing worksharing-loop region must "
1984 "have an ordered clause";
1985
1986 if (hasRegion && orderedAttr.getInt() != 0)
1987 return op.emitOpError() << "the enclosing loop's ordered clause must not "
1988 "have a parameter present";
1989
1990 if (!hasRegion && orderedAttr.getInt() == 0)
1991 return op.emitOpError() << "the enclosing loop's ordered clause must "
1992 "have a parameter present";
1993 } else if (!isa<SimdOp>(wrapper)) {
1994 return op.emitOpError() << "must be nested inside of a worksharing, simd "
1995 "or worksharing simd loop";
1996 }
1997 return success();
1998}
1999
2000void OrderedOp::build(OpBuilder &builder, OperationState &state,
2001 const OrderedOpClauseOps &clauses) {
2002 OrderedOp::build(builder, state, clauses.doacrossDependTypeAttr,
2003 clauses.doacrossNumLoopsAttr, clauses.doacrossVectorVars);
2004}
2005
2006LogicalResult OrderedOp::verify() {
2007 if (failed(verifyOrderedParent(**this)))
2008 return failure();
2009
2010 auto wrapper = (*this)->getParentOfType<WsloopOp>();
2011 if (!wrapper || *wrapper.getOrderedVal() != *getNumLoopsVal())
2012 return emitOpError() << "number of variables in depend clause does not "
2013 << "match number of iteration variables in the "
2014 << "doacross loop";
2015
2016 return success();
2017}
2018
2019void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
2020 const OrderedRegionClauseOps &clauses) {
2021 OrderedRegionOp::build(builder, state, clauses.parLevelSimdAttr);
2022}
2023
2024LogicalResult OrderedRegionOp::verify() {
2025 // TODO: The code generation for ordered simd directive is not supported yet.
2026 if (getSimd())
2027 return failure();
2028
2029 return verifyOrderedParent(**this);
2030}
2031
2032//===----------------------------------------------------------------------===//
2033// TaskwaitOp
2034//===----------------------------------------------------------------------===//
2035
2036void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
2037 const TaskwaitClauseOps &clauses) {
2038 // TODO Store clauses in op: dependTypeAttrs, dependVars, nowaitAttr.
2039 TaskwaitOp::build(builder, state);
2040}
2041
2042//===----------------------------------------------------------------------===//
2043// Verifier for AtomicReadOp
2044//===----------------------------------------------------------------------===//
2045
2046LogicalResult AtomicReadOp::verify() {
2047 if (verifyCommon().failed())
2048 return mlir::failure();
2049
2050 if (auto mo = getMemoryOrderVal()) {
2051 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2052 *mo == ClauseMemoryOrderKind::Release) {
2053 return emitError(
2054 "memory-order must not be acq_rel or release for atomic reads");
2055 }
2056 }
2057 return verifySynchronizationHint(*this, getHintVal());
2058}
2059
2060//===----------------------------------------------------------------------===//
2061// Verifier for AtomicWriteOp
2062//===----------------------------------------------------------------------===//
2063
2064LogicalResult AtomicWriteOp::verify() {
2065 if (verifyCommon().failed())
2066 return mlir::failure();
2067
2068 if (auto mo = getMemoryOrderVal()) {
2069 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2070 *mo == ClauseMemoryOrderKind::Acquire) {
2071 return emitError(
2072 "memory-order must not be acq_rel or acquire for atomic writes");
2073 }
2074 }
2075 return verifySynchronizationHint(*this, getHintVal());
2076}
2077
2078//===----------------------------------------------------------------------===//
2079// Verifier for AtomicUpdateOp
2080//===----------------------------------------------------------------------===//
2081
2082LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2083 PatternRewriter &rewriter) {
2084 if (op.isNoOp()) {
2085 rewriter.eraseOp(op);
2086 return success();
2087 }
2088 if (Value writeVal = op.getWriteOpVal()) {
2089 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
2090 op.getHintValAttr(),
2091 op.getMemoryOrderValAttr());
2092 return success();
2093 }
2094 return failure();
2095}
2096
2097LogicalResult AtomicUpdateOp::verify() {
2098 if (verifyCommon().failed())
2099 return mlir::failure();
2100
2101 if (auto mo = getMemoryOrderVal()) {
2102 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
2103 *mo == ClauseMemoryOrderKind::Acquire) {
2104 return emitError(
2105 "memory-order must not be acq_rel or acquire for atomic updates");
2106 }
2107 }
2108
2109 return verifySynchronizationHint(*this, getHintVal());
2110}
2111
2112LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2113
2114//===----------------------------------------------------------------------===//
2115// Verifier for AtomicCaptureOp
2116//===----------------------------------------------------------------------===//
2117
2118AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2119 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2120 return op;
2121 return dyn_cast<AtomicReadOp>(getSecondOp());
2122}
2123
2124AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2125 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2126 return op;
2127 return dyn_cast<AtomicWriteOp>(getSecondOp());
2128}
2129
2130AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2131 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2132 return op;
2133 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2134}
2135
2136LogicalResult AtomicCaptureOp::verify() {
2137 return verifySynchronizationHint(*this, getHintVal());
2138}
2139
2140LogicalResult AtomicCaptureOp::verifyRegions() {
2141 if (verifyRegionsCommon().failed())
2142 return mlir::failure();
2143
2144 if (getFirstOp()->getAttr("hint_val") || getSecondOp()->getAttr("hint_val"))
2145 return emitOpError(
2146 "operations inside capture region must not have hint clause");
2147
2148 if (getFirstOp()->getAttr("memory_order_val") ||
2149 getSecondOp()->getAttr("memory_order_val"))
2150 return emitOpError(
2151 "operations inside capture region must not have memory_order clause");
2152 return success();
2153}
2154
2155//===----------------------------------------------------------------------===//
2156// Verifier for CancelOp
2157//===----------------------------------------------------------------------===//
2158
2159LogicalResult CancelOp::verify() {
2160 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
2161 Operation *parentOp = (*this)->getParentOp();
2162
2163 if (!parentOp) {
2164 return emitOpError() << "must be used within a region supporting "
2165 "cancel directive";
2166 }
2167
2168 if ((cct == ClauseCancellationConstructType::Parallel) &&
2169 !isa<ParallelOp>(parentOp)) {
2170 return emitOpError() << "cancel parallel must appear "
2171 << "inside a parallel region";
2172 }
2173 if (cct == ClauseCancellationConstructType::Loop) {
2174 auto loopOp = dyn_cast<LoopNestOp>(parentOp);
2175 auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
2176 loopOp ? loopOp->getParentOp() : nullptr);
2177
2178 if (!wsloopOp) {
2179 return emitOpError()
2180 << "cancel loop must appear inside a worksharing-loop region";
2181 }
2182 if (wsloopOp.getNowaitAttr()) {
2183 return emitError() << "A worksharing construct that is canceled "
2184 << "must not have a nowait clause";
2185 }
2186 if (wsloopOp.getOrderedValAttr()) {
2187 return emitError() << "A worksharing construct that is canceled "
2188 << "must not have an ordered clause";
2189 }
2190
2191 } else if (cct == ClauseCancellationConstructType::Sections) {
2192 if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2193 return emitOpError() << "cancel sections must appear "
2194 << "inside a sections region";
2195 }
2196 if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
2197 cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
2198 return emitError() << "A sections construct that is canceled "
2199 << "must not have a nowait clause";
2200 }
2201 }
2202 // TODO : Add more when we support taskgroup.
2203 return success();
2204}
2205//===----------------------------------------------------------------------===//
2206// Verifier for CancelOp
2207//===----------------------------------------------------------------------===//
2208
2209LogicalResult CancellationPointOp::verify() {
2210 ClauseCancellationConstructType cct = getCancellationConstructTypeVal();
2211 Operation *parentOp = (*this)->getParentOp();
2212
2213 if (!parentOp) {
2214 return emitOpError() << "must be used within a region supporting "
2215 "cancellation point directive";
2216 }
2217
2218 if ((cct == ClauseCancellationConstructType::Parallel) &&
2219 !(isa<ParallelOp>(parentOp))) {
2220 return emitOpError() << "cancellation point parallel must appear "
2221 << "inside a parallel region";
2222 }
2223 if ((cct == ClauseCancellationConstructType::Loop) &&
2224 (!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
2225 return emitOpError() << "cancellation point loop must appear "
2226 << "inside a worksharing-loop region";
2227 }
2228 if ((cct == ClauseCancellationConstructType::Sections) &&
2229 !(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
2230 return emitOpError() << "cancellation point sections must appear "
2231 << "inside a sections region";
2232 }
2233 // TODO : Add more when we support taskgroup.
2234 return success();
2235}
2236
2237//===----------------------------------------------------------------------===//
2238// MapBoundsOp
2239//===----------------------------------------------------------------------===//
2240
2241LogicalResult MapBoundsOp::verify() {
2242 auto extent = getExtent();
2243 auto upperbound = getUpperBound();
2244 if (!extent && !upperbound)
2245 return emitError("expected extent or upperbound.");
2246 return success();
2247}
2248
2249void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2250 TypeRange /*result_types*/, StringAttr symName,
2251 TypeAttr type) {
2252 PrivateClauseOp::build(
2253 odsBuilder, odsState, symName, type,
2254 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
2255 DataSharingClauseType::Private));
2256}
2257
2258LogicalResult PrivateClauseOp::verify() {
2259 Type symType = getType();
2260
2261 auto verifyTerminator = [&](Operation *terminator) -> LogicalResult {
2262 if (!terminator->getBlock()->getSuccessors().empty())
2263 return success();
2264
2265 if (!llvm::isa<YieldOp>(terminator))
2266 return mlir::emitError(terminator->getLoc())
2267 << "expected exit block terminator to be an `omp.yield` op.";
2268
2269 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
2270 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
2271
2272 if (yieldedTypes.size() == 1 && yieldedTypes.front() == symType)
2273 return success();
2274
2275 auto error = mlir::emitError(yieldOp.getLoc())
2276 << "Invalid yielded value. Expected type: " << symType
2277 << ", got: ";
2278
2279 if (yieldedTypes.empty())
2280 error << "None";
2281 else
2282 error << yieldedTypes;
2283
2284 return error;
2285 };
2286
2287 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
2288 StringRef regionName) -> LogicalResult {
2289 assert(!region.empty());
2290
2291 if (region.getNumArguments() != expectedNumArgs)
2292 return mlir::emitError(region.getLoc())
2293 << "`" << regionName << "`: "
2294 << "expected " << expectedNumArgs
2295 << " region arguments, got: " << region.getNumArguments();
2296
2297 for (Block &block : region) {
2298 // MLIR will verify the absence of the terminator for us.
2299 if (!block.mightHaveTerminator())
2300 continue;
2301
2302 if (failed(verifyTerminator(block.getTerminator())))
2303 return failure();
2304 }
2305
2306 return success();
2307 };
2308
2309 if (failed(verifyRegion(getAllocRegion(), /*expectedNumArgs=*/1, "alloc")))
2310 return failure();
2311
2312 DataSharingClauseType dsType = getDataSharingType();
2313
2314 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
2315 return emitError("`private` clauses require only an `alloc` region.");
2316
2317 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
2318 return emitError(
2319 "`firstprivate` clauses require both `alloc` and `copy` regions.");
2320
2321 if (dsType == DataSharingClauseType::FirstPrivate &&
2322 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy")))
2323 return failure();
2324
2325 return success();
2326}
2327
2328#define GET_ATTRDEF_CLASSES
2329#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
2330
2331#define GET_OP_CLASSES
2332#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
2333
2334#define GET_TYPEDEF_CLASSES
2335#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
2336

source code of mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp