1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
9#include "mlir/Dialect/OpenACC/OpenACC.h"
10#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/DialectImplementation.h"
16#include "mlir/IR/Matchers.h"
17#include "mlir/IR/OpImplementation.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "llvm/ADT/SmallSet.h"
20#include "llvm/ADT/TypeSwitch.h"
21
22using namespace mlir;
23using namespace acc;
24
25#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
28#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
29
30namespace {
31struct MemRefPointerLikeModel
32 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
33 MemRefType> {
34 Type getElementType(Type pointer) const {
35 return llvm::cast<MemRefType>(pointer).getElementType();
36 }
37};
38
39struct LLVMPointerPointerLikeModel
40 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
41 LLVM::LLVMPointerType> {
42 Type getElementType(Type pointer) const { return Type(); }
43};
44} // namespace
45
46//===----------------------------------------------------------------------===//
47// OpenACC operations
48//===----------------------------------------------------------------------===//
49
50void OpenACCDialect::initialize() {
51 addOperations<
52#define GET_OP_LIST
53#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
54 >();
55 addAttributes<
56#define GET_ATTRDEF_LIST
57#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
58 >();
59 addTypes<
60#define GET_TYPEDEF_LIST
61#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
62 >();
63
64 // By attaching interfaces here, we make the OpenACC dialect dependent on
65 // the other dialects. This is probably better than having dialects like LLVM
66 // and memref be dependent on OpenACC.
67 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
68 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
69 *getContext());
70}
71
72//===----------------------------------------------------------------------===//
73// device_type support helpers
74//===----------------------------------------------------------------------===//
75
76static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
77 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
78 return true;
79 return false;
80}
81
82static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
83 mlir::acc::DeviceType deviceType) {
84 if (!hasDeviceTypeValues(arrayAttr))
85 return false;
86
87 for (auto attr : *arrayAttr) {
88 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
89 if (deviceTypeAttr.getValue() == deviceType)
90 return true;
91 }
92
93 return false;
94}
95
96static void printDeviceTypes(mlir::OpAsmPrinter &p,
97 std::optional<mlir::ArrayAttr> deviceTypes) {
98 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
99 return;
100
101 p << "[";
102 llvm::interleaveComma(*deviceTypes, p,
103 [&](mlir::Attribute attr) { p << attr; });
104 p << "]";
105}
106
107static std::optional<unsigned> findSegment(ArrayAttr segments,
108 mlir::acc::DeviceType deviceType) {
109 unsigned segmentIdx = 0;
110 for (auto attr : segments) {
111 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
112 if (deviceTypeAttr.getValue() == deviceType)
113 return std::make_optional(segmentIdx);
114 ++segmentIdx;
115 }
116 return std::nullopt;
117}
118
119static mlir::Operation::operand_range
120getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
121 mlir::Operation::operand_range range,
122 std::optional<llvm::ArrayRef<int32_t>> segments,
123 mlir::acc::DeviceType deviceType) {
124 if (!arrayAttr)
125 return range.take_front(n: 0);
126 if (auto pos = findSegment(*arrayAttr, deviceType)) {
127 int32_t nbOperandsBefore = 0;
128 for (unsigned i = 0; i < *pos; ++i)
129 nbOperandsBefore += (*segments)[i];
130 return range.drop_front(n: nbOperandsBefore).take_front(n: (*segments)[*pos]);
131 }
132 return range.take_front(n: 0);
133}
134
135static mlir::Value
136getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
137 mlir::Operation::operand_range operands,
138 std::optional<llvm::ArrayRef<int32_t>> segments,
139 std::optional<mlir::ArrayAttr> hasWaitDevnum,
140 mlir::acc::DeviceType deviceType) {
141 if (!hasDeviceTypeValues(arrayAttr: deviceTypeAttr))
142 return {};
143 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
144 if (hasWaitDevnum->getValue()[*pos])
145 return getValuesFromSegments(deviceTypeAttr, operands, segments,
146 deviceType)
147 .front();
148 return {};
149}
150
151static mlir::Operation::operand_range
152getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
153 mlir::Operation::operand_range operands,
154 std::optional<llvm::ArrayRef<int32_t>> segments,
155 std::optional<mlir::ArrayAttr> hasWaitDevnum,
156 mlir::acc::DeviceType deviceType) {
157 auto range =
158 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
159 if (range.empty())
160 return range;
161 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
162 if (hasWaitDevnum && *hasWaitDevnum) {
163 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
164 if (boolAttr.getValue())
165 return range.drop_front(1); // first value is devnum
166 }
167 }
168 return range;
169}
170
171template <typename Op>
172static LogicalResult checkWaitAndAsyncConflict(Op op) {
173 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
174 ++dtypeInt) {
175 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
176
177 // The async attribute represent the async clause without value. Therefore
178 // the attribute and operand cannot appear at the same time.
179 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
180 op.hasAsyncOnly(dtype))
181 return op.emitError("async attribute cannot appear with asyncOperand");
182
183 // The wait attribute represent the wait clause without values. Therefore
184 // the attribute and operands cannot appear at the same time.
185 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
186 op.hasWaitOnly(dtype))
187 return op.emitError("wait attribute cannot appear with waitOperands");
188 }
189 return success();
190}
191
192//===----------------------------------------------------------------------===//
193// DataBoundsOp
194//===----------------------------------------------------------------------===//
195LogicalResult acc::DataBoundsOp::verify() {
196 auto extent = getExtent();
197 auto upperbound = getUpperbound();
198 if (!extent && !upperbound)
199 return emitError("expected extent or upperbound.");
200 return success();
201}
202
203//===----------------------------------------------------------------------===//
204// PrivateOp
205//===----------------------------------------------------------------------===//
206LogicalResult acc::PrivateOp::verify() {
207 if (getDataClause() != acc::DataClause::acc_private)
208 return emitError(
209 "data clause associated with private operation must match its intent");
210 return success();
211}
212
213//===----------------------------------------------------------------------===//
214// FirstprivateOp
215//===----------------------------------------------------------------------===//
216LogicalResult acc::FirstprivateOp::verify() {
217 if (getDataClause() != acc::DataClause::acc_firstprivate)
218 return emitError("data clause associated with firstprivate operation must "
219 "match its intent");
220 return success();
221}
222
223//===----------------------------------------------------------------------===//
224// ReductionOp
225//===----------------------------------------------------------------------===//
226LogicalResult acc::ReductionOp::verify() {
227 if (getDataClause() != acc::DataClause::acc_reduction)
228 return emitError("data clause associated with reduction operation must "
229 "match its intent");
230 return success();
231}
232
233//===----------------------------------------------------------------------===//
234// DevicePtrOp
235//===----------------------------------------------------------------------===//
236LogicalResult acc::DevicePtrOp::verify() {
237 if (getDataClause() != acc::DataClause::acc_deviceptr)
238 return emitError("data clause associated with deviceptr operation must "
239 "match its intent");
240 return success();
241}
242
243//===----------------------------------------------------------------------===//
244// PresentOp
245//===----------------------------------------------------------------------===//
246LogicalResult acc::PresentOp::verify() {
247 if (getDataClause() != acc::DataClause::acc_present)
248 return emitError(
249 "data clause associated with present operation must match its intent");
250 return success();
251}
252
253//===----------------------------------------------------------------------===//
254// CopyinOp
255//===----------------------------------------------------------------------===//
256LogicalResult acc::CopyinOp::verify() {
257 // Test for all clauses this operation can be decomposed from:
258 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
259 getDataClause() != acc::DataClause::acc_copyin_readonly &&
260 getDataClause() != acc::DataClause::acc_copy &&
261 getDataClause() != acc::DataClause::acc_reduction)
262 return emitError(
263 "data clause associated with copyin operation must match its intent"
264 " or specify original clause this operation was decomposed from");
265 return success();
266}
267
268bool acc::CopyinOp::isCopyinReadonly() {
269 return getDataClause() == acc::DataClause::acc_copyin_readonly;
270}
271
272//===----------------------------------------------------------------------===//
273// CreateOp
274//===----------------------------------------------------------------------===//
275LogicalResult acc::CreateOp::verify() {
276 // Test for all clauses this operation can be decomposed from:
277 if (getDataClause() != acc::DataClause::acc_create &&
278 getDataClause() != acc::DataClause::acc_create_zero &&
279 getDataClause() != acc::DataClause::acc_copyout &&
280 getDataClause() != acc::DataClause::acc_copyout_zero)
281 return emitError(
282 "data clause associated with create operation must match its intent"
283 " or specify original clause this operation was decomposed from");
284 return success();
285}
286
287bool acc::CreateOp::isCreateZero() {
288 // The zero modifier is encoded in the data clause.
289 return getDataClause() == acc::DataClause::acc_create_zero ||
290 getDataClause() == acc::DataClause::acc_copyout_zero;
291}
292
293//===----------------------------------------------------------------------===//
294// NoCreateOp
295//===----------------------------------------------------------------------===//
296LogicalResult acc::NoCreateOp::verify() {
297 if (getDataClause() != acc::DataClause::acc_no_create)
298 return emitError("data clause associated with no_create operation must "
299 "match its intent");
300 return success();
301}
302
303//===----------------------------------------------------------------------===//
304// AttachOp
305//===----------------------------------------------------------------------===//
306LogicalResult acc::AttachOp::verify() {
307 if (getDataClause() != acc::DataClause::acc_attach)
308 return emitError(
309 "data clause associated with attach operation must match its intent");
310 return success();
311}
312
313//===----------------------------------------------------------------------===//
314// DeclareDeviceResidentOp
315//===----------------------------------------------------------------------===//
316
317LogicalResult acc::DeclareDeviceResidentOp::verify() {
318 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
319 return emitError("data clause associated with device_resident operation "
320 "must match its intent");
321 return success();
322}
323
324//===----------------------------------------------------------------------===//
325// DeclareLinkOp
326//===----------------------------------------------------------------------===//
327
328LogicalResult acc::DeclareLinkOp::verify() {
329 if (getDataClause() != acc::DataClause::acc_declare_link)
330 return emitError(
331 "data clause associated with link operation must match its intent");
332 return success();
333}
334
335//===----------------------------------------------------------------------===//
336// CopyoutOp
337//===----------------------------------------------------------------------===//
338LogicalResult acc::CopyoutOp::verify() {
339 // Test for all clauses this operation can be decomposed from:
340 if (getDataClause() != acc::DataClause::acc_copyout &&
341 getDataClause() != acc::DataClause::acc_copyout_zero &&
342 getDataClause() != acc::DataClause::acc_copy &&
343 getDataClause() != acc::DataClause::acc_reduction)
344 return emitError(
345 "data clause associated with copyout operation must match its intent"
346 " or specify original clause this operation was decomposed from");
347 if (!getVarPtr() || !getAccPtr())
348 return emitError("must have both host and device pointers");
349 return success();
350}
351
352bool acc::CopyoutOp::isCopyoutZero() {
353 return getDataClause() == acc::DataClause::acc_copyout_zero;
354}
355
356//===----------------------------------------------------------------------===//
357// DeleteOp
358//===----------------------------------------------------------------------===//
359LogicalResult acc::DeleteOp::verify() {
360 // Test for all clauses this operation can be decomposed from:
361 if (getDataClause() != acc::DataClause::acc_delete &&
362 getDataClause() != acc::DataClause::acc_create &&
363 getDataClause() != acc::DataClause::acc_create_zero &&
364 getDataClause() != acc::DataClause::acc_copyin &&
365 getDataClause() != acc::DataClause::acc_copyin_readonly &&
366 getDataClause() != acc::DataClause::acc_present &&
367 getDataClause() != acc::DataClause::acc_declare_device_resident &&
368 getDataClause() != acc::DataClause::acc_declare_link)
369 return emitError(
370 "data clause associated with delete operation must match its intent"
371 " or specify original clause this operation was decomposed from");
372 if (!getAccPtr())
373 return emitError("must have device pointer");
374 return success();
375}
376
377//===----------------------------------------------------------------------===//
378// DetachOp
379//===----------------------------------------------------------------------===//
380LogicalResult acc::DetachOp::verify() {
381 // Test for all clauses this operation can be decomposed from:
382 if (getDataClause() != acc::DataClause::acc_detach &&
383 getDataClause() != acc::DataClause::acc_attach)
384 return emitError(
385 "data clause associated with detach operation must match its intent"
386 " or specify original clause this operation was decomposed from");
387 if (!getAccPtr())
388 return emitError("must have device pointer");
389 return success();
390}
391
392//===----------------------------------------------------------------------===//
393// HostOp
394//===----------------------------------------------------------------------===//
395LogicalResult acc::UpdateHostOp::verify() {
396 // Test for all clauses this operation can be decomposed from:
397 if (getDataClause() != acc::DataClause::acc_update_host &&
398 getDataClause() != acc::DataClause::acc_update_self)
399 return emitError(
400 "data clause associated with host operation must match its intent"
401 " or specify original clause this operation was decomposed from");
402 if (!getVarPtr() || !getAccPtr())
403 return emitError("must have both host and device pointers");
404 return success();
405}
406
407//===----------------------------------------------------------------------===//
408// DeviceOp
409//===----------------------------------------------------------------------===//
410LogicalResult acc::UpdateDeviceOp::verify() {
411 // Test for all clauses this operation can be decomposed from:
412 if (getDataClause() != acc::DataClause::acc_update_device)
413 return emitError(
414 "data clause associated with device operation must match its intent"
415 " or specify original clause this operation was decomposed from");
416 return success();
417}
418
419//===----------------------------------------------------------------------===//
420// UseDeviceOp
421//===----------------------------------------------------------------------===//
422LogicalResult acc::UseDeviceOp::verify() {
423 // Test for all clauses this operation can be decomposed from:
424 if (getDataClause() != acc::DataClause::acc_use_device)
425 return emitError(
426 "data clause associated with use_device operation must match its intent"
427 " or specify original clause this operation was decomposed from");
428 return success();
429}
430
431//===----------------------------------------------------------------------===//
432// CacheOp
433//===----------------------------------------------------------------------===//
434LogicalResult acc::CacheOp::verify() {
435 // Test for all clauses this operation can be decomposed from:
436 if (getDataClause() != acc::DataClause::acc_cache &&
437 getDataClause() != acc::DataClause::acc_cache_readonly)
438 return emitError(
439 "data clause associated with cache operation must match its intent"
440 " or specify original clause this operation was decomposed from");
441 return success();
442}
443
444template <typename StructureOp>
445static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
446 unsigned nRegions = 1) {
447
448 SmallVector<Region *, 2> regions;
449 for (unsigned i = 0; i < nRegions; ++i)
450 regions.push_back(Elt: state.addRegion());
451
452 for (Region *region : regions)
453 if (parser.parseRegion(region&: *region, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
454 return failure();
455
456 return success();
457}
458
459static bool isComputeOperation(Operation *op) {
460 return isa<acc::ParallelOp, acc::LoopOp>(op);
461}
462
463namespace {
464/// Pattern to remove operation without region that have constant false `ifCond`
465/// and remove the condition from the operation if the `ifCond` is a true
466/// constant.
467template <typename OpTy>
468struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
469 using OpRewritePattern<OpTy>::OpRewritePattern;
470
471 LogicalResult matchAndRewrite(OpTy op,
472 PatternRewriter &rewriter) const override {
473 // Early return if there is no condition.
474 Value ifCond = op.getIfCond();
475 if (!ifCond)
476 return failure();
477
478 IntegerAttr constAttr;
479 if (!matchPattern(ifCond, m_Constant(&constAttr)))
480 return failure();
481 if (constAttr.getInt())
482 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
483 else
484 rewriter.eraseOp(op);
485
486 return success();
487 }
488};
489
490/// Replaces the given op with the contents of the given single-block region,
491/// using the operands of the block terminator to replace operation results.
492static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
493 Region &region, ValueRange blockArgs = {}) {
494 assert(llvm::hasSingleElement(region) && "expected single-region block");
495 Block *block = &region.front();
496 Operation *terminator = block->getTerminator();
497 ValueRange results = terminator->getOperands();
498 rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs);
499 rewriter.replaceOp(op, newValues: results);
500 rewriter.eraseOp(op: terminator);
501}
502
503/// Pattern to remove operation with region that have constant false `ifCond`
504/// and remove the condition from the operation if the `ifCond` is constant
505/// true.
506template <typename OpTy>
507struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
508 using OpRewritePattern<OpTy>::OpRewritePattern;
509
510 LogicalResult matchAndRewrite(OpTy op,
511 PatternRewriter &rewriter) const override {
512 // Early return if there is no condition.
513 Value ifCond = op.getIfCond();
514 if (!ifCond)
515 return failure();
516
517 IntegerAttr constAttr;
518 if (!matchPattern(ifCond, m_Constant(&constAttr)))
519 return failure();
520 if (constAttr.getInt())
521 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
522 else
523 replaceOpWithRegion(rewriter, op, op.getRegion());
524
525 return success();
526 }
527};
528
529} // namespace
530
531//===----------------------------------------------------------------------===//
532// PrivateRecipeOp
533//===----------------------------------------------------------------------===//
534
535static LogicalResult verifyInitLikeSingleArgRegion(
536 Operation *op, Region &region, StringRef regionType, StringRef regionName,
537 Type type, bool verifyYield, bool optional = false) {
538 if (optional && region.empty())
539 return success();
540
541 if (region.empty())
542 return op->emitOpError() << "expects non-empty " << regionName << " region";
543 Block &firstBlock = region.front();
544 if (firstBlock.getNumArguments() < 1 ||
545 firstBlock.getArgument(i: 0).getType() != type)
546 return op->emitOpError() << "expects " << regionName
547 << " region first "
548 "argument of the "
549 << regionType << " type";
550
551 if (verifyYield) {
552 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
553 if (yieldOp.getOperands().size() != 1 ||
554 yieldOp.getOperands().getTypes()[0] != type)
555 return op->emitOpError() << "expects " << regionName
556 << " region to "
557 "yield a value of the "
558 << regionType << " type";
559 }
560 }
561 return success();
562}
563
564LogicalResult acc::PrivateRecipeOp::verifyRegions() {
565 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
566 "privatization", "init", getType(),
567 /*verifyYield=*/false)))
568 return failure();
569 if (failed(verifyInitLikeSingleArgRegion(
570 *this, getDestroyRegion(), "privatization", "destroy", getType(),
571 /*verifyYield=*/false, /*optional=*/true)))
572 return failure();
573 return success();
574}
575
576//===----------------------------------------------------------------------===//
577// FirstprivateRecipeOp
578//===----------------------------------------------------------------------===//
579
580LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
581 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
582 "privatization", "init", getType(),
583 /*verifyYield=*/false)))
584 return failure();
585
586 if (getCopyRegion().empty())
587 return emitOpError() << "expects non-empty copy region";
588
589 Block &firstBlock = getCopyRegion().front();
590 if (firstBlock.getNumArguments() < 2 ||
591 firstBlock.getArgument(0).getType() != getType())
592 return emitOpError() << "expects copy region with two arguments of the "
593 "privatization type";
594
595 if (getDestroyRegion().empty())
596 return success();
597
598 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
599 "privatization", "destroy",
600 getType(), /*verifyYield=*/false)))
601 return failure();
602
603 return success();
604}
605
606//===----------------------------------------------------------------------===//
607// ReductionRecipeOp
608//===----------------------------------------------------------------------===//
609
610LogicalResult acc::ReductionRecipeOp::verifyRegions() {
611 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
612 "init", getType(),
613 /*verifyYield=*/false)))
614 return failure();
615
616 if (getCombinerRegion().empty())
617 return emitOpError() << "expects non-empty combiner region";
618
619 Block &reductionBlock = getCombinerRegion().front();
620 if (reductionBlock.getNumArguments() < 2 ||
621 reductionBlock.getArgument(0).getType() != getType() ||
622 reductionBlock.getArgument(1).getType() != getType())
623 return emitOpError() << "expects combiner region with the first two "
624 << "arguments of the reduction type";
625
626 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
627 if (yieldOp.getOperands().size() != 1 ||
628 yieldOp.getOperands().getTypes()[0] != getType())
629 return emitOpError() << "expects combiner region to yield a value "
630 "of the reduction type";
631 }
632
633 return success();
634}
635
636//===----------------------------------------------------------------------===//
637// Custom parser and printer verifier for private clause
638//===----------------------------------------------------------------------===//
639
640static ParseResult parseSymOperandList(
641 mlir::OpAsmParser &parser,
642 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
643 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
644 llvm::SmallVector<SymbolRefAttr> attributes;
645 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
646 if (parser.parseAttribute(attributes.emplace_back()) ||
647 parser.parseArrow() ||
648 parser.parseOperand(result&: operands.emplace_back()) ||
649 parser.parseColonType(result&: types.emplace_back()))
650 return failure();
651 return success();
652 })))
653 return failure();
654 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
655 attributes.end());
656 symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
657 return success();
658}
659
660static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
661 mlir::OperandRange operands,
662 mlir::TypeRange types,
663 std::optional<mlir::ArrayAttr> attributes) {
664 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
665 p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
666 << std::get<1>(it).getType();
667 });
668}
669
670//===----------------------------------------------------------------------===//
671// ParallelOp
672//===----------------------------------------------------------------------===//
673
674/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
675template <typename Op>
676static LogicalResult checkDataOperands(Op op,
677 const mlir::ValueRange &operands) {
678 for (mlir::Value operand : operands)
679 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
680 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
681 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
682 operand.getDefiningOp()))
683 return op.emitError(
684 "expect data entry/exit operation or acc.getdeviceptr "
685 "as defining op");
686 return success();
687}
688
689template <typename Op>
690static LogicalResult
691checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
692 mlir::OperandRange operands, llvm::StringRef operandName,
693 llvm::StringRef symbolName, bool checkOperandType = true) {
694 if (!operands.empty()) {
695 if (!attributes || attributes->size() != operands.size())
696 return op->emitOpError()
697 << "expected as many " << symbolName << " symbol reference as "
698 << operandName << " operands";
699 } else {
700 if (attributes)
701 return op->emitOpError()
702 << "unexpected " << symbolName << " symbol reference";
703 return success();
704 }
705
706 llvm::DenseSet<Value> set;
707 for (auto args : llvm::zip(operands, *attributes)) {
708 mlir::Value operand = std::get<0>(args);
709
710 if (!set.insert(operand).second)
711 return op->emitOpError()
712 << operandName << " operand appears more than once";
713
714 mlir::Type varType = operand.getType();
715 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
716 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
717 if (!decl)
718 return op->emitOpError()
719 << "expected symbol reference " << symbolRef << " to point to a "
720 << operandName << " declaration";
721
722 if (checkOperandType && decl.getType() && decl.getType() != varType)
723 return op->emitOpError() << "expected " << operandName << " (" << varType
724 << ") to be the same type as " << operandName
725 << " declaration (" << decl.getType() << ")";
726 }
727
728 return success();
729}
730
731unsigned ParallelOp::getNumDataOperands() {
732 return getReductionOperands().size() + getGangPrivateOperands().size() +
733 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
734}
735
736Value ParallelOp::getDataOperand(unsigned i) {
737 unsigned numOptional = getAsyncOperands().size();
738 numOptional += getNumGangs().size();
739 numOptional += getNumWorkers().size();
740 numOptional += getVectorLength().size();
741 numOptional += getIfCond() ? 1 : 0;
742 numOptional += getSelfCond() ? 1 : 0;
743 return getOperand(getWaitOperands().size() + numOptional + i);
744}
745
746template <typename Op>
747static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
748 ArrayAttr deviceTypes,
749 llvm::StringRef keyword) {
750 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
751 return op.emitOpError() << keyword << " operands count must match "
752 << keyword << " device_type count";
753 return success();
754}
755
756template <typename Op>
757static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
758 Op op, OperandRange operands, DenseI32ArrayAttr segments,
759 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
760 std::size_t numOperandsInSegments = 0;
761
762 if (!segments)
763 return success();
764
765 for (auto segCount : segments.asArrayRef()) {
766 if (maxInSegment != 0 && segCount > maxInSegment)
767 return op.emitOpError() << keyword << " expects a maximum of "
768 << maxInSegment << " values per segment";
769 numOperandsInSegments += segCount;
770 }
771 if (numOperandsInSegments != operands.size())
772 return op.emitOpError()
773 << keyword << " operand count does not match count in segments";
774 if (deviceTypes.getValue().size() != (size_t)segments.size())
775 return op.emitOpError()
776 << keyword << " segment count does not match device_type count";
777 return success();
778}
779
780LogicalResult acc::ParallelOp::verify() {
781 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
782 *this, getPrivatizations(), getGangPrivateOperands(), "private",
783 "privatizations", /*checkOperandType=*/false)))
784 return failure();
785 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
786 *this, getReductionRecipes(), getReductionOperands(), "reduction",
787 "reductions", false)))
788 return failure();
789
790 if (failed(verifyDeviceTypeAndSegmentCountMatch(
791 *this, getNumGangs(), getNumGangsSegmentsAttr(),
792 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
793 return failure();
794
795 if (failed(verifyDeviceTypeAndSegmentCountMatch(
796 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
797 getWaitOperandsDeviceTypeAttr(), "wait")))
798 return failure();
799
800 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
801 getNumWorkersDeviceTypeAttr(),
802 "num_workers")))
803 return failure();
804
805 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
806 getVectorLengthDeviceTypeAttr(),
807 "vector_length")))
808 return failure();
809
810 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
811 getAsyncOperandsDeviceTypeAttr(),
812 "async")))
813 return failure();
814
815 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
816 return failure();
817
818 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
819}
820
821static mlir::Value
822getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
823 mlir::Operation::operand_range range,
824 mlir::acc::DeviceType deviceType) {
825 if (!arrayAttr)
826 return {};
827 if (auto pos = findSegment(*arrayAttr, deviceType))
828 return range[*pos];
829 return {};
830}
831
832bool acc::ParallelOp::hasAsyncOnly() {
833 return hasAsyncOnly(mlir::acc::DeviceType::None);
834}
835
836bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
837 return hasDeviceType(getAsyncOnly(), deviceType);
838}
839
840mlir::Value acc::ParallelOp::getAsyncValue() {
841 return getAsyncValue(mlir::acc::DeviceType::None);
842}
843
844mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
845 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
846 getAsyncOperands(), deviceType);
847}
848
849mlir::Value acc::ParallelOp::getNumWorkersValue() {
850 return getNumWorkersValue(mlir::acc::DeviceType::None);
851}
852
853mlir::Value
854acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
855 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
856 deviceType);
857}
858
859mlir::Value acc::ParallelOp::getVectorLengthValue() {
860 return getVectorLengthValue(mlir::acc::DeviceType::None);
861}
862
863mlir::Value
864acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
865 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
866 getVectorLength(), deviceType);
867}
868
869mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
870 return getNumGangsValues(mlir::acc::DeviceType::None);
871}
872
873mlir::Operation::operand_range
874ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
875 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
876 getNumGangsSegments(), deviceType);
877}
878
879bool acc::ParallelOp::hasWaitOnly() {
880 return hasWaitOnly(mlir::acc::DeviceType::None);
881}
882
883bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
884 return hasDeviceType(getWaitOnly(), deviceType);
885}
886
887mlir::Operation::operand_range ParallelOp::getWaitValues() {
888 return getWaitValues(mlir::acc::DeviceType::None);
889}
890
891mlir::Operation::operand_range
892ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
893 return getWaitValuesWithoutDevnum(
894 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
895 getHasWaitDevnum(), deviceType);
896}
897
898mlir::Value ParallelOp::getWaitDevnum() {
899 return getWaitDevnum(mlir::acc::DeviceType::None);
900}
901
902mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
903 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
904 getWaitOperandsSegments(), getHasWaitDevnum(),
905 deviceType);
906}
907
908static ParseResult parseNumGangs(
909 mlir::OpAsmParser &parser,
910 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
911 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
912 mlir::DenseI32ArrayAttr &segments) {
913 llvm::SmallVector<DeviceTypeAttr> attributes;
914 llvm::SmallVector<int32_t> seg;
915
916 do {
917 if (failed(result: parser.parseLBrace()))
918 return failure();
919
920 int32_t crtOperandsSize = operands.size();
921 if (failed(result: parser.parseCommaSeparatedList(
922 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
923 if (parser.parseOperand(result&: operands.emplace_back()) ||
924 parser.parseColonType(result&: types.emplace_back()))
925 return failure();
926 return success();
927 })))
928 return failure();
929 seg.push_back(Elt: operands.size() - crtOperandsSize);
930
931 if (failed(result: parser.parseRBrace()))
932 return failure();
933
934 if (succeeded(result: parser.parseOptionalLSquare())) {
935 if (parser.parseAttribute(attributes.emplace_back()) ||
936 parser.parseRSquare())
937 return failure();
938 } else {
939 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
940 parser.getContext(), mlir::acc::DeviceType::None));
941 }
942 } while (succeeded(result: parser.parseOptionalComma()));
943
944 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
945 attributes.end());
946 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
947 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
948
949 return success();
950}
951
952static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
953 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
954 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
955 p << " [" << attr << "]";
956}
957
958static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
959 mlir::OperandRange operands, mlir::TypeRange types,
960 std::optional<mlir::ArrayAttr> deviceTypes,
961 std::optional<mlir::DenseI32ArrayAttr> segments) {
962 unsigned opIdx = 0;
963 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
964 p << "{";
965 llvm::interleaveComma(
966 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
967 p << operands[opIdx] << " : " << operands[opIdx].getType();
968 ++opIdx;
969 });
970 p << "}";
971 printSingleDeviceType(p, it.value());
972 });
973}
974
975static ParseResult parseDeviceTypeOperandsWithSegment(
976 mlir::OpAsmParser &parser,
977 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
978 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
979 mlir::DenseI32ArrayAttr &segments) {
980 llvm::SmallVector<DeviceTypeAttr> attributes;
981 llvm::SmallVector<int32_t> seg;
982
983 do {
984 if (failed(result: parser.parseLBrace()))
985 return failure();
986
987 int32_t crtOperandsSize = operands.size();
988
989 if (failed(result: parser.parseCommaSeparatedList(
990 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
991 if (parser.parseOperand(result&: operands.emplace_back()) ||
992 parser.parseColonType(result&: types.emplace_back()))
993 return failure();
994 return success();
995 })))
996 return failure();
997
998 seg.push_back(Elt: operands.size() - crtOperandsSize);
999
1000 if (failed(result: parser.parseRBrace()))
1001 return failure();
1002
1003 if (succeeded(result: parser.parseOptionalLSquare())) {
1004 if (parser.parseAttribute(attributes.emplace_back()) ||
1005 parser.parseRSquare())
1006 return failure();
1007 } else {
1008 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1009 parser.getContext(), mlir::acc::DeviceType::None));
1010 }
1011 } while (succeeded(result: parser.parseOptionalComma()));
1012
1013 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1014 attributes.end());
1015 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1016 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1017
1018 return success();
1019}
1020
1021static void printDeviceTypeOperandsWithSegment(
1022 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1023 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1024 std::optional<mlir::DenseI32ArrayAttr> segments) {
1025 unsigned opIdx = 0;
1026 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1027 p << "{";
1028 llvm::interleaveComma(
1029 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1030 p << operands[opIdx] << " : " << operands[opIdx].getType();
1031 ++opIdx;
1032 });
1033 p << "}";
1034 printSingleDeviceType(p, it.value());
1035 });
1036}
1037
1038static ParseResult parseWaitClause(
1039 mlir::OpAsmParser &parser,
1040 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1041 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1042 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1043 mlir::ArrayAttr &keywordOnly) {
1044 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1045 llvm::SmallVector<int32_t> seg;
1046
1047 bool needCommaBeforeOperands = false;
1048
1049 // Keyword only
1050 if (failed(result: parser.parseOptionalLParen())) {
1051 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1052 parser.getContext(), mlir::acc::DeviceType::None));
1053 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1054 return success();
1055 }
1056
1057 // Parse keyword only attributes
1058 if (succeeded(result: parser.parseOptionalLSquare())) {
1059 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1060 if (parser.parseAttribute(result&: keywordAttrs.emplace_back()))
1061 return failure();
1062 return success();
1063 })))
1064 return failure();
1065 if (parser.parseRSquare())
1066 return failure();
1067 needCommaBeforeOperands = true;
1068 }
1069
1070 if (needCommaBeforeOperands && failed(result: parser.parseComma()))
1071 return failure();
1072
1073 do {
1074 if (failed(result: parser.parseLBrace()))
1075 return failure();
1076
1077 int32_t crtOperandsSize = operands.size();
1078
1079 if (succeeded(result: parser.parseOptionalKeyword(keyword: "devnum"))) {
1080 if (failed(result: parser.parseColon()))
1081 return failure();
1082 devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: true));
1083 } else {
1084 devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: false));
1085 }
1086
1087 if (failed(result: parser.parseCommaSeparatedList(
1088 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
1089 if (parser.parseOperand(result&: operands.emplace_back()) ||
1090 parser.parseColonType(result&: types.emplace_back()))
1091 return failure();
1092 return success();
1093 })))
1094 return failure();
1095
1096 seg.push_back(Elt: operands.size() - crtOperandsSize);
1097
1098 if (failed(result: parser.parseRBrace()))
1099 return failure();
1100
1101 if (succeeded(result: parser.parseOptionalLSquare())) {
1102 if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) ||
1103 parser.parseRSquare())
1104 return failure();
1105 } else {
1106 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1107 parser.getContext(), mlir::acc::DeviceType::None));
1108 }
1109 } while (succeeded(result: parser.parseOptionalComma()));
1110
1111 if (failed(result: parser.parseRParen()))
1112 return failure();
1113
1114 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1115 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1116 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1117 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1118
1119 return success();
1120}
1121
1122static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1123 if (!hasDeviceTypeValues(arrayAttr: attrs))
1124 return false;
1125 if (attrs->size() != 1)
1126 return false;
1127 if (auto deviceTypeAttr =
1128 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1129 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1130 return false;
1131}
1132
1133static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
1134 mlir::OperandRange operands, mlir::TypeRange types,
1135 std::optional<mlir::ArrayAttr> deviceTypes,
1136 std::optional<mlir::DenseI32ArrayAttr> segments,
1137 std::optional<mlir::ArrayAttr> hasDevNum,
1138 std::optional<mlir::ArrayAttr> keywordOnly) {
1139
1140 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(attrs: keywordOnly))
1141 return;
1142
1143 p << "(";
1144
1145 printDeviceTypes(p, deviceTypes: keywordOnly);
1146 if (hasDeviceTypeValues(arrayAttr: keywordOnly) && hasDeviceTypeValues(arrayAttr: deviceTypes))
1147 p << ", ";
1148
1149 unsigned opIdx = 0;
1150 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1151 p << "{";
1152 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1153 if (boolAttr && boolAttr.getValue())
1154 p << "devnum: ";
1155 llvm::interleaveComma(
1156 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1157 p << operands[opIdx] << " : " << operands[opIdx].getType();
1158 ++opIdx;
1159 });
1160 p << "}";
1161 printSingleDeviceType(p, it.value());
1162 });
1163
1164 p << ")";
1165}
1166
1167static ParseResult parseDeviceTypeOperands(
1168 mlir::OpAsmParser &parser,
1169 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1170 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1171 llvm::SmallVector<DeviceTypeAttr> attributes;
1172 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1173 if (parser.parseOperand(result&: operands.emplace_back()) ||
1174 parser.parseColonType(result&: types.emplace_back()))
1175 return failure();
1176 if (succeeded(result: parser.parseOptionalLSquare())) {
1177 if (parser.parseAttribute(attributes.emplace_back()) ||
1178 parser.parseRSquare())
1179 return failure();
1180 } else {
1181 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1182 parser.getContext(), mlir::acc::DeviceType::None));
1183 }
1184 return success();
1185 })))
1186 return failure();
1187 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1188 attributes.end());
1189 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1190 return success();
1191}
1192
1193static void
1194printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
1195 mlir::OperandRange operands, mlir::TypeRange types,
1196 std::optional<mlir::ArrayAttr> deviceTypes) {
1197 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
1198 return;
1199 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1200 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1201 printSingleDeviceType(p, std::get<0>(it));
1202 });
1203}
1204
1205static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
1206 mlir::OpAsmParser &parser,
1207 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1208 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1209 mlir::ArrayAttr &keywordOnlyDeviceType) {
1210
1211 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1212 bool needCommaBeforeOperands = false;
1213
1214 if (failed(result: parser.parseOptionalLParen())) {
1215 // Keyword only
1216 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1217 parser.getContext(), mlir::acc::DeviceType::None));
1218 keywordOnlyDeviceType =
1219 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1220 return success();
1221 }
1222
1223 // Parse keyword only attributes
1224 if (succeeded(result: parser.parseOptionalLSquare())) {
1225 // Parse keyword only attributes
1226 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1227 if (parser.parseAttribute(
1228 result&: keywordOnlyDeviceTypeAttributes.emplace_back()))
1229 return failure();
1230 return success();
1231 })))
1232 return failure();
1233 if (parser.parseRSquare())
1234 return failure();
1235 needCommaBeforeOperands = true;
1236 }
1237
1238 if (needCommaBeforeOperands && failed(result: parser.parseComma()))
1239 return failure();
1240
1241 llvm::SmallVector<DeviceTypeAttr> attributes;
1242 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1243 if (parser.parseOperand(result&: operands.emplace_back()) ||
1244 parser.parseColonType(result&: types.emplace_back()))
1245 return failure();
1246 if (succeeded(result: parser.parseOptionalLSquare())) {
1247 if (parser.parseAttribute(attributes.emplace_back()) ||
1248 parser.parseRSquare())
1249 return failure();
1250 } else {
1251 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1252 parser.getContext(), mlir::acc::DeviceType::None));
1253 }
1254 return success();
1255 })))
1256 return failure();
1257
1258 if (failed(result: parser.parseRParen()))
1259 return failure();
1260
1261 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1262 attributes.end());
1263 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1264 return success();
1265}
1266
1267static void printDeviceTypeOperandsWithKeywordOnly(
1268 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1269 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1270 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1271
1272 if (operands.begin() == operands.end() &&
1273 hasOnlyDeviceTypeNone(attrs: keywordOnlyDeviceTypes)) {
1274 return;
1275 }
1276
1277 p << "(";
1278 printDeviceTypes(p, deviceTypes: keywordOnlyDeviceTypes);
1279 if (hasDeviceTypeValues(arrayAttr: keywordOnlyDeviceTypes) &&
1280 hasDeviceTypeValues(arrayAttr: deviceTypes))
1281 p << ", ";
1282 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1283 p << ")";
1284}
1285
1286static ParseResult
1287parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
1288 mlir::acc::CombinedConstructsTypeAttr &attr) {
1289 if (succeeded(result: parser.parseOptionalKeyword(keyword: "combined"))) {
1290 if (parser.parseLParen())
1291 return failure();
1292 if (succeeded(result: parser.parseOptionalKeyword(keyword: "kernels"))) {
1293 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1294 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1295 } else if (succeeded(result: parser.parseOptionalKeyword(keyword: "parallel"))) {
1296 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1297 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1298 } else if (succeeded(result: parser.parseOptionalKeyword(keyword: "serial"))) {
1299 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1300 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1301 } else {
1302 parser.emitError(loc: parser.getCurrentLocation(),
1303 message: "expected compute construct name");
1304 return failure();
1305 }
1306 if (parser.parseRParen())
1307 return failure();
1308 }
1309 return success();
1310}
1311
1312static void
1313printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
1314 mlir::acc::CombinedConstructsTypeAttr attr) {
1315 if (attr) {
1316 switch (attr.getValue()) {
1317 case mlir::acc::CombinedConstructsType::KernelsLoop:
1318 p << "combined(kernels)";
1319 break;
1320 case mlir::acc::CombinedConstructsType::ParallelLoop:
1321 p << "combined(parallel)";
1322 break;
1323 case mlir::acc::CombinedConstructsType::SerialLoop:
1324 p << "combined(serial)";
1325 break;
1326 };
1327 }
1328}
1329
1330//===----------------------------------------------------------------------===//
1331// SerialOp
1332//===----------------------------------------------------------------------===//
1333
1334unsigned SerialOp::getNumDataOperands() {
1335 return getReductionOperands().size() + getGangPrivateOperands().size() +
1336 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
1337}
1338
1339Value SerialOp::getDataOperand(unsigned i) {
1340 unsigned numOptional = getAsyncOperands().size();
1341 numOptional += getIfCond() ? 1 : 0;
1342 numOptional += getSelfCond() ? 1 : 0;
1343 return getOperand(getWaitOperands().size() + numOptional + i);
1344}
1345
1346bool acc::SerialOp::hasAsyncOnly() {
1347 return hasAsyncOnly(mlir::acc::DeviceType::None);
1348}
1349
1350bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1351 return hasDeviceType(getAsyncOnly(), deviceType);
1352}
1353
1354mlir::Value acc::SerialOp::getAsyncValue() {
1355 return getAsyncValue(mlir::acc::DeviceType::None);
1356}
1357
1358mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1359 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1360 getAsyncOperands(), deviceType);
1361}
1362
1363bool acc::SerialOp::hasWaitOnly() {
1364 return hasWaitOnly(mlir::acc::DeviceType::None);
1365}
1366
1367bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1368 return hasDeviceType(getWaitOnly(), deviceType);
1369}
1370
1371mlir::Operation::operand_range SerialOp::getWaitValues() {
1372 return getWaitValues(mlir::acc::DeviceType::None);
1373}
1374
1375mlir::Operation::operand_range
1376SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1377 return getWaitValuesWithoutDevnum(
1378 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1379 getHasWaitDevnum(), deviceType);
1380}
1381
1382mlir::Value SerialOp::getWaitDevnum() {
1383 return getWaitDevnum(mlir::acc::DeviceType::None);
1384}
1385
1386mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1387 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1388 getWaitOperandsSegments(), getHasWaitDevnum(),
1389 deviceType);
1390}
1391
1392LogicalResult acc::SerialOp::verify() {
1393 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1394 *this, getPrivatizations(), getGangPrivateOperands(), "private",
1395 "privatizations", /*checkOperandType=*/false)))
1396 return failure();
1397 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1398 *this, getReductionRecipes(), getReductionOperands(), "reduction",
1399 "reductions", false)))
1400 return failure();
1401
1402 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1403 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1404 getWaitOperandsDeviceTypeAttr(), "wait")))
1405 return failure();
1406
1407 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1408 getAsyncOperandsDeviceTypeAttr(),
1409 "async")))
1410 return failure();
1411
1412 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1413 return failure();
1414
1415 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1416}
1417
1418//===----------------------------------------------------------------------===//
1419// KernelsOp
1420//===----------------------------------------------------------------------===//
1421
1422unsigned KernelsOp::getNumDataOperands() {
1423 return getDataClauseOperands().size();
1424}
1425
1426Value KernelsOp::getDataOperand(unsigned i) {
1427 unsigned numOptional = getAsyncOperands().size();
1428 numOptional += getWaitOperands().size();
1429 numOptional += getNumGangs().size();
1430 numOptional += getNumWorkers().size();
1431 numOptional += getVectorLength().size();
1432 numOptional += getIfCond() ? 1 : 0;
1433 numOptional += getSelfCond() ? 1 : 0;
1434 return getOperand(numOptional + i);
1435}
1436
1437bool acc::KernelsOp::hasAsyncOnly() {
1438 return hasAsyncOnly(mlir::acc::DeviceType::None);
1439}
1440
1441bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1442 return hasDeviceType(getAsyncOnly(), deviceType);
1443}
1444
1445mlir::Value acc::KernelsOp::getAsyncValue() {
1446 return getAsyncValue(mlir::acc::DeviceType::None);
1447}
1448
1449mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1450 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1451 getAsyncOperands(), deviceType);
1452}
1453
1454mlir::Value acc::KernelsOp::getNumWorkersValue() {
1455 return getNumWorkersValue(mlir::acc::DeviceType::None);
1456}
1457
1458mlir::Value
1459acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1460 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1461 deviceType);
1462}
1463
1464mlir::Value acc::KernelsOp::getVectorLengthValue() {
1465 return getVectorLengthValue(mlir::acc::DeviceType::None);
1466}
1467
1468mlir::Value
1469acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1470 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1471 getVectorLength(), deviceType);
1472}
1473
1474mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
1475 return getNumGangsValues(mlir::acc::DeviceType::None);
1476}
1477
1478mlir::Operation::operand_range
1479KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1480 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1481 getNumGangsSegments(), deviceType);
1482}
1483
1484bool acc::KernelsOp::hasWaitOnly() {
1485 return hasWaitOnly(mlir::acc::DeviceType::None);
1486}
1487
1488bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1489 return hasDeviceType(getWaitOnly(), deviceType);
1490}
1491
1492mlir::Operation::operand_range KernelsOp::getWaitValues() {
1493 return getWaitValues(mlir::acc::DeviceType::None);
1494}
1495
1496mlir::Operation::operand_range
1497KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1498 return getWaitValuesWithoutDevnum(
1499 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1500 getHasWaitDevnum(), deviceType);
1501}
1502
1503mlir::Value KernelsOp::getWaitDevnum() {
1504 return getWaitDevnum(mlir::acc::DeviceType::None);
1505}
1506
1507mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1508 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1509 getWaitOperandsSegments(), getHasWaitDevnum(),
1510 deviceType);
1511}
1512
1513LogicalResult acc::KernelsOp::verify() {
1514 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1515 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1516 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1517 return failure();
1518
1519 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1520 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1521 getWaitOperandsDeviceTypeAttr(), "wait")))
1522 return failure();
1523
1524 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1525 getNumWorkersDeviceTypeAttr(),
1526 "num_workers")))
1527 return failure();
1528
1529 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1530 getVectorLengthDeviceTypeAttr(),
1531 "vector_length")))
1532 return failure();
1533
1534 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1535 getAsyncOperandsDeviceTypeAttr(),
1536 "async")))
1537 return failure();
1538
1539 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
1540 return failure();
1541
1542 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
1543}
1544
1545//===----------------------------------------------------------------------===//
1546// HostDataOp
1547//===----------------------------------------------------------------------===//
1548
1549LogicalResult acc::HostDataOp::verify() {
1550 if (getDataClauseOperands().empty())
1551 return emitError("at least one operand must appear on the host_data "
1552 "operation");
1553
1554 for (mlir::Value operand : getDataClauseOperands())
1555 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1556 return emitError("expect data entry operation as defining op");
1557 return success();
1558}
1559
1560void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
1561 MLIRContext *context) {
1562 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1563}
1564
1565//===----------------------------------------------------------------------===//
1566// LoopOp
1567//===----------------------------------------------------------------------===//
1568
1569static ParseResult parseGangValue(
1570 OpAsmParser &parser, llvm::StringRef keyword,
1571 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1572 llvm::SmallVectorImpl<Type> &types,
1573 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
1574 bool &needCommaBetweenValues, bool &newValue) {
1575 if (succeeded(result: parser.parseOptionalKeyword(keyword))) {
1576 if (parser.parseEqual())
1577 return failure();
1578 if (parser.parseOperand(result&: operands.emplace_back()) ||
1579 parser.parseColonType(result&: types.emplace_back()))
1580 return failure();
1581 attributes.push_back(gangArgType);
1582 needCommaBetweenValues = true;
1583 newValue = true;
1584 }
1585 return success();
1586}
1587
1588static ParseResult parseGangClause(
1589 OpAsmParser &parser,
1590 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
1591 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
1592 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
1593 mlir::ArrayAttr &gangOnlyDeviceType) {
1594 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
1595 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
1596 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
1597 llvm::SmallVector<int32_t> seg;
1598 bool needCommaBetweenValues = false;
1599 bool needCommaBeforeOperands = false;
1600
1601 if (failed(result: parser.parseOptionalLParen())) {
1602 // Gang only keyword
1603 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1604 parser.getContext(), mlir::acc::DeviceType::None));
1605 gangOnlyDeviceType =
1606 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
1607 return success();
1608 }
1609
1610 // Parse gang only attributes
1611 if (succeeded(result: parser.parseOptionalLSquare())) {
1612 // Parse gang only attributes
1613 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1614 if (parser.parseAttribute(
1615 result&: gangOnlyDeviceTypeAttributes.emplace_back()))
1616 return failure();
1617 return success();
1618 })))
1619 return failure();
1620 if (parser.parseRSquare())
1621 return failure();
1622 needCommaBeforeOperands = true;
1623 }
1624
1625 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1626 mlir::acc::GangArgType::Num);
1627 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
1628 mlir::acc::GangArgType::Dim);
1629 auto argStatic = mlir::acc::GangArgTypeAttr::get(
1630 parser.getContext(), mlir::acc::GangArgType::Static);
1631
1632 do {
1633 if (needCommaBeforeOperands) {
1634 needCommaBeforeOperands = false;
1635 continue;
1636 }
1637
1638 if (failed(result: parser.parseLBrace()))
1639 return failure();
1640
1641 int32_t crtOperandsSize = gangOperands.size();
1642 while (true) {
1643 bool newValue = false;
1644 bool needValue = false;
1645 if (needCommaBetweenValues) {
1646 if (succeeded(result: parser.parseOptionalComma()))
1647 needValue = true; // expect a new value after comma.
1648 else
1649 break;
1650 }
1651
1652 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
1653 gangOperands, gangOperandsType,
1654 gangArgTypeAttributes, argNum,
1655 needCommaBetweenValues, newValue)))
1656 return failure();
1657 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
1658 gangOperands, gangOperandsType,
1659 gangArgTypeAttributes, argDim,
1660 needCommaBetweenValues, newValue)))
1661 return failure();
1662 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1663 gangOperands, gangOperandsType,
1664 gangArgTypeAttributes, argStatic,
1665 needCommaBetweenValues, newValue)))
1666 return failure();
1667
1668 if (!newValue && needValue) {
1669 parser.emitError(loc: parser.getCurrentLocation(),
1670 message: "new value expected after comma");
1671 return failure();
1672 }
1673
1674 if (!newValue)
1675 break;
1676 }
1677
1678 if (gangOperands.empty())
1679 return parser.emitError(
1680 loc: parser.getCurrentLocation(),
1681 message: "expect at least one of num, dim or static values");
1682
1683 if (failed(result: parser.parseRBrace()))
1684 return failure();
1685
1686 if (succeeded(result: parser.parseOptionalLSquare())) {
1687 if (parser.parseAttribute(result&: deviceTypeAttributes.emplace_back()) ||
1688 parser.parseRSquare())
1689 return failure();
1690 } else {
1691 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1692 parser.getContext(), mlir::acc::DeviceType::None));
1693 }
1694
1695 seg.push_back(Elt: gangOperands.size() - crtOperandsSize);
1696
1697 } while (succeeded(result: parser.parseOptionalComma()));
1698
1699 if (failed(result: parser.parseRParen()))
1700 return failure();
1701
1702 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
1703 gangArgTypeAttributes.end());
1704 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
1705 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
1706
1707 llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
1708 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1709 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
1710
1711 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1712 return success();
1713}
1714
1715void printGangClause(OpAsmPrinter &p, Operation *op,
1716 mlir::OperandRange operands, mlir::TypeRange types,
1717 std::optional<mlir::ArrayAttr> gangArgTypes,
1718 std::optional<mlir::ArrayAttr> deviceTypes,
1719 std::optional<mlir::DenseI32ArrayAttr> segments,
1720 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1721
1722 if (operands.begin() == operands.end() &&
1723 hasOnlyDeviceTypeNone(attrs: gangOnlyDeviceTypes)) {
1724 return;
1725 }
1726
1727 p << "(";
1728
1729 printDeviceTypes(p, deviceTypes: gangOnlyDeviceTypes);
1730
1731 if (hasDeviceTypeValues(arrayAttr: gangOnlyDeviceTypes) &&
1732 hasDeviceTypeValues(arrayAttr: deviceTypes))
1733 p << ", ";
1734
1735 if (hasDeviceTypeValues(arrayAttr: deviceTypes)) {
1736 unsigned opIdx = 0;
1737 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1738 p << "{";
1739 llvm::interleaveComma(
1740 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1741 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1742 (*gangArgTypes)[opIdx]);
1743 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1744 p << LoopOp::getGangNumKeyword();
1745 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1746 p << LoopOp::getGangDimKeyword();
1747 else if (gangArgTypeAttr.getValue() ==
1748 mlir::acc::GangArgType::Static)
1749 p << LoopOp::getGangStaticKeyword();
1750 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
1751 ++opIdx;
1752 });
1753 p << "}";
1754 printSingleDeviceType(p, it.value());
1755 });
1756 }
1757 p << ")";
1758}
1759
1760bool hasDuplicateDeviceTypes(
1761 std::optional<mlir::ArrayAttr> segments,
1762 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1763 if (!segments)
1764 return false;
1765 for (auto attr : *segments) {
1766 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1767 if (deviceTypes.contains(deviceTypeAttr.getValue()))
1768 return true;
1769 deviceTypes.insert(deviceTypeAttr.getValue());
1770 }
1771 return false;
1772}
1773
1774/// Check for duplicates in the DeviceType array attribute.
1775LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
1776 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1777 if (!deviceTypes)
1778 return success();
1779 for (auto attr : deviceTypes) {
1780 auto deviceTypeAttr =
1781 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1782 if (!deviceTypeAttr)
1783 return failure();
1784 if (crtDeviceTypes.contains(deviceTypeAttr.getValue()))
1785 return failure();
1786 crtDeviceTypes.insert(deviceTypeAttr.getValue());
1787 }
1788 return success();
1789}
1790
1791LogicalResult acc::LoopOp::verify() {
1792 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1793 (getUpperbound().size() != getInclusiveUpperbound()->size()))
1794 return emitError() << "inclusiveUpperbound size is expected to be the same"
1795 << " as upperbound size";
1796
1797 // Check collapse
1798 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1799 return emitOpError() << "collapse device_type attr must be define when"
1800 << " collapse attr is present";
1801
1802 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1803 getCollapseAttr().getValue().size() !=
1804 getCollapseDeviceTypeAttr().getValue().size())
1805 return emitOpError() << "collapse attribute count must match collapse"
1806 << " device_type count";
1807 if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
1808 return emitOpError()
1809 << "duplicate device_type found in collapseDeviceType attribute";
1810
1811 // Check gang
1812 if (!getGangOperands().empty()) {
1813 if (!getGangOperandsArgType())
1814 return emitOpError() << "gangOperandsArgType attribute must be defined"
1815 << " when gang operands are present";
1816
1817 if (getGangOperands().size() !=
1818 getGangOperandsArgTypeAttr().getValue().size())
1819 return emitOpError() << "gangOperandsArgType attribute count must match"
1820 << " gangOperands count";
1821 }
1822 if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
1823 return emitOpError() << "duplicate device_type found in gang attribute";
1824
1825 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1826 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
1827 getGangOperandsDeviceTypeAttr(), "gang")))
1828 return failure();
1829
1830 // Check worker
1831 if (failed(checkDeviceTypes(getWorkerAttr())))
1832 return emitOpError() << "duplicate device_type found in worker attribute";
1833 if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
1834 return emitOpError() << "duplicate device_type found in "
1835 "workerNumOperandsDeviceType attribute";
1836 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
1837 getWorkerNumOperandsDeviceTypeAttr(),
1838 "worker")))
1839 return failure();
1840
1841 // Check vector
1842 if (failed(checkDeviceTypes(getVectorAttr())))
1843 return emitOpError() << "duplicate device_type found in vector attribute";
1844 if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
1845 return emitOpError() << "duplicate device_type found in "
1846 "vectorOperandsDeviceType attribute";
1847 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
1848 getVectorOperandsDeviceTypeAttr(),
1849 "vector")))
1850 return failure();
1851
1852 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1853 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
1854 getTileOperandsDeviceTypeAttr(), "tile")))
1855 return failure();
1856
1857 // auto, independent and seq attribute are mutually exclusive.
1858 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1859 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
1860 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
1861 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
1862 return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName()
1863 << "\", " << getIndependentAttrName() << ", "
1864 << getSeqAttrName()
1865 << " can be present at the same time";
1866 }
1867
1868 // Gang, worker and vector are incompatible with seq.
1869 if (getSeqAttr()) {
1870 for (auto attr : getSeqAttr()) {
1871 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1872 if (hasVector(deviceTypeAttr.getValue()) ||
1873 getVectorValue(deviceTypeAttr.getValue()) ||
1874 hasWorker(deviceTypeAttr.getValue()) ||
1875 getWorkerValue(deviceTypeAttr.getValue()) ||
1876 hasGang(deviceTypeAttr.getValue()) ||
1877 getGangValue(mlir::acc::GangArgType::Num,
1878 deviceTypeAttr.getValue()) ||
1879 getGangValue(mlir::acc::GangArgType::Dim,
1880 deviceTypeAttr.getValue()) ||
1881 getGangValue(mlir::acc::GangArgType::Static,
1882 deviceTypeAttr.getValue()))
1883 return emitError()
1884 << "gang, worker or vector cannot appear with the seq attr";
1885 }
1886 }
1887
1888 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1889 *this, getPrivatizations(), getPrivateOperands(), "private",
1890 "privatizations", false)))
1891 return failure();
1892
1893 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1894 *this, getReductionRecipes(), getReductionOperands(), "reduction",
1895 "reductions", false)))
1896 return failure();
1897
1898 if (getCombined().has_value() &&
1899 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1900 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1901 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1902 return emitError("unexpected combined constructs attribute");
1903 }
1904
1905 // Check non-empty body().
1906 if (getRegion().empty())
1907 return emitError("expected non-empty body.");
1908
1909 return success();
1910}
1911
1912unsigned LoopOp::getNumDataOperands() {
1913 return getReductionOperands().size() + getPrivateOperands().size();
1914}
1915
1916Value LoopOp::getDataOperand(unsigned i) {
1917 unsigned numOptional =
1918 getLowerbound().size() + getUpperbound().size() + getStep().size();
1919 numOptional += getGangOperands().size();
1920 numOptional += getVectorOperands().size();
1921 numOptional += getWorkerNumOperands().size();
1922 numOptional += getTileOperands().size();
1923 numOptional += getCacheOperands().size();
1924 return getOperand(numOptional + i);
1925}
1926
1927bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
1928
1929bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1930 return hasDeviceType(getAuto_(), deviceType);
1931}
1932
1933bool LoopOp::hasIndependent() {
1934 return hasIndependent(mlir::acc::DeviceType::None);
1935}
1936
1937bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1938 return hasDeviceType(getIndependent(), deviceType);
1939}
1940
1941bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
1942
1943bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1944 return hasDeviceType(getSeq(), deviceType);
1945}
1946
1947mlir::Value LoopOp::getVectorValue() {
1948 return getVectorValue(mlir::acc::DeviceType::None);
1949}
1950
1951mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1952 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
1953 getVectorOperands(), deviceType);
1954}
1955
1956bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
1957
1958bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1959 return hasDeviceType(getVector(), deviceType);
1960}
1961
1962mlir::Value LoopOp::getWorkerValue() {
1963 return getWorkerValue(mlir::acc::DeviceType::None);
1964}
1965
1966mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1967 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
1968 getWorkerNumOperands(), deviceType);
1969}
1970
1971bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
1972
1973bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
1974 return hasDeviceType(getWorker(), deviceType);
1975}
1976
1977mlir::Operation::operand_range LoopOp::getTileValues() {
1978 return getTileValues(mlir::acc::DeviceType::None);
1979}
1980
1981mlir::Operation::operand_range
1982LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
1983 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
1984 getTileOperandsSegments(), deviceType);
1985}
1986
1987std::optional<int64_t> LoopOp::getCollapseValue() {
1988 return getCollapseValue(mlir::acc::DeviceType::None);
1989}
1990
1991std::optional<int64_t>
1992LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
1993 if (!getCollapseAttr())
1994 return std::nullopt;
1995 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
1996 auto intAttr =
1997 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
1998 return intAttr.getValue().getZExtValue();
1999 }
2000 return std::nullopt;
2001}
2002
2003mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2004 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2005}
2006
2007mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2008 mlir::acc::DeviceType deviceType) {
2009 if (getGangOperands().empty())
2010 return {};
2011 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2012 int32_t nbOperandsBefore = 0;
2013 for (unsigned i = 0; i < *pos; ++i)
2014 nbOperandsBefore += (*getGangOperandsSegments())[i];
2015 mlir::Operation::operand_range values =
2016 getGangOperands()
2017 .drop_front(nbOperandsBefore)
2018 .take_front((*getGangOperandsSegments())[*pos]);
2019
2020 int32_t argTypeIdx = nbOperandsBefore;
2021 for (auto value : values) {
2022 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2023 (*getGangOperandsArgType())[argTypeIdx]);
2024 if (gangArgTypeAttr.getValue() == gangArgType)
2025 return value;
2026 ++argTypeIdx;
2027 }
2028 }
2029 return {};
2030}
2031
2032bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2033
2034bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2035 return hasDeviceType(getGang(), deviceType);
2036}
2037
2038llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2039 return {&getRegion()};
2040}
2041
2042/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2043/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2044/// `(` ssa-id-and-type-list `)`
2045/// region
2046ParseResult
2047parseLoopControl(OpAsmParser &parser, Region &region,
2048 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound,
2049 SmallVectorImpl<Type> &lowerboundType,
2050 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound,
2051 SmallVectorImpl<Type> &upperboundType,
2052 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step,
2053 SmallVectorImpl<Type> &stepType) {
2054
2055 SmallVector<OpAsmParser::Argument> inductionVars;
2056 if (succeeded(
2057 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2058 if (parser.parseLParen() ||
2059 parser.parseArgumentList(result&: inductionVars, delimiter: OpAsmParser::Delimiter::None,
2060 /*allowType=*/true) ||
2061 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2062 parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(),
2063 delimiter: OpAsmParser::Delimiter::None) ||
2064 parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() ||
2065 parser.parseKeyword(keyword: "to") || parser.parseLParen() ||
2066 parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(),
2067 delimiter: OpAsmParser::Delimiter::None) ||
2068 parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() ||
2069 parser.parseKeyword(keyword: "step") || parser.parseLParen() ||
2070 parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(),
2071 delimiter: OpAsmParser::Delimiter::None) ||
2072 parser.parseColonTypeList(result&: stepType) || parser.parseRParen())
2073 return failure();
2074 }
2075 return parser.parseRegion(region, arguments: inductionVars);
2076}
2077
2078void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
2079 ValueRange lowerbound, TypeRange lowerboundType,
2080 ValueRange upperbound, TypeRange upperboundType,
2081 ValueRange steps, TypeRange stepType) {
2082 ValueRange regionArgs = region.front().getArguments();
2083 if (!regionArgs.empty()) {
2084 p << acc::LoopOp::getControlKeyword() << "(";
2085 llvm::interleaveComma(c: regionArgs, os&: p,
2086 each_fn: [&p](Value v) { p << v << " : " << v.getType(); });
2087 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2088 << upperbound << " : " << upperboundType << ") "
2089 << " step (" << steps << " : " << stepType << ") ";
2090 }
2091 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
2092}
2093
2094//===----------------------------------------------------------------------===//
2095// DataOp
2096//===----------------------------------------------------------------------===//
2097
2098LogicalResult acc::DataOp::verify() {
2099 // 2.6.5. Data Construct restriction
2100 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2101 // attach, or default clause must appear on a data construct.
2102 if (getOperands().empty() && !getDefaultAttr())
2103 return emitError("at least one operand or the default attribute "
2104 "must appear on the data operation");
2105
2106 for (mlir::Value operand : getDataClauseOperands())
2107 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2108 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2109 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2110 operand.getDefiningOp()))
2111 return emitError("expect data entry/exit operation or acc.getdeviceptr "
2112 "as defining op");
2113
2114 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2115 return failure();
2116
2117 return success();
2118}
2119
2120unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2121
2122Value DataOp::getDataOperand(unsigned i) {
2123 unsigned numOptional = getIfCond() ? 1 : 0;
2124 numOptional += getAsyncOperands().size() ? 1 : 0;
2125 numOptional += getWaitOperands().size();
2126 return getOperand(numOptional + i);
2127}
2128
2129bool acc::DataOp::hasAsyncOnly() {
2130 return hasAsyncOnly(mlir::acc::DeviceType::None);
2131}
2132
2133bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2134 return hasDeviceType(getAsyncOnly(), deviceType);
2135}
2136
2137mlir::Value DataOp::getAsyncValue() {
2138 return getAsyncValue(mlir::acc::DeviceType::None);
2139}
2140
2141mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2142 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
2143 getAsyncOperands(), deviceType);
2144}
2145
2146bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2147
2148bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2149 return hasDeviceType(getWaitOnly(), deviceType);
2150}
2151
2152mlir::Operation::operand_range DataOp::getWaitValues() {
2153 return getWaitValues(mlir::acc::DeviceType::None);
2154}
2155
2156mlir::Operation::operand_range
2157DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2158 return getWaitValuesWithoutDevnum(
2159 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2160 getHasWaitDevnum(), deviceType);
2161}
2162
2163mlir::Value DataOp::getWaitDevnum() {
2164 return getWaitDevnum(mlir::acc::DeviceType::None);
2165}
2166
2167mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2168 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2169 getWaitOperandsSegments(), getHasWaitDevnum(),
2170 deviceType);
2171}
2172
2173//===----------------------------------------------------------------------===//
2174// ExitDataOp
2175//===----------------------------------------------------------------------===//
2176
2177LogicalResult acc::ExitDataOp::verify() {
2178 // 2.6.6. Data Exit Directive restriction
2179 // At least one copyout, delete, or detach clause must appear on an exit data
2180 // directive.
2181 if (getDataClauseOperands().empty())
2182 return emitError("at least one operand must be present in dataOperands on "
2183 "the exit data operation");
2184
2185 // The async attribute represent the async clause without value. Therefore the
2186 // attribute and operand cannot appear at the same time.
2187 if (getAsyncOperand() && getAsync())
2188 return emitError("async attribute cannot appear with asyncOperand");
2189
2190 // The wait attribute represent the wait clause without values. Therefore the
2191 // attribute and operands cannot appear at the same time.
2192 if (!getWaitOperands().empty() && getWait())
2193 return emitError("wait attribute cannot appear with waitOperands");
2194
2195 if (getWaitDevnum() && getWaitOperands().empty())
2196 return emitError("wait_devnum cannot appear without waitOperands");
2197
2198 return success();
2199}
2200
2201unsigned ExitDataOp::getNumDataOperands() {
2202 return getDataClauseOperands().size();
2203}
2204
2205Value ExitDataOp::getDataOperand(unsigned i) {
2206 unsigned numOptional = getIfCond() ? 1 : 0;
2207 numOptional += getAsyncOperand() ? 1 : 0;
2208 numOptional += getWaitDevnum() ? 1 : 0;
2209 return getOperand(getWaitOperands().size() + numOptional + i);
2210}
2211
2212void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2213 MLIRContext *context) {
2214 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
2215}
2216
2217//===----------------------------------------------------------------------===//
2218// EnterDataOp
2219//===----------------------------------------------------------------------===//
2220
2221LogicalResult acc::EnterDataOp::verify() {
2222 // 2.6.6. Data Enter Directive restriction
2223 // At least one copyin, create, or attach clause must appear on an enter data
2224 // directive.
2225 if (getDataClauseOperands().empty())
2226 return emitError("at least one operand must be present in dataOperands on "
2227 "the enter data operation");
2228
2229 // The async attribute represent the async clause without value. Therefore the
2230 // attribute and operand cannot appear at the same time.
2231 if (getAsyncOperand() && getAsync())
2232 return emitError("async attribute cannot appear with asyncOperand");
2233
2234 // The wait attribute represent the wait clause without values. Therefore the
2235 // attribute and operands cannot appear at the same time.
2236 if (!getWaitOperands().empty() && getWait())
2237 return emitError("wait attribute cannot appear with waitOperands");
2238
2239 if (getWaitDevnum() && getWaitOperands().empty())
2240 return emitError("wait_devnum cannot appear without waitOperands");
2241
2242 for (mlir::Value operand : getDataClauseOperands())
2243 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2244 operand.getDefiningOp()))
2245 return emitError("expect data entry operation as defining op");
2246
2247 return success();
2248}
2249
2250unsigned EnterDataOp::getNumDataOperands() {
2251 return getDataClauseOperands().size();
2252}
2253
2254Value EnterDataOp::getDataOperand(unsigned i) {
2255 unsigned numOptional = getIfCond() ? 1 : 0;
2256 numOptional += getAsyncOperand() ? 1 : 0;
2257 numOptional += getWaitDevnum() ? 1 : 0;
2258 return getOperand(getWaitOperands().size() + numOptional + i);
2259}
2260
2261void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2262 MLIRContext *context) {
2263 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
2264}
2265
2266//===----------------------------------------------------------------------===//
2267// AtomicReadOp
2268//===----------------------------------------------------------------------===//
2269
2270LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
2271
2272//===----------------------------------------------------------------------===//
2273// AtomicWriteOp
2274//===----------------------------------------------------------------------===//
2275
2276LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
2277
2278//===----------------------------------------------------------------------===//
2279// AtomicUpdateOp
2280//===----------------------------------------------------------------------===//
2281
2282LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2283 PatternRewriter &rewriter) {
2284 if (op.isNoOp()) {
2285 rewriter.eraseOp(op);
2286 return success();
2287 }
2288
2289 if (Value writeVal = op.getWriteOpVal()) {
2290 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
2291 return success();
2292 }
2293
2294 return failure();
2295}
2296
2297LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
2298
2299LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
2300
2301//===----------------------------------------------------------------------===//
2302// AtomicCaptureOp
2303//===----------------------------------------------------------------------===//
2304
2305AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2306 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2307 return op;
2308 return dyn_cast<AtomicReadOp>(getSecondOp());
2309}
2310
2311AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2312 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2313 return op;
2314 return dyn_cast<AtomicWriteOp>(getSecondOp());
2315}
2316
2317AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2318 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2319 return op;
2320 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2321}
2322
2323LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
2324
2325//===----------------------------------------------------------------------===//
2326// DeclareEnterOp
2327//===----------------------------------------------------------------------===//
2328
2329template <typename Op>
2330static LogicalResult
2331checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
2332 bool requireAtLeastOneOperand = true) {
2333 if (operands.empty() && requireAtLeastOneOperand)
2334 return emitError(
2335 op->getLoc(),
2336 "at least one operand must appear on the declare operation");
2337
2338 for (mlir::Value operand : operands) {
2339 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2340 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2341 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2342 operand.getDefiningOp()))
2343 return op.emitError(
2344 "expect valid declare data entry operation or acc.getdeviceptr "
2345 "as defining op");
2346
2347 mlir::Value varPtr{getVarPtr(accDataClauseOp: operand.getDefiningOp())};
2348 assert(varPtr && "declare operands can only be data entry operations which "
2349 "must have varPtr");
2350 std::optional<mlir::acc::DataClause> dataClauseOptional{
2351 getDataClause(operand.getDefiningOp())};
2352 assert(dataClauseOptional.has_value() &&
2353 "declare operands can only be data entry operations which must have "
2354 "dataClause");
2355
2356 // If varPtr has no defining op - there is nothing to check further.
2357 if (!varPtr.getDefiningOp())
2358 continue;
2359
2360 // Check that the varPtr has a declare attribute.
2361 auto declareAttribute{
2362 varPtr.getDefiningOp()->getAttr(name: mlir::acc::getDeclareAttrName())};
2363 if (!declareAttribute)
2364 return op.emitError(
2365 "expect declare attribute on variable in declare operation");
2366
2367 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2368 if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2369 return op.emitError(
2370 "expect matching declare attribute on variable in declare operation");
2371
2372 // If the variable is marked with implicit attribute, the matching declare
2373 // data action must also be marked implicit. The reverse is not checked
2374 // since implicit data action may be inserted to do actions like updating
2375 // device copy, in which case the variable is not necessarily implicitly
2376 // declare'd.
2377 if (declAttr.getImplicit() &&
2378 declAttr.getImplicit() != acc::getImplicitFlag(accDataEntryOp: operand.getDefiningOp()))
2379 return op.emitError(
2380 "implicitness must match between declare op and flag on variable");
2381 }
2382
2383 return success();
2384}
2385
2386LogicalResult acc::DeclareEnterOp::verify() {
2387 return checkDeclareOperands(*this, this->getDataClauseOperands());
2388}
2389
2390//===----------------------------------------------------------------------===//
2391// DeclareExitOp
2392//===----------------------------------------------------------------------===//
2393
2394LogicalResult acc::DeclareExitOp::verify() {
2395 if (getToken())
2396 return checkDeclareOperands(*this, this->getDataClauseOperands(),
2397 /*requireAtLeastOneOperand=*/false);
2398 return checkDeclareOperands(*this, this->getDataClauseOperands());
2399}
2400
2401//===----------------------------------------------------------------------===//
2402// DeclareOp
2403//===----------------------------------------------------------------------===//
2404
2405LogicalResult acc::DeclareOp::verify() {
2406 return checkDeclareOperands(*this, this->getDataClauseOperands());
2407}
2408
2409//===----------------------------------------------------------------------===//
2410// RoutineOp
2411//===----------------------------------------------------------------------===//
2412
2413static unsigned getParallelismForDeviceType(acc::RoutineOp op,
2414 acc::DeviceType dtype) {
2415 unsigned parallelism = 0;
2416 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2417 parallelism += op.hasWorker(dtype) ? 1 : 0;
2418 parallelism += op.hasVector(dtype) ? 1 : 0;
2419 parallelism += op.hasSeq(dtype) ? 1 : 0;
2420 return parallelism;
2421}
2422
2423LogicalResult acc::RoutineOp::verify() {
2424 unsigned baseParallelism =
2425 getParallelismForDeviceType(*this, acc::DeviceType::None);
2426
2427 if (baseParallelism > 1)
2428 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2429 "be present at the same time";
2430
2431 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2432 ++dtypeInt) {
2433 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
2434 if (dtype == acc::DeviceType::None)
2435 continue;
2436 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
2437
2438 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2439 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
2440 "be present at the same time";
2441 }
2442
2443 return success();
2444}
2445
2446static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
2447 mlir::ArrayAttr &deviceTypes) {
2448 llvm::SmallVector<mlir::Attribute> bindNameAttrs;
2449 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
2450
2451 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
2452 if (parser.parseAttribute(result&: bindNameAttrs.emplace_back()))
2453 return failure();
2454 if (failed(result: parser.parseOptionalLSquare())) {
2455 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2456 parser.getContext(), mlir::acc::DeviceType::None));
2457 } else {
2458 if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) ||
2459 parser.parseRSquare())
2460 return failure();
2461 }
2462 return success();
2463 })))
2464 return failure();
2465
2466 bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
2467 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2468
2469 return success();
2470}
2471
2472static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
2473 std::optional<mlir::ArrayAttr> bindName,
2474 std::optional<mlir::ArrayAttr> deviceTypes) {
2475 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2476 [&](const auto &pair) {
2477 p << std::get<0>(pair);
2478 printSingleDeviceType(p, std::get<1>(pair));
2479 });
2480}
2481
2482static ParseResult parseRoutineGangClause(OpAsmParser &parser,
2483 mlir::ArrayAttr &gang,
2484 mlir::ArrayAttr &gangDim,
2485 mlir::ArrayAttr &gangDimDeviceTypes) {
2486
2487 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
2488 gangDimDeviceTypeAttrs;
2489 bool needCommaBeforeOperands = false;
2490
2491 // Gang keyword only
2492 if (failed(result: parser.parseOptionalLParen())) {
2493 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2494 parser.getContext(), mlir::acc::DeviceType::None));
2495 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2496 return success();
2497 }
2498
2499 // Parse keyword only attributes
2500 if (succeeded(result: parser.parseOptionalLSquare())) {
2501 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
2502 if (parser.parseAttribute(result&: gangAttrs.emplace_back()))
2503 return failure();
2504 return success();
2505 })))
2506 return failure();
2507 if (parser.parseRSquare())
2508 return failure();
2509 needCommaBeforeOperands = true;
2510 }
2511
2512 if (needCommaBeforeOperands && failed(result: parser.parseComma()))
2513 return failure();
2514
2515 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
2516 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2517 parser.parseColon() ||
2518 parser.parseAttribute(gangDimAttrs.emplace_back()))
2519 return failure();
2520 if (succeeded(result: parser.parseOptionalLSquare())) {
2521 if (parser.parseAttribute(result&: gangDimDeviceTypeAttrs.emplace_back()) ||
2522 parser.parseRSquare())
2523 return failure();
2524 } else {
2525 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2526 parser.getContext(), mlir::acc::DeviceType::None));
2527 }
2528 return success();
2529 })))
2530 return failure();
2531
2532 if (failed(result: parser.parseRParen()))
2533 return failure();
2534
2535 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
2536 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
2537 gangDimDeviceTypes =
2538 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
2539
2540 return success();
2541}
2542
2543void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
2544 std::optional<mlir::ArrayAttr> gang,
2545 std::optional<mlir::ArrayAttr> gangDim,
2546 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2547
2548 if (!hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes) && hasDeviceTypeValues(arrayAttr: gang) &&
2549 gang->size() == 1) {
2550 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2551 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2552 return;
2553 }
2554
2555 p << "(";
2556
2557 printDeviceTypes(p, deviceTypes: gang);
2558
2559 if (hasDeviceTypeValues(arrayAttr: gang) && hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes))
2560 p << ", ";
2561
2562 if (hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes))
2563 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2564 [&](const auto &pair) {
2565 p << acc::RoutineOp::getGangDimKeyword() << ": ";
2566 p << std::get<0>(pair);
2567 printSingleDeviceType(p, std::get<1>(pair));
2568 });
2569
2570 p << ")";
2571}
2572
2573static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
2574 mlir::ArrayAttr &deviceTypes) {
2575 llvm::SmallVector<mlir::Attribute> attributes;
2576 // Keyword only
2577 if (failed(result: parser.parseOptionalLParen())) {
2578 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2579 parser.getContext(), mlir::acc::DeviceType::None));
2580 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2581 return success();
2582 }
2583
2584 // Parse device type attributes
2585 if (succeeded(result: parser.parseOptionalLSquare())) {
2586 if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
2587 if (parser.parseAttribute(result&: attributes.emplace_back()))
2588 return failure();
2589 return success();
2590 })))
2591 return failure();
2592 if (parser.parseRSquare() || parser.parseRParen())
2593 return failure();
2594 }
2595 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
2596 return success();
2597}
2598
2599static void
2600printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
2601 std::optional<mlir::ArrayAttr> deviceTypes) {
2602
2603 if (hasDeviceTypeValues(arrayAttr: deviceTypes) && deviceTypes->size() == 1) {
2604 auto deviceTypeAttr =
2605 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2606 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
2607 return;
2608 }
2609
2610 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
2611 return;
2612
2613 p << "([";
2614 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
2615 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2616 p << dTypeAttr;
2617 });
2618 p << "])";
2619}
2620
2621bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2622
2623bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2624 return hasDeviceType(getWorker(), deviceType);
2625}
2626
2627bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2628
2629bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2630 return hasDeviceType(getVector(), deviceType);
2631}
2632
2633bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2634
2635bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2636 return hasDeviceType(getSeq(), deviceType);
2637}
2638
2639std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2640 return getBindNameValue(mlir::acc::DeviceType::None);
2641}
2642
2643std::optional<llvm::StringRef>
2644RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2645 if (!hasDeviceTypeValues(getBindNameDeviceType()))
2646 return std::nullopt;
2647 if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
2648 auto attr = (*getBindName())[*pos];
2649 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2650 return stringAttr.getValue();
2651 }
2652 return std::nullopt;
2653}
2654
2655bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2656
2657bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2658 return hasDeviceType(getGang(), deviceType);
2659}
2660
2661std::optional<int64_t> RoutineOp::getGangDimValue() {
2662 return getGangDimValue(mlir::acc::DeviceType::None);
2663}
2664
2665std::optional<int64_t>
2666RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2667 if (!hasDeviceTypeValues(getGangDimDeviceType()))
2668 return std::nullopt;
2669 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
2670 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2671 return intAttr.getInt();
2672 }
2673 return std::nullopt;
2674}
2675
2676//===----------------------------------------------------------------------===//
2677// InitOp
2678//===----------------------------------------------------------------------===//
2679
2680LogicalResult acc::InitOp::verify() {
2681 Operation *currOp = *this;
2682 while ((currOp = currOp->getParentOp()))
2683 if (isComputeOperation(currOp))
2684 return emitOpError("cannot be nested in a compute operation");
2685 return success();
2686}
2687
2688//===----------------------------------------------------------------------===//
2689// ShutdownOp
2690//===----------------------------------------------------------------------===//
2691
2692LogicalResult acc::ShutdownOp::verify() {
2693 Operation *currOp = *this;
2694 while ((currOp = currOp->getParentOp()))
2695 if (isComputeOperation(currOp))
2696 return emitOpError("cannot be nested in a compute operation");
2697 return success();
2698}
2699
2700//===----------------------------------------------------------------------===//
2701// SetOp
2702//===----------------------------------------------------------------------===//
2703
2704LogicalResult acc::SetOp::verify() {
2705 Operation *currOp = *this;
2706 while ((currOp = currOp->getParentOp()))
2707 if (isComputeOperation(currOp))
2708 return emitOpError("cannot be nested in a compute operation");
2709 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2710 return emitOpError("at least one default_async, device_num, or device_type "
2711 "operand must appear");
2712 return success();
2713}
2714
2715//===----------------------------------------------------------------------===//
2716// UpdateOp
2717//===----------------------------------------------------------------------===//
2718
2719LogicalResult acc::UpdateOp::verify() {
2720 // At least one of host or device should have a value.
2721 if (getDataClauseOperands().empty())
2722 return emitError("at least one value must be present in dataOperands");
2723
2724 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2725 getAsyncOperandsDeviceTypeAttr(),
2726 "async")))
2727 return failure();
2728
2729 if (failed(verifyDeviceTypeAndSegmentCountMatch(
2730 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2731 getWaitOperandsDeviceTypeAttr(), "wait")))
2732 return failure();
2733
2734 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
2735 return failure();
2736
2737 for (mlir::Value operand : getDataClauseOperands())
2738 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2739 operand.getDefiningOp()))
2740 return emitError("expect data entry/exit operation or acc.getdeviceptr "
2741 "as defining op");
2742
2743 return success();
2744}
2745
2746unsigned UpdateOp::getNumDataOperands() {
2747 return getDataClauseOperands().size();
2748}
2749
2750Value UpdateOp::getDataOperand(unsigned i) {
2751 unsigned numOptional = getAsyncOperands().size();
2752 numOptional += getIfCond() ? 1 : 0;
2753 return getOperand(getWaitOperands().size() + numOptional + i);
2754}
2755
2756void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
2757 MLIRContext *context) {
2758 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
2759}
2760
2761bool UpdateOp::hasAsyncOnly() {
2762 return hasAsyncOnly(mlir::acc::DeviceType::None);
2763}
2764
2765bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2766 return hasDeviceType(getAsync(), deviceType);
2767}
2768
2769mlir::Value UpdateOp::getAsyncValue() {
2770 return getAsyncValue(mlir::acc::DeviceType::None);
2771}
2772
2773mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2774 if (!hasDeviceTypeValues(getAsyncOperandsDeviceType()))
2775 return {};
2776
2777 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
2778 return getAsyncOperands()[*pos];
2779
2780 return {};
2781}
2782
2783bool UpdateOp::hasWaitOnly() {
2784 return hasWaitOnly(mlir::acc::DeviceType::None);
2785}
2786
2787bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2788 return hasDeviceType(getWaitOnly(), deviceType);
2789}
2790
2791mlir::Operation::operand_range UpdateOp::getWaitValues() {
2792 return getWaitValues(mlir::acc::DeviceType::None);
2793}
2794
2795mlir::Operation::operand_range
2796UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2797 return getWaitValuesWithoutDevnum(
2798 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2799 getHasWaitDevnum(), deviceType);
2800}
2801
2802mlir::Value UpdateOp::getWaitDevnum() {
2803 return getWaitDevnum(mlir::acc::DeviceType::None);
2804}
2805
2806mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2807 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2808 getWaitOperandsSegments(), getHasWaitDevnum(),
2809 deviceType);
2810}
2811
2812//===----------------------------------------------------------------------===//
2813// WaitOp
2814//===----------------------------------------------------------------------===//
2815
2816LogicalResult acc::WaitOp::verify() {
2817 // The async attribute represent the async clause without value. Therefore the
2818 // attribute and operand cannot appear at the same time.
2819 if (getAsyncOperand() && getAsync())
2820 return emitError("async attribute cannot appear with asyncOperand");
2821
2822 if (getWaitDevnum() && getWaitOperands().empty())
2823 return emitError("wait_devnum cannot appear without waitOperands");
2824
2825 return success();
2826}
2827
2828#define GET_OP_CLASSES
2829#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2830
2831#define GET_ATTRDEF_CLASSES
2832#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2833
2834#define GET_TYPEDEF_CLASSES
2835#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2836
2837//===----------------------------------------------------------------------===//
2838// acc dialect utilities
2839//===----------------------------------------------------------------------===//
2840
2841mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) {
2842 auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2843 .Case<ACC_DATA_ENTRY_OPS>(
2844 [&](auto entry) { return entry.getVarPtr(); })
2845 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2846 [&](auto exit) { return exit.getVarPtr(); })
2847 .Default([&](mlir::Operation *) { return mlir::Value(); })};
2848 return varPtr;
2849}
2850
2851mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) {
2852 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2853 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
2854 [&](auto dataClause) { return dataClause.getAccPtr(); })
2855 .Default([&](mlir::Operation *) { return mlir::Value(); })};
2856 return accPtr;
2857}
2858
2859mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) {
2860 auto varPtrPtr{
2861 llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
2862 .Case<ACC_DATA_ENTRY_OPS>(
2863 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
2864 .Default([&](mlir::Operation *) { return mlir::Value(); })};
2865 return varPtrPtr;
2866}
2867
2868mlir::SmallVector<mlir::Value>
2869mlir::acc::getBounds(mlir::Operation *accDataClauseOp) {
2870 mlir::SmallVector<mlir::Value> bounds{
2871 llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
2872 accDataClauseOp)
2873 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
2874 return mlir::SmallVector<mlir::Value>(
2875 dataClause.getBounds().begin(), dataClause.getBounds().end());
2876 })
2877 .Default([&](mlir::Operation *) {
2878 return mlir::SmallVector<mlir::Value, 0>();
2879 })};
2880 return bounds;
2881}
2882
2883std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
2884 auto name{
2885 llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp)
2886 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
2887 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
2888 return {};
2889 })};
2890 return name;
2891}
2892
2893std::optional<mlir::acc::DataClause>
2894mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) {
2895 auto dataClause{
2896 llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>(
2897 accDataEntryOp)
2898 .Case<ACC_DATA_ENTRY_OPS>(
2899 [&](auto entry) { return entry.getDataClause(); })
2900 .Default([&](mlir::Operation *) { return std::nullopt; })};
2901 return dataClause;
2902}
2903
2904bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) {
2905 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
2906 .Case<ACC_DATA_ENTRY_OPS>(
2907 [&](auto entry) { return entry.getImplicit(); })
2908 .Default([&](mlir::Operation *) { return false; })};
2909 return implicit;
2910}
2911
2912mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) {
2913 auto dataOperands{
2914 llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp)
2915 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
2916 [&](auto entry) { return entry.getDataClauseOperands(); })
2917 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
2918 return dataOperands;
2919}
2920
2921mlir::MutableOperandRange
2922mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
2923 auto dataOperands{
2924 llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp)
2925 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
2926 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
2927 .Default([&](mlir::Operation *) { return nullptr; })};
2928 return dataOperands;
2929}
2930

source code of mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp