1 | //===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the OpenMP dialect and its operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
14 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
15 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
16 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
17 | #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h" |
18 | #include "mlir/IR/Attributes.h" |
19 | #include "mlir/IR/BuiltinAttributes.h" |
20 | #include "mlir/IR/DialectImplementation.h" |
21 | #include "mlir/IR/OpImplementation.h" |
22 | #include "mlir/IR/OperationSupport.h" |
23 | #include "mlir/Interfaces/FoldInterfaces.h" |
24 | |
25 | #include "llvm/ADT/ArrayRef.h" |
26 | #include "llvm/ADT/BitVector.h" |
27 | #include "llvm/ADT/STLExtras.h" |
28 | #include "llvm/ADT/STLForwardCompat.h" |
29 | #include "llvm/ADT/SmallString.h" |
30 | #include "llvm/ADT/StringExtras.h" |
31 | #include "llvm/ADT/StringRef.h" |
32 | #include "llvm/ADT/TypeSwitch.h" |
33 | #include "llvm/Frontend/OpenMP/OMPConstants.h" |
34 | #include "llvm/Frontend/OpenMP/OMPDeviceConstants.h" |
35 | #include <cstddef> |
36 | #include <iterator> |
37 | #include <optional> |
38 | #include <variant> |
39 | |
40 | #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" |
41 | #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" |
42 | #include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc" |
43 | #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" |
44 | |
45 | using namespace mlir; |
46 | using namespace mlir::omp; |
47 | |
48 | static ArrayAttr makeArrayAttr(MLIRContext *context, |
49 | llvm::ArrayRef<Attribute> attrs) { |
50 | return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs); |
51 | } |
52 | |
53 | static DenseBoolArrayAttr |
54 | makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) { |
55 | return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray); |
56 | } |
57 | |
58 | namespace { |
59 | struct MemRefPointerLikeModel |
60 | : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, |
61 | MemRefType> { |
62 | Type getElementType(Type pointer) const { |
63 | return llvm::cast<MemRefType>(pointer).getElementType(); |
64 | } |
65 | }; |
66 | |
67 | struct LLVMPointerPointerLikeModel |
68 | : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, |
69 | LLVM::LLVMPointerType> { |
70 | Type getElementType(Type pointer) const { return Type(); } |
71 | }; |
72 | } // namespace |
73 | |
74 | void OpenMPDialect::initialize() { |
75 | addOperations< |
76 | #define GET_OP_LIST |
77 | #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" |
78 | >(); |
79 | addAttributes< |
80 | #define GET_ATTRDEF_LIST |
81 | #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" |
82 | >(); |
83 | addTypes< |
84 | #define GET_TYPEDEF_LIST |
85 | #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" |
86 | >(); |
87 | |
88 | declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>(); |
89 | |
90 | MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext()); |
91 | LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( |
92 | *getContext()); |
93 | |
94 | // Attach default offload module interface to module op to access |
95 | // offload functionality through |
96 | mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>( |
97 | *getContext()); |
98 | |
99 | // Attach default declare target interfaces to operations which can be marked |
100 | // as declare target (Global Operations and Functions/Subroutines in dialects |
101 | // that Fortran (or other languages that lower to MLIR) translates too |
102 | mlir::LLVM::GlobalOp::attachInterface< |
103 | mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::GlobalOp>>( |
104 | *getContext()); |
105 | mlir::LLVM::LLVMFuncOp::attachInterface< |
106 | mlir::omp::DeclareTargetDefaultModel<mlir::LLVM::LLVMFuncOp>>( |
107 | *getContext()); |
108 | mlir::func::FuncOp::attachInterface< |
109 | mlir::omp::DeclareTargetDefaultModel<mlir::func::FuncOp>>(*getContext()); |
110 | } |
111 | |
112 | //===----------------------------------------------------------------------===// |
113 | // Parser and printer for Allocate Clause |
114 | //===----------------------------------------------------------------------===// |
115 | |
116 | /// Parse an allocate clause with allocators and a list of operands with types. |
117 | /// |
118 | /// allocate-operand-list :: = allocate-operand | |
119 | /// allocator-operand `,` allocate-operand-list |
120 | /// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type |
121 | /// ssa-id-and-type ::= ssa-id `:` type |
122 | static ParseResult parseAllocateAndAllocator( |
123 | OpAsmParser &parser, |
124 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocateVars, |
125 | SmallVectorImpl<Type> &allocateTypes, |
126 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &allocatorVars, |
127 | SmallVectorImpl<Type> &allocatorTypes) { |
128 | |
129 | return parser.parseCommaSeparatedList([&]() { |
130 | OpAsmParser::UnresolvedOperand operand; |
131 | Type type; |
132 | if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type)) |
133 | return failure(); |
134 | allocatorVars.push_back(operand); |
135 | allocatorTypes.push_back(Elt: type); |
136 | if (parser.parseArrow()) |
137 | return failure(); |
138 | if (parser.parseOperand(result&: operand) || parser.parseColonType(result&: type)) |
139 | return failure(); |
140 | |
141 | allocateVars.push_back(operand); |
142 | allocateTypes.push_back(Elt: type); |
143 | return success(); |
144 | }); |
145 | } |
146 | |
147 | /// Print allocate clause |
148 | static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, |
149 | OperandRange allocateVars, |
150 | TypeRange allocateTypes, |
151 | OperandRange allocatorVars, |
152 | TypeRange allocatorTypes) { |
153 | for (unsigned i = 0; i < allocateVars.size(); ++i) { |
154 | std::string separator = i == allocateVars.size() - 1 ? "": ", "; |
155 | p << allocatorVars[i] << " : "<< allocatorTypes[i] << " -> "; |
156 | p << allocateVars[i] << " : "<< allocateTypes[i] << separator; |
157 | } |
158 | } |
159 | |
160 | //===----------------------------------------------------------------------===// |
161 | // Parser and printer for a clause attribute (StringEnumAttr) |
162 | //===----------------------------------------------------------------------===// |
163 | |
164 | template <typename ClauseAttr> |
165 | static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) { |
166 | using ClauseT = decltype(std::declval<ClauseAttr>().getValue()); |
167 | StringRef enumStr; |
168 | SMLoc loc = parser.getCurrentLocation(); |
169 | if (parser.parseKeyword(keyword: &enumStr)) |
170 | return failure(); |
171 | if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) { |
172 | attr = ClauseAttr::get(parser.getContext(), *enumValue); |
173 | return success(); |
174 | } |
175 | return parser.emitError(loc, message: "invalid clause value: '") << enumStr << "'"; |
176 | } |
177 | |
178 | template <typename ClauseAttr> |
179 | void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) { |
180 | p << stringifyEnum(attr.getValue()); |
181 | } |
182 | |
183 | //===----------------------------------------------------------------------===// |
184 | // Parser and printer for Linear Clause |
185 | //===----------------------------------------------------------------------===// |
186 | |
187 | /// linear ::= `linear` `(` linear-list `)` |
188 | /// linear-list := linear-val | linear-val linear-list |
189 | /// linear-val := ssa-id-and-type `=` ssa-id-and-type |
190 | static ParseResult parseLinearClause( |
191 | OpAsmParser &parser, |
192 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearVars, |
193 | SmallVectorImpl<Type> &linearTypes, |
194 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &linearStepVars) { |
195 | return parser.parseCommaSeparatedList(parseElementFn: [&]() { |
196 | OpAsmParser::UnresolvedOperand var; |
197 | Type type; |
198 | OpAsmParser::UnresolvedOperand stepVar; |
199 | if (parser.parseOperand(result&: var) || parser.parseEqual() || |
200 | parser.parseOperand(result&: stepVar) || parser.parseColonType(result&: type)) |
201 | return failure(); |
202 | |
203 | linearVars.push_back(Elt: var); |
204 | linearTypes.push_back(Elt: type); |
205 | linearStepVars.push_back(Elt: stepVar); |
206 | return success(); |
207 | }); |
208 | } |
209 | |
210 | /// Print Linear Clause |
211 | static void printLinearClause(OpAsmPrinter &p, Operation *op, |
212 | ValueRange linearVars, TypeRange linearTypes, |
213 | ValueRange linearStepVars) { |
214 | size_t linearVarsSize = linearVars.size(); |
215 | for (unsigned i = 0; i < linearVarsSize; ++i) { |
216 | std::string separator = i == linearVarsSize - 1 ? "": ", "; |
217 | p << linearVars[i]; |
218 | if (linearStepVars.size() > i) |
219 | p << " = "<< linearStepVars[i]; |
220 | p << " : "<< linearVars[i].getType() << separator; |
221 | } |
222 | } |
223 | |
224 | //===----------------------------------------------------------------------===// |
225 | // Verifier for Nontemporal Clause |
226 | //===----------------------------------------------------------------------===// |
227 | |
228 | static LogicalResult verifyNontemporalClause(Operation *op, |
229 | OperandRange nontemporalVars) { |
230 | |
231 | // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section |
232 | DenseSet<Value> nontemporalItems; |
233 | for (const auto &it : nontemporalVars) |
234 | if (!nontemporalItems.insert(V: it).second) |
235 | return op->emitOpError() << "nontemporal variable used more than once"; |
236 | |
237 | return success(); |
238 | } |
239 | |
240 | //===----------------------------------------------------------------------===// |
241 | // Parser, verifier and printer for Aligned Clause |
242 | //===----------------------------------------------------------------------===// |
243 | static LogicalResult verifyAlignedClause(Operation *op, |
244 | std::optional<ArrayAttr> alignments, |
245 | OperandRange alignedVars) { |
246 | // Check if number of alignment values equals to number of aligned variables |
247 | if (!alignedVars.empty()) { |
248 | if (!alignments || alignments->size() != alignedVars.size()) |
249 | return op->emitOpError() |
250 | << "expected as many alignment values as aligned variables"; |
251 | } else { |
252 | if (alignments) |
253 | return op->emitOpError() << "unexpected alignment values attribute"; |
254 | return success(); |
255 | } |
256 | |
257 | // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section |
258 | DenseSet<Value> alignedItems; |
259 | for (auto it : alignedVars) |
260 | if (!alignedItems.insert(V: it).second) |
261 | return op->emitOpError() << "aligned variable used more than once"; |
262 | |
263 | if (!alignments) |
264 | return success(); |
265 | |
266 | // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section |
267 | for (unsigned i = 0; i < (*alignments).size(); ++i) { |
268 | if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) { |
269 | if (intAttr.getValue().sle(0)) |
270 | return op->emitOpError() << "alignment should be greater than 0"; |
271 | } else { |
272 | return op->emitOpError() << "expected integer alignment"; |
273 | } |
274 | } |
275 | |
276 | return success(); |
277 | } |
278 | |
279 | /// aligned ::= `aligned` `(` aligned-list `)` |
280 | /// aligned-list := aligned-val | aligned-val aligned-list |
281 | /// aligned-val := ssa-id-and-type `->` alignment |
282 | static ParseResult |
283 | parseAlignedClause(OpAsmParser &parser, |
284 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &alignedVars, |
285 | SmallVectorImpl<Type> &alignedTypes, |
286 | ArrayAttr &alignmentsAttr) { |
287 | SmallVector<Attribute> alignmentVec; |
288 | if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
289 | if (parser.parseOperand(result&: alignedVars.emplace_back()) || |
290 | parser.parseColonType(result&: alignedTypes.emplace_back()) || |
291 | parser.parseArrow() || |
292 | parser.parseAttribute(result&: alignmentVec.emplace_back())) { |
293 | return failure(); |
294 | } |
295 | return success(); |
296 | }))) |
297 | return failure(); |
298 | SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end()); |
299 | alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments); |
300 | return success(); |
301 | } |
302 | |
303 | /// Print Aligned Clause |
304 | static void printAlignedClause(OpAsmPrinter &p, Operation *op, |
305 | ValueRange alignedVars, TypeRange alignedTypes, |
306 | std::optional<ArrayAttr> alignments) { |
307 | for (unsigned i = 0; i < alignedVars.size(); ++i) { |
308 | if (i != 0) |
309 | p << ", "; |
310 | p << alignedVars[i] << " : "<< alignedVars[i].getType(); |
311 | p << " -> "<< (*alignments)[i]; |
312 | } |
313 | } |
314 | |
315 | //===----------------------------------------------------------------------===// |
316 | // Parser, printer and verifier for Schedule Clause |
317 | //===----------------------------------------------------------------------===// |
318 | |
319 | static ParseResult |
320 | verifyScheduleModifiers(OpAsmParser &parser, |
321 | SmallVectorImpl<SmallString<12>> &modifiers) { |
322 | if (modifiers.size() > 2) |
323 | return parser.emitError(loc: parser.getNameLoc()) << " unexpected modifier(s)"; |
324 | for (const auto &mod : modifiers) { |
325 | // Translate the string. If it has no value, then it was not a valid |
326 | // modifier! |
327 | auto symbol = symbolizeScheduleModifier(mod); |
328 | if (!symbol) |
329 | return parser.emitError(loc: parser.getNameLoc()) |
330 | << " unknown modifier type: "<< mod; |
331 | } |
332 | |
333 | // If we have one modifier that is "simd", then stick a "none" modiifer in |
334 | // index 0. |
335 | if (modifiers.size() == 1) { |
336 | if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) { |
337 | modifiers.push_back(Elt: modifiers[0]); |
338 | modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none); |
339 | } |
340 | } else if (modifiers.size() == 2) { |
341 | // If there are two modifier: |
342 | // First modifier should not be simd, second one should be simd |
343 | if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd || |
344 | symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd) |
345 | return parser.emitError(loc: parser.getNameLoc()) |
346 | << " incorrect modifier order"; |
347 | } |
348 | return success(); |
349 | } |
350 | |
351 | /// schedule ::= `schedule` `(` sched-list `)` |
352 | /// sched-list ::= sched-val | sched-val sched-list | |
353 | /// sched-val `,` sched-modifier |
354 | /// sched-val ::= sched-with-chunk | sched-wo-chunk |
355 | /// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)? |
356 | /// sched-with-chunk-types ::= `static` | `dynamic` | `guided` |
357 | /// sched-wo-chunk ::= `auto` | `runtime` |
358 | /// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val |
359 | /// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none` |
360 | static ParseResult |
361 | parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, |
362 | ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, |
363 | std::optional<OpAsmParser::UnresolvedOperand> &chunkSize, |
364 | Type &chunkType) { |
365 | StringRef keyword; |
366 | if (parser.parseKeyword(keyword: &keyword)) |
367 | return failure(); |
368 | std::optional<mlir::omp::ClauseScheduleKind> schedule = |
369 | symbolizeClauseScheduleKind(keyword); |
370 | if (!schedule) |
371 | return parser.emitError(loc: parser.getNameLoc()) << " expected schedule kind"; |
372 | |
373 | scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule); |
374 | switch (*schedule) { |
375 | case ClauseScheduleKind::Static: |
376 | case ClauseScheduleKind::Dynamic: |
377 | case ClauseScheduleKind::Guided: |
378 | if (succeeded(Result: parser.parseOptionalEqual())) { |
379 | chunkSize = OpAsmParser::UnresolvedOperand{}; |
380 | if (parser.parseOperand(result&: *chunkSize) || parser.parseColonType(result&: chunkType)) |
381 | return failure(); |
382 | } else { |
383 | chunkSize = std::nullopt; |
384 | } |
385 | break; |
386 | case ClauseScheduleKind::Auto: |
387 | case ClauseScheduleKind::Runtime: |
388 | chunkSize = std::nullopt; |
389 | } |
390 | |
391 | // If there is a comma, we have one or more modifiers.. |
392 | SmallVector<SmallString<12>> modifiers; |
393 | while (succeeded(Result: parser.parseOptionalComma())) { |
394 | StringRef mod; |
395 | if (parser.parseKeyword(keyword: &mod)) |
396 | return failure(); |
397 | modifiers.push_back(Elt: mod); |
398 | } |
399 | |
400 | if (verifyScheduleModifiers(parser, modifiers)) |
401 | return failure(); |
402 | |
403 | if (!modifiers.empty()) { |
404 | SMLoc loc = parser.getCurrentLocation(); |
405 | if (std::optional<ScheduleModifier> mod = |
406 | symbolizeScheduleModifier(modifiers[0])) { |
407 | scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod); |
408 | } else { |
409 | return parser.emitError(loc, message: "invalid schedule modifier"); |
410 | } |
411 | // Only SIMD attribute is allowed here! |
412 | if (modifiers.size() > 1) { |
413 | assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd); |
414 | scheduleSimd = UnitAttr::get(parser.getBuilder().getContext()); |
415 | } |
416 | } |
417 | |
418 | return success(); |
419 | } |
420 | |
421 | /// Print schedule clause |
422 | static void printScheduleClause(OpAsmPrinter &p, Operation *op, |
423 | ClauseScheduleKindAttr scheduleKind, |
424 | ScheduleModifierAttr scheduleMod, |
425 | UnitAttr scheduleSimd, Value scheduleChunk, |
426 | Type scheduleChunkType) { |
427 | p << stringifyClauseScheduleKind(scheduleKind.getValue()); |
428 | if (scheduleChunk) |
429 | p << " = "<< scheduleChunk << " : "<< scheduleChunk.getType(); |
430 | if (scheduleMod) |
431 | p << ", "<< stringifyScheduleModifier(scheduleMod.getValue()); |
432 | if (scheduleSimd) |
433 | p << ", simd"; |
434 | } |
435 | |
436 | //===----------------------------------------------------------------------===// |
437 | // Parser and printer for Order Clause |
438 | //===----------------------------------------------------------------------===// |
439 | |
440 | // order ::= `order` `(` [order-modifier ':'] concurrent `)` |
441 | // order-modifier ::= reproducible | unconstrained |
442 | static ParseResult parseOrderClause(OpAsmParser &parser, |
443 | ClauseOrderKindAttr &order, |
444 | OrderModifierAttr &orderMod) { |
445 | StringRef enumStr; |
446 | SMLoc loc = parser.getCurrentLocation(); |
447 | if (parser.parseKeyword(keyword: &enumStr)) |
448 | return failure(); |
449 | if (std::optional<OrderModifier> enumValue = |
450 | symbolizeOrderModifier(enumStr)) { |
451 | orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue); |
452 | if (parser.parseOptionalColon()) |
453 | return failure(); |
454 | loc = parser.getCurrentLocation(); |
455 | if (parser.parseKeyword(keyword: &enumStr)) |
456 | return failure(); |
457 | } |
458 | if (std::optional<ClauseOrderKind> enumValue = |
459 | symbolizeClauseOrderKind(enumStr)) { |
460 | order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue); |
461 | return success(); |
462 | } |
463 | return parser.emitError(loc, message: "invalid clause value: '") << enumStr << "'"; |
464 | } |
465 | |
466 | static void printOrderClause(OpAsmPrinter &p, Operation *op, |
467 | ClauseOrderKindAttr order, |
468 | OrderModifierAttr orderMod) { |
469 | if (orderMod) |
470 | p << stringifyOrderModifier(orderMod.getValue()) << ":"; |
471 | if (order) |
472 | p << stringifyClauseOrderKind(order.getValue()); |
473 | } |
474 | |
475 | template <typename ClauseTypeAttr, typename ClauseType> |
476 | static ParseResult |
477 | parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, |
478 | std::optional<OpAsmParser::UnresolvedOperand> &operand, |
479 | Type &operandType, |
480 | std::optional<ClauseType> (*symbolizeClause)(StringRef), |
481 | StringRef clauseName) { |
482 | StringRef enumStr; |
483 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: &enumStr))) { |
484 | if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) { |
485 | prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue); |
486 | if (parser.parseComma()) |
487 | return failure(); |
488 | } else { |
489 | return parser.emitError(loc: parser.getCurrentLocation()) |
490 | << "invalid "<< clauseName << " modifier : '"<< enumStr << "'"; |
491 | ; |
492 | } |
493 | } |
494 | |
495 | OpAsmParser::UnresolvedOperand var; |
496 | if (succeeded(Result: parser.parseOperand(result&: var))) { |
497 | operand = var; |
498 | } else { |
499 | return parser.emitError(loc: parser.getCurrentLocation()) |
500 | << "expected "<< clauseName << " operand"; |
501 | } |
502 | |
503 | if (operand.has_value()) { |
504 | if (parser.parseColonType(result&: operandType)) |
505 | return failure(); |
506 | } |
507 | |
508 | return success(); |
509 | } |
510 | |
511 | template <typename ClauseTypeAttr, typename ClauseType> |
512 | static void |
513 | printGranularityClause(OpAsmPrinter &p, Operation *op, |
514 | ClauseTypeAttr prescriptiveness, Value operand, |
515 | mlir::Type operandType, |
516 | StringRef (*stringifyClauseType)(ClauseType)) { |
517 | |
518 | if (prescriptiveness) |
519 | p << stringifyClauseType(prescriptiveness.getValue()) << ", "; |
520 | |
521 | if (operand) |
522 | p << operand << ": "<< operandType; |
523 | } |
524 | |
525 | //===----------------------------------------------------------------------===// |
526 | // Parser and printer for grainsize Clause |
527 | //===----------------------------------------------------------------------===// |
528 | |
529 | // grainsize ::= `grainsize` `(` [strict ':'] grain-size `)` |
530 | static ParseResult |
531 | parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, |
532 | std::optional<OpAsmParser::UnresolvedOperand> &grainsize, |
533 | Type &grainsizeType) { |
534 | return parseGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>( |
535 | parser, grainsizeMod, grainsize, grainsizeType, |
536 | &symbolizeClauseGrainsizeType, "grainsize"); |
537 | } |
538 | |
539 | static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, |
540 | ClauseGrainsizeTypeAttr grainsizeMod, |
541 | Value grainsize, mlir::Type grainsizeType) { |
542 | printGranularityClause<ClauseGrainsizeTypeAttr, ClauseGrainsizeType>( |
543 | p, op, grainsizeMod, grainsize, grainsizeType, |
544 | &stringifyClauseGrainsizeType); |
545 | } |
546 | |
547 | //===----------------------------------------------------------------------===// |
548 | // Parser and printer for num_tasks Clause |
549 | //===----------------------------------------------------------------------===// |
550 | |
551 | // numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)` |
552 | static ParseResult |
553 | parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, |
554 | std::optional<OpAsmParser::UnresolvedOperand> &numTasks, |
555 | Type &numTasksType) { |
556 | return parseGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>( |
557 | parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType, |
558 | "num_tasks"); |
559 | } |
560 | |
561 | static void printNumTasksClause(OpAsmPrinter &p, Operation *op, |
562 | ClauseNumTasksTypeAttr numTasksMod, |
563 | Value numTasks, mlir::Type numTasksType) { |
564 | printGranularityClause<ClauseNumTasksTypeAttr, ClauseNumTasksType>( |
565 | p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType); |
566 | } |
567 | |
568 | //===----------------------------------------------------------------------===// |
569 | // Parsers for operations including clauses that define entry block arguments. |
570 | //===----------------------------------------------------------------------===// |
571 | |
572 | namespace { |
573 | struct MapParseArgs { |
574 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars; |
575 | SmallVectorImpl<Type> &types; |
576 | MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, |
577 | SmallVectorImpl<Type> &types) |
578 | : vars(vars), types(types) {} |
579 | }; |
580 | struct PrivateParseArgs { |
581 | llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars; |
582 | llvm::SmallVectorImpl<Type> &types; |
583 | ArrayAttr &syms; |
584 | UnitAttr &needsBarrier; |
585 | DenseI64ArrayAttr *mapIndices; |
586 | PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, |
587 | SmallVectorImpl<Type> &types, ArrayAttr &syms, |
588 | UnitAttr &needsBarrier, |
589 | DenseI64ArrayAttr *mapIndices = nullptr) |
590 | : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier), |
591 | mapIndices(mapIndices) {} |
592 | }; |
593 | |
594 | struct ReductionParseArgs { |
595 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars; |
596 | SmallVectorImpl<Type> &types; |
597 | DenseBoolArrayAttr &byref; |
598 | ArrayAttr &syms; |
599 | ReductionModifierAttr *modifier; |
600 | ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars, |
601 | SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref, |
602 | ArrayAttr &syms, ReductionModifierAttr *mod = nullptr) |
603 | : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {} |
604 | }; |
605 | |
606 | struct AllRegionParseArgs { |
607 | std::optional<MapParseArgs> hasDeviceAddrArgs; |
608 | std::optional<MapParseArgs> hostEvalArgs; |
609 | std::optional<ReductionParseArgs> inReductionArgs; |
610 | std::optional<MapParseArgs> mapArgs; |
611 | std::optional<PrivateParseArgs> privateArgs; |
612 | std::optional<ReductionParseArgs> reductionArgs; |
613 | std::optional<ReductionParseArgs> taskReductionArgs; |
614 | std::optional<MapParseArgs> useDeviceAddrArgs; |
615 | std::optional<MapParseArgs> useDevicePtrArgs; |
616 | }; |
617 | } // namespace |
618 | |
619 | static inline constexpr StringRef getPrivateNeedsBarrierSpelling() { |
620 | return "private_barrier"; |
621 | } |
622 | |
623 | static ParseResult parseClauseWithRegionArgs( |
624 | OpAsmParser &parser, |
625 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, |
626 | SmallVectorImpl<Type> &types, |
627 | SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs, |
628 | ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr, |
629 | DenseBoolArrayAttr *byref = nullptr, |
630 | ReductionModifierAttr *modifier = nullptr, |
631 | UnitAttr *needsBarrier = nullptr) { |
632 | SmallVector<SymbolRefAttr> symbolVec; |
633 | SmallVector<int64_t> mapIndicesVec; |
634 | SmallVector<bool> isByRefVec; |
635 | unsigned regionArgOffset = regionPrivateArgs.size(); |
636 | |
637 | if (parser.parseLParen()) |
638 | return failure(); |
639 | |
640 | if (modifier && succeeded(Result: parser.parseOptionalKeyword(keyword: "mod"))) { |
641 | StringRef enumStr; |
642 | if (parser.parseColon() || parser.parseKeyword(keyword: &enumStr) || |
643 | parser.parseComma()) |
644 | return failure(); |
645 | std::optional<ReductionModifier> enumValue = |
646 | symbolizeReductionModifier(enumStr); |
647 | if (!enumValue.has_value()) |
648 | return failure(); |
649 | *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue); |
650 | if (!*modifier) |
651 | return failure(); |
652 | } |
653 | |
654 | if (parser.parseCommaSeparatedList(parseElementFn: [&]() { |
655 | if (byref) |
656 | isByRefVec.push_back( |
657 | Elt: parser.parseOptionalKeyword(keyword: "byref").succeeded()); |
658 | |
659 | if (symbols && parser.parseAttribute(symbolVec.emplace_back())) |
660 | return failure(); |
661 | |
662 | if (parser.parseOperand(result&: operands.emplace_back()) || |
663 | parser.parseArrow() || |
664 | parser.parseArgument(result&: regionPrivateArgs.emplace_back())) |
665 | return failure(); |
666 | |
667 | if (mapIndices) { |
668 | if (parser.parseOptionalLSquare().succeeded()) { |
669 | if (parser.parseKeyword(keyword: "map_idx") || parser.parseEqual() || |
670 | parser.parseInteger(result&: mapIndicesVec.emplace_back()) || |
671 | parser.parseRSquare()) |
672 | return failure(); |
673 | } else { |
674 | mapIndicesVec.push_back(Elt: -1); |
675 | } |
676 | } |
677 | |
678 | return success(); |
679 | })) |
680 | return failure(); |
681 | |
682 | if (parser.parseColon()) |
683 | return failure(); |
684 | |
685 | if (parser.parseCommaSeparatedList(parseElementFn: [&]() { |
686 | if (parser.parseType(result&: types.emplace_back())) |
687 | return failure(); |
688 | |
689 | return success(); |
690 | })) |
691 | return failure(); |
692 | |
693 | if (operands.size() != types.size()) |
694 | return failure(); |
695 | |
696 | if (parser.parseRParen()) |
697 | return failure(); |
698 | |
699 | if (needsBarrier) { |
700 | if (parser.parseOptionalKeyword(getPrivateNeedsBarrierSpelling()) |
701 | .succeeded()) |
702 | *needsBarrier = mlir::UnitAttr::get(parser.getContext()); |
703 | } |
704 | |
705 | auto *argsBegin = regionPrivateArgs.begin(); |
706 | MutableArrayRef argsSubrange(argsBegin + regionArgOffset, |
707 | argsBegin + regionArgOffset + types.size()); |
708 | for (auto [prv, type] : llvm::zip_equal(t&: argsSubrange, u&: types)) { |
709 | prv.type = type; |
710 | } |
711 | |
712 | if (symbols) { |
713 | SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end()); |
714 | *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs); |
715 | } |
716 | |
717 | if (!mapIndicesVec.empty()) |
718 | *mapIndices = |
719 | mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec); |
720 | |
721 | if (byref) |
722 | *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec); |
723 | |
724 | return success(); |
725 | } |
726 | |
727 | static ParseResult parseBlockArgClause( |
728 | OpAsmParser &parser, |
729 | llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs, |
730 | StringRef keyword, std::optional<MapParseArgs> mapArgs) { |
731 | if (succeeded(Result: parser.parseOptionalKeyword(keyword))) { |
732 | if (!mapArgs) |
733 | return failure(); |
734 | |
735 | if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types, |
736 | entryBlockArgs))) |
737 | return failure(); |
738 | } |
739 | return success(); |
740 | } |
741 | |
742 | static ParseResult parseBlockArgClause( |
743 | OpAsmParser &parser, |
744 | llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs, |
745 | StringRef keyword, std::optional<PrivateParseArgs> privateArgs) { |
746 | if (succeeded(Result: parser.parseOptionalKeyword(keyword))) { |
747 | if (!privateArgs) |
748 | return failure(); |
749 | |
750 | if (failed(parseClauseWithRegionArgs( |
751 | parser, privateArgs->vars, privateArgs->types, entryBlockArgs, |
752 | &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr, |
753 | /*modifier=*/nullptr, &privateArgs->needsBarrier))) |
754 | return failure(); |
755 | } |
756 | return success(); |
757 | } |
758 | |
759 | static ParseResult parseBlockArgClause( |
760 | OpAsmParser &parser, |
761 | llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs, |
762 | StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) { |
763 | if (succeeded(Result: parser.parseOptionalKeyword(keyword))) { |
764 | if (!reductionArgs) |
765 | return failure(); |
766 | if (failed(parseClauseWithRegionArgs( |
767 | parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs, |
768 | &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref, |
769 | reductionArgs->modifier))) |
770 | return failure(); |
771 | } |
772 | return success(); |
773 | } |
774 | |
775 | static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, |
776 | AllRegionParseArgs args) { |
777 | llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs; |
778 | |
779 | if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "has_device_addr", |
780 | mapArgs: args.hasDeviceAddrArgs))) |
781 | return parser.emitError(loc: parser.getCurrentLocation()) |
782 | << "invalid `has_device_addr` format"; |
783 | |
784 | if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "host_eval", |
785 | mapArgs: args.hostEvalArgs))) |
786 | return parser.emitError(loc: parser.getCurrentLocation()) |
787 | << "invalid `host_eval` format"; |
788 | |
789 | if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction", |
790 | args.inReductionArgs))) |
791 | return parser.emitError(loc: parser.getCurrentLocation()) |
792 | << "invalid `in_reduction` format"; |
793 | |
794 | if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "map_entries", |
795 | mapArgs: args.mapArgs))) |
796 | return parser.emitError(loc: parser.getCurrentLocation()) |
797 | << "invalid `map_entries` format"; |
798 | |
799 | if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "private", |
800 | privateArgs: args.privateArgs))) |
801 | return parser.emitError(loc: parser.getCurrentLocation()) |
802 | << "invalid `private` format"; |
803 | |
804 | if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction", |
805 | args.reductionArgs))) |
806 | return parser.emitError(loc: parser.getCurrentLocation()) |
807 | << "invalid `reduction` format"; |
808 | |
809 | if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction", |
810 | args.taskReductionArgs))) |
811 | return parser.emitError(loc: parser.getCurrentLocation()) |
812 | << "invalid `task_reduction` format"; |
813 | |
814 | if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "use_device_addr", |
815 | mapArgs: args.useDeviceAddrArgs))) |
816 | return parser.emitError(loc: parser.getCurrentLocation()) |
817 | << "invalid `use_device_addr` format"; |
818 | |
819 | if (failed(Result: parseBlockArgClause(parser, entryBlockArgs, keyword: "use_device_ptr", |
820 | mapArgs: args.useDevicePtrArgs))) |
821 | return parser.emitError(loc: parser.getCurrentLocation()) |
822 | << "invalid `use_device_addr` format"; |
823 | |
824 | return parser.parseRegion(region, arguments: entryBlockArgs); |
825 | } |
826 | |
827 | // These parseXyz functions correspond to the custom<Xyz> definitions |
828 | // in the .td file(s). |
829 | static ParseResult parseTargetOpRegion( |
830 | OpAsmParser &parser, Region ®ion, |
831 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hasDeviceAddrVars, |
832 | SmallVectorImpl<Type> &hasDeviceAddrTypes, |
833 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars, |
834 | SmallVectorImpl<Type> &hostEvalTypes, |
835 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars, |
836 | SmallVectorImpl<Type> &inReductionTypes, |
837 | DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, |
838 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars, |
839 | SmallVectorImpl<Type> &mapTypes, |
840 | llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, |
841 | llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, |
842 | UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) { |
843 | AllRegionParseArgs args; |
844 | args.hasDeviceAddrArgs.emplace(args&: hasDeviceAddrVars, args&: hasDeviceAddrTypes); |
845 | args.hostEvalArgs.emplace(args&: hostEvalVars, args&: hostEvalTypes); |
846 | args.inReductionArgs.emplace(inReductionVars, inReductionTypes, |
847 | inReductionByref, inReductionSyms); |
848 | args.mapArgs.emplace(args&: mapVars, args&: mapTypes); |
849 | args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms, |
850 | args&: privateNeedsBarrier, args: &privateMaps); |
851 | return parseBlockArgRegion(parser, region, args); |
852 | } |
853 | |
854 | static ParseResult parseInReductionPrivateRegion( |
855 | OpAsmParser &parser, Region ®ion, |
856 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars, |
857 | SmallVectorImpl<Type> &inReductionTypes, |
858 | DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, |
859 | llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, |
860 | llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, |
861 | UnitAttr &privateNeedsBarrier) { |
862 | AllRegionParseArgs args; |
863 | args.inReductionArgs.emplace(inReductionVars, inReductionTypes, |
864 | inReductionByref, inReductionSyms); |
865 | args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms, |
866 | args&: privateNeedsBarrier); |
867 | return parseBlockArgRegion(parser, region, args); |
868 | } |
869 | |
870 | static ParseResult parseInReductionPrivateReductionRegion( |
871 | OpAsmParser &parser, Region ®ion, |
872 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars, |
873 | SmallVectorImpl<Type> &inReductionTypes, |
874 | DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, |
875 | llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, |
876 | llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, |
877 | UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, |
878 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars, |
879 | SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref, |
880 | ArrayAttr &reductionSyms) { |
881 | AllRegionParseArgs args; |
882 | args.inReductionArgs.emplace(inReductionVars, inReductionTypes, |
883 | inReductionByref, inReductionSyms); |
884 | args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms, |
885 | args&: privateNeedsBarrier); |
886 | args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, |
887 | reductionSyms, &reductionMod); |
888 | return parseBlockArgRegion(parser, region, args); |
889 | } |
890 | |
891 | static ParseResult parsePrivateRegion( |
892 | OpAsmParser &parser, Region ®ion, |
893 | llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, |
894 | llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, |
895 | UnitAttr &privateNeedsBarrier) { |
896 | AllRegionParseArgs args; |
897 | args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms, |
898 | args&: privateNeedsBarrier); |
899 | return parseBlockArgRegion(parser, region, args); |
900 | } |
901 | |
902 | static ParseResult parsePrivateReductionRegion( |
903 | OpAsmParser &parser, Region ®ion, |
904 | llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars, |
905 | llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms, |
906 | UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, |
907 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars, |
908 | SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref, |
909 | ArrayAttr &reductionSyms) { |
910 | AllRegionParseArgs args; |
911 | args.privateArgs.emplace(args&: privateVars, args&: privateTypes, args&: privateSyms, |
912 | args&: privateNeedsBarrier); |
913 | args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, |
914 | reductionSyms, &reductionMod); |
915 | return parseBlockArgRegion(parser, region, args); |
916 | } |
917 | |
918 | static ParseResult parseTaskReductionRegion( |
919 | OpAsmParser &parser, Region ®ion, |
920 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars, |
921 | SmallVectorImpl<Type> &taskReductionTypes, |
922 | DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) { |
923 | AllRegionParseArgs args; |
924 | args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes, |
925 | taskReductionByref, taskReductionSyms); |
926 | return parseBlockArgRegion(parser, region, args); |
927 | } |
928 | |
929 | static ParseResult parseUseDeviceAddrUseDevicePtrRegion( |
930 | OpAsmParser &parser, Region ®ion, |
931 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDeviceAddrVars, |
932 | SmallVectorImpl<Type> &useDeviceAddrTypes, |
933 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &useDevicePtrVars, |
934 | SmallVectorImpl<Type> &useDevicePtrTypes) { |
935 | AllRegionParseArgs args; |
936 | args.useDeviceAddrArgs.emplace(args&: useDeviceAddrVars, args&: useDeviceAddrTypes); |
937 | args.useDevicePtrArgs.emplace(args&: useDevicePtrVars, args&: useDevicePtrTypes); |
938 | return parseBlockArgRegion(parser, region, args); |
939 | } |
940 | |
941 | //===----------------------------------------------------------------------===// |
942 | // Printers for operations including clauses that define entry block arguments. |
943 | //===----------------------------------------------------------------------===// |
944 | |
945 | namespace { |
946 | struct MapPrintArgs { |
947 | ValueRange vars; |
948 | TypeRange types; |
949 | MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {} |
950 | }; |
951 | struct PrivatePrintArgs { |
952 | ValueRange vars; |
953 | TypeRange types; |
954 | ArrayAttr syms; |
955 | UnitAttr needsBarrier; |
956 | DenseI64ArrayAttr mapIndices; |
957 | PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms, |
958 | UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices) |
959 | : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier), |
960 | mapIndices(mapIndices) {} |
961 | }; |
962 | struct ReductionPrintArgs { |
963 | ValueRange vars; |
964 | TypeRange types; |
965 | DenseBoolArrayAttr byref; |
966 | ArrayAttr syms; |
967 | ReductionModifierAttr modifier; |
968 | ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref, |
969 | ArrayAttr syms, ReductionModifierAttr mod = nullptr) |
970 | : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {} |
971 | }; |
972 | struct AllRegionPrintArgs { |
973 | std::optional<MapPrintArgs> hasDeviceAddrArgs; |
974 | std::optional<MapPrintArgs> hostEvalArgs; |
975 | std::optional<ReductionPrintArgs> inReductionArgs; |
976 | std::optional<MapPrintArgs> mapArgs; |
977 | std::optional<PrivatePrintArgs> privateArgs; |
978 | std::optional<ReductionPrintArgs> reductionArgs; |
979 | std::optional<ReductionPrintArgs> taskReductionArgs; |
980 | std::optional<MapPrintArgs> useDeviceAddrArgs; |
981 | std::optional<MapPrintArgs> useDevicePtrArgs; |
982 | }; |
983 | } // namespace |
984 | |
985 | static void printClauseWithRegionArgs( |
986 | OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, |
987 | ValueRange argsSubrange, ValueRange operands, TypeRange types, |
988 | ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr, |
989 | DenseBoolArrayAttr byref = nullptr, |
990 | ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) { |
991 | if (argsSubrange.empty()) |
992 | return; |
993 | |
994 | p << clauseName << "("; |
995 | |
996 | if (modifier) |
997 | p << "mod: "<< stringifyReductionModifier(modifier.getValue()) << ", "; |
998 | |
999 | if (!symbols) { |
1000 | llvm::SmallVector<Attribute> values(operands.size(), nullptr); |
1001 | symbols = ArrayAttr::get(ctx, values); |
1002 | } |
1003 | |
1004 | if (!mapIndices) { |
1005 | llvm::SmallVector<int64_t> values(operands.size(), -1); |
1006 | mapIndices = DenseI64ArrayAttr::get(ctx, values); |
1007 | } |
1008 | |
1009 | if (!byref) { |
1010 | mlir::SmallVector<bool> values(operands.size(), false); |
1011 | byref = DenseBoolArrayAttr::get(ctx, values); |
1012 | } |
1013 | |
1014 | llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols, |
1015 | mapIndices.asArrayRef(), |
1016 | byref.asArrayRef()), |
1017 | p, [&p](auto t) { |
1018 | auto [op, arg, sym, map, isByRef] = t; |
1019 | if (isByRef) |
1020 | p << "byref "; |
1021 | if (sym) |
1022 | p << sym << " "; |
1023 | |
1024 | p << op << " -> "<< arg; |
1025 | |
1026 | if (map != -1) |
1027 | p << " [map_idx="<< map << "]"; |
1028 | }); |
1029 | p << " : "; |
1030 | llvm::interleaveComma(c: types, os&: p); |
1031 | p << ") "; |
1032 | |
1033 | if (needsBarrier) |
1034 | p << getPrivateNeedsBarrierSpelling() << " "; |
1035 | } |
1036 | |
1037 | static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, |
1038 | StringRef clauseName, ValueRange argsSubrange, |
1039 | std::optional<MapPrintArgs> mapArgs) { |
1040 | if (mapArgs) |
1041 | printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, operands: mapArgs->vars, |
1042 | types: mapArgs->types); |
1043 | } |
1044 | |
1045 | static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, |
1046 | StringRef clauseName, ValueRange argsSubrange, |
1047 | std::optional<PrivatePrintArgs> privateArgs) { |
1048 | if (privateArgs) |
1049 | printClauseWithRegionArgs( |
1050 | p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types, |
1051 | privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr, |
1052 | /*modifier=*/nullptr, privateArgs->needsBarrier); |
1053 | } |
1054 | |
1055 | static void |
1056 | printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, |
1057 | ValueRange argsSubrange, |
1058 | std::optional<ReductionPrintArgs> reductionArgs) { |
1059 | if (reductionArgs) |
1060 | printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, |
1061 | reductionArgs->vars, reductionArgs->types, |
1062 | reductionArgs->syms, /*mapIndices=*/nullptr, |
1063 | reductionArgs->byref, reductionArgs->modifier); |
1064 | } |
1065 | |
1066 | static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, |
1067 | const AllRegionPrintArgs &args) { |
1068 | auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op); |
1069 | MLIRContext *ctx = op->getContext(); |
1070 | |
1071 | printBlockArgClause(p, ctx, "has_device_addr", |
1072 | iface.getHasDeviceAddrBlockArgs(), |
1073 | args.hasDeviceAddrArgs); |
1074 | printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(), |
1075 | args.hostEvalArgs); |
1076 | printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(), |
1077 | args.inReductionArgs); |
1078 | printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(), |
1079 | args.mapArgs); |
1080 | printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(), |
1081 | args.privateArgs); |
1082 | printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(), |
1083 | args.reductionArgs); |
1084 | printBlockArgClause(p, ctx, "task_reduction", |
1085 | iface.getTaskReductionBlockArgs(), |
1086 | args.taskReductionArgs); |
1087 | printBlockArgClause(p, ctx, "use_device_addr", |
1088 | iface.getUseDeviceAddrBlockArgs(), |
1089 | args.useDeviceAddrArgs); |
1090 | printBlockArgClause(p, ctx, "use_device_ptr", |
1091 | iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs); |
1092 | |
1093 | p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false); |
1094 | } |
1095 | |
1096 | // These parseXyz functions correspond to the custom<Xyz> definitions |
1097 | // in the .td file(s). |
1098 | static void printTargetOpRegion( |
1099 | OpAsmPrinter &p, Operation *op, Region ®ion, |
1100 | ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, |
1101 | ValueRange hostEvalVars, TypeRange hostEvalTypes, |
1102 | ValueRange inReductionVars, TypeRange inReductionTypes, |
1103 | DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, |
1104 | ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, |
1105 | TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, |
1106 | DenseI64ArrayAttr privateMaps) { |
1107 | AllRegionPrintArgs args; |
1108 | args.hasDeviceAddrArgs.emplace(args&: hasDeviceAddrVars, args&: hasDeviceAddrTypes); |
1109 | args.hostEvalArgs.emplace(args&: hostEvalVars, args&: hostEvalTypes); |
1110 | args.inReductionArgs.emplace(inReductionVars, inReductionTypes, |
1111 | inReductionByref, inReductionSyms); |
1112 | args.mapArgs.emplace(args&: mapVars, args&: mapTypes); |
1113 | args.privateArgs.emplace(privateVars, privateTypes, privateSyms, |
1114 | privateNeedsBarrier, privateMaps); |
1115 | printBlockArgRegion(p, op, region, args); |
1116 | } |
1117 | |
1118 | static void printInReductionPrivateRegion( |
1119 | OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, |
1120 | TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, |
1121 | ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, |
1122 | ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) { |
1123 | AllRegionPrintArgs args; |
1124 | args.inReductionArgs.emplace(inReductionVars, inReductionTypes, |
1125 | inReductionByref, inReductionSyms); |
1126 | args.privateArgs.emplace(privateVars, privateTypes, privateSyms, |
1127 | privateNeedsBarrier, |
1128 | /*mapIndices=*/nullptr); |
1129 | printBlockArgRegion(p, op, region, args); |
1130 | } |
1131 | |
1132 | static void printInReductionPrivateReductionRegion( |
1133 | OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, |
1134 | TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, |
1135 | ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, |
1136 | ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, |
1137 | ReductionModifierAttr reductionMod, ValueRange reductionVars, |
1138 | TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, |
1139 | ArrayAttr reductionSyms) { |
1140 | AllRegionPrintArgs args; |
1141 | args.inReductionArgs.emplace(inReductionVars, inReductionTypes, |
1142 | inReductionByref, inReductionSyms); |
1143 | args.privateArgs.emplace(privateVars, privateTypes, privateSyms, |
1144 | privateNeedsBarrier, |
1145 | /*mapIndices=*/nullptr); |
1146 | args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, |
1147 | reductionSyms, reductionMod); |
1148 | printBlockArgRegion(p, op, region, args); |
1149 | } |
1150 | |
1151 | static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, |
1152 | ValueRange privateVars, TypeRange privateTypes, |
1153 | ArrayAttr privateSyms, |
1154 | UnitAttr privateNeedsBarrier) { |
1155 | AllRegionPrintArgs args; |
1156 | args.privateArgs.emplace(privateVars, privateTypes, privateSyms, |
1157 | privateNeedsBarrier, |
1158 | /*mapIndices=*/nullptr); |
1159 | printBlockArgRegion(p, op, region, args); |
1160 | } |
1161 | |
1162 | static void printPrivateReductionRegion( |
1163 | OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, |
1164 | TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, |
1165 | ReductionModifierAttr reductionMod, ValueRange reductionVars, |
1166 | TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, |
1167 | ArrayAttr reductionSyms) { |
1168 | AllRegionPrintArgs args; |
1169 | args.privateArgs.emplace(privateVars, privateTypes, privateSyms, |
1170 | privateNeedsBarrier, |
1171 | /*mapIndices=*/nullptr); |
1172 | args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref, |
1173 | reductionSyms, reductionMod); |
1174 | printBlockArgRegion(p, op, region, args); |
1175 | } |
1176 | |
1177 | static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, |
1178 | Region ®ion, |
1179 | ValueRange taskReductionVars, |
1180 | TypeRange taskReductionTypes, |
1181 | DenseBoolArrayAttr taskReductionByref, |
1182 | ArrayAttr taskReductionSyms) { |
1183 | AllRegionPrintArgs args; |
1184 | args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes, |
1185 | taskReductionByref, taskReductionSyms); |
1186 | printBlockArgRegion(p, op, region, args); |
1187 | } |
1188 | |
1189 | static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, |
1190 | Region ®ion, |
1191 | ValueRange useDeviceAddrVars, |
1192 | TypeRange useDeviceAddrTypes, |
1193 | ValueRange useDevicePtrVars, |
1194 | TypeRange useDevicePtrTypes) { |
1195 | AllRegionPrintArgs args; |
1196 | args.useDeviceAddrArgs.emplace(args&: useDeviceAddrVars, args&: useDeviceAddrTypes); |
1197 | args.useDevicePtrArgs.emplace(args&: useDevicePtrVars, args&: useDevicePtrTypes); |
1198 | printBlockArgRegion(p, op, region, args); |
1199 | } |
1200 | |
1201 | /// Verifies Reduction Clause |
1202 | static LogicalResult |
1203 | verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms, |
1204 | OperandRange reductionVars, |
1205 | std::optional<ArrayRef<bool>> reductionByref) { |
1206 | if (!reductionVars.empty()) { |
1207 | if (!reductionSyms || reductionSyms->size() != reductionVars.size()) |
1208 | return op->emitOpError() |
1209 | << "expected as many reduction symbol references " |
1210 | "as reduction variables"; |
1211 | if (reductionByref && reductionByref->size() != reductionVars.size()) |
1212 | return op->emitError() << "expected as many reduction variable by " |
1213 | "reference attributes as reduction variables"; |
1214 | } else { |
1215 | if (reductionSyms) |
1216 | return op->emitOpError() << "unexpected reduction symbol references"; |
1217 | return success(); |
1218 | } |
1219 | |
1220 | // TODO: The followings should be done in |
1221 | // SymbolUserOpInterface::verifySymbolUses. |
1222 | DenseSet<Value> accumulators; |
1223 | for (auto args : llvm::zip(t&: reductionVars, u: *reductionSyms)) { |
1224 | Value accum = std::get<0>(args); |
1225 | |
1226 | if (!accumulators.insert(V: accum).second) |
1227 | return op->emitOpError() << "accumulator variable used more than once"; |
1228 | |
1229 | Type varType = accum.getType(); |
1230 | auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); |
1231 | auto decl = |
1232 | SymbolTable::lookupNearestSymbolFrom<DeclareReductionOp>(op, symbolRef); |
1233 | if (!decl) |
1234 | return op->emitOpError() << "expected symbol reference "<< symbolRef |
1235 | << " to point to a reduction declaration"; |
1236 | |
1237 | if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType) |
1238 | return op->emitOpError() |
1239 | << "expected accumulator ("<< varType |
1240 | << ") to be the same type as reduction declaration (" |
1241 | << decl.getAccumulatorType() << ")"; |
1242 | } |
1243 | |
1244 | return success(); |
1245 | } |
1246 | |
1247 | //===----------------------------------------------------------------------===// |
1248 | // Parser, printer and verifier for Copyprivate |
1249 | //===----------------------------------------------------------------------===// |
1250 | |
1251 | /// copyprivate-entry-list ::= copyprivate-entry |
1252 | /// | copyprivate-entry-list `,` copyprivate-entry |
1253 | /// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type |
1254 | static ParseResult parseCopyprivate( |
1255 | OpAsmParser &parser, |
1256 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> ©privateVars, |
1257 | SmallVectorImpl<Type> ©privateTypes, ArrayAttr ©privateSyms) { |
1258 | SmallVector<SymbolRefAttr> symsVec; |
1259 | if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
1260 | if (parser.parseOperand(result&: copyprivateVars.emplace_back()) || |
1261 | parser.parseArrow() || |
1262 | parser.parseAttribute(symsVec.emplace_back()) || |
1263 | parser.parseColonType(result&: copyprivateTypes.emplace_back())) |
1264 | return failure(); |
1265 | return success(); |
1266 | }))) |
1267 | return failure(); |
1268 | SmallVector<Attribute> syms(symsVec.begin(), symsVec.end()); |
1269 | copyprivateSyms = ArrayAttr::get(parser.getContext(), syms); |
1270 | return success(); |
1271 | } |
1272 | |
1273 | /// Print Copyprivate clause |
1274 | static void printCopyprivate(OpAsmPrinter &p, Operation *op, |
1275 | OperandRange copyprivateVars, |
1276 | TypeRange copyprivateTypes, |
1277 | std::optional<ArrayAttr> copyprivateSyms) { |
1278 | if (!copyprivateSyms.has_value()) |
1279 | return; |
1280 | llvm::interleaveComma( |
1281 | llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p, |
1282 | [&](const auto &args) { |
1283 | p << std::get<0>(args) << " -> "<< std::get<1>(args) << " : " |
1284 | << std::get<2>(args); |
1285 | }); |
1286 | } |
1287 | |
1288 | /// Verifies CopyPrivate Clause |
1289 | static LogicalResult |
1290 | verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, |
1291 | std::optional<ArrayAttr> copyprivateSyms) { |
1292 | size_t copyprivateSymsSize = |
1293 | copyprivateSyms.has_value() ? copyprivateSyms->size() : 0; |
1294 | if (copyprivateSymsSize != copyprivateVars.size()) |
1295 | return op->emitOpError() << "inconsistent number of copyprivate vars (= " |
1296 | << copyprivateVars.size() |
1297 | << ") and functions (= "<< copyprivateSymsSize |
1298 | << "), both must be equal"; |
1299 | if (!copyprivateSyms.has_value()) |
1300 | return success(); |
1301 | |
1302 | for (auto copyprivateVarAndSym : |
1303 | llvm::zip(copyprivateVars, *copyprivateSyms)) { |
1304 | auto symbolRef = |
1305 | llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym)); |
1306 | std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>> |
1307 | funcOp; |
1308 | if (mlir::func::FuncOp mlirFuncOp = |
1309 | SymbolTable::lookupNearestSymbolFrom<mlir::func::FuncOp>(op, |
1310 | symbolRef)) |
1311 | funcOp = mlirFuncOp; |
1312 | else if (mlir::LLVM::LLVMFuncOp llvmFuncOp = |
1313 | SymbolTable::lookupNearestSymbolFrom<mlir::LLVM::LLVMFuncOp>( |
1314 | op, symbolRef)) |
1315 | funcOp = llvmFuncOp; |
1316 | |
1317 | auto getNumArguments = [&] { |
1318 | return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp); |
1319 | }; |
1320 | |
1321 | auto getArgumentType = [&](unsigned i) { |
1322 | return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; }, |
1323 | *funcOp); |
1324 | }; |
1325 | |
1326 | if (!funcOp) |
1327 | return op->emitOpError() << "expected symbol reference "<< symbolRef |
1328 | << " to point to a copy function"; |
1329 | |
1330 | if (getNumArguments() != 2) |
1331 | return op->emitOpError() |
1332 | << "expected copy function "<< symbolRef << " to have 2 operands"; |
1333 | |
1334 | Type argTy = getArgumentType(0); |
1335 | if (argTy != getArgumentType(1)) |
1336 | return op->emitOpError() << "expected copy function "<< symbolRef |
1337 | << " arguments to have the same type"; |
1338 | |
1339 | Type varType = std::get<0>(copyprivateVarAndSym).getType(); |
1340 | if (argTy != varType) |
1341 | return op->emitOpError() |
1342 | << "expected copy function arguments' type ("<< argTy |
1343 | << ") to be the same as copyprivate variable's type ("<< varType |
1344 | << ")"; |
1345 | } |
1346 | |
1347 | return success(); |
1348 | } |
1349 | |
1350 | //===----------------------------------------------------------------------===// |
1351 | // Parser, printer and verifier for DependVarList |
1352 | //===----------------------------------------------------------------------===// |
1353 | |
1354 | /// depend-entry-list ::= depend-entry |
1355 | /// | depend-entry-list `,` depend-entry |
1356 | /// depend-entry ::= depend-kind `->` ssa-id `:` type |
1357 | static ParseResult |
1358 | parseDependVarList(OpAsmParser &parser, |
1359 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dependVars, |
1360 | SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) { |
1361 | SmallVector<ClauseTaskDependAttr> kindsVec; |
1362 | if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
1363 | StringRef keyword; |
1364 | if (parser.parseKeyword(keyword: &keyword) || parser.parseArrow() || |
1365 | parser.parseOperand(result&: dependVars.emplace_back()) || |
1366 | parser.parseColonType(result&: dependTypes.emplace_back())) |
1367 | return failure(); |
1368 | if (std::optional<ClauseTaskDepend> keywordDepend = |
1369 | (symbolizeClauseTaskDepend(keyword))) |
1370 | kindsVec.emplace_back( |
1371 | ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend)); |
1372 | else |
1373 | return failure(); |
1374 | return success(); |
1375 | }))) |
1376 | return failure(); |
1377 | SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end()); |
1378 | dependKinds = ArrayAttr::get(parser.getContext(), kinds); |
1379 | return success(); |
1380 | } |
1381 | |
1382 | /// Print Depend clause |
1383 | static void printDependVarList(OpAsmPrinter &p, Operation *op, |
1384 | OperandRange dependVars, TypeRange dependTypes, |
1385 | std::optional<ArrayAttr> dependKinds) { |
1386 | |
1387 | for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) { |
1388 | if (i != 0) |
1389 | p << ", "; |
1390 | p << stringifyClauseTaskDepend( |
1391 | llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i]) |
1392 | .getValue()) |
1393 | << " -> "<< dependVars[i] << " : "<< dependTypes[i]; |
1394 | } |
1395 | } |
1396 | |
1397 | /// Verifies Depend clause |
1398 | static LogicalResult verifyDependVarList(Operation *op, |
1399 | std::optional<ArrayAttr> dependKinds, |
1400 | OperandRange dependVars) { |
1401 | if (!dependVars.empty()) { |
1402 | if (!dependKinds || dependKinds->size() != dependVars.size()) |
1403 | return op->emitOpError() << "expected as many depend values" |
1404 | " as depend variables"; |
1405 | } else { |
1406 | if (dependKinds && !dependKinds->empty()) |
1407 | return op->emitOpError() << "unexpected depend values"; |
1408 | return success(); |
1409 | } |
1410 | |
1411 | return success(); |
1412 | } |
1413 | |
1414 | //===----------------------------------------------------------------------===// |
1415 | // Parser, printer and verifier for Synchronization Hint (2.17.12) |
1416 | //===----------------------------------------------------------------------===// |
1417 | |
1418 | /// Parses a Synchronization Hint clause. The value of hint is an integer |
1419 | /// which is a combination of different hints from `omp_sync_hint_t`. |
1420 | /// |
1421 | /// hint-clause = `hint` `(` hint-value `)` |
1422 | static ParseResult parseSynchronizationHint(OpAsmParser &parser, |
1423 | IntegerAttr &hintAttr) { |
1424 | StringRef hintKeyword; |
1425 | int64_t hint = 0; |
1426 | if (succeeded(Result: parser.parseOptionalKeyword(keyword: "none"))) { |
1427 | hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0); |
1428 | return success(); |
1429 | } |
1430 | auto parseKeyword = [&]() -> ParseResult { |
1431 | if (failed(Result: parser.parseKeyword(keyword: &hintKeyword))) |
1432 | return failure(); |
1433 | if (hintKeyword == "uncontended") |
1434 | hint |= 1; |
1435 | else if (hintKeyword == "contended") |
1436 | hint |= 2; |
1437 | else if (hintKeyword == "nonspeculative") |
1438 | hint |= 4; |
1439 | else if (hintKeyword == "speculative") |
1440 | hint |= 8; |
1441 | else |
1442 | return parser.emitError(loc: parser.getCurrentLocation()) |
1443 | << hintKeyword << " is not a valid hint"; |
1444 | return success(); |
1445 | }; |
1446 | if (parser.parseCommaSeparatedList(parseElementFn: parseKeyword)) |
1447 | return failure(); |
1448 | hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint); |
1449 | return success(); |
1450 | } |
1451 | |
1452 | /// Prints a Synchronization Hint clause |
1453 | static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, |
1454 | IntegerAttr hintAttr) { |
1455 | int64_t hint = hintAttr.getInt(); |
1456 | |
1457 | if (hint == 0) { |
1458 | p << "none"; |
1459 | return; |
1460 | } |
1461 | |
1462 | // Helper function to get n-th bit from the right end of `value` |
1463 | auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; |
1464 | |
1465 | bool uncontended = bitn(hint, 0); |
1466 | bool contended = bitn(hint, 1); |
1467 | bool nonspeculative = bitn(hint, 2); |
1468 | bool speculative = bitn(hint, 3); |
1469 | |
1470 | SmallVector<StringRef> hints; |
1471 | if (uncontended) |
1472 | hints.push_back(Elt: "uncontended"); |
1473 | if (contended) |
1474 | hints.push_back(Elt: "contended"); |
1475 | if (nonspeculative) |
1476 | hints.push_back(Elt: "nonspeculative"); |
1477 | if (speculative) |
1478 | hints.push_back(Elt: "speculative"); |
1479 | |
1480 | llvm::interleaveComma(c: hints, os&: p); |
1481 | } |
1482 | |
1483 | /// Verifies a synchronization hint clause |
1484 | static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) { |
1485 | |
1486 | // Helper function to get n-th bit from the right end of `value` |
1487 | auto bitn = [](int value, int n) -> bool { return value & (1 << n); }; |
1488 | |
1489 | bool uncontended = bitn(hint, 0); |
1490 | bool contended = bitn(hint, 1); |
1491 | bool nonspeculative = bitn(hint, 2); |
1492 | bool speculative = bitn(hint, 3); |
1493 | |
1494 | if (uncontended && contended) |
1495 | return op->emitOpError() << "the hints omp_sync_hint_uncontended and " |
1496 | "omp_sync_hint_contended cannot be combined"; |
1497 | if (nonspeculative && speculative) |
1498 | return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and " |
1499 | "omp_sync_hint_speculative cannot be combined."; |
1500 | return success(); |
1501 | } |
1502 | |
1503 | //===----------------------------------------------------------------------===// |
1504 | // Parser, printer and verifier for Target |
1505 | //===----------------------------------------------------------------------===// |
1506 | |
1507 | // Helper function to get bitwise AND of `value` and 'flag' |
1508 | uint64_t mapTypeToBitFlag(uint64_t value, |
1509 | llvm::omp::OpenMPOffloadMappingFlags flag) { |
1510 | return value & llvm::to_underlying(E: flag); |
1511 | } |
1512 | |
1513 | /// Parses a map_entries map type from a string format back into its numeric |
1514 | /// value. |
1515 | /// |
1516 | /// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `? |
1517 | /// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` ) |
1518 | static ParseResult parseMapClause(OpAsmParser &parser, IntegerAttr &mapType) { |
1519 | llvm::omp::OpenMPOffloadMappingFlags mapTypeBits = |
1520 | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE; |
1521 | |
1522 | // This simply verifies the correct keyword is read in, the |
1523 | // keyword itself is stored inside of the operation |
1524 | auto parseTypeAndMod = [&]() -> ParseResult { |
1525 | StringRef mapTypeMod; |
1526 | if (parser.parseKeyword(keyword: &mapTypeMod)) |
1527 | return failure(); |
1528 | |
1529 | if (mapTypeMod == "always") |
1530 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS; |
1531 | |
1532 | if (mapTypeMod == "implicit") |
1533 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT; |
1534 | |
1535 | if (mapTypeMod == "ompx_hold") |
1536 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD; |
1537 | |
1538 | if (mapTypeMod == "close") |
1539 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE; |
1540 | |
1541 | if (mapTypeMod == "present") |
1542 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT; |
1543 | |
1544 | if (mapTypeMod == "to") |
1545 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO; |
1546 | |
1547 | if (mapTypeMod == "from") |
1548 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
1549 | |
1550 | if (mapTypeMod == "tofrom") |
1551 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO | |
1552 | llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM; |
1553 | |
1554 | if (mapTypeMod == "delete") |
1555 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE; |
1556 | |
1557 | if (mapTypeMod == "return_param") |
1558 | mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM; |
1559 | |
1560 | return success(); |
1561 | }; |
1562 | |
1563 | if (parser.parseCommaSeparatedList(parseElementFn: parseTypeAndMod)) |
1564 | return failure(); |
1565 | |
1566 | mapType = parser.getBuilder().getIntegerAttr( |
1567 | parser.getBuilder().getIntegerType(64, /*isSigned=*/false), |
1568 | llvm::to_underlying(E: mapTypeBits)); |
1569 | |
1570 | return success(); |
1571 | } |
1572 | |
1573 | /// Prints a map_entries map type from its numeric value out into its string |
1574 | /// format. |
1575 | static void printMapClause(OpAsmPrinter &p, Operation *op, |
1576 | IntegerAttr mapType) { |
1577 | uint64_t mapTypeBits = mapType.getUInt(); |
1578 | |
1579 | bool emitAllocRelease = true; |
1580 | llvm::SmallVector<std::string, 4> mapTypeStrs; |
1581 | |
1582 | // handling of always, close, present placed at the beginning of the string |
1583 | // to aid readability |
1584 | if (mapTypeToBitFlag(value: mapTypeBits, |
1585 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS)) |
1586 | mapTypeStrs.push_back(Elt: "always"); |
1587 | if (mapTypeToBitFlag(value: mapTypeBits, |
1588 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT)) |
1589 | mapTypeStrs.push_back(Elt: "implicit"); |
1590 | if (mapTypeToBitFlag(value: mapTypeBits, |
1591 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_OMPX_HOLD)) |
1592 | mapTypeStrs.push_back(Elt: "ompx_hold"); |
1593 | if (mapTypeToBitFlag(value: mapTypeBits, |
1594 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE)) |
1595 | mapTypeStrs.push_back(Elt: "close"); |
1596 | if (mapTypeToBitFlag(value: mapTypeBits, |
1597 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) |
1598 | mapTypeStrs.push_back(Elt: "present"); |
1599 | |
1600 | // special handling of to/from/tofrom/delete and release/alloc, release + |
1601 | // alloc are the abscense of one of the other flags, whereas tofrom requires |
1602 | // both the to and from flag to be set. |
1603 | bool to = mapTypeToBitFlag(value: mapTypeBits, |
1604 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); |
1605 | bool from = mapTypeToBitFlag( |
1606 | value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); |
1607 | if (to && from) { |
1608 | emitAllocRelease = false; |
1609 | mapTypeStrs.push_back(Elt: "tofrom"); |
1610 | } else if (from) { |
1611 | emitAllocRelease = false; |
1612 | mapTypeStrs.push_back(Elt: "from"); |
1613 | } else if (to) { |
1614 | emitAllocRelease = false; |
1615 | mapTypeStrs.push_back(Elt: "to"); |
1616 | } |
1617 | if (mapTypeToBitFlag(value: mapTypeBits, |
1618 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE)) { |
1619 | emitAllocRelease = false; |
1620 | mapTypeStrs.push_back(Elt: "delete"); |
1621 | } |
1622 | if (mapTypeToBitFlag( |
1623 | value: mapTypeBits, |
1624 | flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM)) { |
1625 | emitAllocRelease = false; |
1626 | mapTypeStrs.push_back(Elt: "return_param"); |
1627 | } |
1628 | if (emitAllocRelease) |
1629 | mapTypeStrs.push_back(Elt: "exit_release_or_enter_alloc"); |
1630 | |
1631 | for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) { |
1632 | p << mapTypeStrs[i]; |
1633 | if (i + 1 < mapTypeStrs.size()) { |
1634 | p << ", "; |
1635 | } |
1636 | } |
1637 | } |
1638 | |
1639 | static ParseResult parseMembersIndex(OpAsmParser &parser, |
1640 | ArrayAttr &membersIdx) { |
1641 | SmallVector<Attribute> values, memberIdxs; |
1642 | |
1643 | auto parseIndices = [&]() -> ParseResult { |
1644 | int64_t value; |
1645 | if (parser.parseInteger(result&: value)) |
1646 | return failure(); |
1647 | values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64), |
1648 | APInt(64, value, /*isSigned=*/false))); |
1649 | return success(); |
1650 | }; |
1651 | |
1652 | do { |
1653 | if (failed(Result: parser.parseLSquare())) |
1654 | return failure(); |
1655 | |
1656 | if (parser.parseCommaSeparatedList(parseElementFn: parseIndices)) |
1657 | return failure(); |
1658 | |
1659 | if (failed(Result: parser.parseRSquare())) |
1660 | return failure(); |
1661 | |
1662 | memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values)); |
1663 | values.clear(); |
1664 | } while (succeeded(Result: parser.parseOptionalComma())); |
1665 | |
1666 | if (!memberIdxs.empty()) |
1667 | membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs); |
1668 | |
1669 | return success(); |
1670 | } |
1671 | |
1672 | static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, |
1673 | ArrayAttr membersIdx) { |
1674 | if (!membersIdx) |
1675 | return; |
1676 | |
1677 | llvm::interleaveComma(membersIdx, p, [&p](Attribute v) { |
1678 | p << "["; |
1679 | auto memberIdx = cast<ArrayAttr>(v); |
1680 | llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) { |
1681 | p << cast<IntegerAttr>(v2).getInt(); |
1682 | }); |
1683 | p << "]"; |
1684 | }); |
1685 | } |
1686 | |
1687 | static void printCaptureType(OpAsmPrinter &p, Operation *op, |
1688 | VariableCaptureKindAttr mapCaptureType) { |
1689 | std::string typeCapStr; |
1690 | llvm::raw_string_ostream typeCap(typeCapStr); |
1691 | if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef) |
1692 | typeCap << "ByRef"; |
1693 | if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy) |
1694 | typeCap << "ByCopy"; |
1695 | if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType) |
1696 | typeCap << "VLAType"; |
1697 | if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This) |
1698 | typeCap << "This"; |
1699 | p << typeCapStr; |
1700 | } |
1701 | |
1702 | static ParseResult parseCaptureType(OpAsmParser &parser, |
1703 | VariableCaptureKindAttr &mapCaptureType) { |
1704 | StringRef mapCaptureKey; |
1705 | if (parser.parseKeyword(keyword: &mapCaptureKey)) |
1706 | return failure(); |
1707 | |
1708 | if (mapCaptureKey == "This") |
1709 | mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( |
1710 | parser.getContext(), mlir::omp::VariableCaptureKind::This); |
1711 | if (mapCaptureKey == "ByRef") |
1712 | mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( |
1713 | parser.getContext(), mlir::omp::VariableCaptureKind::ByRef); |
1714 | if (mapCaptureKey == "ByCopy") |
1715 | mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( |
1716 | parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy); |
1717 | if (mapCaptureKey == "VLAType") |
1718 | mapCaptureType = mlir::omp::VariableCaptureKindAttr::get( |
1719 | parser.getContext(), mlir::omp::VariableCaptureKind::VLAType); |
1720 | |
1721 | return success(); |
1722 | } |
1723 | |
1724 | static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { |
1725 | llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateToVars; |
1726 | llvm::DenseSet<mlir::TypedValue<mlir::omp::PointerLikeType>> updateFromVars; |
1727 | |
1728 | for (auto mapOp : mapVars) { |
1729 | if (!mapOp.getDefiningOp()) |
1730 | return emitError(loc: op->getLoc(), message: "missing map operation"); |
1731 | |
1732 | if (auto mapInfoOp = |
1733 | mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) { |
1734 | uint64_t mapTypeBits = mapInfoOp.getMapType(); |
1735 | |
1736 | bool to = mapTypeToBitFlag( |
1737 | value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO); |
1738 | bool from = mapTypeToBitFlag( |
1739 | value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM); |
1740 | bool del = mapTypeToBitFlag( |
1741 | value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE); |
1742 | |
1743 | bool always = mapTypeToBitFlag( |
1744 | value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS); |
1745 | bool close = mapTypeToBitFlag( |
1746 | value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_CLOSE); |
1747 | bool implicit = mapTypeToBitFlag( |
1748 | value: mapTypeBits, flag: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT); |
1749 | |
1750 | if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del) |
1751 | return emitError(loc: op->getLoc(), |
1752 | message: "to, from, tofrom and alloc map types are permitted"); |
1753 | |
1754 | if (isa<TargetEnterDataOp>(op) && (from || del)) |
1755 | return emitError(loc: op->getLoc(), message: "to and alloc map types are permitted"); |
1756 | |
1757 | if (isa<TargetExitDataOp>(op) && to) |
1758 | return emitError(loc: op->getLoc(), |
1759 | message: "from, release and delete map types are permitted"); |
1760 | |
1761 | if (isa<TargetUpdateOp>(op)) { |
1762 | if (del) { |
1763 | return emitError(loc: op->getLoc(), |
1764 | message: "at least one of to or from map types must be " |
1765 | "specified, other map types are not permitted"); |
1766 | } |
1767 | |
1768 | if (!to && !from) { |
1769 | return emitError(loc: op->getLoc(), |
1770 | message: "at least one of to or from map types must be " |
1771 | "specified, other map types are not permitted"); |
1772 | } |
1773 | |
1774 | auto updateVar = mapInfoOp.getVarPtr(); |
1775 | |
1776 | if ((to && from) || (to && updateFromVars.contains(updateVar)) || |
1777 | (from && updateToVars.contains(updateVar))) { |
1778 | return emitError( |
1779 | loc: op->getLoc(), |
1780 | message: "either to or from map types can be specified, not both"); |
1781 | } |
1782 | |
1783 | if (always || close || implicit) { |
1784 | return emitError( |
1785 | loc: op->getLoc(), |
1786 | message: "present, mapper and iterator map type modifiers are permitted"); |
1787 | } |
1788 | |
1789 | to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar); |
1790 | } |
1791 | } else if (!isa<DeclareMapperInfoOp>(op)) { |
1792 | return emitError(loc: op->getLoc(), |
1793 | message: "map argument is not a map entry operation"); |
1794 | } |
1795 | } |
1796 | |
1797 | return success(); |
1798 | } |
1799 | |
1800 | static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) { |
1801 | std::optional<DenseI64ArrayAttr> privateMapIndices = |
1802 | targetOp.getPrivateMapsAttr(); |
1803 | |
1804 | // None of the private operands are mapped. |
1805 | if (!privateMapIndices.has_value() || !privateMapIndices.value()) |
1806 | return success(); |
1807 | |
1808 | OperandRange privateVars = targetOp.getPrivateVars(); |
1809 | |
1810 | if (privateMapIndices.value().size() != |
1811 | static_cast<int64_t>(privateVars.size())) |
1812 | return emitError(targetOp.getLoc(), "sizes of `private` operand range and " |
1813 | "`private_maps` attribute mismatch"); |
1814 | |
1815 | return success(); |
1816 | } |
1817 | |
1818 | //===----------------------------------------------------------------------===// |
1819 | // MapInfoOp |
1820 | //===----------------------------------------------------------------------===// |
1821 | |
1822 | static LogicalResult verifyMapInfoDefinedArgs(Operation *op, |
1823 | StringRef clauseName, |
1824 | OperandRange vars) { |
1825 | for (Value var : vars) |
1826 | if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp())) |
1827 | return op->emitOpError() |
1828 | << "'"<< clauseName |
1829 | << "' arguments must be defined by 'omp.map.info' ops"; |
1830 | return success(); |
1831 | } |
1832 | |
1833 | LogicalResult MapInfoOp::verify() { |
1834 | if (getMapperId() && |
1835 | !SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>( |
1836 | *this, getMapperIdAttr())) { |
1837 | return emitError("invalid mapper id"); |
1838 | } |
1839 | |
1840 | if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers()))) |
1841 | return failure(); |
1842 | |
1843 | return success(); |
1844 | } |
1845 | |
1846 | //===----------------------------------------------------------------------===// |
1847 | // TargetDataOp |
1848 | //===----------------------------------------------------------------------===// |
1849 | |
1850 | void TargetDataOp::build(OpBuilder &builder, OperationState &state, |
1851 | const TargetDataOperands &clauses) { |
1852 | TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr, |
1853 | clauses.mapVars, clauses.useDeviceAddrVars, |
1854 | clauses.useDevicePtrVars); |
1855 | } |
1856 | |
1857 | LogicalResult TargetDataOp::verify() { |
1858 | if (getMapVars().empty() && getUseDevicePtrVars().empty() && |
1859 | getUseDeviceAddrVars().empty()) { |
1860 | return ::emitError(this->getLoc(), |
1861 | "At least one of map, use_device_ptr_vars, or " |
1862 | "use_device_addr_vars operand must be present"); |
1863 | } |
1864 | |
1865 | if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr", |
1866 | getUseDevicePtrVars()))) |
1867 | return failure(); |
1868 | |
1869 | if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr", |
1870 | getUseDeviceAddrVars()))) |
1871 | return failure(); |
1872 | |
1873 | return verifyMapClause(*this, getMapVars()); |
1874 | } |
1875 | |
1876 | //===----------------------------------------------------------------------===// |
1877 | // TargetEnterDataOp |
1878 | //===----------------------------------------------------------------------===// |
1879 | |
1880 | void TargetEnterDataOp::build( |
1881 | OpBuilder &builder, OperationState &state, |
1882 | const TargetEnterExitUpdateDataOperands &clauses) { |
1883 | MLIRContext *ctx = builder.getContext(); |
1884 | TargetEnterDataOp::build(builder, state, |
1885 | makeArrayAttr(ctx, clauses.dependKinds), |
1886 | clauses.dependVars, clauses.device, clauses.ifExpr, |
1887 | clauses.mapVars, clauses.nowait); |
1888 | } |
1889 | |
1890 | LogicalResult TargetEnterDataOp::verify() { |
1891 | LogicalResult verifyDependVars = |
1892 | verifyDependVarList(*this, getDependKinds(), getDependVars()); |
1893 | return failed(verifyDependVars) ? verifyDependVars |
1894 | : verifyMapClause(*this, getMapVars()); |
1895 | } |
1896 | |
1897 | //===----------------------------------------------------------------------===// |
1898 | // TargetExitDataOp |
1899 | //===----------------------------------------------------------------------===// |
1900 | |
1901 | void TargetExitDataOp::build(OpBuilder &builder, OperationState &state, |
1902 | const TargetEnterExitUpdateDataOperands &clauses) { |
1903 | MLIRContext *ctx = builder.getContext(); |
1904 | TargetExitDataOp::build(builder, state, |
1905 | makeArrayAttr(ctx, clauses.dependKinds), |
1906 | clauses.dependVars, clauses.device, clauses.ifExpr, |
1907 | clauses.mapVars, clauses.nowait); |
1908 | } |
1909 | |
1910 | LogicalResult TargetExitDataOp::verify() { |
1911 | LogicalResult verifyDependVars = |
1912 | verifyDependVarList(*this, getDependKinds(), getDependVars()); |
1913 | return failed(verifyDependVars) ? verifyDependVars |
1914 | : verifyMapClause(*this, getMapVars()); |
1915 | } |
1916 | |
1917 | //===----------------------------------------------------------------------===// |
1918 | // TargetUpdateOp |
1919 | //===----------------------------------------------------------------------===// |
1920 | |
1921 | void TargetUpdateOp::build(OpBuilder &builder, OperationState &state, |
1922 | const TargetEnterExitUpdateDataOperands &clauses) { |
1923 | MLIRContext *ctx = builder.getContext(); |
1924 | TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds), |
1925 | clauses.dependVars, clauses.device, clauses.ifExpr, |
1926 | clauses.mapVars, clauses.nowait); |
1927 | } |
1928 | |
1929 | LogicalResult TargetUpdateOp::verify() { |
1930 | LogicalResult verifyDependVars = |
1931 | verifyDependVarList(*this, getDependKinds(), getDependVars()); |
1932 | return failed(verifyDependVars) ? verifyDependVars |
1933 | : verifyMapClause(*this, getMapVars()); |
1934 | } |
1935 | |
1936 | //===----------------------------------------------------------------------===// |
1937 | // TargetOp |
1938 | //===----------------------------------------------------------------------===// |
1939 | |
1940 | void TargetOp::build(OpBuilder &builder, OperationState &state, |
1941 | const TargetOperands &clauses) { |
1942 | MLIRContext *ctx = builder.getContext(); |
1943 | // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars, |
1944 | // inReductionByref, inReductionSyms. |
1945 | TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, |
1946 | clauses.bare, makeArrayAttr(ctx, clauses.dependKinds), |
1947 | clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars, |
1948 | clauses.hostEvalVars, clauses.ifExpr, |
1949 | /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, |
1950 | /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, |
1951 | clauses.mapVars, clauses.nowait, clauses.privateVars, |
1952 | makeArrayAttr(ctx, clauses.privateSyms), |
1953 | clauses.privateNeedsBarrier, clauses.threadLimit, |
1954 | /*private_maps=*/nullptr); |
1955 | } |
1956 | |
1957 | LogicalResult TargetOp::verify() { |
1958 | if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars()))) |
1959 | return failure(); |
1960 | |
1961 | if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr", |
1962 | getHasDeviceAddrVars()))) |
1963 | return failure(); |
1964 | |
1965 | if (failed(verifyMapClause(*this, getMapVars()))) |
1966 | return failure(); |
1967 | |
1968 | return verifyPrivateVarsMapping(*this); |
1969 | } |
1970 | |
1971 | LogicalResult TargetOp::verifyRegions() { |
1972 | auto teamsOps = getOps<TeamsOp>(); |
1973 | if (std::distance(teamsOps.begin(), teamsOps.end()) > 1) |
1974 | return emitError("target containing multiple 'omp.teams' nested ops"); |
1975 | |
1976 | // Check that host_eval values are only used in legal ways. |
1977 | Operation *capturedOp = getInnermostCapturedOmpOp(); |
1978 | TargetRegionFlags execFlags = getKernelExecFlags(capturedOp); |
1979 | for (Value hostEvalArg : |
1980 | cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) { |
1981 | for (Operation *user : hostEvalArg.getUsers()) { |
1982 | if (auto teamsOp = dyn_cast<TeamsOp>(user)) { |
1983 | if (llvm::is_contained({teamsOp.getNumTeamsLower(), |
1984 | teamsOp.getNumTeamsUpper(), |
1985 | teamsOp.getThreadLimit()}, |
1986 | hostEvalArg)) |
1987 | continue; |
1988 | |
1989 | return emitOpError() << "host_eval argument only legal as 'num_teams' " |
1990 | "and 'thread_limit' in 'omp.teams'"; |
1991 | } |
1992 | if (auto parallelOp = dyn_cast<ParallelOp>(user)) { |
1993 | if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) && |
1994 | parallelOp->isAncestor(capturedOp) && |
1995 | hostEvalArg == parallelOp.getNumThreads()) |
1996 | continue; |
1997 | |
1998 | return emitOpError() |
1999 | << "host_eval argument only legal as 'num_threads' in " |
2000 | "'omp.parallel' when representing target SPMD"; |
2001 | } |
2002 | if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) { |
2003 | if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) && |
2004 | loopNestOp.getOperation() == capturedOp && |
2005 | (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) || |
2006 | llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) || |
2007 | llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg))) |
2008 | continue; |
2009 | |
2010 | return emitOpError() << "host_eval argument only legal as loop bounds " |
2011 | "and steps in 'omp.loop_nest' when trip count " |
2012 | "must be evaluated in the host"; |
2013 | } |
2014 | |
2015 | return emitOpError() << "host_eval argument illegal use in '" |
2016 | << user->getName() << "' operation"; |
2017 | } |
2018 | } |
2019 | return success(); |
2020 | } |
2021 | |
2022 | static Operation * |
2023 | findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, |
2024 | llvm::function_ref<bool(Operation *)> siblingAllowedFn) { |
2025 | assert(rootOp && "expected valid operation"); |
2026 | |
2027 | Dialect *ompDialect = rootOp->getDialect(); |
2028 | Operation *capturedOp = nullptr; |
2029 | DominanceInfo domInfo; |
2030 | |
2031 | // Process in pre-order to check operations from outermost to innermost, |
2032 | // ensuring we only enter the region of an operation if it meets the criteria |
2033 | // for being captured. We stop the exploration of nested operations as soon as |
2034 | // we process a region holding no operations to be captured. |
2035 | rootOp->walk<WalkOrder::PreOrder>(callback: [&](Operation *op) { |
2036 | if (op == rootOp) |
2037 | return WalkResult::advance(); |
2038 | |
2039 | // Ignore operations of other dialects or omp operations with no regions, |
2040 | // because these will only be checked if they are siblings of an omp |
2041 | // operation that can potentially be captured. |
2042 | bool isOmpDialect = op->getDialect() == ompDialect; |
2043 | bool hasRegions = op->getNumRegions() > 0; |
2044 | if (!isOmpDialect || !hasRegions) |
2045 | return WalkResult::skip(); |
2046 | |
2047 | // This operation cannot be captured if it can be executed more than once |
2048 | // (i.e. its block's successors can reach it) or if it's not guaranteed to |
2049 | // be executed before all exits of the region (i.e. it doesn't dominate all |
2050 | // blocks with no successors reachable from the entry block). |
2051 | if (checkSingleMandatoryExec) { |
2052 | Region *parentRegion = op->getParentRegion(); |
2053 | Block *parentBlock = op->getBlock(); |
2054 | |
2055 | for (Block *successor : parentBlock->getSuccessors()) |
2056 | if (successor->isReachable(other: parentBlock)) |
2057 | return WalkResult::interrupt(); |
2058 | |
2059 | for (Block &block : *parentRegion) |
2060 | if (domInfo.isReachableFromEntry(a: &block) && block.hasNoSuccessors() && |
2061 | !domInfo.dominates(a: parentBlock, b: &block)) |
2062 | return WalkResult::interrupt(); |
2063 | } |
2064 | |
2065 | // Don't capture this op if it has a not-allowed sibling, and stop recursing |
2066 | // into nested operations. |
2067 | for (Operation &sibling : op->getParentRegion()->getOps()) |
2068 | if (&sibling != op && !siblingAllowedFn(&sibling)) |
2069 | return WalkResult::interrupt(); |
2070 | |
2071 | // Don't continue capturing nested operations if we reach an omp.loop_nest. |
2072 | // Otherwise, process the contents of this operation. |
2073 | capturedOp = op; |
2074 | return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt() |
2075 | : WalkResult::advance(); |
2076 | }); |
2077 | |
2078 | return capturedOp; |
2079 | } |
2080 | |
2081 | Operation *TargetOp::getInnermostCapturedOmpOp() { |
2082 | auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>(); |
2083 | |
2084 | // Only allow OpenMP terminators and non-OpenMP ops that have known memory |
2085 | // effects, but don't include a memory write effect. |
2086 | return findCapturedOmpOp( |
2087 | *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) { |
2088 | if (!sibling) |
2089 | return false; |
2090 | |
2091 | if (ompDialect == sibling->getDialect()) |
2092 | return sibling->hasTrait<OpTrait::IsTerminator>(); |
2093 | |
2094 | if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) { |
2095 | SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> |
2096 | effects; |
2097 | memOp.getEffects(effects); |
2098 | return !llvm::any_of( |
2099 | effects, [&](MemoryEffects::EffectInstance &effect) { |
2100 | return isa<MemoryEffects::Write>(effect.getEffect()) && |
2101 | isa<SideEffects::AutomaticAllocationScopeResource>( |
2102 | effect.getResource()); |
2103 | }); |
2104 | } |
2105 | return true; |
2106 | }); |
2107 | } |
2108 | |
2109 | TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) { |
2110 | // A non-null captured op is only valid if it resides inside of a TargetOp |
2111 | // and is the result of calling getInnermostCapturedOmpOp() on it. |
2112 | TargetOp targetOp = |
2113 | capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr; |
2114 | assert((!capturedOp || |
2115 | (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) && |
2116 | "unexpected captured op"); |
2117 | |
2118 | // If it's not capturing a loop, it's a default target region. |
2119 | if (!isa_and_present<LoopNestOp>(capturedOp)) |
2120 | return TargetRegionFlags::generic; |
2121 | |
2122 | // Get the innermost non-simd loop wrapper. |
2123 | SmallVector<LoopWrapperInterface> loopWrappers; |
2124 | cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers); |
2125 | assert(!loopWrappers.empty()); |
2126 | |
2127 | LoopWrapperInterface *innermostWrapper = loopWrappers.begin(); |
2128 | if (isa<SimdOp>(innermostWrapper)) |
2129 | innermostWrapper = std::next(innermostWrapper); |
2130 | |
2131 | auto numWrappers = std::distance(innermostWrapper, loopWrappers.end()); |
2132 | if (numWrappers != 1 && numWrappers != 2) |
2133 | return TargetRegionFlags::generic; |
2134 | |
2135 | // Detect target-teams-distribute-parallel-wsloop[-simd]. |
2136 | if (numWrappers == 2) { |
2137 | if (!isa<WsloopOp>(innermostWrapper)) |
2138 | return TargetRegionFlags::generic; |
2139 | |
2140 | innermostWrapper = std::next(innermostWrapper); |
2141 | if (!isa<DistributeOp>(innermostWrapper)) |
2142 | return TargetRegionFlags::generic; |
2143 | |
2144 | Operation *parallelOp = (*innermostWrapper)->getParentOp(); |
2145 | if (!isa_and_present<ParallelOp>(parallelOp)) |
2146 | return TargetRegionFlags::generic; |
2147 | |
2148 | Operation *teamsOp = parallelOp->getParentOp(); |
2149 | if (!isa_and_present<TeamsOp>(teamsOp)) |
2150 | return TargetRegionFlags::generic; |
2151 | |
2152 | if (teamsOp->getParentOp() == targetOp.getOperation()) |
2153 | return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; |
2154 | } |
2155 | // Detect target-teams-distribute[-simd] and target-teams-loop. |
2156 | else if (isa<DistributeOp, LoopOp>(innermostWrapper)) { |
2157 | Operation *teamsOp = (*innermostWrapper)->getParentOp(); |
2158 | if (!isa_and_present<TeamsOp>(teamsOp)) |
2159 | return TargetRegionFlags::generic; |
2160 | |
2161 | if (teamsOp->getParentOp() != targetOp.getOperation()) |
2162 | return TargetRegionFlags::generic; |
2163 | |
2164 | if (isa<LoopOp>(innermostWrapper)) |
2165 | return TargetRegionFlags::spmd | TargetRegionFlags::trip_count; |
2166 | |
2167 | // Find single immediately nested captured omp.parallel and add spmd flag |
2168 | // (generic-spmd case). |
2169 | // |
2170 | // TODO: This shouldn't have to be done here, as it is too easy to break. |
2171 | // The openmp-opt pass should be updated to be able to promote kernels like |
2172 | // this from "Generic" to "Generic-SPMD". However, the use of the |
2173 | // `kmpc_distribute_static_loop` family of functions produced by the |
2174 | // OMPIRBuilder for these kernels prevents that from working. |
2175 | Dialect *ompDialect = targetOp->getDialect(); |
2176 | Operation *nestedCapture = findCapturedOmpOp( |
2177 | capturedOp, /*checkSingleMandatoryExec=*/false, |
2178 | [&](Operation *sibling) { |
2179 | return sibling && (ompDialect != sibling->getDialect() || |
2180 | sibling->hasTrait<OpTrait::IsTerminator>()); |
2181 | }); |
2182 | |
2183 | TargetRegionFlags result = |
2184 | TargetRegionFlags::generic | TargetRegionFlags::trip_count; |
2185 | |
2186 | if (!nestedCapture) |
2187 | return result; |
2188 | |
2189 | while (nestedCapture->getParentOp() != capturedOp) |
2190 | nestedCapture = nestedCapture->getParentOp(); |
2191 | |
2192 | return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd |
2193 | : result; |
2194 | } |
2195 | // Detect target-parallel-wsloop[-simd]. |
2196 | else if (isa<WsloopOp>(innermostWrapper)) { |
2197 | Operation *parallelOp = (*innermostWrapper)->getParentOp(); |
2198 | if (!isa_and_present<ParallelOp>(parallelOp)) |
2199 | return TargetRegionFlags::generic; |
2200 | |
2201 | if (parallelOp->getParentOp() == targetOp.getOperation()) |
2202 | return TargetRegionFlags::spmd; |
2203 | } |
2204 | |
2205 | return TargetRegionFlags::generic; |
2206 | } |
2207 | |
2208 | //===----------------------------------------------------------------------===// |
2209 | // ParallelOp |
2210 | //===----------------------------------------------------------------------===// |
2211 | |
2212 | void ParallelOp::build(OpBuilder &builder, OperationState &state, |
2213 | ArrayRef<NamedAttribute> attributes) { |
2214 | ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(), |
2215 | /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr, |
2216 | /*num_threads=*/nullptr, /*private_vars=*/ValueRange(), |
2217 | /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, |
2218 | /*proc_bind_kind=*/nullptr, |
2219 | /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(), |
2220 | /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr); |
2221 | state.addAttributes(attributes); |
2222 | } |
2223 | |
2224 | void ParallelOp::build(OpBuilder &builder, OperationState &state, |
2225 | const ParallelOperands &clauses) { |
2226 | MLIRContext *ctx = builder.getContext(); |
2227 | ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, |
2228 | clauses.ifExpr, clauses.numThreads, clauses.privateVars, |
2229 | makeArrayAttr(ctx, clauses.privateSyms), |
2230 | clauses.privateNeedsBarrier, clauses.procBindKind, |
2231 | clauses.reductionMod, clauses.reductionVars, |
2232 | makeDenseBoolArrayAttr(ctx, clauses.reductionByref), |
2233 | makeArrayAttr(ctx, clauses.reductionSyms)); |
2234 | } |
2235 | |
2236 | template <typename OpType> |
2237 | static LogicalResult verifyPrivateVarList(OpType &op) { |
2238 | auto privateVars = op.getPrivateVars(); |
2239 | auto privateSyms = op.getPrivateSymsAttr(); |
2240 | |
2241 | if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty())) |
2242 | return success(); |
2243 | |
2244 | auto numPrivateVars = privateVars.size(); |
2245 | auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size(); |
2246 | |
2247 | if (numPrivateVars != numPrivateSyms) |
2248 | return op.emitError() << "inconsistent number of private variables and " |
2249 | "privatizer op symbols, private vars: " |
2250 | << numPrivateVars |
2251 | << " vs. privatizer op symbols: "<< numPrivateSyms; |
2252 | |
2253 | for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) { |
2254 | Type varType = std::get<0>(privateVarInfo).getType(); |
2255 | SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo)); |
2256 | PrivateClauseOp privatizerOp = |
2257 | SymbolTable::lookupNearestSymbolFrom<PrivateClauseOp>(op, privateSym); |
2258 | |
2259 | if (privatizerOp == nullptr) |
2260 | return op.emitError() << "failed to lookup privatizer op with symbol: '" |
2261 | << privateSym << "'"; |
2262 | |
2263 | Type privatizerType = privatizerOp.getArgType(); |
2264 | |
2265 | if (privatizerType && (varType != privatizerType)) |
2266 | return op.emitError() |
2267 | << "type mismatch between a " |
2268 | << (privatizerOp.getDataSharingType() == |
2269 | DataSharingClauseType::Private |
2270 | ? "private" |
2271 | : "firstprivate") |
2272 | << " variable and its privatizer op, var type: "<< varType |
2273 | << " vs. privatizer op type: "<< privatizerType; |
2274 | } |
2275 | |
2276 | return success(); |
2277 | } |
2278 | |
2279 | LogicalResult ParallelOp::verify() { |
2280 | if (getAllocateVars().size() != getAllocatorVars().size()) |
2281 | return emitError( |
2282 | "expected equal sizes for allocate and allocator variables"); |
2283 | |
2284 | if (failed(verifyPrivateVarList(*this))) |
2285 | return failure(); |
2286 | |
2287 | return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), |
2288 | getReductionByref()); |
2289 | } |
2290 | |
2291 | LogicalResult ParallelOp::verifyRegions() { |
2292 | auto distChildOps = getOps<DistributeOp>(); |
2293 | int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end()); |
2294 | if (numDistChildOps > 1) |
2295 | return emitError() |
2296 | << "multiple 'omp.distribute' nested inside of 'omp.parallel'"; |
2297 | |
2298 | if (numDistChildOps == 1) { |
2299 | if (!isComposite()) |
2300 | return emitError() |
2301 | << "'omp.composite' attribute missing from composite operation"; |
2302 | |
2303 | auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>(); |
2304 | Operation &distributeOp = **distChildOps.begin(); |
2305 | for (Operation &childOp : getOps()) { |
2306 | if (&childOp == &distributeOp || ompDialect != childOp.getDialect()) |
2307 | continue; |
2308 | |
2309 | if (!childOp.hasTrait<OpTrait::IsTerminator>()) |
2310 | return emitError() << "unexpected OpenMP operation inside of composite " |
2311 | "'omp.parallel': " |
2312 | << childOp.getName(); |
2313 | } |
2314 | } else if (isComposite()) { |
2315 | return emitError() |
2316 | << "'omp.composite' attribute present in non-composite operation"; |
2317 | } |
2318 | return success(); |
2319 | } |
2320 | |
2321 | //===----------------------------------------------------------------------===// |
2322 | // TeamsOp |
2323 | //===----------------------------------------------------------------------===// |
2324 | |
2325 | static bool opInGlobalImplicitParallelRegion(Operation *op) { |
2326 | while ((op = op->getParentOp())) |
2327 | if (isa<OpenMPDialect>(op->getDialect())) |
2328 | return false; |
2329 | return true; |
2330 | } |
2331 | |
2332 | void TeamsOp::build(OpBuilder &builder, OperationState &state, |
2333 | const TeamsOperands &clauses) { |
2334 | MLIRContext *ctx = builder.getContext(); |
2335 | // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier |
2336 | TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, |
2337 | clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper, |
2338 | /*private_vars=*/{}, /*private_syms=*/nullptr, |
2339 | /*private_needs_barrier=*/nullptr, clauses.reductionMod, |
2340 | clauses.reductionVars, |
2341 | makeDenseBoolArrayAttr(ctx, clauses.reductionByref), |
2342 | makeArrayAttr(ctx, clauses.reductionSyms), |
2343 | clauses.threadLimit); |
2344 | } |
2345 | |
2346 | LogicalResult TeamsOp::verify() { |
2347 | // Check parent region |
2348 | // TODO If nested inside of a target region, also check that it does not |
2349 | // contain any statements, declarations or directives other than this |
2350 | // omp.teams construct. The issue is how to support the initialization of |
2351 | // this operation's own arguments (allow SSA values across omp.target?). |
2352 | Operation *op = getOperation(); |
2353 | if (!isa<TargetOp>(op->getParentOp()) && |
2354 | !opInGlobalImplicitParallelRegion(op)) |
2355 | return emitError("expected to be nested inside of omp.target or not nested " |
2356 | "in any OpenMP dialect operations"); |
2357 | |
2358 | // Check for num_teams clause restrictions |
2359 | if (auto numTeamsLowerBound = getNumTeamsLower()) { |
2360 | auto numTeamsUpperBound = getNumTeamsUpper(); |
2361 | if (!numTeamsUpperBound) |
2362 | return emitError("expected num_teams upper bound to be defined if the " |
2363 | "lower bound is defined"); |
2364 | if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType()) |
2365 | return emitError( |
2366 | "expected num_teams upper bound and lower bound to be the same type"); |
2367 | } |
2368 | |
2369 | // Check for allocate clause restrictions |
2370 | if (getAllocateVars().size() != getAllocatorVars().size()) |
2371 | return emitError( |
2372 | "expected equal sizes for allocate and allocator variables"); |
2373 | |
2374 | return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), |
2375 | getReductionByref()); |
2376 | } |
2377 | |
2378 | //===----------------------------------------------------------------------===// |
2379 | // SectionOp |
2380 | //===----------------------------------------------------------------------===// |
2381 | |
2382 | OperandRange SectionOp::getPrivateVars() { |
2383 | return getParentOp().getPrivateVars(); |
2384 | } |
2385 | |
2386 | OperandRange SectionOp::getReductionVars() { |
2387 | return getParentOp().getReductionVars(); |
2388 | } |
2389 | |
2390 | //===----------------------------------------------------------------------===// |
2391 | // SectionsOp |
2392 | //===----------------------------------------------------------------------===// |
2393 | |
2394 | void SectionsOp::build(OpBuilder &builder, OperationState &state, |
2395 | const SectionsOperands &clauses) { |
2396 | MLIRContext *ctx = builder.getContext(); |
2397 | // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier |
2398 | SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, |
2399 | clauses.nowait, /*private_vars=*/{}, |
2400 | /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr, |
2401 | clauses.reductionMod, clauses.reductionVars, |
2402 | makeDenseBoolArrayAttr(ctx, clauses.reductionByref), |
2403 | makeArrayAttr(ctx, clauses.reductionSyms)); |
2404 | } |
2405 | |
2406 | LogicalResult SectionsOp::verify() { |
2407 | if (getAllocateVars().size() != getAllocatorVars().size()) |
2408 | return emitError( |
2409 | "expected equal sizes for allocate and allocator variables"); |
2410 | |
2411 | return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), |
2412 | getReductionByref()); |
2413 | } |
2414 | |
2415 | LogicalResult SectionsOp::verifyRegions() { |
2416 | for (auto &inst : *getRegion().begin()) { |
2417 | if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) { |
2418 | return emitOpError() |
2419 | << "expected omp.section op or terminator op inside region"; |
2420 | } |
2421 | } |
2422 | |
2423 | return success(); |
2424 | } |
2425 | |
2426 | //===----------------------------------------------------------------------===// |
2427 | // SingleOp |
2428 | //===----------------------------------------------------------------------===// |
2429 | |
2430 | void SingleOp::build(OpBuilder &builder, OperationState &state, |
2431 | const SingleOperands &clauses) { |
2432 | MLIRContext *ctx = builder.getContext(); |
2433 | // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier |
2434 | SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, |
2435 | clauses.copyprivateVars, |
2436 | makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait, |
2437 | /*private_vars=*/{}, /*private_syms=*/nullptr, |
2438 | /*private_needs_barrier=*/nullptr); |
2439 | } |
2440 | |
2441 | LogicalResult SingleOp::verify() { |
2442 | // Check for allocate clause restrictions |
2443 | if (getAllocateVars().size() != getAllocatorVars().size()) |
2444 | return emitError( |
2445 | "expected equal sizes for allocate and allocator variables"); |
2446 | |
2447 | return verifyCopyprivateVarList(*this, getCopyprivateVars(), |
2448 | getCopyprivateSyms()); |
2449 | } |
2450 | |
2451 | //===----------------------------------------------------------------------===// |
2452 | // WorkshareOp |
2453 | //===----------------------------------------------------------------------===// |
2454 | |
2455 | void WorkshareOp::build(OpBuilder &builder, OperationState &state, |
2456 | const WorkshareOperands &clauses) { |
2457 | WorkshareOp::build(builder, state, clauses.nowait); |
2458 | } |
2459 | |
2460 | //===----------------------------------------------------------------------===// |
2461 | // WorkshareLoopWrapperOp |
2462 | //===----------------------------------------------------------------------===// |
2463 | |
2464 | LogicalResult WorkshareLoopWrapperOp::verify() { |
2465 | if (!(*this)->getParentOfType<WorkshareOp>()) |
2466 | return emitOpError() << "must be nested in an omp.workshare"; |
2467 | return success(); |
2468 | } |
2469 | |
2470 | LogicalResult WorkshareLoopWrapperOp::verifyRegions() { |
2471 | if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) || |
2472 | getNestedWrapper()) |
2473 | return emitOpError() << "expected to be a standalone loop wrapper"; |
2474 | |
2475 | return success(); |
2476 | } |
2477 | |
2478 | //===----------------------------------------------------------------------===// |
2479 | // LoopWrapperInterface |
2480 | //===----------------------------------------------------------------------===// |
2481 | |
2482 | LogicalResult LoopWrapperInterface::verifyImpl() { |
2483 | Operation *op = this->getOperation(); |
2484 | if (!op->hasTrait<OpTrait::NoTerminator>() || |
2485 | !op->hasTrait<OpTrait::SingleBlock>()) |
2486 | return emitOpError() << "loop wrapper must also have the `NoTerminator` " |
2487 | "and `SingleBlock` traits"; |
2488 | |
2489 | if (op->getNumRegions() != 1) |
2490 | return emitOpError() << "loop wrapper does not contain exactly one region"; |
2491 | |
2492 | Region ®ion = op->getRegion(0); |
2493 | if (range_size(region.getOps()) != 1) |
2494 | return emitOpError() |
2495 | << "loop wrapper does not contain exactly one nested op"; |
2496 | |
2497 | Operation &firstOp = *region.op_begin(); |
2498 | if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp)) |
2499 | return emitOpError() << "nested in loop wrapper is not another loop " |
2500 | "wrapper or `omp.loop_nest`"; |
2501 | |
2502 | return success(); |
2503 | } |
2504 | |
2505 | //===----------------------------------------------------------------------===// |
2506 | // LoopOp |
2507 | //===----------------------------------------------------------------------===// |
2508 | |
2509 | void LoopOp::build(OpBuilder &builder, OperationState &state, |
2510 | const LoopOperands &clauses) { |
2511 | MLIRContext *ctx = builder.getContext(); |
2512 | |
2513 | LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars, |
2514 | makeArrayAttr(ctx, clauses.privateSyms), |
2515 | clauses.privateNeedsBarrier, clauses.order, clauses.orderMod, |
2516 | clauses.reductionMod, clauses.reductionVars, |
2517 | makeDenseBoolArrayAttr(ctx, clauses.reductionByref), |
2518 | makeArrayAttr(ctx, clauses.reductionSyms)); |
2519 | } |
2520 | |
2521 | LogicalResult LoopOp::verify() { |
2522 | return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), |
2523 | getReductionByref()); |
2524 | } |
2525 | |
2526 | LogicalResult LoopOp::verifyRegions() { |
2527 | if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) || |
2528 | getNestedWrapper()) |
2529 | return emitOpError() << "expected to be a standalone loop wrapper"; |
2530 | |
2531 | return success(); |
2532 | } |
2533 | |
2534 | //===----------------------------------------------------------------------===// |
2535 | // WsloopOp |
2536 | //===----------------------------------------------------------------------===// |
2537 | |
2538 | void WsloopOp::build(OpBuilder &builder, OperationState &state, |
2539 | ArrayRef<NamedAttribute> attributes) { |
2540 | build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, |
2541 | /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(), |
2542 | /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr, |
2543 | /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr, |
2544 | /*private_needs_barrier=*/false, |
2545 | /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(), |
2546 | /*reduction_byref=*/nullptr, |
2547 | /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr, |
2548 | /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr, |
2549 | /*schedule_simd=*/false); |
2550 | state.addAttributes(attributes); |
2551 | } |
2552 | |
2553 | void WsloopOp::build(OpBuilder &builder, OperationState &state, |
2554 | const WsloopOperands &clauses) { |
2555 | MLIRContext *ctx = builder.getContext(); |
2556 | // TODO: Store clauses in op: allocateVars, allocatorVars |
2557 | WsloopOp::build( |
2558 | builder, state, |
2559 | /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars, |
2560 | clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod, |
2561 | clauses.ordered, clauses.privateVars, |
2562 | makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, |
2563 | clauses.reductionMod, clauses.reductionVars, |
2564 | makeDenseBoolArrayAttr(ctx, clauses.reductionByref), |
2565 | makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind, |
2566 | clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd); |
2567 | } |
2568 | |
2569 | LogicalResult WsloopOp::verify() { |
2570 | return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(), |
2571 | getReductionByref()); |
2572 | } |
2573 | |
2574 | LogicalResult WsloopOp::verifyRegions() { |
2575 | bool isCompositeChildLeaf = |
2576 | llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()); |
2577 | |
2578 | if (LoopWrapperInterface nested = getNestedWrapper()) { |
2579 | if (!isComposite()) |
2580 | return emitError() |
2581 | << "'omp.composite' attribute missing from composite wrapper"; |
2582 | |
2583 | // Check for the allowed leaf constructs that may appear in a composite |
2584 | // construct directly after DO/FOR. |
2585 | if (!isa<SimdOp>(nested)) |
2586 | return emitError() << "only supported nested wrapper is 'omp.simd'"; |
2587 | |
2588 | } else if (isComposite() && !isCompositeChildLeaf) { |
2589 | return emitError() |
2590 | << "'omp.composite' attribute present in non-composite wrapper"; |
2591 | } else if (!isComposite() && isCompositeChildLeaf) { |
2592 | return emitError() |
2593 | << "'omp.composite' attribute missing from composite wrapper"; |
2594 | } |
2595 | |
2596 | return success(); |
2597 | } |
2598 | |
2599 | //===----------------------------------------------------------------------===// |
2600 | // Simd construct [2.9.3.1] |
2601 | //===----------------------------------------------------------------------===// |
2602 | |
2603 | void SimdOp::build(OpBuilder &builder, OperationState &state, |
2604 | const SimdOperands &clauses) { |
2605 | MLIRContext *ctx = builder.getContext(); |
2606 | // TODO Store clauses in op: linearVars, linearStepVars |
2607 | SimdOp::build(builder, state, clauses.alignedVars, |
2608 | makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr, |
2609 | /*linear_vars=*/{}, /*linear_step_vars=*/{}, |
2610 | clauses.nontemporalVars, clauses.order, clauses.orderMod, |
2611 | clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms), |
2612 | clauses.privateNeedsBarrier, clauses.reductionMod, |
2613 | clauses.reductionVars, |
2614 | makeDenseBoolArrayAttr(ctx, clauses.reductionByref), |
2615 | makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen, |
2616 | clauses.simdlen); |
2617 | } |
2618 | |
2619 | LogicalResult SimdOp::verify() { |
2620 | if (getSimdlen().has_value() && getSafelen().has_value() && |
2621 | getSimdlen().value() > getSafelen().value()) |
2622 | return emitOpError() |
2623 | << "simdlen clause and safelen clause are both present, but the " |
2624 | "simdlen value is not less than or equal to safelen value"; |
2625 | |
2626 | if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed()) |
2627 | return failure(); |
2628 | |
2629 | if (verifyNontemporalClause(*this, getNontemporalVars()).failed()) |
2630 | return failure(); |
2631 | |
2632 | bool isCompositeChildLeaf = |
2633 | llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()); |
2634 | |
2635 | if (!isComposite() && isCompositeChildLeaf) |
2636 | return emitError() |
2637 | << "'omp.composite' attribute missing from composite wrapper"; |
2638 | |
2639 | if (isComposite() && !isCompositeChildLeaf) |
2640 | return emitError() |
2641 | << "'omp.composite' attribute present in non-composite wrapper"; |
2642 | |
2643 | return success(); |
2644 | } |
2645 | |
2646 | LogicalResult SimdOp::verifyRegions() { |
2647 | if (getNestedWrapper()) |
2648 | return emitOpError() << "must wrap an 'omp.loop_nest' directly"; |
2649 | |
2650 | return success(); |
2651 | } |
2652 | |
2653 | //===----------------------------------------------------------------------===// |
2654 | // Distribute construct [2.9.4.1] |
2655 | //===----------------------------------------------------------------------===// |
2656 | |
2657 | void DistributeOp::build(OpBuilder &builder, OperationState &state, |
2658 | const DistributeOperands &clauses) { |
2659 | DistributeOp::build(builder, state, clauses.allocateVars, |
2660 | clauses.allocatorVars, clauses.distScheduleStatic, |
2661 | clauses.distScheduleChunkSize, clauses.order, |
2662 | clauses.orderMod, clauses.privateVars, |
2663 | makeArrayAttr(builder.getContext(), clauses.privateSyms), |
2664 | clauses.privateNeedsBarrier); |
2665 | } |
2666 | |
2667 | LogicalResult DistributeOp::verify() { |
2668 | if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic()) |
2669 | return emitOpError() << "chunk size set without " |
2670 | "dist_schedule_static being present"; |
2671 | |
2672 | if (getAllocateVars().size() != getAllocatorVars().size()) |
2673 | return emitError( |
2674 | "expected equal sizes for allocate and allocator variables"); |
2675 | |
2676 | return success(); |
2677 | } |
2678 | |
2679 | LogicalResult DistributeOp::verifyRegions() { |
2680 | if (LoopWrapperInterface nested = getNestedWrapper()) { |
2681 | if (!isComposite()) |
2682 | return emitError() |
2683 | << "'omp.composite' attribute missing from composite wrapper"; |
2684 | // Check for the allowed leaf constructs that may appear in a composite |
2685 | // construct directly after DISTRIBUTE. |
2686 | if (isa<WsloopOp>(nested)) { |
2687 | Operation *parentOp = (*this)->getParentOp(); |
2688 | if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) || |
2689 | !cast<ComposableOpInterface>(parentOp).isComposite()) { |
2690 | return emitError() << "an 'omp.wsloop' nested wrapper is only allowed " |
2691 | "when a composite 'omp.parallel' is the direct " |
2692 | "parent"; |
2693 | } |
2694 | } else if (!isa<SimdOp>(nested)) |
2695 | return emitError() << "only supported nested wrappers are 'omp.simd' and " |
2696 | "'omp.wsloop'"; |
2697 | } else if (isComposite()) { |
2698 | return emitError() |
2699 | << "'omp.composite' attribute present in non-composite wrapper"; |
2700 | } |
2701 | |
2702 | return success(); |
2703 | } |
2704 | |
2705 | //===----------------------------------------------------------------------===// |
2706 | // DeclareMapperOp / DeclareMapperInfoOp |
2707 | //===----------------------------------------------------------------------===// |
2708 | |
2709 | LogicalResult DeclareMapperInfoOp::verify() { |
2710 | return verifyMapClause(*this, getMapVars()); |
2711 | } |
2712 | |
2713 | LogicalResult DeclareMapperOp::verifyRegions() { |
2714 | if (!llvm::isa_and_present<DeclareMapperInfoOp>( |
2715 | getRegion().getBlocks().front().getTerminator())) |
2716 | return emitOpError() << "expected terminator to be a DeclareMapperInfoOp"; |
2717 | |
2718 | return success(); |
2719 | } |
2720 | |
2721 | //===----------------------------------------------------------------------===// |
2722 | // DeclareReductionOp |
2723 | //===----------------------------------------------------------------------===// |
2724 | |
2725 | LogicalResult DeclareReductionOp::verifyRegions() { |
2726 | if (!getAllocRegion().empty()) { |
2727 | for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) { |
2728 | if (yieldOp.getResults().size() != 1 || |
2729 | yieldOp.getResults().getTypes()[0] != getType()) |
2730 | return emitOpError() << "expects alloc region to yield a value " |
2731 | "of the reduction type"; |
2732 | } |
2733 | } |
2734 | |
2735 | if (getInitializerRegion().empty()) |
2736 | return emitOpError() << "expects non-empty initializer region"; |
2737 | Block &initializerEntryBlock = getInitializerRegion().front(); |
2738 | |
2739 | if (initializerEntryBlock.getNumArguments() == 1) { |
2740 | if (!getAllocRegion().empty()) |
2741 | return emitOpError() << "expects two arguments to the initializer region " |
2742 | "when an allocation region is used"; |
2743 | } else if (initializerEntryBlock.getNumArguments() == 2) { |
2744 | if (getAllocRegion().empty()) |
2745 | return emitOpError() << "expects one argument to the initializer region " |
2746 | "when no allocation region is used"; |
2747 | } else { |
2748 | return emitOpError() |
2749 | << "expects one or two arguments to the initializer region"; |
2750 | } |
2751 | |
2752 | for (mlir::Value arg : initializerEntryBlock.getArguments()) |
2753 | if (arg.getType() != getType()) |
2754 | return emitOpError() << "expects initializer region argument to match " |
2755 | "the reduction type"; |
2756 | |
2757 | for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) { |
2758 | if (yieldOp.getResults().size() != 1 || |
2759 | yieldOp.getResults().getTypes()[0] != getType()) |
2760 | return emitOpError() << "expects initializer region to yield a value " |
2761 | "of the reduction type"; |
2762 | } |
2763 | |
2764 | if (getReductionRegion().empty()) |
2765 | return emitOpError() << "expects non-empty reduction region"; |
2766 | Block &reductionEntryBlock = getReductionRegion().front(); |
2767 | if (reductionEntryBlock.getNumArguments() != 2 || |
2768 | reductionEntryBlock.getArgumentTypes()[0] != |
2769 | reductionEntryBlock.getArgumentTypes()[1] || |
2770 | reductionEntryBlock.getArgumentTypes()[0] != getType()) |
2771 | return emitOpError() << "expects reduction region with two arguments of " |
2772 | "the reduction type"; |
2773 | for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) { |
2774 | if (yieldOp.getResults().size() != 1 || |
2775 | yieldOp.getResults().getTypes()[0] != getType()) |
2776 | return emitOpError() << "expects reduction region to yield a value " |
2777 | "of the reduction type"; |
2778 | } |
2779 | |
2780 | if (!getAtomicReductionRegion().empty()) { |
2781 | Block &atomicReductionEntryBlock = getAtomicReductionRegion().front(); |
2782 | if (atomicReductionEntryBlock.getNumArguments() != 2 || |
2783 | atomicReductionEntryBlock.getArgumentTypes()[0] != |
2784 | atomicReductionEntryBlock.getArgumentTypes()[1]) |
2785 | return emitOpError() << "expects atomic reduction region with two " |
2786 | "arguments of the same type"; |
2787 | auto ptrType = llvm::dyn_cast<PointerLikeType>( |
2788 | atomicReductionEntryBlock.getArgumentTypes()[0]); |
2789 | if (!ptrType || |
2790 | (ptrType.getElementType() && ptrType.getElementType() != getType())) |
2791 | return emitOpError() << "expects atomic reduction region arguments to " |
2792 | "be accumulators containing the reduction type"; |
2793 | } |
2794 | |
2795 | if (getCleanupRegion().empty()) |
2796 | return success(); |
2797 | Block &cleanupEntryBlock = getCleanupRegion().front(); |
2798 | if (cleanupEntryBlock.getNumArguments() != 1 || |
2799 | cleanupEntryBlock.getArgument(0).getType() != getType()) |
2800 | return emitOpError() << "expects cleanup region with one argument " |
2801 | "of the reduction type"; |
2802 | |
2803 | return success(); |
2804 | } |
2805 | |
2806 | //===----------------------------------------------------------------------===// |
2807 | // TaskOp |
2808 | //===----------------------------------------------------------------------===// |
2809 | |
2810 | void TaskOp::build(OpBuilder &builder, OperationState &state, |
2811 | const TaskOperands &clauses) { |
2812 | MLIRContext *ctx = builder.getContext(); |
2813 | TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars, |
2814 | makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars, |
2815 | clauses.final, clauses.ifExpr, clauses.inReductionVars, |
2816 | makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), |
2817 | makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, |
2818 | clauses.priority, /*private_vars=*/clauses.privateVars, |
2819 | /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), |
2820 | clauses.privateNeedsBarrier, clauses.untied, |
2821 | clauses.eventHandle); |
2822 | } |
2823 | |
2824 | LogicalResult TaskOp::verify() { |
2825 | LogicalResult verifyDependVars = |
2826 | verifyDependVarList(*this, getDependKinds(), getDependVars()); |
2827 | return failed(verifyDependVars) |
2828 | ? verifyDependVars |
2829 | : verifyReductionVarList(*this, getInReductionSyms(), |
2830 | getInReductionVars(), |
2831 | getInReductionByref()); |
2832 | } |
2833 | |
2834 | //===----------------------------------------------------------------------===// |
2835 | // TaskgroupOp |
2836 | //===----------------------------------------------------------------------===// |
2837 | |
2838 | void TaskgroupOp::build(OpBuilder &builder, OperationState &state, |
2839 | const TaskgroupOperands &clauses) { |
2840 | MLIRContext *ctx = builder.getContext(); |
2841 | TaskgroupOp::build(builder, state, clauses.allocateVars, |
2842 | clauses.allocatorVars, clauses.taskReductionVars, |
2843 | makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref), |
2844 | makeArrayAttr(ctx, clauses.taskReductionSyms)); |
2845 | } |
2846 | |
2847 | LogicalResult TaskgroupOp::verify() { |
2848 | return verifyReductionVarList(*this, getTaskReductionSyms(), |
2849 | getTaskReductionVars(), |
2850 | getTaskReductionByref()); |
2851 | } |
2852 | |
2853 | //===----------------------------------------------------------------------===// |
2854 | // TaskloopOp |
2855 | //===----------------------------------------------------------------------===// |
2856 | |
2857 | void TaskloopOp::build(OpBuilder &builder, OperationState &state, |
2858 | const TaskloopOperands &clauses) { |
2859 | MLIRContext *ctx = builder.getContext(); |
2860 | TaskloopOp::build( |
2861 | builder, state, clauses.allocateVars, clauses.allocatorVars, |
2862 | clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr, |
2863 | clauses.inReductionVars, |
2864 | makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), |
2865 | makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, |
2866 | clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority, |
2867 | /*private_vars=*/clauses.privateVars, |
2868 | /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), |
2869 | clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars, |
2870 | makeDenseBoolArrayAttr(ctx, clauses.reductionByref), |
2871 | makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied); |
2872 | } |
2873 | |
2874 | LogicalResult TaskloopOp::verify() { |
2875 | if (getAllocateVars().size() != getAllocatorVars().size()) |
2876 | return emitError( |
2877 | "expected equal sizes for allocate and allocator variables"); |
2878 | if (failed(verifyReductionVarList(*this, getReductionSyms(), |
2879 | getReductionVars(), getReductionByref())) || |
2880 | failed(verifyReductionVarList(*this, getInReductionSyms(), |
2881 | getInReductionVars(), |
2882 | getInReductionByref()))) |
2883 | return failure(); |
2884 | |
2885 | if (!getReductionVars().empty() && getNogroup()) |
2886 | return emitError("if a reduction clause is present on the taskloop " |
2887 | "directive, the nogroup clause must not be specified"); |
2888 | for (auto var : getReductionVars()) { |
2889 | if (llvm::is_contained(getInReductionVars(), var)) |
2890 | return emitError("the same list item cannot appear in both a reduction " |
2891 | "and an in_reduction clause"); |
2892 | } |
2893 | |
2894 | if (getGrainsize() && getNumTasks()) { |
2895 | return emitError( |
2896 | "the grainsize clause and num_tasks clause are mutually exclusive and " |
2897 | "may not appear on the same taskloop directive"); |
2898 | } |
2899 | |
2900 | return success(); |
2901 | } |
2902 | |
2903 | LogicalResult TaskloopOp::verifyRegions() { |
2904 | if (LoopWrapperInterface nested = getNestedWrapper()) { |
2905 | if (!isComposite()) |
2906 | return emitError() |
2907 | << "'omp.composite' attribute missing from composite wrapper"; |
2908 | |
2909 | // Check for the allowed leaf constructs that may appear in a composite |
2910 | // construct directly after TASKLOOP. |
2911 | if (!isa<SimdOp>(nested)) |
2912 | return emitError() << "only supported nested wrapper is 'omp.simd'"; |
2913 | } else if (isComposite()) { |
2914 | return emitError() |
2915 | << "'omp.composite' attribute present in non-composite wrapper"; |
2916 | } |
2917 | |
2918 | return success(); |
2919 | } |
2920 | |
2921 | //===----------------------------------------------------------------------===// |
2922 | // LoopNestOp |
2923 | //===----------------------------------------------------------------------===// |
2924 | |
2925 | ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) { |
2926 | // Parse an opening `(` followed by induction variables followed by `)` |
2927 | SmallVector<OpAsmParser::Argument> ivs; |
2928 | SmallVector<OpAsmParser::UnresolvedOperand> lbs, ubs; |
2929 | Type loopVarType; |
2930 | if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) || |
2931 | parser.parseColonType(loopVarType) || |
2932 | // Parse loop bounds. |
2933 | parser.parseEqual() || |
2934 | parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) || |
2935 | parser.parseKeyword("to") || |
2936 | parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren)) |
2937 | return failure(); |
2938 | |
2939 | for (auto &iv : ivs) |
2940 | iv.type = loopVarType; |
2941 | |
2942 | // Parse "inclusive" flag. |
2943 | if (succeeded(parser.parseOptionalKeyword("inclusive"))) |
2944 | result.addAttribute("loop_inclusive", |
2945 | UnitAttr::get(parser.getBuilder().getContext())); |
2946 | |
2947 | // Parse step values. |
2948 | SmallVector<OpAsmParser::UnresolvedOperand> steps; |
2949 | if (parser.parseKeyword("step") || |
2950 | parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren)) |
2951 | return failure(); |
2952 | |
2953 | // Parse the body. |
2954 | Region *region = result.addRegion(); |
2955 | if (parser.parseRegion(*region, ivs)) |
2956 | return failure(); |
2957 | |
2958 | // Resolve operands. |
2959 | if (parser.resolveOperands(lbs, loopVarType, result.operands) || |
2960 | parser.resolveOperands(ubs, loopVarType, result.operands) || |
2961 | parser.resolveOperands(steps, loopVarType, result.operands)) |
2962 | return failure(); |
2963 | |
2964 | // Parse the optional attribute list. |
2965 | return parser.parseOptionalAttrDict(result.attributes); |
2966 | } |
2967 | |
2968 | void LoopNestOp::print(OpAsmPrinter &p) { |
2969 | Region ®ion = getRegion(); |
2970 | auto args = region.getArguments(); |
2971 | p << " ("<< args << ") : "<< args[0].getType() << " = (" |
2972 | << getLoopLowerBounds() << ") to ("<< getLoopUpperBounds() << ") "; |
2973 | if (getLoopInclusive()) |
2974 | p << "inclusive "; |
2975 | p << "step ("<< getLoopSteps() << ") "; |
2976 | p.printRegion(region, /*printEntryBlockArgs=*/false); |
2977 | } |
2978 | |
2979 | void LoopNestOp::build(OpBuilder &builder, OperationState &state, |
2980 | const LoopNestOperands &clauses) { |
2981 | LoopNestOp::build(builder, state, clauses.loopLowerBounds, |
2982 | clauses.loopUpperBounds, clauses.loopSteps, |
2983 | clauses.loopInclusive); |
2984 | } |
2985 | |
2986 | LogicalResult LoopNestOp::verify() { |
2987 | if (getLoopLowerBounds().empty()) |
2988 | return emitOpError() << "must represent at least one loop"; |
2989 | |
2990 | if (getLoopLowerBounds().size() != getIVs().size()) |
2991 | return emitOpError() << "number of range arguments and IVs do not match"; |
2992 | |
2993 | for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) { |
2994 | if (lb.getType() != iv.getType()) |
2995 | return emitOpError() |
2996 | << "range argument type does not match corresponding IV type"; |
2997 | } |
2998 | |
2999 | if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp())) |
3000 | return emitOpError() << "expects parent op to be a loop wrapper"; |
3001 | |
3002 | return success(); |
3003 | } |
3004 | |
3005 | void LoopNestOp::gatherWrappers( |
3006 | SmallVectorImpl<LoopWrapperInterface> &wrappers) { |
3007 | Operation *parent = (*this)->getParentOp(); |
3008 | while (auto wrapper = |
3009 | llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) { |
3010 | wrappers.push_back(wrapper); |
3011 | parent = parent->getParentOp(); |
3012 | } |
3013 | } |
3014 | |
3015 | //===----------------------------------------------------------------------===// |
3016 | // Critical construct (2.17.1) |
3017 | //===----------------------------------------------------------------------===// |
3018 | |
3019 | void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state, |
3020 | const CriticalDeclareOperands &clauses) { |
3021 | CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint); |
3022 | } |
3023 | |
3024 | LogicalResult CriticalDeclareOp::verify() { |
3025 | return verifySynchronizationHint(*this, getHint()); |
3026 | } |
3027 | |
3028 | LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { |
3029 | if (getNameAttr()) { |
3030 | SymbolRefAttr symbolRef = getNameAttr(); |
3031 | auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>( |
3032 | *this, symbolRef); |
3033 | if (!decl) { |
3034 | return emitOpError() << "expected symbol reference "<< symbolRef |
3035 | << " to point to a critical declaration"; |
3036 | } |
3037 | } |
3038 | |
3039 | return success(); |
3040 | } |
3041 | |
3042 | //===----------------------------------------------------------------------===// |
3043 | // Ordered construct |
3044 | //===----------------------------------------------------------------------===// |
3045 | |
3046 | static LogicalResult verifyOrderedParent(Operation &op) { |
3047 | bool hasRegion = op.getNumRegions() > 0; |
3048 | auto loopOp = op.getParentOfType<LoopNestOp>(); |
3049 | if (!loopOp) { |
3050 | if (hasRegion) |
3051 | return success(); |
3052 | |
3053 | // TODO: Consider if this needs to be the case only for the standalone |
3054 | // variant of the ordered construct. |
3055 | return op.emitOpError() << "must be nested inside of a loop"; |
3056 | } |
3057 | |
3058 | Operation *wrapper = loopOp->getParentOp(); |
3059 | if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) { |
3060 | IntegerAttr orderedAttr = wsloopOp.getOrderedAttr(); |
3061 | if (!orderedAttr) |
3062 | return op.emitOpError() << "the enclosing worksharing-loop region must " |
3063 | "have an ordered clause"; |
3064 | |
3065 | if (hasRegion && orderedAttr.getInt() != 0) |
3066 | return op.emitOpError() << "the enclosing loop's ordered clause must not " |
3067 | "have a parameter present"; |
3068 | |
3069 | if (!hasRegion && orderedAttr.getInt() == 0) |
3070 | return op.emitOpError() << "the enclosing loop's ordered clause must " |
3071 | "have a parameter present"; |
3072 | } else if (!isa<SimdOp>(wrapper)) { |
3073 | return op.emitOpError() << "must be nested inside of a worksharing, simd " |
3074 | "or worksharing simd loop"; |
3075 | } |
3076 | return success(); |
3077 | } |
3078 | |
3079 | void OrderedOp::build(OpBuilder &builder, OperationState &state, |
3080 | const OrderedOperands &clauses) { |
3081 | OrderedOp::build(builder, state, clauses.doacrossDependType, |
3082 | clauses.doacrossNumLoops, clauses.doacrossDependVars); |
3083 | } |
3084 | |
3085 | LogicalResult OrderedOp::verify() { |
3086 | if (failed(verifyOrderedParent(**this))) |
3087 | return failure(); |
3088 | |
3089 | auto wrapper = (*this)->getParentOfType<WsloopOp>(); |
3090 | if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops()) |
3091 | return emitOpError() << "number of variables in depend clause does not " |
3092 | << "match number of iteration variables in the " |
3093 | << "doacross loop"; |
3094 | |
3095 | return success(); |
3096 | } |
3097 | |
3098 | void OrderedRegionOp::build(OpBuilder &builder, OperationState &state, |
3099 | const OrderedRegionOperands &clauses) { |
3100 | OrderedRegionOp::build(builder, state, clauses.parLevelSimd); |
3101 | } |
3102 | |
3103 | LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); } |
3104 | |
3105 | //===----------------------------------------------------------------------===// |
3106 | // TaskwaitOp |
3107 | //===----------------------------------------------------------------------===// |
3108 | |
3109 | void TaskwaitOp::build(OpBuilder &builder, OperationState &state, |
3110 | const TaskwaitOperands &clauses) { |
3111 | // TODO Store clauses in op: dependKinds, dependVars, nowait. |
3112 | TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr, |
3113 | /*depend_vars=*/{}, /*nowait=*/nullptr); |
3114 | } |
3115 | |
3116 | //===----------------------------------------------------------------------===// |
3117 | // Verifier for AtomicReadOp |
3118 | //===----------------------------------------------------------------------===// |
3119 | |
3120 | LogicalResult AtomicReadOp::verify() { |
3121 | if (verifyCommon().failed()) |
3122 | return mlir::failure(); |
3123 | |
3124 | if (auto mo = getMemoryOrder()) { |
3125 | if (*mo == ClauseMemoryOrderKind::Acq_rel || |
3126 | *mo == ClauseMemoryOrderKind::Release) { |
3127 | return emitError( |
3128 | "memory-order must not be acq_rel or release for atomic reads"); |
3129 | } |
3130 | } |
3131 | return verifySynchronizationHint(*this, getHint()); |
3132 | } |
3133 | |
3134 | //===----------------------------------------------------------------------===// |
3135 | // Verifier for AtomicWriteOp |
3136 | //===----------------------------------------------------------------------===// |
3137 | |
3138 | LogicalResult AtomicWriteOp::verify() { |
3139 | if (verifyCommon().failed()) |
3140 | return mlir::failure(); |
3141 | |
3142 | if (auto mo = getMemoryOrder()) { |
3143 | if (*mo == ClauseMemoryOrderKind::Acq_rel || |
3144 | *mo == ClauseMemoryOrderKind::Acquire) { |
3145 | return emitError( |
3146 | "memory-order must not be acq_rel or acquire for atomic writes"); |
3147 | } |
3148 | } |
3149 | return verifySynchronizationHint(*this, getHint()); |
3150 | } |
3151 | |
3152 | //===----------------------------------------------------------------------===// |
3153 | // Verifier for AtomicUpdateOp |
3154 | //===----------------------------------------------------------------------===// |
3155 | |
3156 | LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, |
3157 | PatternRewriter &rewriter) { |
3158 | if (op.isNoOp()) { |
3159 | rewriter.eraseOp(op); |
3160 | return success(); |
3161 | } |
3162 | if (Value writeVal = op.getWriteOpVal()) { |
3163 | rewriter.replaceOpWithNewOp<AtomicWriteOp>( |
3164 | op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr()); |
3165 | return success(); |
3166 | } |
3167 | return failure(); |
3168 | } |
3169 | |
3170 | LogicalResult AtomicUpdateOp::verify() { |
3171 | if (verifyCommon().failed()) |
3172 | return mlir::failure(); |
3173 | |
3174 | if (auto mo = getMemoryOrder()) { |
3175 | if (*mo == ClauseMemoryOrderKind::Acq_rel || |
3176 | *mo == ClauseMemoryOrderKind::Acquire) { |
3177 | return emitError( |
3178 | "memory-order must not be acq_rel or acquire for atomic updates"); |
3179 | } |
3180 | } |
3181 | |
3182 | return verifySynchronizationHint(*this, getHint()); |
3183 | } |
3184 | |
3185 | LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); } |
3186 | |
3187 | //===----------------------------------------------------------------------===// |
3188 | // Verifier for AtomicCaptureOp |
3189 | //===----------------------------------------------------------------------===// |
3190 | |
3191 | AtomicReadOp AtomicCaptureOp::getAtomicReadOp() { |
3192 | if (auto op = dyn_cast<AtomicReadOp>(getFirstOp())) |
3193 | return op; |
3194 | return dyn_cast<AtomicReadOp>(getSecondOp()); |
3195 | } |
3196 | |
3197 | AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() { |
3198 | if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp())) |
3199 | return op; |
3200 | return dyn_cast<AtomicWriteOp>(getSecondOp()); |
3201 | } |
3202 | |
3203 | AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { |
3204 | if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp())) |
3205 | return op; |
3206 | return dyn_cast<AtomicUpdateOp>(getSecondOp()); |
3207 | } |
3208 | |
3209 | LogicalResult AtomicCaptureOp::verify() { |
3210 | return verifySynchronizationHint(*this, getHint()); |
3211 | } |
3212 | |
3213 | LogicalResult AtomicCaptureOp::verifyRegions() { |
3214 | if (verifyRegionsCommon().failed()) |
3215 | return mlir::failure(); |
3216 | |
3217 | if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr( "hint")) |
3218 | return emitOpError( |
3219 | "operations inside capture region must not have hint clause"); |
3220 | |
3221 | if (getFirstOp()->getAttr("memory_order") || |
3222 | getSecondOp()->getAttr("memory_order")) |
3223 | return emitOpError( |
3224 | "operations inside capture region must not have memory_order clause"); |
3225 | return success(); |
3226 | } |
3227 | |
3228 | //===----------------------------------------------------------------------===// |
3229 | // CancelOp |
3230 | //===----------------------------------------------------------------------===// |
3231 | |
3232 | void CancelOp::build(OpBuilder &builder, OperationState &state, |
3233 | const CancelOperands &clauses) { |
3234 | CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr); |
3235 | } |
3236 | |
3237 | static Operation *getParentInSameDialect(Operation *thisOp) { |
3238 | Operation *parent = thisOp->getParentOp(); |
3239 | while (parent) { |
3240 | if (parent->getDialect() == thisOp->getDialect()) |
3241 | return parent; |
3242 | parent = parent->getParentOp(); |
3243 | } |
3244 | return nullptr; |
3245 | } |
3246 | |
3247 | LogicalResult CancelOp::verify() { |
3248 | ClauseCancellationConstructType cct = getCancelDirective(); |
3249 | // The next OpenMP operation in the chain of parents |
3250 | Operation *structuralParent = getParentInSameDialect((*this).getOperation()); |
3251 | if (!structuralParent) |
3252 | return emitOpError() << "Orphaned cancel construct"; |
3253 | |
3254 | if ((cct == ClauseCancellationConstructType::Parallel) && |
3255 | !mlir::isa<ParallelOp>(structuralParent)) { |
3256 | return emitOpError() << "cancel parallel must appear " |
3257 | << "inside a parallel region"; |
3258 | } |
3259 | if (cct == ClauseCancellationConstructType::Loop) { |
3260 | // structural parent will be omp.loop_nest, directly nested inside |
3261 | // omp.wsloop |
3262 | auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp()); |
3263 | |
3264 | if (!wsloopOp) { |
3265 | return emitOpError() |
3266 | << "cancel loop must appear inside a worksharing-loop region"; |
3267 | } |
3268 | if (wsloopOp.getNowaitAttr()) { |
3269 | return emitError() << "A worksharing construct that is canceled " |
3270 | << "must not have a nowait clause"; |
3271 | } |
3272 | if (wsloopOp.getOrderedAttr()) { |
3273 | return emitError() << "A worksharing construct that is canceled " |
3274 | << "must not have an ordered clause"; |
3275 | } |
3276 | |
3277 | } else if (cct == ClauseCancellationConstructType::Sections) { |
3278 | // structural parent will be an omp.section, directly nested inside |
3279 | // omp.sections |
3280 | auto sectionsOp = |
3281 | mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp()); |
3282 | if (!sectionsOp) { |
3283 | return emitOpError() << "cancel sections must appear " |
3284 | << "inside a sections region"; |
3285 | } |
3286 | if (sectionsOp.getNowait()) { |
3287 | return emitError() << "A sections construct that is canceled " |
3288 | << "must not have a nowait clause"; |
3289 | } |
3290 | } |
3291 | if ((cct == ClauseCancellationConstructType::Taskgroup) && |
3292 | (!mlir::isa<omp::TaskOp>(structuralParent) && |
3293 | !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) { |
3294 | return emitOpError() << "cancel taskgroup must appear " |
3295 | << "inside a task region"; |
3296 | } |
3297 | return success(); |
3298 | } |
3299 | |
3300 | //===----------------------------------------------------------------------===// |
3301 | // CancellationPointOp |
3302 | //===----------------------------------------------------------------------===// |
3303 | |
3304 | void CancellationPointOp::build(OpBuilder &builder, OperationState &state, |
3305 | const CancellationPointOperands &clauses) { |
3306 | CancellationPointOp::build(builder, state, clauses.cancelDirective); |
3307 | } |
3308 | |
3309 | LogicalResult CancellationPointOp::verify() { |
3310 | ClauseCancellationConstructType cct = getCancelDirective(); |
3311 | // The next OpenMP operation in the chain of parents |
3312 | Operation *structuralParent = getParentInSameDialect((*this).getOperation()); |
3313 | if (!structuralParent) |
3314 | return emitOpError() << "Orphaned cancellation point"; |
3315 | |
3316 | if ((cct == ClauseCancellationConstructType::Parallel) && |
3317 | !mlir::isa<ParallelOp>(structuralParent)) { |
3318 | return emitOpError() << "cancellation point parallel must appear " |
3319 | << "inside a parallel region"; |
3320 | } |
3321 | // Strucutal parent here will be an omp.loop_nest. Get the parent of that to |
3322 | // find the wsloop |
3323 | if ((cct == ClauseCancellationConstructType::Loop) && |
3324 | !mlir::isa<WsloopOp>(structuralParent->getParentOp())) { |
3325 | return emitOpError() << "cancellation point loop must appear " |
3326 | << "inside a worksharing-loop region"; |
3327 | } |
3328 | if ((cct == ClauseCancellationConstructType::Sections) && |
3329 | !mlir::isa<omp::SectionOp>(structuralParent)) { |
3330 | return emitOpError() << "cancellation point sections must appear " |
3331 | << "inside a sections region"; |
3332 | } |
3333 | if ((cct == ClauseCancellationConstructType::Taskgroup) && |
3334 | !mlir::isa<omp::TaskOp>(structuralParent)) { |
3335 | return emitOpError() << "cancellation point taskgroup must appear " |
3336 | << "inside a task region"; |
3337 | } |
3338 | return success(); |
3339 | } |
3340 | |
3341 | //===----------------------------------------------------------------------===// |
3342 | // MapBoundsOp |
3343 | //===----------------------------------------------------------------------===// |
3344 | |
3345 | LogicalResult MapBoundsOp::verify() { |
3346 | auto extent = getExtent(); |
3347 | auto upperbound = getUpperBound(); |
3348 | if (!extent && !upperbound) |
3349 | return emitError("expected extent or upperbound."); |
3350 | return success(); |
3351 | } |
3352 | |
3353 | void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
3354 | TypeRange /*result_types*/, StringAttr symName, |
3355 | TypeAttr type) { |
3356 | PrivateClauseOp::build( |
3357 | odsBuilder, odsState, symName, type, |
3358 | DataSharingClauseTypeAttr::get(odsBuilder.getContext(), |
3359 | DataSharingClauseType::Private)); |
3360 | } |
3361 | |
3362 | LogicalResult PrivateClauseOp::verifyRegions() { |
3363 | Type argType = getArgType(); |
3364 | auto verifyTerminator = [&](Operation *terminator, |
3365 | bool yieldsValue) -> LogicalResult { |
3366 | if (!terminator->getBlock()->getSuccessors().empty()) |
3367 | return success(); |
3368 | |
3369 | if (!llvm::isa<YieldOp>(terminator)) |
3370 | return mlir::emitError(terminator->getLoc()) |
3371 | << "expected exit block terminator to be an `omp.yield` op."; |
3372 | |
3373 | YieldOp yieldOp = llvm::cast<YieldOp>(terminator); |
3374 | TypeRange yieldedTypes = yieldOp.getResults().getTypes(); |
3375 | |
3376 | if (!yieldsValue) { |
3377 | if (yieldedTypes.empty()) |
3378 | return success(); |
3379 | |
3380 | return mlir::emitError(terminator->getLoc()) |
3381 | << "Did not expect any values to be yielded."; |
3382 | } |
3383 | |
3384 | if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType) |
3385 | return success(); |
3386 | |
3387 | auto error = mlir::emitError(yieldOp.getLoc()) |
3388 | << "Invalid yielded value. Expected type: "<< argType |
3389 | << ", got: "; |
3390 | |
3391 | if (yieldedTypes.empty()) |
3392 | error << "None"; |
3393 | else |
3394 | error << yieldedTypes; |
3395 | |
3396 | return error; |
3397 | }; |
3398 | |
3399 | auto verifyRegion = [&](Region ®ion, unsigned expectedNumArgs, |
3400 | StringRef regionName, |
3401 | bool yieldsValue) -> LogicalResult { |
3402 | assert(!region.empty()); |
3403 | |
3404 | if (region.getNumArguments() != expectedNumArgs) |
3405 | return mlir::emitError(region.getLoc()) |
3406 | << "`"<< regionName << "`: " |
3407 | << "expected "<< expectedNumArgs |
3408 | << " region arguments, got: "<< region.getNumArguments(); |
3409 | |
3410 | for (Block &block : region) { |
3411 | // MLIR will verify the absence of the terminator for us. |
3412 | if (!block.mightHaveTerminator()) |
3413 | continue; |
3414 | |
3415 | if (failed(verifyTerminator(block.getTerminator(), yieldsValue))) |
3416 | return failure(); |
3417 | } |
3418 | |
3419 | return success(); |
3420 | }; |
3421 | |
3422 | // Ensure all of the region arguments have the same type |
3423 | for (Region *region : getRegions()) |
3424 | for (Type ty : region->getArgumentTypes()) |
3425 | if (ty != argType) |
3426 | return emitError() << "Region argument type mismatch: got "<< ty |
3427 | << " expected "<< argType << "."; |
3428 | |
3429 | mlir::Region &initRegion = getInitRegion(); |
3430 | if (!initRegion.empty() && |
3431 | failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init", |
3432 | /*yieldsValue=*/true))) |
3433 | return failure(); |
3434 | |
3435 | DataSharingClauseType dsType = getDataSharingType(); |
3436 | |
3437 | if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty()) |
3438 | return emitError("`private` clauses do not require a `copy` region."); |
3439 | |
3440 | if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty()) |
3441 | return emitError( |
3442 | "`firstprivate` clauses require at least a `copy` region."); |
3443 | |
3444 | if (dsType == DataSharingClauseType::FirstPrivate && |
3445 | failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy", |
3446 | /*yieldsValue=*/true))) |
3447 | return failure(); |
3448 | |
3449 | if (!getDeallocRegion().empty() && |
3450 | failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc", |
3451 | /*yieldsValue=*/false))) |
3452 | return failure(); |
3453 | |
3454 | return success(); |
3455 | } |
3456 | |
3457 | //===----------------------------------------------------------------------===// |
3458 | // Spec 5.2: Masked construct (10.5) |
3459 | //===----------------------------------------------------------------------===// |
3460 | |
3461 | void MaskedOp::build(OpBuilder &builder, OperationState &state, |
3462 | const MaskedOperands &clauses) { |
3463 | MaskedOp::build(builder, state, clauses.filteredThreadId); |
3464 | } |
3465 | |
3466 | //===----------------------------------------------------------------------===// |
3467 | // Spec 5.2: Scan construct (5.6) |
3468 | //===----------------------------------------------------------------------===// |
3469 | |
3470 | void ScanOp::build(OpBuilder &builder, OperationState &state, |
3471 | const ScanOperands &clauses) { |
3472 | ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars); |
3473 | } |
3474 | |
3475 | LogicalResult ScanOp::verify() { |
3476 | if (hasExclusiveVars() == hasInclusiveVars()) |
3477 | return emitError( |
3478 | "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected"); |
3479 | if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) { |
3480 | if (parentWsLoopOp.getReductionModAttr() && |
3481 | parentWsLoopOp.getReductionModAttr().getValue() == |
3482 | ReductionModifier::inscan) |
3483 | return success(); |
3484 | } |
3485 | if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) { |
3486 | if (parentSimdOp.getReductionModAttr() && |
3487 | parentSimdOp.getReductionModAttr().getValue() == |
3488 | ReductionModifier::inscan) |
3489 | return success(); |
3490 | } |
3491 | return emitError("SCAN directive needs to be enclosed within a parent " |
3492 | "worksharing loop construct or SIMD construct with INSCAN " |
3493 | "reduction modifier"); |
3494 | } |
3495 | |
3496 | #define GET_ATTRDEF_CLASSES |
3497 | #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" |
3498 | |
3499 | #define GET_OP_CLASSES |
3500 | #include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc" |
3501 | |
3502 | #define GET_TYPEDEF_CLASSES |
3503 | #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc" |
3504 |
Definitions
- makeArrayAttr
- makeDenseBoolArrayAttr
- MemRefPointerLikeModel
- getElementType
- LLVMPointerPointerLikeModel
- getElementType
- parseAllocateAndAllocator
- printAllocateAndAllocator
- parseClauseAttr
- printClauseAttr
- parseLinearClause
- printLinearClause
- verifyNontemporalClause
- verifyAlignedClause
- parseAlignedClause
- printAlignedClause
- verifyScheduleModifiers
- parseScheduleClause
- printScheduleClause
- parseOrderClause
- printOrderClause
- parseGranularityClause
- printGranularityClause
- parseGrainsizeClause
- printGrainsizeClause
- parseNumTasksClause
- printNumTasksClause
- MapParseArgs
- MapParseArgs
- PrivateParseArgs
- PrivateParseArgs
- ReductionParseArgs
- ReductionParseArgs
- AllRegionParseArgs
- getPrivateNeedsBarrierSpelling
- parseClauseWithRegionArgs
- parseBlockArgClause
- parseBlockArgClause
- parseBlockArgClause
- parseBlockArgRegion
- parseTargetOpRegion
- parseInReductionPrivateRegion
- parseInReductionPrivateReductionRegion
- parsePrivateRegion
- parsePrivateReductionRegion
- parseTaskReductionRegion
- parseUseDeviceAddrUseDevicePtrRegion
- MapPrintArgs
- MapPrintArgs
- PrivatePrintArgs
- PrivatePrintArgs
- ReductionPrintArgs
- ReductionPrintArgs
- AllRegionPrintArgs
- printClauseWithRegionArgs
- printBlockArgClause
- printBlockArgClause
- printBlockArgClause
- printBlockArgRegion
- printTargetOpRegion
- printInReductionPrivateRegion
- printInReductionPrivateReductionRegion
- printPrivateRegion
- printPrivateReductionRegion
- printTaskReductionRegion
- printUseDeviceAddrUseDevicePtrRegion
- verifyReductionVarList
- parseCopyprivate
- printCopyprivate
- verifyCopyprivateVarList
- parseDependVarList
- printDependVarList
- verifyDependVarList
- parseSynchronizationHint
- printSynchronizationHint
- verifySynchronizationHint
- mapTypeToBitFlag
- parseMapClause
- printMapClause
- parseMembersIndex
- printMembersIndex
- printCaptureType
- parseCaptureType
- verifyMapClause
- verifyPrivateVarsMapping
- verifyMapInfoDefinedArgs
- findCapturedOmpOp
- verifyPrivateVarList
- opInGlobalImplicitParallelRegion
- verifyOrderedParent
Improve your Profiling and Debugging skills
Find out more