1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
9#include "mlir/Dialect/OpenACC/OpenACC.h"
10#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/DialectImplementation.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/OpImplementation.h"
19#include "mlir/Support/LLVM.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/ADT/SmallSet.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/LogicalResult.h"
24
25using namespace mlir;
26using namespace acc;
27
28#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
33
34namespace {
35
36static bool isScalarLikeType(Type type) {
37 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
38}
39
40struct MemRefPointerLikeModel
41 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
42 MemRefType> {
43 Type getElementType(Type pointer) const {
44 return cast<MemRefType>(pointer).getElementType();
45 }
46 mlir::acc::VariableTypeCategory
47 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
48 Type varType) const {
49 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
50 return mappableTy.getTypeCategory(varPtr);
51 }
52 auto memrefTy = cast<MemRefType>(pointer);
53 if (!memrefTy.hasRank()) {
54 // This memref is unranked - aka it could have any rank, including a
55 // rank of 0 which could mean scalar. For now, return uncategorized.
56 return mlir::acc::VariableTypeCategory::uncategorized;
57 }
58
59 if (memrefTy.getRank() == 0) {
60 if (isScalarLikeType(memrefTy.getElementType())) {
61 return mlir::acc::VariableTypeCategory::scalar;
62 }
63 // Zero-rank non-scalar - need further analysis to determine the type
64 // category. For now, return uncategorized.
65 return mlir::acc::VariableTypeCategory::uncategorized;
66 }
67
68 // It has a rank - must be an array.
69 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
70 return mlir::acc::VariableTypeCategory::array;
71 }
72};
73
74struct LLVMPointerPointerLikeModel
75 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
77 Type getElementType(Type pointer) const { return Type(); }
78};
79
80/// Helper function for any of the times we need to modify an ArrayAttr based on
81/// a device type list. Returns a new ArrayAttr with all of the
82/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
83/// list is empty).
84mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
85 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
86 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
87 llvm::SmallVector<mlir::Attribute> deviceTypes;
88 if (existingDeviceTypes)
89 llvm::copy(existingDeviceTypes, std::back_inserter(x&: deviceTypes));
90
91 if (newDeviceTypes.empty())
92 deviceTypes.push_back(
93 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
94
95 for (DeviceType DT : newDeviceTypes)
96 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
97
98 return mlir::ArrayAttr::get(context, deviceTypes);
99}
100
101/// Helper function for any of the times we need to add operands that are
102/// affected by a device type list. Returns a new ArrayAttr with all of the
103/// existingDeviceTypes, plus the effective new ones (or an added none, if the
104/// new list is empty). Additionally, adds the arguments to the argCollection
105/// the correct number of times. This will also update a 'segments' array, even
106/// if it won't be used.
107mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
108 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
109 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
110 mlir::MutableOperandRange argCollection,
111 llvm::SmallVector<int32_t> &segments) {
112 llvm::SmallVector<mlir::Attribute> deviceTypes;
113 if (existingDeviceTypes)
114 llvm::copy(existingDeviceTypes, std::back_inserter(x&: deviceTypes));
115
116 if (newDeviceTypes.empty()) {
117 argCollection.append(values: arguments);
118 segments.push_back(Elt: arguments.size());
119 deviceTypes.push_back(
120 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
121 }
122
123 for (DeviceType DT : newDeviceTypes) {
124 argCollection.append(arguments);
125 segments.push_back(arguments.size());
126 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
127 }
128
129 return mlir::ArrayAttr::get(context, deviceTypes);
130}
131
132/// Overload for when the 'segments' aren't needed.
133mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
134 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
135 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
136 mlir::MutableOperandRange argCollection) {
137 llvm::SmallVector<int32_t> segments;
138 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
139 newDeviceTypes, arguments,
140 argCollection, segments);
141}
142} // namespace
143
144//===----------------------------------------------------------------------===//
145// OpenACC operations
146//===----------------------------------------------------------------------===//
147
148void OpenACCDialect::initialize() {
149 addOperations<
150#define GET_OP_LIST
151#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
152 >();
153 addAttributes<
154#define GET_ATTRDEF_LIST
155#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
156 >();
157 addTypes<
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
160 >();
161
162 // By attaching interfaces here, we make the OpenACC dialect dependent on
163 // the other dialects. This is probably better than having dialects like LLVM
164 // and memref be dependent on OpenACC.
165 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
166 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
167 *getContext());
168}
169
170//===----------------------------------------------------------------------===//
171// device_type support helpers
172//===----------------------------------------------------------------------===//
173
174static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
175 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
176 return true;
177 return false;
178}
179
180static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
181 mlir::acc::DeviceType deviceType) {
182 if (!hasDeviceTypeValues(arrayAttr))
183 return false;
184
185 for (auto attr : *arrayAttr) {
186 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
187 if (deviceTypeAttr.getValue() == deviceType)
188 return true;
189 }
190
191 return false;
192}
193
194static void printDeviceTypes(mlir::OpAsmPrinter &p,
195 std::optional<mlir::ArrayAttr> deviceTypes) {
196 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
197 return;
198
199 p << "[";
200 llvm::interleaveComma(*deviceTypes, p,
201 [&](mlir::Attribute attr) { p << attr; });
202 p << "]";
203}
204
205static std::optional<unsigned> findSegment(ArrayAttr segments,
206 mlir::acc::DeviceType deviceType) {
207 unsigned segmentIdx = 0;
208 for (auto attr : segments) {
209 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
210 if (deviceTypeAttr.getValue() == deviceType)
211 return std::make_optional(segmentIdx);
212 ++segmentIdx;
213 }
214 return std::nullopt;
215}
216
217static mlir::Operation::operand_range
218getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
219 mlir::Operation::operand_range range,
220 std::optional<llvm::ArrayRef<int32_t>> segments,
221 mlir::acc::DeviceType deviceType) {
222 if (!arrayAttr)
223 return range.take_front(n: 0);
224 if (auto pos = findSegment(*arrayAttr, deviceType)) {
225 int32_t nbOperandsBefore = 0;
226 for (unsigned i = 0; i < *pos; ++i)
227 nbOperandsBefore += (*segments)[i];
228 return range.drop_front(n: nbOperandsBefore).take_front(n: (*segments)[*pos]);
229 }
230 return range.take_front(n: 0);
231}
232
233static mlir::Value
234getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
235 mlir::Operation::operand_range operands,
236 std::optional<llvm::ArrayRef<int32_t>> segments,
237 std::optional<mlir::ArrayAttr> hasWaitDevnum,
238 mlir::acc::DeviceType deviceType) {
239 if (!hasDeviceTypeValues(arrayAttr: deviceTypeAttr))
240 return {};
241 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
242 if (hasWaitDevnum->getValue()[*pos])
243 return getValuesFromSegments(deviceTypeAttr, operands, segments,
244 deviceType)
245 .front();
246 return {};
247}
248
249static mlir::Operation::operand_range
250getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
251 mlir::Operation::operand_range operands,
252 std::optional<llvm::ArrayRef<int32_t>> segments,
253 std::optional<mlir::ArrayAttr> hasWaitDevnum,
254 mlir::acc::DeviceType deviceType) {
255 auto range =
256 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
257 if (range.empty())
258 return range;
259 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
260 if (hasWaitDevnum && *hasWaitDevnum) {
261 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
262 if (boolAttr.getValue())
263 return range.drop_front(1); // first value is devnum
264 }
265 }
266 return range;
267}
268
269template <typename Op>
270static LogicalResult checkWaitAndAsyncConflict(Op op) {
271 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
272 ++dtypeInt) {
273 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
274
275 // The asyncOnly attribute represent the async clause without value.
276 // Therefore the attribute and operand cannot appear at the same time.
277 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
278 op.hasAsyncOnly(dtype))
279 return op.emitError(
280 "asyncOnly attribute cannot appear with asyncOperand");
281
282 // The wait attribute represent the wait clause without values. Therefore
283 // the attribute and operands cannot appear at the same time.
284 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
285 op.hasWaitOnly(dtype))
286 return op.emitError("wait attribute cannot appear with waitOperands");
287 }
288 return success();
289}
290
291template <typename Op>
292static LogicalResult checkVarAndVarType(Op op) {
293 if (!op.getVar())
294 return op.emitError("must have var operand");
295
296 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297 mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
298 // TODO: If a type implements both interfaces (mappable and pointer-like),
299 // it is unclear which semantics to apply without additional info which
300 // would need captured in the data operation. For now restrict this case
301 // unless a compelling reason to support disambiguating between the two.
302 return op.emitError("var must be mappable or pointer-like (not both)");
303 }
304
305 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
306 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
307 return op.emitError("var must be mappable or pointer-like");
308
309 if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
310 op.getVarType() != op.getVar().getType())
311 return op.emitError("varType must match when var is mappable");
312
313 return success();
314}
315
316template <typename Op>
317static LogicalResult checkVarAndAccVar(Op op) {
318 if (op.getVar().getType() != op.getAccVar().getType())
319 return op.emitError("input and output types must match");
320
321 return success();
322}
323
324static ParseResult parseVar(mlir::OpAsmParser &parser,
325 OpAsmParser::UnresolvedOperand &var) {
326 // Either `var` or `varPtr` keyword is required.
327 if (failed(Result: parser.parseOptionalKeyword(keyword: "varPtr"))) {
328 if (failed(Result: parser.parseKeyword(keyword: "var")))
329 return failure();
330 }
331 if (failed(Result: parser.parseLParen()))
332 return failure();
333 if (failed(Result: parser.parseOperand(result&: var)))
334 return failure();
335
336 return success();
337}
338
339static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
340 mlir::Value var) {
341 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
342 p << "varPtr(";
343 else
344 p << "var(";
345 p.printOperand(value: var);
346}
347
348static ParseResult parseAccVar(mlir::OpAsmParser &parser,
349 OpAsmParser::UnresolvedOperand &var,
350 mlir::Type &accVarType) {
351 // Either `accVar` or `accPtr` keyword is required.
352 if (failed(Result: parser.parseOptionalKeyword(keyword: "accPtr"))) {
353 if (failed(Result: parser.parseKeyword(keyword: "accVar")))
354 return failure();
355 }
356 if (failed(Result: parser.parseLParen()))
357 return failure();
358 if (failed(Result: parser.parseOperand(result&: var)))
359 return failure();
360 if (failed(Result: parser.parseColon()))
361 return failure();
362 if (failed(Result: parser.parseType(result&: accVarType)))
363 return failure();
364 if (failed(Result: parser.parseRParen()))
365 return failure();
366
367 return success();
368}
369
370static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
371 mlir::Value accVar, mlir::Type accVarType) {
372 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
373 p << "accPtr(";
374 else
375 p << "accVar(";
376 p.printOperand(value: accVar);
377 p << " : ";
378 p.printType(type: accVarType);
379 p << ")";
380}
381
382static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
383 mlir::Type &varPtrType,
384 mlir::TypeAttr &varTypeAttr) {
385 if (failed(Result: parser.parseType(result&: varPtrType)))
386 return failure();
387 if (failed(Result: parser.parseRParen()))
388 return failure();
389
390 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "varType"))) {
391 if (failed(Result: parser.parseLParen()))
392 return failure();
393 mlir::Type varType;
394 if (failed(Result: parser.parseType(result&: varType)))
395 return failure();
396 varTypeAttr = mlir::TypeAttr::get(varType);
397 if (failed(Result: parser.parseRParen()))
398 return failure();
399 } else {
400 // Set `varType` from the element type of the type of `varPtr`.
401 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
402 varTypeAttr = mlir::TypeAttr::get(
403 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
404 else
405 varTypeAttr = mlir::TypeAttr::get(varPtrType);
406 }
407
408 return success();
409}
410
411static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
412 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
413 p.printType(type: varPtrType);
414 p << ")";
415
416 // Print the `varType` only if it differs from the element type of
417 // `varPtr`'s type.
418 mlir::Type varType = varTypeAttr.getValue();
419 mlir::Type typeToCheckAgainst =
420 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
421 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
422 : varPtrType;
423 if (typeToCheckAgainst != varType) {
424 p << " varType(";
425 p.printType(type: varType);
426 p << ")";
427 }
428}
429
430//===----------------------------------------------------------------------===//
431// DataBoundsOp
432//===----------------------------------------------------------------------===//
433LogicalResult acc::DataBoundsOp::verify() {
434 auto extent = getExtent();
435 auto upperbound = getUpperbound();
436 if (!extent && !upperbound)
437 return emitError("expected extent or upperbound.");
438 return success();
439}
440
441//===----------------------------------------------------------------------===//
442// PrivateOp
443//===----------------------------------------------------------------------===//
444LogicalResult acc::PrivateOp::verify() {
445 if (getDataClause() != acc::DataClause::acc_private)
446 return emitError(
447 "data clause associated with private operation must match its intent");
448 if (failed(checkVarAndVarType(*this)))
449 return failure();
450 return success();
451}
452
453//===----------------------------------------------------------------------===//
454// FirstprivateOp
455//===----------------------------------------------------------------------===//
456LogicalResult acc::FirstprivateOp::verify() {
457 if (getDataClause() != acc::DataClause::acc_firstprivate)
458 return emitError("data clause associated with firstprivate operation must "
459 "match its intent");
460 if (failed(checkVarAndVarType(*this)))
461 return failure();
462 return success();
463}
464
465//===----------------------------------------------------------------------===//
466// ReductionOp
467//===----------------------------------------------------------------------===//
468LogicalResult acc::ReductionOp::verify() {
469 if (getDataClause() != acc::DataClause::acc_reduction)
470 return emitError("data clause associated with reduction operation must "
471 "match its intent");
472 if (failed(checkVarAndVarType(*this)))
473 return failure();
474 return success();
475}
476
477//===----------------------------------------------------------------------===//
478// DevicePtrOp
479//===----------------------------------------------------------------------===//
480LogicalResult acc::DevicePtrOp::verify() {
481 if (getDataClause() != acc::DataClause::acc_deviceptr)
482 return emitError("data clause associated with deviceptr operation must "
483 "match its intent");
484 if (failed(checkVarAndVarType(*this)))
485 return failure();
486 if (failed(checkVarAndAccVar(*this)))
487 return failure();
488 return success();
489}
490
491//===----------------------------------------------------------------------===//
492// PresentOp
493//===----------------------------------------------------------------------===//
494LogicalResult acc::PresentOp::verify() {
495 if (getDataClause() != acc::DataClause::acc_present)
496 return emitError(
497 "data clause associated with present operation must match its intent");
498 if (failed(checkVarAndVarType(*this)))
499 return failure();
500 if (failed(checkVarAndAccVar(*this)))
501 return failure();
502 return success();
503}
504
505//===----------------------------------------------------------------------===//
506// CopyinOp
507//===----------------------------------------------------------------------===//
508LogicalResult acc::CopyinOp::verify() {
509 // Test for all clauses this operation can be decomposed from:
510 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
511 getDataClause() != acc::DataClause::acc_copyin_readonly &&
512 getDataClause() != acc::DataClause::acc_copy &&
513 getDataClause() != acc::DataClause::acc_reduction)
514 return emitError(
515 "data clause associated with copyin operation must match its intent"
516 " or specify original clause this operation was decomposed from");
517 if (failed(checkVarAndVarType(*this)))
518 return failure();
519 if (failed(checkVarAndAccVar(*this)))
520 return failure();
521 return success();
522}
523
524bool acc::CopyinOp::isCopyinReadonly() {
525 return getDataClause() == acc::DataClause::acc_copyin_readonly;
526}
527
528//===----------------------------------------------------------------------===//
529// CreateOp
530//===----------------------------------------------------------------------===//
531LogicalResult acc::CreateOp::verify() {
532 // Test for all clauses this operation can be decomposed from:
533 if (getDataClause() != acc::DataClause::acc_create &&
534 getDataClause() != acc::DataClause::acc_create_zero &&
535 getDataClause() != acc::DataClause::acc_copyout &&
536 getDataClause() != acc::DataClause::acc_copyout_zero)
537 return emitError(
538 "data clause associated with create operation must match its intent"
539 " or specify original clause this operation was decomposed from");
540 if (failed(checkVarAndVarType(*this)))
541 return failure();
542 if (failed(checkVarAndAccVar(*this)))
543 return failure();
544 return success();
545}
546
547bool acc::CreateOp::isCreateZero() {
548 // The zero modifier is encoded in the data clause.
549 return getDataClause() == acc::DataClause::acc_create_zero ||
550 getDataClause() == acc::DataClause::acc_copyout_zero;
551}
552
553//===----------------------------------------------------------------------===//
554// NoCreateOp
555//===----------------------------------------------------------------------===//
556LogicalResult acc::NoCreateOp::verify() {
557 if (getDataClause() != acc::DataClause::acc_no_create)
558 return emitError("data clause associated with no_create operation must "
559 "match its intent");
560 if (failed(checkVarAndVarType(*this)))
561 return failure();
562 if (failed(checkVarAndAccVar(*this)))
563 return failure();
564 return success();
565}
566
567//===----------------------------------------------------------------------===//
568// AttachOp
569//===----------------------------------------------------------------------===//
570LogicalResult acc::AttachOp::verify() {
571 if (getDataClause() != acc::DataClause::acc_attach)
572 return emitError(
573 "data clause associated with attach operation must match its intent");
574 if (failed(checkVarAndVarType(*this)))
575 return failure();
576 if (failed(checkVarAndAccVar(*this)))
577 return failure();
578 return success();
579}
580
581//===----------------------------------------------------------------------===//
582// DeclareDeviceResidentOp
583//===----------------------------------------------------------------------===//
584
585LogicalResult acc::DeclareDeviceResidentOp::verify() {
586 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
587 return emitError("data clause associated with device_resident operation "
588 "must match its intent");
589 if (failed(checkVarAndVarType(*this)))
590 return failure();
591 if (failed(checkVarAndAccVar(*this)))
592 return failure();
593 return success();
594}
595
596//===----------------------------------------------------------------------===//
597// DeclareLinkOp
598//===----------------------------------------------------------------------===//
599
600LogicalResult acc::DeclareLinkOp::verify() {
601 if (getDataClause() != acc::DataClause::acc_declare_link)
602 return emitError(
603 "data clause associated with link operation must match its intent");
604 if (failed(checkVarAndVarType(*this)))
605 return failure();
606 if (failed(checkVarAndAccVar(*this)))
607 return failure();
608 return success();
609}
610
611//===----------------------------------------------------------------------===//
612// CopyoutOp
613//===----------------------------------------------------------------------===//
614LogicalResult acc::CopyoutOp::verify() {
615 // Test for all clauses this operation can be decomposed from:
616 if (getDataClause() != acc::DataClause::acc_copyout &&
617 getDataClause() != acc::DataClause::acc_copyout_zero &&
618 getDataClause() != acc::DataClause::acc_copy &&
619 getDataClause() != acc::DataClause::acc_reduction)
620 return emitError(
621 "data clause associated with copyout operation must match its intent"
622 " or specify original clause this operation was decomposed from");
623 if (!getVar() || !getAccVar())
624 return emitError("must have both host and device pointers");
625 if (failed(checkVarAndVarType(*this)))
626 return failure();
627 if (failed(checkVarAndAccVar(*this)))
628 return failure();
629 return success();
630}
631
632bool acc::CopyoutOp::isCopyoutZero() {
633 return getDataClause() == acc::DataClause::acc_copyout_zero;
634}
635
636//===----------------------------------------------------------------------===//
637// DeleteOp
638//===----------------------------------------------------------------------===//
639LogicalResult acc::DeleteOp::verify() {
640 // Test for all clauses this operation can be decomposed from:
641 if (getDataClause() != acc::DataClause::acc_delete &&
642 getDataClause() != acc::DataClause::acc_create &&
643 getDataClause() != acc::DataClause::acc_create_zero &&
644 getDataClause() != acc::DataClause::acc_copyin &&
645 getDataClause() != acc::DataClause::acc_copyin_readonly &&
646 getDataClause() != acc::DataClause::acc_present &&
647 getDataClause() != acc::DataClause::acc_no_create &&
648 getDataClause() != acc::DataClause::acc_declare_device_resident &&
649 getDataClause() != acc::DataClause::acc_declare_link)
650 return emitError(
651 "data clause associated with delete operation must match its intent"
652 " or specify original clause this operation was decomposed from");
653 if (!getAccVar())
654 return emitError("must have device pointer");
655 return success();
656}
657
658//===----------------------------------------------------------------------===//
659// DetachOp
660//===----------------------------------------------------------------------===//
661LogicalResult acc::DetachOp::verify() {
662 // Test for all clauses this operation can be decomposed from:
663 if (getDataClause() != acc::DataClause::acc_detach &&
664 getDataClause() != acc::DataClause::acc_attach)
665 return emitError(
666 "data clause associated with detach operation must match its intent"
667 " or specify original clause this operation was decomposed from");
668 if (!getAccVar())
669 return emitError("must have device pointer");
670 return success();
671}
672
673//===----------------------------------------------------------------------===//
674// HostOp
675//===----------------------------------------------------------------------===//
676LogicalResult acc::UpdateHostOp::verify() {
677 // Test for all clauses this operation can be decomposed from:
678 if (getDataClause() != acc::DataClause::acc_update_host &&
679 getDataClause() != acc::DataClause::acc_update_self)
680 return emitError(
681 "data clause associated with host operation must match its intent"
682 " or specify original clause this operation was decomposed from");
683 if (!getVar() || !getAccVar())
684 return emitError("must have both host and device pointers");
685 if (failed(checkVarAndVarType(*this)))
686 return failure();
687 if (failed(checkVarAndAccVar(*this)))
688 return failure();
689 return success();
690}
691
692//===----------------------------------------------------------------------===//
693// DeviceOp
694//===----------------------------------------------------------------------===//
695LogicalResult acc::UpdateDeviceOp::verify() {
696 // Test for all clauses this operation can be decomposed from:
697 if (getDataClause() != acc::DataClause::acc_update_device)
698 return emitError(
699 "data clause associated with device operation must match its intent"
700 " or specify original clause this operation was decomposed from");
701 if (failed(checkVarAndVarType(*this)))
702 return failure();
703 if (failed(checkVarAndAccVar(*this)))
704 return failure();
705 return success();
706}
707
708//===----------------------------------------------------------------------===//
709// UseDeviceOp
710//===----------------------------------------------------------------------===//
711LogicalResult acc::UseDeviceOp::verify() {
712 // Test for all clauses this operation can be decomposed from:
713 if (getDataClause() != acc::DataClause::acc_use_device)
714 return emitError(
715 "data clause associated with use_device operation must match its intent"
716 " or specify original clause this operation was decomposed from");
717 if (failed(checkVarAndVarType(*this)))
718 return failure();
719 if (failed(checkVarAndAccVar(*this)))
720 return failure();
721 return success();
722}
723
724//===----------------------------------------------------------------------===//
725// CacheOp
726//===----------------------------------------------------------------------===//
727LogicalResult acc::CacheOp::verify() {
728 // Test for all clauses this operation can be decomposed from:
729 if (getDataClause() != acc::DataClause::acc_cache &&
730 getDataClause() != acc::DataClause::acc_cache_readonly)
731 return emitError(
732 "data clause associated with cache operation must match its intent"
733 " or specify original clause this operation was decomposed from");
734 if (failed(checkVarAndVarType(*this)))
735 return failure();
736 if (failed(checkVarAndAccVar(*this)))
737 return failure();
738 return success();
739}
740
741template <typename StructureOp>
742static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
743 unsigned nRegions = 1) {
744
745 SmallVector<Region *, 2> regions;
746 for (unsigned i = 0; i < nRegions; ++i)
747 regions.push_back(Elt: state.addRegion());
748
749 for (Region *region : regions)
750 if (parser.parseRegion(region&: *region, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
751 return failure();
752
753 return success();
754}
755
756static bool isComputeOperation(Operation *op) {
757 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
758}
759
760namespace {
761/// Pattern to remove operation without region that have constant false `ifCond`
762/// and remove the condition from the operation if the `ifCond` is a true
763/// constant.
764template <typename OpTy>
765struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
766 using OpRewritePattern<OpTy>::OpRewritePattern;
767
768 LogicalResult matchAndRewrite(OpTy op,
769 PatternRewriter &rewriter) const override {
770 // Early return if there is no condition.
771 Value ifCond = op.getIfCond();
772 if (!ifCond)
773 return failure();
774
775 IntegerAttr constAttr;
776 if (!matchPattern(ifCond, m_Constant(&constAttr)))
777 return failure();
778 if (constAttr.getInt())
779 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
780 else
781 rewriter.eraseOp(op);
782
783 return success();
784 }
785};
786
787/// Replaces the given op with the contents of the given single-block region,
788/// using the operands of the block terminator to replace operation results.
789static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
790 Region &region, ValueRange blockArgs = {}) {
791 assert(llvm::hasSingleElement(region) && "expected single-region block");
792 Block *block = &region.front();
793 Operation *terminator = block->getTerminator();
794 ValueRange results = terminator->getOperands();
795 rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs);
796 rewriter.replaceOp(op, newValues: results);
797 rewriter.eraseOp(op: terminator);
798}
799
800/// Pattern to remove operation with region that have constant false `ifCond`
801/// and remove the condition from the operation if the `ifCond` is constant
802/// true.
803template <typename OpTy>
804struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
805 using OpRewritePattern<OpTy>::OpRewritePattern;
806
807 LogicalResult matchAndRewrite(OpTy op,
808 PatternRewriter &rewriter) const override {
809 // Early return if there is no condition.
810 Value ifCond = op.getIfCond();
811 if (!ifCond)
812 return failure();
813
814 IntegerAttr constAttr;
815 if (!matchPattern(ifCond, m_Constant(&constAttr)))
816 return failure();
817 if (constAttr.getInt())
818 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
819 else
820 replaceOpWithRegion(rewriter, op, op.getRegion());
821
822 return success();
823 }
824};
825
826} // namespace
827
828//===----------------------------------------------------------------------===//
829// PrivateRecipeOp
830//===----------------------------------------------------------------------===//
831
832static LogicalResult verifyInitLikeSingleArgRegion(
833 Operation *op, Region &region, StringRef regionType, StringRef regionName,
834 Type type, bool verifyYield, bool optional = false) {
835 if (optional && region.empty())
836 return success();
837
838 if (region.empty())
839 return op->emitOpError() << "expects non-empty " << regionName << " region";
840 Block &firstBlock = region.front();
841 if (firstBlock.getNumArguments() < 1 ||
842 firstBlock.getArgument(i: 0).getType() != type)
843 return op->emitOpError() << "expects " << regionName
844 << " region first "
845 "argument of the "
846 << regionType << " type";
847
848 if (verifyYield) {
849 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
850 if (yieldOp.getOperands().size() != 1 ||
851 yieldOp.getOperands().getTypes()[0] != type)
852 return op->emitOpError() << "expects " << regionName
853 << " region to "
854 "yield a value of the "
855 << regionType << " type";
856 }
857 }
858 return success();
859}
860
861LogicalResult acc::PrivateRecipeOp::verifyRegions() {
862 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
863 "privatization", "init", getType(),
864 /*verifyYield=*/false)))
865 return failure();
866 if (failed(verifyInitLikeSingleArgRegion(
867 *this, getDestroyRegion(), "privatization", "destroy", getType(),
868 /*verifyYield=*/false, /*optional=*/true)))
869 return failure();
870 return success();
871}
872
873//===----------------------------------------------------------------------===//
874// FirstprivateRecipeOp
875//===----------------------------------------------------------------------===//
876
877LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
878 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
879 "privatization", "init", getType(),
880 /*verifyYield=*/false)))
881 return failure();
882
883 if (getCopyRegion().empty())
884 return emitOpError() << "expects non-empty copy region";
885
886 Block &firstBlock = getCopyRegion().front();
887 if (firstBlock.getNumArguments() < 2 ||
888 firstBlock.getArgument(0).getType() != getType())
889 return emitOpError() << "expects copy region with two arguments of the "
890 "privatization type";
891
892 if (getDestroyRegion().empty())
893 return success();
894
895 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
896 "privatization", "destroy",
897 getType(), /*verifyYield=*/false)))
898 return failure();
899
900 return success();
901}
902
903//===----------------------------------------------------------------------===//
904// ReductionRecipeOp
905//===----------------------------------------------------------------------===//
906
907LogicalResult acc::ReductionRecipeOp::verifyRegions() {
908 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
909 "init", getType(),
910 /*verifyYield=*/false)))
911 return failure();
912
913 if (getCombinerRegion().empty())
914 return emitOpError() << "expects non-empty combiner region";
915
916 Block &reductionBlock = getCombinerRegion().front();
917 if (reductionBlock.getNumArguments() < 2 ||
918 reductionBlock.getArgument(0).getType() != getType() ||
919 reductionBlock.getArgument(1).getType() != getType())
920 return emitOpError() << "expects combiner region with the first two "
921 << "arguments of the reduction type";
922
923 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
924 if (yieldOp.getOperands().size() != 1 ||
925 yieldOp.getOperands().getTypes()[0] != getType())
926 return emitOpError() << "expects combiner region to yield a value "
927 "of the reduction type";
928 }
929
930 return success();
931}
932
933//===----------------------------------------------------------------------===//
934// Custom parser and printer verifier for private clause
935//===----------------------------------------------------------------------===//
936
937static ParseResult parseSymOperandList(
938 mlir::OpAsmParser &parser,
939 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
940 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
941 llvm::SmallVector<SymbolRefAttr> attributes;
942 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
943 if (parser.parseAttribute(attributes.emplace_back()) ||
944 parser.parseArrow() ||
945 parser.parseOperand(result&: operands.emplace_back()) ||
946 parser.parseColonType(result&: types.emplace_back()))
947 return failure();
948 return success();
949 })))
950 return failure();
951 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
952 attributes.end());
953 symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
954 return success();
955}
956
957static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
958 mlir::OperandRange operands,
959 mlir::TypeRange types,
960 std::optional<mlir::ArrayAttr> attributes) {
961 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) {
962 p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
963 << std::get<1>(it).getType();
964 });
965}
966
967//===----------------------------------------------------------------------===//
968// ParallelOp
969//===----------------------------------------------------------------------===//
970
971/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
972template <typename Op>
973static LogicalResult checkDataOperands(Op op,
974 const mlir::ValueRange &operands) {
975 for (mlir::Value operand : operands)
976 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
977 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
978 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
979 operand.getDefiningOp()))
980 return op.emitError(
981 "expect data entry/exit operation or acc.getdeviceptr "
982 "as defining op");
983 return success();
984}
985
986template <typename Op>
987static LogicalResult
988checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
989 mlir::OperandRange operands, llvm::StringRef operandName,
990 llvm::StringRef symbolName, bool checkOperandType = true) {
991 if (!operands.empty()) {
992 if (!attributes || attributes->size() != operands.size())
993 return op->emitOpError()
994 << "expected as many " << symbolName << " symbol reference as "
995 << operandName << " operands";
996 } else {
997 if (attributes)
998 return op->emitOpError()
999 << "unexpected " << symbolName << " symbol reference";
1000 return success();
1001 }
1002
1003 llvm::DenseSet<Value> set;
1004 for (auto args : llvm::zip(operands, *attributes)) {
1005 mlir::Value operand = std::get<0>(args);
1006
1007 if (!set.insert(operand).second)
1008 return op->emitOpError()
1009 << operandName << " operand appears more than once";
1010
1011 mlir::Type varType = operand.getType();
1012 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1013 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1014 if (!decl)
1015 return op->emitOpError()
1016 << "expected symbol reference " << symbolRef << " to point to a "
1017 << operandName << " declaration";
1018
1019 if (checkOperandType && decl.getType() && decl.getType() != varType)
1020 return op->emitOpError() << "expected " << operandName << " (" << varType
1021 << ") to be the same type as " << operandName
1022 << " declaration (" << decl.getType() << ")";
1023 }
1024
1025 return success();
1026}
1027
1028unsigned ParallelOp::getNumDataOperands() {
1029 return getReductionOperands().size() + getPrivateOperands().size() +
1030 getFirstprivateOperands().size() + getDataClauseOperands().size();
1031}
1032
1033Value ParallelOp::getDataOperand(unsigned i) {
1034 unsigned numOptional = getAsyncOperands().size();
1035 numOptional += getNumGangs().size();
1036 numOptional += getNumWorkers().size();
1037 numOptional += getVectorLength().size();
1038 numOptional += getIfCond() ? 1 : 0;
1039 numOptional += getSelfCond() ? 1 : 0;
1040 return getOperand(getWaitOperands().size() + numOptional + i);
1041}
1042
1043template <typename Op>
1044static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1045 ArrayAttr deviceTypes,
1046 llvm::StringRef keyword) {
1047 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1048 return op.emitOpError() << keyword << " operands count must match "
1049 << keyword << " device_type count";
1050 return success();
1051}
1052
1053template <typename Op>
1054static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
1055 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1056 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1057 std::size_t numOperandsInSegments = 0;
1058 std::size_t nbOfSegments = 0;
1059
1060 if (segments) {
1061 for (auto segCount : segments.asArrayRef()) {
1062 if (maxInSegment != 0 && segCount > maxInSegment)
1063 return op.emitOpError() << keyword << " expects a maximum of "
1064 << maxInSegment << " values per segment";
1065 numOperandsInSegments += segCount;
1066 ++nbOfSegments;
1067 }
1068 }
1069
1070 if ((numOperandsInSegments != operands.size()) ||
1071 (!deviceTypes && !operands.empty()))
1072 return op.emitOpError()
1073 << keyword << " operand count does not match count in segments";
1074 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1075 return op.emitOpError()
1076 << keyword << " segment count does not match device_type count";
1077 return success();
1078}
1079
1080LogicalResult acc::ParallelOp::verify() {
1081 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1082 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1083 "privatizations", /*checkOperandType=*/false)))
1084 return failure();
1085 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1086 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1087 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1088 return failure();
1089 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1090 *this, getReductionRecipes(), getReductionOperands(), "reduction",
1091 "reductions", false)))
1092 return failure();
1093
1094 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1095 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1096 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1097 return failure();
1098
1099 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1100 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1101 getWaitOperandsDeviceTypeAttr(), "wait")))
1102 return failure();
1103
1104 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1105 getNumWorkersDeviceTypeAttr(),
1106 "num_workers")))
1107 return failure();
1108
1109 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1110 getVectorLengthDeviceTypeAttr(),
1111 "vector_length")))
1112 return failure();
1113
1114 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1115 getAsyncOperandsDeviceTypeAttr(),
1116 "async")))
1117 return failure();
1118
1119 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
1120 return failure();
1121
1122 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1123}
1124
1125static mlir::Value
1126getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1127 mlir::Operation::operand_range range,
1128 mlir::acc::DeviceType deviceType) {
1129 if (!arrayAttr)
1130 return {};
1131 if (auto pos = findSegment(*arrayAttr, deviceType))
1132 return range[*pos];
1133 return {};
1134}
1135
1136bool acc::ParallelOp::hasAsyncOnly() {
1137 return hasAsyncOnly(mlir::acc::DeviceType::None);
1138}
1139
1140bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1141 return hasDeviceType(getAsyncOnly(), deviceType);
1142}
1143
1144mlir::Value acc::ParallelOp::getAsyncValue() {
1145 return getAsyncValue(mlir::acc::DeviceType::None);
1146}
1147
1148mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1149 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1150 getAsyncOperands(), deviceType);
1151}
1152
1153mlir::Value acc::ParallelOp::getNumWorkersValue() {
1154 return getNumWorkersValue(mlir::acc::DeviceType::None);
1155}
1156
1157mlir::Value
1158acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1159 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1160 deviceType);
1161}
1162
1163mlir::Value acc::ParallelOp::getVectorLengthValue() {
1164 return getVectorLengthValue(mlir::acc::DeviceType::None);
1165}
1166
1167mlir::Value
1168acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1169 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1170 getVectorLength(), deviceType);
1171}
1172
1173mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1174 return getNumGangsValues(mlir::acc::DeviceType::None);
1175}
1176
1177mlir::Operation::operand_range
1178ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1179 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1180 getNumGangsSegments(), deviceType);
1181}
1182
1183bool acc::ParallelOp::hasWaitOnly() {
1184 return hasWaitOnly(mlir::acc::DeviceType::None);
1185}
1186
1187bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1188 return hasDeviceType(getWaitOnly(), deviceType);
1189}
1190
1191mlir::Operation::operand_range ParallelOp::getWaitValues() {
1192 return getWaitValues(mlir::acc::DeviceType::None);
1193}
1194
1195mlir::Operation::operand_range
1196ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1197 return getWaitValuesWithoutDevnum(
1198 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1199 getHasWaitDevnum(), deviceType);
1200}
1201
1202mlir::Value ParallelOp::getWaitDevnum() {
1203 return getWaitDevnum(mlir::acc::DeviceType::None);
1204}
1205
1206mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1207 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1208 getWaitOperandsSegments(), getHasWaitDevnum(),
1209 deviceType);
1210}
1211
1212void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1213 mlir::OperationState &odsState,
1214 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1215 mlir::ValueRange vectorLength,
1216 mlir::ValueRange asyncOperands,
1217 mlir::ValueRange waitOperands, mlir::Value ifCond,
1218 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1219 mlir::ValueRange gangPrivateOperands,
1220 mlir::ValueRange gangFirstPrivateOperands,
1221 mlir::ValueRange dataClauseOperands) {
1222
1223 ParallelOp::build(
1224 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1225 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1226 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1227 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1228 /*numGangsDeviceType=*/nullptr, numWorkers,
1229 /*numWorkersDeviceType=*/nullptr, vectorLength,
1230 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1231 /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1232 gangPrivateOperands, /*privatizations=*/nullptr, gangFirstPrivateOperands,
1233 /*firstprivatizations=*/nullptr, dataClauseOperands,
1234 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1235}
1236
1237void acc::ParallelOp::addNumWorkersOperand(
1238 MLIRContext *context, mlir::Value newValue,
1239 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1240 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1241 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1242 getNumWorkersMutable()));
1243}
1244void acc::ParallelOp::addVectorLengthOperand(
1245 MLIRContext *context, mlir::Value newValue,
1246 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1247 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1248 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1249 getVectorLengthMutable()));
1250}
1251
1252void acc::ParallelOp::addAsyncOnly(
1253 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1254 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1255 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1256}
1257
1258void acc::ParallelOp::addAsyncOperand(
1259 MLIRContext *context, mlir::Value newValue,
1260 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1261 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1262 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1263 getAsyncOperandsMutable()));
1264}
1265
1266void acc::ParallelOp::addNumGangsOperands(
1267 MLIRContext *context, mlir::ValueRange newValues,
1268 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1269 llvm::SmallVector<int32_t> segments;
1270 if (getNumGangsSegments())
1271 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1272
1273 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1274 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1275 getNumGangsMutable(), segments));
1276
1277 setNumGangsSegments(segments);
1278}
1279void acc::ParallelOp::addWaitOnly(
1280 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1281 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1282 effectiveDeviceTypes));
1283}
1284void acc::ParallelOp::addWaitOperands(
1285 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1286 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1287
1288 llvm::SmallVector<int32_t> segments;
1289 if (getWaitOperandsSegments())
1290 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1291
1292 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1293 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1294 getWaitOperandsMutable(), segments));
1295 setWaitOperandsSegments(segments);
1296
1297 llvm::SmallVector<mlir::Attribute> hasDevnums;
1298 if (getHasWaitDevnumAttr())
1299 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1300 hasDevnums.insert(
1301 hasDevnums.end(),
1302 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1303 mlir::BoolAttr::get(context, hasDevnum));
1304 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1305}
1306
1307static ParseResult parseNumGangs(
1308 mlir::OpAsmParser &parser,
1309 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1310 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1311 mlir::DenseI32ArrayAttr &segments) {
1312 llvm::SmallVector<DeviceTypeAttr> attributes;
1313 llvm::SmallVector<int32_t> seg;
1314
1315 do {
1316 if (failed(Result: parser.parseLBrace()))
1317 return failure();
1318
1319 int32_t crtOperandsSize = operands.size();
1320 if (failed(Result: parser.parseCommaSeparatedList(
1321 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
1322 if (parser.parseOperand(result&: operands.emplace_back()) ||
1323 parser.parseColonType(result&: types.emplace_back()))
1324 return failure();
1325 return success();
1326 })))
1327 return failure();
1328 seg.push_back(Elt: operands.size() - crtOperandsSize);
1329
1330 if (failed(Result: parser.parseRBrace()))
1331 return failure();
1332
1333 if (succeeded(Result: parser.parseOptionalLSquare())) {
1334 if (parser.parseAttribute(attributes.emplace_back()) ||
1335 parser.parseRSquare())
1336 return failure();
1337 } else {
1338 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1339 parser.getContext(), mlir::acc::DeviceType::None));
1340 }
1341 } while (succeeded(Result: parser.parseOptionalComma()));
1342
1343 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1344 attributes.end());
1345 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1346 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1347
1348 return success();
1349}
1350
1351static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
1352 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1353 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1354 p << " [" << attr << "]";
1355}
1356
1357static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
1358 mlir::OperandRange operands, mlir::TypeRange types,
1359 std::optional<mlir::ArrayAttr> deviceTypes,
1360 std::optional<mlir::DenseI32ArrayAttr> segments) {
1361 unsigned opIdx = 0;
1362 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1363 p << "{";
1364 llvm::interleaveComma(
1365 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1366 p << operands[opIdx] << " : " << operands[opIdx].getType();
1367 ++opIdx;
1368 });
1369 p << "}";
1370 printSingleDeviceType(p, it.value());
1371 });
1372}
1373
1374static ParseResult parseDeviceTypeOperandsWithSegment(
1375 mlir::OpAsmParser &parser,
1376 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1377 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1378 mlir::DenseI32ArrayAttr &segments) {
1379 llvm::SmallVector<DeviceTypeAttr> attributes;
1380 llvm::SmallVector<int32_t> seg;
1381
1382 do {
1383 if (failed(Result: parser.parseLBrace()))
1384 return failure();
1385
1386 int32_t crtOperandsSize = operands.size();
1387
1388 if (failed(Result: parser.parseCommaSeparatedList(
1389 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
1390 if (parser.parseOperand(result&: operands.emplace_back()) ||
1391 parser.parseColonType(result&: types.emplace_back()))
1392 return failure();
1393 return success();
1394 })))
1395 return failure();
1396
1397 seg.push_back(Elt: operands.size() - crtOperandsSize);
1398
1399 if (failed(Result: parser.parseRBrace()))
1400 return failure();
1401
1402 if (succeeded(Result: parser.parseOptionalLSquare())) {
1403 if (parser.parseAttribute(attributes.emplace_back()) ||
1404 parser.parseRSquare())
1405 return failure();
1406 } else {
1407 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1408 parser.getContext(), mlir::acc::DeviceType::None));
1409 }
1410 } while (succeeded(Result: parser.parseOptionalComma()));
1411
1412 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1413 attributes.end());
1414 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1415 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1416
1417 return success();
1418}
1419
1420static void printDeviceTypeOperandsWithSegment(
1421 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1422 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1423 std::optional<mlir::DenseI32ArrayAttr> segments) {
1424 unsigned opIdx = 0;
1425 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1426 p << "{";
1427 llvm::interleaveComma(
1428 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1429 p << operands[opIdx] << " : " << operands[opIdx].getType();
1430 ++opIdx;
1431 });
1432 p << "}";
1433 printSingleDeviceType(p, it.value());
1434 });
1435}
1436
1437static ParseResult parseWaitClause(
1438 mlir::OpAsmParser &parser,
1439 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1440 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1441 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1442 mlir::ArrayAttr &keywordOnly) {
1443 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1444 llvm::SmallVector<int32_t> seg;
1445
1446 bool needCommaBeforeOperands = false;
1447
1448 // Keyword only
1449 if (failed(Result: parser.parseOptionalLParen())) {
1450 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1451 parser.getContext(), mlir::acc::DeviceType::None));
1452 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1453 return success();
1454 }
1455
1456 // Parse keyword only attributes
1457 if (succeeded(Result: parser.parseOptionalLSquare())) {
1458 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1459 if (parser.parseAttribute(result&: keywordAttrs.emplace_back()))
1460 return failure();
1461 return success();
1462 })))
1463 return failure();
1464 if (parser.parseRSquare())
1465 return failure();
1466 needCommaBeforeOperands = true;
1467 }
1468
1469 if (needCommaBeforeOperands && failed(Result: parser.parseComma()))
1470 return failure();
1471
1472 do {
1473 if (failed(Result: parser.parseLBrace()))
1474 return failure();
1475
1476 int32_t crtOperandsSize = operands.size();
1477
1478 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "devnum"))) {
1479 if (failed(Result: parser.parseColon()))
1480 return failure();
1481 devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: true));
1482 } else {
1483 devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: false));
1484 }
1485
1486 if (failed(Result: parser.parseCommaSeparatedList(
1487 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
1488 if (parser.parseOperand(result&: operands.emplace_back()) ||
1489 parser.parseColonType(result&: types.emplace_back()))
1490 return failure();
1491 return success();
1492 })))
1493 return failure();
1494
1495 seg.push_back(Elt: operands.size() - crtOperandsSize);
1496
1497 if (failed(Result: parser.parseRBrace()))
1498 return failure();
1499
1500 if (succeeded(Result: parser.parseOptionalLSquare())) {
1501 if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) ||
1502 parser.parseRSquare())
1503 return failure();
1504 } else {
1505 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1506 parser.getContext(), mlir::acc::DeviceType::None));
1507 }
1508 } while (succeeded(Result: parser.parseOptionalComma()));
1509
1510 if (failed(Result: parser.parseRParen()))
1511 return failure();
1512
1513 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
1514 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
1515 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
1516 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
1517
1518 return success();
1519}
1520
1521static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1522 if (!hasDeviceTypeValues(arrayAttr: attrs))
1523 return false;
1524 if (attrs->size() != 1)
1525 return false;
1526 if (auto deviceTypeAttr =
1527 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1528 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1529 return false;
1530}
1531
1532static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
1533 mlir::OperandRange operands, mlir::TypeRange types,
1534 std::optional<mlir::ArrayAttr> deviceTypes,
1535 std::optional<mlir::DenseI32ArrayAttr> segments,
1536 std::optional<mlir::ArrayAttr> hasDevNum,
1537 std::optional<mlir::ArrayAttr> keywordOnly) {
1538
1539 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(attrs: keywordOnly))
1540 return;
1541
1542 p << "(";
1543
1544 printDeviceTypes(p, deviceTypes: keywordOnly);
1545 if (hasDeviceTypeValues(arrayAttr: keywordOnly) && hasDeviceTypeValues(arrayAttr: deviceTypes))
1546 p << ", ";
1547
1548 if (hasDeviceTypeValues(arrayAttr: deviceTypes)) {
1549 unsigned opIdx = 0;
1550 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1551 p << "{";
1552 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1553 if (boolAttr && boolAttr.getValue())
1554 p << "devnum: ";
1555 llvm::interleaveComma(
1556 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1557 p << operands[opIdx] << " : " << operands[opIdx].getType();
1558 ++opIdx;
1559 });
1560 p << "}";
1561 printSingleDeviceType(p, it.value());
1562 });
1563 }
1564
1565 p << ")";
1566}
1567
1568static ParseResult parseDeviceTypeOperands(
1569 mlir::OpAsmParser &parser,
1570 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1571 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1572 llvm::SmallVector<DeviceTypeAttr> attributes;
1573 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1574 if (parser.parseOperand(result&: operands.emplace_back()) ||
1575 parser.parseColonType(result&: types.emplace_back()))
1576 return failure();
1577 if (succeeded(Result: parser.parseOptionalLSquare())) {
1578 if (parser.parseAttribute(attributes.emplace_back()) ||
1579 parser.parseRSquare())
1580 return failure();
1581 } else {
1582 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1583 parser.getContext(), mlir::acc::DeviceType::None));
1584 }
1585 return success();
1586 })))
1587 return failure();
1588 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1589 attributes.end());
1590 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1591 return success();
1592}
1593
1594static void
1595printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
1596 mlir::OperandRange operands, mlir::TypeRange types,
1597 std::optional<mlir::ArrayAttr> deviceTypes) {
1598 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
1599 return;
1600 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
1601 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1602 printSingleDeviceType(p, std::get<0>(it));
1603 });
1604}
1605
1606static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
1607 mlir::OpAsmParser &parser,
1608 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1609 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1610 mlir::ArrayAttr &keywordOnlyDeviceType) {
1611
1612 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1613 bool needCommaBeforeOperands = false;
1614
1615 if (failed(Result: parser.parseOptionalLParen())) {
1616 // Keyword only
1617 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
1618 parser.getContext(), mlir::acc::DeviceType::None));
1619 keywordOnlyDeviceType =
1620 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
1621 return success();
1622 }
1623
1624 // Parse keyword only attributes
1625 if (succeeded(Result: parser.parseOptionalLSquare())) {
1626 // Parse keyword only attributes
1627 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1628 if (parser.parseAttribute(
1629 result&: keywordOnlyDeviceTypeAttributes.emplace_back()))
1630 return failure();
1631 return success();
1632 })))
1633 return failure();
1634 if (parser.parseRSquare())
1635 return failure();
1636 needCommaBeforeOperands = true;
1637 }
1638
1639 if (needCommaBeforeOperands && failed(Result: parser.parseComma()))
1640 return failure();
1641
1642 llvm::SmallVector<DeviceTypeAttr> attributes;
1643 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1644 if (parser.parseOperand(result&: operands.emplace_back()) ||
1645 parser.parseColonType(result&: types.emplace_back()))
1646 return failure();
1647 if (succeeded(Result: parser.parseOptionalLSquare())) {
1648 if (parser.parseAttribute(attributes.emplace_back()) ||
1649 parser.parseRSquare())
1650 return failure();
1651 } else {
1652 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1653 parser.getContext(), mlir::acc::DeviceType::None));
1654 }
1655 return success();
1656 })))
1657 return failure();
1658
1659 if (failed(Result: parser.parseRParen()))
1660 return failure();
1661
1662 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1663 attributes.end());
1664 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
1665 return success();
1666}
1667
1668static void printDeviceTypeOperandsWithKeywordOnly(
1669 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1670 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1671 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1672
1673 if (operands.begin() == operands.end() &&
1674 hasOnlyDeviceTypeNone(attrs: keywordOnlyDeviceTypes)) {
1675 return;
1676 }
1677
1678 p << "(";
1679 printDeviceTypes(p, deviceTypes: keywordOnlyDeviceTypes);
1680 if (hasDeviceTypeValues(arrayAttr: keywordOnlyDeviceTypes) &&
1681 hasDeviceTypeValues(arrayAttr: deviceTypes))
1682 p << ", ";
1683 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1684 p << ")";
1685}
1686
1687static ParseResult parseOperandWithKeywordOnly(
1688 mlir::OpAsmParser &parser,
1689 std::optional<OpAsmParser::UnresolvedOperand> &operand,
1690 mlir::Type &operandType, mlir::UnitAttr &attr) {
1691 // Keyword only
1692 if (failed(Result: parser.parseOptionalLParen())) {
1693 attr = mlir::UnitAttr::get(parser.getContext());
1694 return success();
1695 }
1696
1697 OpAsmParser::UnresolvedOperand op;
1698 if (failed(Result: parser.parseOperand(result&: op)))
1699 return failure();
1700 operand = op;
1701 if (failed(Result: parser.parseColon()))
1702 return failure();
1703 if (failed(Result: parser.parseType(result&: operandType)))
1704 return failure();
1705 if (failed(Result: parser.parseRParen()))
1706 return failure();
1707
1708 return success();
1709}
1710
1711static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p,
1712 mlir::Operation *op,
1713 std::optional<mlir::Value> operand,
1714 mlir::Type operandType,
1715 mlir::UnitAttr attr) {
1716 if (attr)
1717 return;
1718
1719 p << "(";
1720 p.printOperand(value: *operand);
1721 p << " : ";
1722 p.printType(type: operandType);
1723 p << ")";
1724}
1725
1726static ParseResult parseOperandsWithKeywordOnly(
1727 mlir::OpAsmParser &parser,
1728 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1729 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
1730 // Keyword only
1731 if (failed(Result: parser.parseOptionalLParen())) {
1732 attr = mlir::UnitAttr::get(parser.getContext());
1733 return success();
1734 }
1735
1736 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1737 if (parser.parseOperand(result&: operands.emplace_back()))
1738 return failure();
1739 return success();
1740 })))
1741 return failure();
1742 if (failed(Result: parser.parseColon()))
1743 return failure();
1744 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1745 if (parser.parseType(result&: types.emplace_back()))
1746 return failure();
1747 return success();
1748 })))
1749 return failure();
1750 if (failed(Result: parser.parseRParen()))
1751 return failure();
1752
1753 return success();
1754}
1755
1756static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p,
1757 mlir::Operation *op,
1758 mlir::OperandRange operands,
1759 mlir::TypeRange types,
1760 mlir::UnitAttr attr) {
1761 if (attr)
1762 return;
1763
1764 p << "(";
1765 llvm::interleaveComma(c: operands, os&: p, each_fn: [&](auto it) { p << it; });
1766 p << " : ";
1767 llvm::interleaveComma(c: types, os&: p, each_fn: [&](auto it) { p << it; });
1768 p << ")";
1769}
1770
1771static ParseResult
1772parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
1773 mlir::acc::CombinedConstructsTypeAttr &attr) {
1774 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "kernels"))) {
1775 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1776 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1777 } else if (succeeded(Result: parser.parseOptionalKeyword(keyword: "parallel"))) {
1778 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1779 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1780 } else if (succeeded(Result: parser.parseOptionalKeyword(keyword: "serial"))) {
1781 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1782 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1783 } else {
1784 parser.emitError(loc: parser.getCurrentLocation(),
1785 message: "expected compute construct name");
1786 return failure();
1787 }
1788 return success();
1789}
1790
1791static void
1792printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
1793 mlir::acc::CombinedConstructsTypeAttr attr) {
1794 if (attr) {
1795 switch (attr.getValue()) {
1796 case mlir::acc::CombinedConstructsType::KernelsLoop:
1797 p << "kernels";
1798 break;
1799 case mlir::acc::CombinedConstructsType::ParallelLoop:
1800 p << "parallel";
1801 break;
1802 case mlir::acc::CombinedConstructsType::SerialLoop:
1803 p << "serial";
1804 break;
1805 };
1806 }
1807}
1808
1809//===----------------------------------------------------------------------===//
1810// SerialOp
1811//===----------------------------------------------------------------------===//
1812
1813unsigned SerialOp::getNumDataOperands() {
1814 return getReductionOperands().size() + getPrivateOperands().size() +
1815 getFirstprivateOperands().size() + getDataClauseOperands().size();
1816}
1817
1818Value SerialOp::getDataOperand(unsigned i) {
1819 unsigned numOptional = getAsyncOperands().size();
1820 numOptional += getIfCond() ? 1 : 0;
1821 numOptional += getSelfCond() ? 1 : 0;
1822 return getOperand(getWaitOperands().size() + numOptional + i);
1823}
1824
1825bool acc::SerialOp::hasAsyncOnly() {
1826 return hasAsyncOnly(mlir::acc::DeviceType::None);
1827}
1828
1829bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1830 return hasDeviceType(getAsyncOnly(), deviceType);
1831}
1832
1833mlir::Value acc::SerialOp::getAsyncValue() {
1834 return getAsyncValue(mlir::acc::DeviceType::None);
1835}
1836
1837mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1838 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1839 getAsyncOperands(), deviceType);
1840}
1841
1842bool acc::SerialOp::hasWaitOnly() {
1843 return hasWaitOnly(mlir::acc::DeviceType::None);
1844}
1845
1846bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1847 return hasDeviceType(getWaitOnly(), deviceType);
1848}
1849
1850mlir::Operation::operand_range SerialOp::getWaitValues() {
1851 return getWaitValues(mlir::acc::DeviceType::None);
1852}
1853
1854mlir::Operation::operand_range
1855SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1856 return getWaitValuesWithoutDevnum(
1857 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1858 getHasWaitDevnum(), deviceType);
1859}
1860
1861mlir::Value SerialOp::getWaitDevnum() {
1862 return getWaitDevnum(mlir::acc::DeviceType::None);
1863}
1864
1865mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1866 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1867 getWaitOperandsSegments(), getHasWaitDevnum(),
1868 deviceType);
1869}
1870
1871LogicalResult acc::SerialOp::verify() {
1872 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1873 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
1874 "privatizations", /*checkOperandType=*/false)))
1875 return failure();
1876 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1877 *this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1878 "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
1879 return failure();
1880 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1881 *this, getReductionRecipes(), getReductionOperands(), "reduction",
1882 "reductions", false)))
1883 return failure();
1884
1885 if (failed(verifyDeviceTypeAndSegmentCountMatch(
1886 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1887 getWaitOperandsDeviceTypeAttr(), "wait")))
1888 return failure();
1889
1890 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
1891 getAsyncOperandsDeviceTypeAttr(),
1892 "async")))
1893 return failure();
1894
1895 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
1896 return failure();
1897
1898 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
1899}
1900
1901void acc::SerialOp::addAsyncOnly(
1902 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1903 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1904 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1905}
1906
1907void acc::SerialOp::addAsyncOperand(
1908 MLIRContext *context, mlir::Value newValue,
1909 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1910 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1911 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1912 getAsyncOperandsMutable()));
1913}
1914
1915void acc::SerialOp::addWaitOnly(
1916 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1917 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1918 effectiveDeviceTypes));
1919}
1920void acc::SerialOp::addWaitOperands(
1921 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1922 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1923
1924 llvm::SmallVector<int32_t> segments;
1925 if (getWaitOperandsSegments())
1926 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1927
1928 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1929 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1930 getWaitOperandsMutable(), segments));
1931 setWaitOperandsSegments(segments);
1932
1933 llvm::SmallVector<mlir::Attribute> hasDevnums;
1934 if (getHasWaitDevnumAttr())
1935 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1936 hasDevnums.insert(
1937 hasDevnums.end(),
1938 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1939 mlir::BoolAttr::get(context, hasDevnum));
1940 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1941}
1942
1943//===----------------------------------------------------------------------===//
1944// KernelsOp
1945//===----------------------------------------------------------------------===//
1946
1947unsigned KernelsOp::getNumDataOperands() {
1948 return getDataClauseOperands().size();
1949}
1950
1951Value KernelsOp::getDataOperand(unsigned i) {
1952 unsigned numOptional = getAsyncOperands().size();
1953 numOptional += getWaitOperands().size();
1954 numOptional += getNumGangs().size();
1955 numOptional += getNumWorkers().size();
1956 numOptional += getVectorLength().size();
1957 numOptional += getIfCond() ? 1 : 0;
1958 numOptional += getSelfCond() ? 1 : 0;
1959 return getOperand(numOptional + i);
1960}
1961
1962bool acc::KernelsOp::hasAsyncOnly() {
1963 return hasAsyncOnly(mlir::acc::DeviceType::None);
1964}
1965
1966bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1967 return hasDeviceType(getAsyncOnly(), deviceType);
1968}
1969
1970mlir::Value acc::KernelsOp::getAsyncValue() {
1971 return getAsyncValue(mlir::acc::DeviceType::None);
1972}
1973
1974mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1975 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
1976 getAsyncOperands(), deviceType);
1977}
1978
1979mlir::Value acc::KernelsOp::getNumWorkersValue() {
1980 return getNumWorkersValue(mlir::acc::DeviceType::None);
1981}
1982
1983mlir::Value
1984acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1985 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1986 deviceType);
1987}
1988
1989mlir::Value acc::KernelsOp::getVectorLengthValue() {
1990 return getVectorLengthValue(mlir::acc::DeviceType::None);
1991}
1992
1993mlir::Value
1994acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1995 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1996 getVectorLength(), deviceType);
1997}
1998
1999mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2000 return getNumGangsValues(mlir::acc::DeviceType::None);
2001}
2002
2003mlir::Operation::operand_range
2004KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2005 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2006 getNumGangsSegments(), deviceType);
2007}
2008
2009bool acc::KernelsOp::hasWaitOnly() {
2010 return hasWaitOnly(mlir::acc::DeviceType::None);
2011}
2012
2013bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2014 return hasDeviceType(getWaitOnly(), deviceType);
2015}
2016
2017mlir::Operation::operand_range KernelsOp::getWaitValues() {
2018 return getWaitValues(mlir::acc::DeviceType::None);
2019}
2020
2021mlir::Operation::operand_range
2022KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2023 return getWaitValuesWithoutDevnum(
2024 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2025 getHasWaitDevnum(), deviceType);
2026}
2027
2028mlir::Value KernelsOp::getWaitDevnum() {
2029 return getWaitDevnum(mlir::acc::DeviceType::None);
2030}
2031
2032mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2033 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2034 getWaitOperandsSegments(), getHasWaitDevnum(),
2035 deviceType);
2036}
2037
2038LogicalResult acc::KernelsOp::verify() {
2039 if (failed(verifyDeviceTypeAndSegmentCountMatch(
2040 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2041 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2042 return failure();
2043
2044 if (failed(verifyDeviceTypeAndSegmentCountMatch(
2045 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2046 getWaitOperandsDeviceTypeAttr(), "wait")))
2047 return failure();
2048
2049 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2050 getNumWorkersDeviceTypeAttr(),
2051 "num_workers")))
2052 return failure();
2053
2054 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2055 getVectorLengthDeviceTypeAttr(),
2056 "vector_length")))
2057 return failure();
2058
2059 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
2060 getAsyncOperandsDeviceTypeAttr(),
2061 "async")))
2062 return failure();
2063
2064 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
2065 return failure();
2066
2067 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2068}
2069
2070void acc::KernelsOp::addNumWorkersOperand(
2071 MLIRContext *context, mlir::Value newValue,
2072 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2073 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2074 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2075 getNumWorkersMutable()));
2076}
2077
2078void acc::KernelsOp::addVectorLengthOperand(
2079 MLIRContext *context, mlir::Value newValue,
2080 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2081 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2082 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2083 getVectorLengthMutable()));
2084}
2085void acc::KernelsOp::addAsyncOnly(
2086 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2087 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2088 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2089}
2090
2091void acc::KernelsOp::addAsyncOperand(
2092 MLIRContext *context, mlir::Value newValue,
2093 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2094 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2095 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2096 getAsyncOperandsMutable()));
2097}
2098
2099void acc::KernelsOp::addNumGangsOperands(
2100 MLIRContext *context, mlir::ValueRange newValues,
2101 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2102 llvm::SmallVector<int32_t> segments;
2103 if (getNumGangsSegmentsAttr())
2104 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2105
2106 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2107 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2108 getNumGangsMutable(), segments));
2109
2110 setNumGangsSegments(segments);
2111}
2112
2113void acc::KernelsOp::addWaitOnly(
2114 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2115 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2116 effectiveDeviceTypes));
2117}
2118void acc::KernelsOp::addWaitOperands(
2119 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2120 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2121
2122 llvm::SmallVector<int32_t> segments;
2123 if (getWaitOperandsSegments())
2124 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2125
2126 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2127 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2128 getWaitOperandsMutable(), segments));
2129 setWaitOperandsSegments(segments);
2130
2131 llvm::SmallVector<mlir::Attribute> hasDevnums;
2132 if (getHasWaitDevnumAttr())
2133 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2134 hasDevnums.insert(
2135 hasDevnums.end(),
2136 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2137 mlir::BoolAttr::get(context, hasDevnum));
2138 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2139}
2140
2141//===----------------------------------------------------------------------===//
2142// HostDataOp
2143//===----------------------------------------------------------------------===//
2144
2145LogicalResult acc::HostDataOp::verify() {
2146 if (getDataClauseOperands().empty())
2147 return emitError("at least one operand must appear on the host_data "
2148 "operation");
2149
2150 for (mlir::Value operand : getDataClauseOperands())
2151 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2152 return emitError("expect data entry operation as defining op");
2153 return success();
2154}
2155
2156void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2157 MLIRContext *context) {
2158 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2159}
2160
2161//===----------------------------------------------------------------------===//
2162// LoopOp
2163//===----------------------------------------------------------------------===//
2164
2165static ParseResult parseGangValue(
2166 OpAsmParser &parser, llvm::StringRef keyword,
2167 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
2168 llvm::SmallVectorImpl<Type> &types,
2169 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2170 bool &needCommaBetweenValues, bool &newValue) {
2171 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
2172 if (parser.parseEqual())
2173 return failure();
2174 if (parser.parseOperand(result&: operands.emplace_back()) ||
2175 parser.parseColonType(result&: types.emplace_back()))
2176 return failure();
2177 attributes.push_back(gangArgType);
2178 needCommaBetweenValues = true;
2179 newValue = true;
2180 }
2181 return success();
2182}
2183
2184static ParseResult parseGangClause(
2185 OpAsmParser &parser,
2186 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
2187 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2188 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2189 mlir::ArrayAttr &gangOnlyDeviceType) {
2190 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2191 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2192 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2193 llvm::SmallVector<int32_t> seg;
2194 bool needCommaBetweenValues = false;
2195 bool needCommaBeforeOperands = false;
2196
2197 if (failed(Result: parser.parseOptionalLParen())) {
2198 // Gang only keyword
2199 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2200 parser.getContext(), mlir::acc::DeviceType::None));
2201 gangOnlyDeviceType =
2202 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2203 return success();
2204 }
2205
2206 // Parse gang only attributes
2207 if (succeeded(Result: parser.parseOptionalLSquare())) {
2208 // Parse gang only attributes
2209 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
2210 if (parser.parseAttribute(
2211 result&: gangOnlyDeviceTypeAttributes.emplace_back()))
2212 return failure();
2213 return success();
2214 })))
2215 return failure();
2216 if (parser.parseRSquare())
2217 return failure();
2218 needCommaBeforeOperands = true;
2219 }
2220
2221 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2222 mlir::acc::GangArgType::Num);
2223 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2224 mlir::acc::GangArgType::Dim);
2225 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2226 parser.getContext(), mlir::acc::GangArgType::Static);
2227
2228 do {
2229 if (needCommaBeforeOperands) {
2230 needCommaBeforeOperands = false;
2231 continue;
2232 }
2233
2234 if (failed(Result: parser.parseLBrace()))
2235 return failure();
2236
2237 int32_t crtOperandsSize = gangOperands.size();
2238 while (true) {
2239 bool newValue = false;
2240 bool needValue = false;
2241 if (needCommaBetweenValues) {
2242 if (succeeded(Result: parser.parseOptionalComma()))
2243 needValue = true; // expect a new value after comma.
2244 else
2245 break;
2246 }
2247
2248 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2249 gangOperands, gangOperandsType,
2250 gangArgTypeAttributes, argNum,
2251 needCommaBetweenValues, newValue)))
2252 return failure();
2253 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2254 gangOperands, gangOperandsType,
2255 gangArgTypeAttributes, argDim,
2256 needCommaBetweenValues, newValue)))
2257 return failure();
2258 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2259 gangOperands, gangOperandsType,
2260 gangArgTypeAttributes, argStatic,
2261 needCommaBetweenValues, newValue)))
2262 return failure();
2263
2264 if (!newValue && needValue) {
2265 parser.emitError(loc: parser.getCurrentLocation(),
2266 message: "new value expected after comma");
2267 return failure();
2268 }
2269
2270 if (!newValue)
2271 break;
2272 }
2273
2274 if (gangOperands.empty())
2275 return parser.emitError(
2276 loc: parser.getCurrentLocation(),
2277 message: "expect at least one of num, dim or static values");
2278
2279 if (failed(Result: parser.parseRBrace()))
2280 return failure();
2281
2282 if (succeeded(Result: parser.parseOptionalLSquare())) {
2283 if (parser.parseAttribute(result&: deviceTypeAttributes.emplace_back()) ||
2284 parser.parseRSquare())
2285 return failure();
2286 } else {
2287 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2288 parser.getContext(), mlir::acc::DeviceType::None));
2289 }
2290
2291 seg.push_back(Elt: gangOperands.size() - crtOperandsSize);
2292
2293 } while (succeeded(Result: parser.parseOptionalComma()));
2294
2295 if (failed(Result: parser.parseRParen()))
2296 return failure();
2297
2298 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2299 gangArgTypeAttributes.end());
2300 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
2301 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
2302
2303 llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
2304 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2305 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
2306
2307 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2308 return success();
2309}
2310
2311void printGangClause(OpAsmPrinter &p, Operation *op,
2312 mlir::OperandRange operands, mlir::TypeRange types,
2313 std::optional<mlir::ArrayAttr> gangArgTypes,
2314 std::optional<mlir::ArrayAttr> deviceTypes,
2315 std::optional<mlir::DenseI32ArrayAttr> segments,
2316 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2317
2318 if (operands.begin() == operands.end() &&
2319 hasOnlyDeviceTypeNone(attrs: gangOnlyDeviceTypes)) {
2320 return;
2321 }
2322
2323 p << "(";
2324
2325 printDeviceTypes(p, deviceTypes: gangOnlyDeviceTypes);
2326
2327 if (hasDeviceTypeValues(arrayAttr: gangOnlyDeviceTypes) &&
2328 hasDeviceTypeValues(arrayAttr: deviceTypes))
2329 p << ", ";
2330
2331 if (hasDeviceTypeValues(arrayAttr: deviceTypes)) {
2332 unsigned opIdx = 0;
2333 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2334 p << "{";
2335 llvm::interleaveComma(
2336 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2337 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2338 (*gangArgTypes)[opIdx]);
2339 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2340 p << LoopOp::getGangNumKeyword();
2341 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2342 p << LoopOp::getGangDimKeyword();
2343 else if (gangArgTypeAttr.getValue() ==
2344 mlir::acc::GangArgType::Static)
2345 p << LoopOp::getGangStaticKeyword();
2346 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2347 ++opIdx;
2348 });
2349 p << "}";
2350 printSingleDeviceType(p, it.value());
2351 });
2352 }
2353 p << ")";
2354}
2355
2356bool hasDuplicateDeviceTypes(
2357 std::optional<mlir::ArrayAttr> segments,
2358 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2359 if (!segments)
2360 return false;
2361 for (auto attr : *segments) {
2362 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2363 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2364 return true;
2365 }
2366 return false;
2367}
2368
2369/// Check for duplicates in the DeviceType array attribute.
2370LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2371 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2372 if (!deviceTypes)
2373 return success();
2374 for (auto attr : deviceTypes) {
2375 auto deviceTypeAttr =
2376 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2377 if (!deviceTypeAttr)
2378 return failure();
2379 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2380 return failure();
2381 }
2382 return success();
2383}
2384
2385LogicalResult acc::LoopOp::verify() {
2386 if (getUpperbound().size() != getStep().size())
2387 return emitError() << "number of upperbounds expected to be the same as "
2388 "number of steps";
2389
2390 if (getUpperbound().size() != getLowerbound().size())
2391 return emitError() << "number of upperbounds expected to be the same as "
2392 "number of lowerbounds";
2393
2394 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2395 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2396 return emitError() << "inclusiveUpperbound size is expected to be the same"
2397 << " as upperbound size";
2398
2399 // Check collapse
2400 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2401 return emitOpError() << "collapse device_type attr must be define when"
2402 << " collapse attr is present";
2403
2404 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2405 getCollapseAttr().getValue().size() !=
2406 getCollapseDeviceTypeAttr().getValue().size())
2407 return emitOpError() << "collapse attribute count must match collapse"
2408 << " device_type count";
2409 if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr())))
2410 return emitOpError()
2411 << "duplicate device_type found in collapseDeviceType attribute";
2412
2413 // Check gang
2414 if (!getGangOperands().empty()) {
2415 if (!getGangOperandsArgType())
2416 return emitOpError() << "gangOperandsArgType attribute must be defined"
2417 << " when gang operands are present";
2418
2419 if (getGangOperands().size() !=
2420 getGangOperandsArgTypeAttr().getValue().size())
2421 return emitOpError() << "gangOperandsArgType attribute count must match"
2422 << " gangOperands count";
2423 }
2424 if (getGangAttr() && failed(checkDeviceTypes(getGangAttr())))
2425 return emitOpError() << "duplicate device_type found in gang attribute";
2426
2427 if (failed(verifyDeviceTypeAndSegmentCountMatch(
2428 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
2429 getGangOperandsDeviceTypeAttr(), "gang")))
2430 return failure();
2431
2432 // Check worker
2433 if (failed(checkDeviceTypes(getWorkerAttr())))
2434 return emitOpError() << "duplicate device_type found in worker attribute";
2435 if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr())))
2436 return emitOpError() << "duplicate device_type found in "
2437 "workerNumOperandsDeviceType attribute";
2438 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
2439 getWorkerNumOperandsDeviceTypeAttr(),
2440 "worker")))
2441 return failure();
2442
2443 // Check vector
2444 if (failed(checkDeviceTypes(getVectorAttr())))
2445 return emitOpError() << "duplicate device_type found in vector attribute";
2446 if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr())))
2447 return emitOpError() << "duplicate device_type found in "
2448 "vectorOperandsDeviceType attribute";
2449 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
2450 getVectorOperandsDeviceTypeAttr(),
2451 "vector")))
2452 return failure();
2453
2454 if (failed(verifyDeviceTypeAndSegmentCountMatch(
2455 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
2456 getTileOperandsDeviceTypeAttr(), "tile")))
2457 return failure();
2458
2459 // auto, independent and seq attribute are mutually exclusive.
2460 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2461 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
2462 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
2463 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
2464 return emitError() << "only one of auto, independent, seq can be present "
2465 "at the same time";
2466 }
2467
2468 // Check that at least one of auto, independent, or seq is present
2469 // for the device-independent default clauses.
2470 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
2471 return attr.getValue() == mlir::acc::DeviceType::None;
2472 };
2473 bool hasDefaultSeq =
2474 getSeqAttr()
2475 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2476 hasDeviceNone)
2477 : false;
2478 bool hasDefaultIndependent =
2479 getIndependentAttr()
2480 ? llvm::any_of(
2481 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2482 hasDeviceNone)
2483 : false;
2484 bool hasDefaultAuto =
2485 getAuto_Attr()
2486 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2487 hasDeviceNone)
2488 : false;
2489 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
2490 return emitError()
2491 << "at least one of auto, independent, seq must be present";
2492 }
2493
2494 // Gang, worker and vector are incompatible with seq.
2495 if (getSeqAttr()) {
2496 for (auto attr : getSeqAttr()) {
2497 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2498 if (hasVector(deviceTypeAttr.getValue()) ||
2499 getVectorValue(deviceTypeAttr.getValue()) ||
2500 hasWorker(deviceTypeAttr.getValue()) ||
2501 getWorkerValue(deviceTypeAttr.getValue()) ||
2502 hasGang(deviceTypeAttr.getValue()) ||
2503 getGangValue(mlir::acc::GangArgType::Num,
2504 deviceTypeAttr.getValue()) ||
2505 getGangValue(mlir::acc::GangArgType::Dim,
2506 deviceTypeAttr.getValue()) ||
2507 getGangValue(mlir::acc::GangArgType::Static,
2508 deviceTypeAttr.getValue()))
2509 return emitError() << "gang, worker or vector cannot appear with seq";
2510 }
2511 }
2512
2513 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2514 *this, getPrivatizationRecipes(), getPrivateOperands(), "private",
2515 "privatizations", false)))
2516 return failure();
2517
2518 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2519 *this, getReductionRecipes(), getReductionOperands(), "reduction",
2520 "reductions", false)))
2521 return failure();
2522
2523 if (getCombined().has_value() &&
2524 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2525 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2526 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2527 return emitError("unexpected combined constructs attribute");
2528 }
2529
2530 // Check non-empty body().
2531 if (getRegion().empty())
2532 return emitError("expected non-empty body.");
2533
2534 // When it is container-like - it is expected to hold a loop-like operation.
2535 if (isContainerLike()) {
2536 // Obtain the maximum collapse count - we use this to check that there
2537 // are enough loops contained.
2538 uint64_t collapseCount = getCollapseValue().value_or(1);
2539 if (getCollapseAttr()) {
2540 for (auto collapseEntry : getCollapseAttr()) {
2541 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2542 if (intAttr.getValue().getZExtValue() > collapseCount)
2543 collapseCount = intAttr.getValue().getZExtValue();
2544 }
2545 }
2546
2547 // We want to check that we find enough loop-like operations inside.
2548 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
2549 // level.
2550 mlir::Operation *expectedParent = this->getOperation();
2551 bool foundSibling = false;
2552 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
2553 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2554 // This effectively checks that we are not looking at a sibling loop.
2555 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2556 expectedParent) {
2557 foundSibling = true;
2558 return mlir::WalkResult::interrupt();
2559 }
2560
2561 collapseCount--;
2562 expectedParent = op;
2563 }
2564 // We found enough contained loops.
2565 if (collapseCount == 0)
2566 return mlir::WalkResult::interrupt();
2567 return mlir::WalkResult::advance();
2568 });
2569
2570 if (foundSibling)
2571 return emitError("found sibling loops inside container-like acc.loop");
2572 if (collapseCount != 0)
2573 return emitError("failed to find enough loop-like operations inside "
2574 "container-like acc.loop");
2575 }
2576
2577 return success();
2578}
2579
2580unsigned LoopOp::getNumDataOperands() {
2581 return getReductionOperands().size() + getPrivateOperands().size();
2582}
2583
2584Value LoopOp::getDataOperand(unsigned i) {
2585 unsigned numOptional =
2586 getLowerbound().size() + getUpperbound().size() + getStep().size();
2587 numOptional += getGangOperands().size();
2588 numOptional += getVectorOperands().size();
2589 numOptional += getWorkerNumOperands().size();
2590 numOptional += getTileOperands().size();
2591 numOptional += getCacheOperands().size();
2592 return getOperand(numOptional + i);
2593}
2594
2595bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
2596
2597bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2598 return hasDeviceType(getAuto_(), deviceType);
2599}
2600
2601bool LoopOp::hasIndependent() {
2602 return hasIndependent(mlir::acc::DeviceType::None);
2603}
2604
2605bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2606 return hasDeviceType(getIndependent(), deviceType);
2607}
2608
2609bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
2610
2611bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2612 return hasDeviceType(getSeq(), deviceType);
2613}
2614
2615mlir::Value LoopOp::getVectorValue() {
2616 return getVectorValue(mlir::acc::DeviceType::None);
2617}
2618
2619mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2620 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
2621 getVectorOperands(), deviceType);
2622}
2623
2624bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
2625
2626bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2627 return hasDeviceType(getVector(), deviceType);
2628}
2629
2630mlir::Value LoopOp::getWorkerValue() {
2631 return getWorkerValue(mlir::acc::DeviceType::None);
2632}
2633
2634mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2635 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
2636 getWorkerNumOperands(), deviceType);
2637}
2638
2639bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
2640
2641bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2642 return hasDeviceType(getWorker(), deviceType);
2643}
2644
2645mlir::Operation::operand_range LoopOp::getTileValues() {
2646 return getTileValues(mlir::acc::DeviceType::None);
2647}
2648
2649mlir::Operation::operand_range
2650LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2651 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
2652 getTileOperandsSegments(), deviceType);
2653}
2654
2655std::optional<int64_t> LoopOp::getCollapseValue() {
2656 return getCollapseValue(mlir::acc::DeviceType::None);
2657}
2658
2659std::optional<int64_t>
2660LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2661 if (!getCollapseAttr())
2662 return std::nullopt;
2663 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2664 auto intAttr =
2665 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2666 return intAttr.getValue().getZExtValue();
2667 }
2668 return std::nullopt;
2669}
2670
2671mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2672 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
2673}
2674
2675mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2676 mlir::acc::DeviceType deviceType) {
2677 if (getGangOperands().empty())
2678 return {};
2679 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
2680 int32_t nbOperandsBefore = 0;
2681 for (unsigned i = 0; i < *pos; ++i)
2682 nbOperandsBefore += (*getGangOperandsSegments())[i];
2683 mlir::Operation::operand_range values =
2684 getGangOperands()
2685 .drop_front(nbOperandsBefore)
2686 .take_front((*getGangOperandsSegments())[*pos]);
2687
2688 int32_t argTypeIdx = nbOperandsBefore;
2689 for (auto value : values) {
2690 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2691 (*getGangOperandsArgType())[argTypeIdx]);
2692 if (gangArgTypeAttr.getValue() == gangArgType)
2693 return value;
2694 ++argTypeIdx;
2695 }
2696 }
2697 return {};
2698}
2699
2700bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
2701
2702bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2703 return hasDeviceType(getGang(), deviceType);
2704}
2705
2706llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2707 return {&getRegion()};
2708}
2709
2710/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2711/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2712/// `(` ssa-id-and-type-list `)`
2713/// region
2714ParseResult
2715parseLoopControl(OpAsmParser &parser, Region &region,
2716 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound,
2717 SmallVectorImpl<Type> &lowerboundType,
2718 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound,
2719 SmallVectorImpl<Type> &upperboundType,
2720 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step,
2721 SmallVectorImpl<Type> &stepType) {
2722
2723 SmallVector<OpAsmParser::Argument> inductionVars;
2724 if (succeeded(
2725 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
2726 if (parser.parseLParen() ||
2727 parser.parseArgumentList(result&: inductionVars, delimiter: OpAsmParser::Delimiter::None,
2728 /*allowType=*/true) ||
2729 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2730 parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(),
2731 delimiter: OpAsmParser::Delimiter::None) ||
2732 parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() ||
2733 parser.parseKeyword(keyword: "to") || parser.parseLParen() ||
2734 parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(),
2735 delimiter: OpAsmParser::Delimiter::None) ||
2736 parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() ||
2737 parser.parseKeyword(keyword: "step") || parser.parseLParen() ||
2738 parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(),
2739 delimiter: OpAsmParser::Delimiter::None) ||
2740 parser.parseColonTypeList(result&: stepType) || parser.parseRParen())
2741 return failure();
2742 }
2743 return parser.parseRegion(region, arguments: inductionVars);
2744}
2745
2746void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
2747 ValueRange lowerbound, TypeRange lowerboundType,
2748 ValueRange upperbound, TypeRange upperboundType,
2749 ValueRange steps, TypeRange stepType) {
2750 ValueRange regionArgs = region.front().getArguments();
2751 if (!regionArgs.empty()) {
2752 p << acc::LoopOp::getControlKeyword() << "(";
2753 llvm::interleaveComma(c: regionArgs, os&: p,
2754 each_fn: [&p](Value v) { p << v << " : " << v.getType(); });
2755 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2756 << upperbound << " : " << upperboundType << ") " << " step (" << steps
2757 << " : " << stepType << ") ";
2758 }
2759 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
2760}
2761
2762void acc::LoopOp::addSeq(MLIRContext *context,
2763 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2764 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2765 effectiveDeviceTypes));
2766}
2767
2768void acc::LoopOp::addIndependent(
2769 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2770 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2771 context, getIndependentAttr(), effectiveDeviceTypes));
2772}
2773
2774void acc::LoopOp::addAuto(MLIRContext *context,
2775 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2776 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2777 effectiveDeviceTypes));
2778}
2779
2780void acc::LoopOp::setCollapseForDeviceTypes(
2781 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2782 llvm::APInt value) {
2783 llvm::SmallVector<mlir::Attribute> newValues;
2784 llvm::SmallVector<mlir::Attribute> newDeviceTypes;
2785
2786 assert((getCollapseAttr() == nullptr) ==
2787 (getCollapseDeviceTypeAttr() == nullptr));
2788 assert(value.getBitWidth() == 64);
2789
2790 if (getCollapseAttr()) {
2791 for (const auto &existing :
2792 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2793 newValues.push_back(std::get<0>(existing));
2794 newDeviceTypes.push_back(std::get<1>(existing));
2795 }
2796 }
2797
2798 if (effectiveDeviceTypes.empty()) {
2799 // If the effective device-types list is empty, this is before there are any
2800 // being applied by device_type, so this should be added as a 'none'.
2801 newValues.push_back(
2802 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2803 newDeviceTypes.push_back(
2804 acc::DeviceTypeAttr::get(context, DeviceType::None));
2805 } else {
2806 for (DeviceType DT : effectiveDeviceTypes) {
2807 newValues.push_back(
2808 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
2809 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, DT));
2810 }
2811 }
2812
2813 setCollapseAttr(ArrayAttr::get(context, newValues));
2814 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
2815}
2816
2817void acc::LoopOp::setTileForDeviceTypes(
2818 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2819 ValueRange values) {
2820 llvm::SmallVector<int32_t> segments;
2821 if (getTileOperandsSegments())
2822 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2823
2824 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2825 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2826 getTileOperandsMutable(), segments));
2827
2828 setTileOperandsSegments(segments);
2829}
2830
2831void acc::LoopOp::addVectorOperand(
2832 MLIRContext *context, mlir::Value newValue,
2833 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2834 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2835 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2836 newValue, getVectorOperandsMutable()));
2837}
2838
2839void acc::LoopOp::addEmptyVector(
2840 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2841 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
2842 effectiveDeviceTypes));
2843}
2844
2845void acc::LoopOp::addWorkerNumOperand(
2846 MLIRContext *context, mlir::Value newValue,
2847 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2848 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2849 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2850 newValue, getWorkerNumOperandsMutable()));
2851}
2852
2853void acc::LoopOp::addEmptyWorker(
2854 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2855 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
2856 effectiveDeviceTypes));
2857}
2858
2859void acc::LoopOp::addEmptyGang(
2860 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2861 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
2862 effectiveDeviceTypes));
2863}
2864
2865bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
2866 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
2867 return attr.getValue() == dt;
2868 };
2869 auto testFromArr = [=](ArrayAttr arr) -> bool {
2870 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
2871 };
2872
2873 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
2874 return true;
2875 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
2876 return true;
2877 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
2878 return true;
2879
2880 return false;
2881}
2882
2883bool acc::LoopOp::hasDefaultGangWorkerVector() {
2884 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
2885 hasGang() || getGangValue(GangArgType::Num) ||
2886 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
2887}
2888
2889void acc::LoopOp::addGangOperands(
2890 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2891 llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
2892 llvm::SmallVector<int32_t> segments;
2893 if (std::optional<ArrayRef<int32_t>> existingSegments =
2894 getGangOperandsSegments())
2895 llvm::copy(*existingSegments, std::back_inserter(segments));
2896
2897 unsigned beforeCount = segments.size();
2898
2899 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2900 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2901 getGangOperandsMutable(), segments));
2902
2903 setGangOperandsSegments(segments);
2904
2905 // This is a bit of extra work to make sure we update the 'types' correctly by
2906 // adding to the types collection the correct number of times. We could
2907 // potentially add something similar to the
2908 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
2909 // excessive for a one-off case.
2910 unsigned numAdded = segments.size() - beforeCount;
2911
2912 if (numAdded > 0) {
2913 llvm::SmallVector<mlir::Attribute> gangTypes;
2914 if (getGangOperandsArgTypeAttr())
2915 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
2916
2917 for (auto i : llvm::index_range(0u, numAdded)) {
2918 llvm::transform(argTypes, std::back_inserter(gangTypes),
2919 [=](mlir::acc::GangArgType gangTy) {
2920 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
2921 });
2922 (void)i;
2923 }
2924
2925 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
2926 }
2927}
2928
2929//===----------------------------------------------------------------------===//
2930// DataOp
2931//===----------------------------------------------------------------------===//
2932
2933LogicalResult acc::DataOp::verify() {
2934 // 2.6.5. Data Construct restriction
2935 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
2936 // attach, or default clause must appear on a data construct.
2937 if (getOperands().empty() && !getDefaultAttr())
2938 return emitError("at least one operand or the default attribute "
2939 "must appear on the data operation");
2940
2941 for (mlir::Value operand : getDataClauseOperands())
2942 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2943 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2944 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2945 operand.getDefiningOp()))
2946 return emitError("expect data entry/exit operation or acc.getdeviceptr "
2947 "as defining op");
2948
2949 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this)))
2950 return failure();
2951
2952 return success();
2953}
2954
2955unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
2956
2957Value DataOp::getDataOperand(unsigned i) {
2958 unsigned numOptional = getIfCond() ? 1 : 0;
2959 numOptional += getAsyncOperands().size() ? 1 : 0;
2960 numOptional += getWaitOperands().size();
2961 return getOperand(numOptional + i);
2962}
2963
2964bool acc::DataOp::hasAsyncOnly() {
2965 return hasAsyncOnly(mlir::acc::DeviceType::None);
2966}
2967
2968bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2969 return hasDeviceType(getAsyncOnly(), deviceType);
2970}
2971
2972mlir::Value DataOp::getAsyncValue() {
2973 return getAsyncValue(mlir::acc::DeviceType::None);
2974}
2975
2976mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2977 return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(),
2978 getAsyncOperands(), deviceType);
2979}
2980
2981bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
2982
2983bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2984 return hasDeviceType(getWaitOnly(), deviceType);
2985}
2986
2987mlir::Operation::operand_range DataOp::getWaitValues() {
2988 return getWaitValues(mlir::acc::DeviceType::None);
2989}
2990
2991mlir::Operation::operand_range
2992DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2993 return getWaitValuesWithoutDevnum(
2994 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2995 getHasWaitDevnum(), deviceType);
2996}
2997
2998mlir::Value DataOp::getWaitDevnum() {
2999 return getWaitDevnum(mlir::acc::DeviceType::None);
3000}
3001
3002mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3003 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3004 getWaitOperandsSegments(), getHasWaitDevnum(),
3005 deviceType);
3006}
3007
3008void acc::DataOp::addAsyncOnly(
3009 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3010 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3011 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3012}
3013
3014void acc::DataOp::addAsyncOperand(
3015 MLIRContext *context, mlir::Value newValue,
3016 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3017 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3018 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3019 getAsyncOperandsMutable()));
3020}
3021
3022void acc::DataOp::addWaitOnly(MLIRContext *context,
3023 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3024 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3025 effectiveDeviceTypes));
3026}
3027
3028void acc::DataOp::addWaitOperands(
3029 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3030 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3031
3032 llvm::SmallVector<int32_t> segments;
3033 if (getWaitOperandsSegments())
3034 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3035
3036 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3037 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3038 getWaitOperandsMutable(), segments));
3039 setWaitOperandsSegments(segments);
3040
3041 llvm::SmallVector<mlir::Attribute> hasDevnums;
3042 if (getHasWaitDevnumAttr())
3043 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3044 hasDevnums.insert(
3045 hasDevnums.end(),
3046 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3047 mlir::BoolAttr::get(context, hasDevnum));
3048 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3049}
3050
3051//===----------------------------------------------------------------------===//
3052// ExitDataOp
3053//===----------------------------------------------------------------------===//
3054
3055LogicalResult acc::ExitDataOp::verify() {
3056 // 2.6.6. Data Exit Directive restriction
3057 // At least one copyout, delete, or detach clause must appear on an exit data
3058 // directive.
3059 if (getDataClauseOperands().empty())
3060 return emitError("at least one operand must be present in dataOperands on "
3061 "the exit data operation");
3062
3063 // The async attribute represent the async clause without value. Therefore the
3064 // attribute and operand cannot appear at the same time.
3065 if (getAsyncOperand() && getAsync())
3066 return emitError("async attribute cannot appear with asyncOperand");
3067
3068 // The wait attribute represent the wait clause without values. Therefore the
3069 // attribute and operands cannot appear at the same time.
3070 if (!getWaitOperands().empty() && getWait())
3071 return emitError("wait attribute cannot appear with waitOperands");
3072
3073 if (getWaitDevnum() && getWaitOperands().empty())
3074 return emitError("wait_devnum cannot appear without waitOperands");
3075
3076 return success();
3077}
3078
3079unsigned ExitDataOp::getNumDataOperands() {
3080 return getDataClauseOperands().size();
3081}
3082
3083Value ExitDataOp::getDataOperand(unsigned i) {
3084 unsigned numOptional = getIfCond() ? 1 : 0;
3085 numOptional += getAsyncOperand() ? 1 : 0;
3086 numOptional += getWaitDevnum() ? 1 : 0;
3087 return getOperand(getWaitOperands().size() + numOptional + i);
3088}
3089
3090void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3091 MLIRContext *context) {
3092 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3093}
3094
3095//===----------------------------------------------------------------------===//
3096// EnterDataOp
3097//===----------------------------------------------------------------------===//
3098
3099LogicalResult acc::EnterDataOp::verify() {
3100 // 2.6.6. Data Enter Directive restriction
3101 // At least one copyin, create, or attach clause must appear on an enter data
3102 // directive.
3103 if (getDataClauseOperands().empty())
3104 return emitError("at least one operand must be present in dataOperands on "
3105 "the enter data operation");
3106
3107 // The async attribute represent the async clause without value. Therefore the
3108 // attribute and operand cannot appear at the same time.
3109 if (getAsyncOperand() && getAsync())
3110 return emitError("async attribute cannot appear with asyncOperand");
3111
3112 // The wait attribute represent the wait clause without values. Therefore the
3113 // attribute and operands cannot appear at the same time.
3114 if (!getWaitOperands().empty() && getWait())
3115 return emitError("wait attribute cannot appear with waitOperands");
3116
3117 if (getWaitDevnum() && getWaitOperands().empty())
3118 return emitError("wait_devnum cannot appear without waitOperands");
3119
3120 for (mlir::Value operand : getDataClauseOperands())
3121 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3122 operand.getDefiningOp()))
3123 return emitError("expect data entry operation as defining op");
3124
3125 return success();
3126}
3127
3128unsigned EnterDataOp::getNumDataOperands() {
3129 return getDataClauseOperands().size();
3130}
3131
3132Value EnterDataOp::getDataOperand(unsigned i) {
3133 unsigned numOptional = getIfCond() ? 1 : 0;
3134 numOptional += getAsyncOperand() ? 1 : 0;
3135 numOptional += getWaitDevnum() ? 1 : 0;
3136 return getOperand(getWaitOperands().size() + numOptional + i);
3137}
3138
3139void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3140 MLIRContext *context) {
3141 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3142}
3143
3144//===----------------------------------------------------------------------===//
3145// AtomicReadOp
3146//===----------------------------------------------------------------------===//
3147
3148LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3149
3150//===----------------------------------------------------------------------===//
3151// AtomicWriteOp
3152//===----------------------------------------------------------------------===//
3153
3154LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3155
3156//===----------------------------------------------------------------------===//
3157// AtomicUpdateOp
3158//===----------------------------------------------------------------------===//
3159
3160LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3161 PatternRewriter &rewriter) {
3162 if (op.isNoOp()) {
3163 rewriter.eraseOp(op);
3164 return success();
3165 }
3166
3167 if (Value writeVal = op.getWriteOpVal()) {
3168 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal);
3169 return success();
3170 }
3171
3172 return failure();
3173}
3174
3175LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3176
3177LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3178
3179//===----------------------------------------------------------------------===//
3180// AtomicCaptureOp
3181//===----------------------------------------------------------------------===//
3182
3183AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3184 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3185 return op;
3186 return dyn_cast<AtomicReadOp>(getSecondOp());
3187}
3188
3189AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3190 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3191 return op;
3192 return dyn_cast<AtomicWriteOp>(getSecondOp());
3193}
3194
3195AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3196 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3197 return op;
3198 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3199}
3200
3201LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3202
3203//===----------------------------------------------------------------------===//
3204// DeclareEnterOp
3205//===----------------------------------------------------------------------===//
3206
3207template <typename Op>
3208static LogicalResult
3209checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
3210 bool requireAtLeastOneOperand = true) {
3211 if (operands.empty() && requireAtLeastOneOperand)
3212 return emitError(
3213 op->getLoc(),
3214 "at least one operand must appear on the declare operation");
3215
3216 for (mlir::Value operand : operands) {
3217 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3218 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3219 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3220 operand.getDefiningOp()))
3221 return op.emitError(
3222 "expect valid declare data entry operation or acc.getdeviceptr "
3223 "as defining op");
3224
3225 mlir::Value var{getVar(accDataClauseOp: operand.getDefiningOp())};
3226 assert(var && "declare operands can only be data entry operations which "
3227 "must have var");
3228 (void)var;
3229 std::optional<mlir::acc::DataClause> dataClauseOptional{
3230 getDataClause(operand.getDefiningOp())};
3231 assert(dataClauseOptional.has_value() &&
3232 "declare operands can only be data entry operations which must have "
3233 "dataClause");
3234 (void)dataClauseOptional;
3235 }
3236
3237 return success();
3238}
3239
3240LogicalResult acc::DeclareEnterOp::verify() {
3241 return checkDeclareOperands(*this, this->getDataClauseOperands());
3242}
3243
3244//===----------------------------------------------------------------------===//
3245// DeclareExitOp
3246//===----------------------------------------------------------------------===//
3247
3248LogicalResult acc::DeclareExitOp::verify() {
3249 if (getToken())
3250 return checkDeclareOperands(*this, this->getDataClauseOperands(),
3251 /*requireAtLeastOneOperand=*/false);
3252 return checkDeclareOperands(*this, this->getDataClauseOperands());
3253}
3254
3255//===----------------------------------------------------------------------===//
3256// DeclareOp
3257//===----------------------------------------------------------------------===//
3258
3259LogicalResult acc::DeclareOp::verify() {
3260 return checkDeclareOperands(*this, this->getDataClauseOperands());
3261}
3262
3263//===----------------------------------------------------------------------===//
3264// RoutineOp
3265//===----------------------------------------------------------------------===//
3266
3267static unsigned getParallelismForDeviceType(acc::RoutineOp op,
3268 acc::DeviceType dtype) {
3269 unsigned parallelism = 0;
3270 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3271 parallelism += op.hasWorker(dtype) ? 1 : 0;
3272 parallelism += op.hasVector(dtype) ? 1 : 0;
3273 parallelism += op.hasSeq(dtype) ? 1 : 0;
3274 return parallelism;
3275}
3276
3277LogicalResult acc::RoutineOp::verify() {
3278 unsigned baseParallelism =
3279 getParallelismForDeviceType(*this, acc::DeviceType::None);
3280
3281 if (baseParallelism > 1)
3282 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3283 "be present at the same time";
3284
3285 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3286 ++dtypeInt) {
3287 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
3288 if (dtype == acc::DeviceType::None)
3289 continue;
3290 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
3291
3292 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3293 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3294 "be present at the same time";
3295 }
3296
3297 return success();
3298}
3299
3300static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
3301 mlir::ArrayAttr &deviceTypes) {
3302 llvm::SmallVector<mlir::Attribute> bindNameAttrs;
3303 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
3304
3305 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3306 if (parser.parseAttribute(result&: bindNameAttrs.emplace_back()))
3307 return failure();
3308 if (failed(Result: parser.parseOptionalLSquare())) {
3309 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3310 parser.getContext(), mlir::acc::DeviceType::None));
3311 } else {
3312 if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) ||
3313 parser.parseRSquare())
3314 return failure();
3315 }
3316 return success();
3317 })))
3318 return failure();
3319
3320 bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
3321 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3322
3323 return success();
3324}
3325
3326static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
3327 std::optional<mlir::ArrayAttr> bindName,
3328 std::optional<mlir::ArrayAttr> deviceTypes) {
3329 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3330 [&](const auto &pair) {
3331 p << std::get<0>(pair);
3332 printSingleDeviceType(p, std::get<1>(pair));
3333 });
3334}
3335
3336static ParseResult parseRoutineGangClause(OpAsmParser &parser,
3337 mlir::ArrayAttr &gang,
3338 mlir::ArrayAttr &gangDim,
3339 mlir::ArrayAttr &gangDimDeviceTypes) {
3340
3341 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
3342 gangDimDeviceTypeAttrs;
3343 bool needCommaBeforeOperands = false;
3344
3345 // Gang keyword only
3346 if (failed(Result: parser.parseOptionalLParen())) {
3347 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3348 parser.getContext(), mlir::acc::DeviceType::None));
3349 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3350 return success();
3351 }
3352
3353 // Parse keyword only attributes
3354 if (succeeded(Result: parser.parseOptionalLSquare())) {
3355 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3356 if (parser.parseAttribute(result&: gangAttrs.emplace_back()))
3357 return failure();
3358 return success();
3359 })))
3360 return failure();
3361 if (parser.parseRSquare())
3362 return failure();
3363 needCommaBeforeOperands = true;
3364 }
3365
3366 if (needCommaBeforeOperands && failed(Result: parser.parseComma()))
3367 return failure();
3368
3369 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3370 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3371 parser.parseColon() ||
3372 parser.parseAttribute(gangDimAttrs.emplace_back()))
3373 return failure();
3374 if (succeeded(Result: parser.parseOptionalLSquare())) {
3375 if (parser.parseAttribute(result&: gangDimDeviceTypeAttrs.emplace_back()) ||
3376 parser.parseRSquare())
3377 return failure();
3378 } else {
3379 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3380 parser.getContext(), mlir::acc::DeviceType::None));
3381 }
3382 return success();
3383 })))
3384 return failure();
3385
3386 if (failed(Result: parser.parseRParen()))
3387 return failure();
3388
3389 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
3390 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
3391 gangDimDeviceTypes =
3392 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
3393
3394 return success();
3395}
3396
3397void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
3398 std::optional<mlir::ArrayAttr> gang,
3399 std::optional<mlir::ArrayAttr> gangDim,
3400 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3401
3402 if (!hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes) && hasDeviceTypeValues(arrayAttr: gang) &&
3403 gang->size() == 1) {
3404 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3405 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3406 return;
3407 }
3408
3409 p << "(";
3410
3411 printDeviceTypes(p, deviceTypes: gang);
3412
3413 if (hasDeviceTypeValues(arrayAttr: gang) && hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes))
3414 p << ", ";
3415
3416 if (hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes))
3417 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3418 [&](const auto &pair) {
3419 p << acc::RoutineOp::getGangDimKeyword() << ": ";
3420 p << std::get<0>(pair);
3421 printSingleDeviceType(p, std::get<1>(pair));
3422 });
3423
3424 p << ")";
3425}
3426
3427static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
3428 mlir::ArrayAttr &deviceTypes) {
3429 llvm::SmallVector<mlir::Attribute> attributes;
3430 // Keyword only
3431 if (failed(Result: parser.parseOptionalLParen())) {
3432 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
3433 parser.getContext(), mlir::acc::DeviceType::None));
3434 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3435 return success();
3436 }
3437
3438 // Parse device type attributes
3439 if (succeeded(Result: parser.parseOptionalLSquare())) {
3440 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3441 if (parser.parseAttribute(result&: attributes.emplace_back()))
3442 return failure();
3443 return success();
3444 })))
3445 return failure();
3446 if (parser.parseRSquare() || parser.parseRParen())
3447 return failure();
3448 }
3449 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
3450 return success();
3451}
3452
3453static void
3454printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
3455 std::optional<mlir::ArrayAttr> deviceTypes) {
3456
3457 if (hasDeviceTypeValues(arrayAttr: deviceTypes) && deviceTypes->size() == 1) {
3458 auto deviceTypeAttr =
3459 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3460 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3461 return;
3462 }
3463
3464 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
3465 return;
3466
3467 p << "([";
3468 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
3469 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3470 p << dTypeAttr;
3471 });
3472 p << "])";
3473}
3474
3475bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3476
3477bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3478 return hasDeviceType(getWorker(), deviceType);
3479}
3480
3481bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3482
3483bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3484 return hasDeviceType(getVector(), deviceType);
3485}
3486
3487bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3488
3489bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3490 return hasDeviceType(getSeq(), deviceType);
3491}
3492
3493std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3494 return getBindNameValue(mlir::acc::DeviceType::None);
3495}
3496
3497std::optional<llvm::StringRef>
3498RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3499 if (!hasDeviceTypeValues(getBindNameDeviceType()))
3500 return std::nullopt;
3501 if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
3502 auto attr = (*getBindName())[*pos];
3503 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3504 return stringAttr.getValue();
3505 }
3506 return std::nullopt;
3507}
3508
3509bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3510
3511bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3512 return hasDeviceType(getGang(), deviceType);
3513}
3514
3515std::optional<int64_t> RoutineOp::getGangDimValue() {
3516 return getGangDimValue(mlir::acc::DeviceType::None);
3517}
3518
3519std::optional<int64_t>
3520RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3521 if (!hasDeviceTypeValues(getGangDimDeviceType()))
3522 return std::nullopt;
3523 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
3524 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3525 return intAttr.getInt();
3526 }
3527 return std::nullopt;
3528}
3529
3530//===----------------------------------------------------------------------===//
3531// InitOp
3532//===----------------------------------------------------------------------===//
3533
3534LogicalResult acc::InitOp::verify() {
3535 Operation *currOp = *this;
3536 while ((currOp = currOp->getParentOp()))
3537 if (isComputeOperation(currOp))
3538 return emitOpError("cannot be nested in a compute operation");
3539 return success();
3540}
3541
3542void acc::InitOp::addDeviceType(MLIRContext *context,
3543 mlir::acc::DeviceType deviceType) {
3544 llvm::SmallVector<mlir::Attribute> deviceTypes;
3545 if (getDeviceTypesAttr())
3546 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3547
3548 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3549 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3550}
3551
3552//===----------------------------------------------------------------------===//
3553// ShutdownOp
3554//===----------------------------------------------------------------------===//
3555
3556LogicalResult acc::ShutdownOp::verify() {
3557 Operation *currOp = *this;
3558 while ((currOp = currOp->getParentOp()))
3559 if (isComputeOperation(currOp))
3560 return emitOpError("cannot be nested in a compute operation");
3561 return success();
3562}
3563
3564void acc::ShutdownOp::addDeviceType(MLIRContext *context,
3565 mlir::acc::DeviceType deviceType) {
3566 llvm::SmallVector<mlir::Attribute> deviceTypes;
3567 if (getDeviceTypesAttr())
3568 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3569
3570 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
3571 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
3572}
3573
3574//===----------------------------------------------------------------------===//
3575// SetOp
3576//===----------------------------------------------------------------------===//
3577
3578LogicalResult acc::SetOp::verify() {
3579 Operation *currOp = *this;
3580 while ((currOp = currOp->getParentOp()))
3581 if (isComputeOperation(currOp))
3582 return emitOpError("cannot be nested in a compute operation");
3583 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3584 return emitOpError("at least one default_async, device_num, or device_type "
3585 "operand must appear");
3586 return success();
3587}
3588
3589//===----------------------------------------------------------------------===//
3590// UpdateOp
3591//===----------------------------------------------------------------------===//
3592
3593LogicalResult acc::UpdateOp::verify() {
3594 // At least one of host or device should have a value.
3595 if (getDataClauseOperands().empty())
3596 return emitError("at least one value must be present in dataOperands");
3597
3598 if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(),
3599 getAsyncOperandsDeviceTypeAttr(),
3600 "async")))
3601 return failure();
3602
3603 if (failed(verifyDeviceTypeAndSegmentCountMatch(
3604 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3605 getWaitOperandsDeviceTypeAttr(), "wait")))
3606 return failure();
3607
3608 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this)))
3609 return failure();
3610
3611 for (mlir::Value operand : getDataClauseOperands())
3612 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3613 operand.getDefiningOp()))
3614 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3615 "as defining op");
3616
3617 return success();
3618}
3619
3620unsigned UpdateOp::getNumDataOperands() {
3621 return getDataClauseOperands().size();
3622}
3623
3624Value UpdateOp::getDataOperand(unsigned i) {
3625 unsigned numOptional = getAsyncOperands().size();
3626 numOptional += getIfCond() ? 1 : 0;
3627 return getOperand(getWaitOperands().size() + numOptional + i);
3628}
3629
3630void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
3631 MLIRContext *context) {
3632 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
3633}
3634
3635bool UpdateOp::hasAsyncOnly() {
3636 return hasAsyncOnly(mlir::acc::DeviceType::None);
3637}
3638
3639bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3640 return hasDeviceType(getAsyncOnly(), deviceType);
3641}
3642
3643mlir::Value UpdateOp::getAsyncValue() {
3644 return getAsyncValue(mlir::acc::DeviceType::None);
3645}
3646
3647mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3648 if (!hasDeviceTypeValues(getAsyncOperandsDeviceType()))
3649 return {};
3650
3651 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
3652 return getAsyncOperands()[*pos];
3653
3654 return {};
3655}
3656
3657bool UpdateOp::hasWaitOnly() {
3658 return hasWaitOnly(mlir::acc::DeviceType::None);
3659}
3660
3661bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3662 return hasDeviceType(getWaitOnly(), deviceType);
3663}
3664
3665mlir::Operation::operand_range UpdateOp::getWaitValues() {
3666 return getWaitValues(mlir::acc::DeviceType::None);
3667}
3668
3669mlir::Operation::operand_range
3670UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3671 return getWaitValuesWithoutDevnum(
3672 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3673 getHasWaitDevnum(), deviceType);
3674}
3675
3676mlir::Value UpdateOp::getWaitDevnum() {
3677 return getWaitDevnum(mlir::acc::DeviceType::None);
3678}
3679
3680mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3681 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3682 getWaitOperandsSegments(), getHasWaitDevnum(),
3683 deviceType);
3684}
3685
3686//===----------------------------------------------------------------------===//
3687// WaitOp
3688//===----------------------------------------------------------------------===//
3689
3690LogicalResult acc::WaitOp::verify() {
3691 // The async attribute represent the async clause without value. Therefore the
3692 // attribute and operand cannot appear at the same time.
3693 if (getAsyncOperand() && getAsync())
3694 return emitError("async attribute cannot appear with asyncOperand");
3695
3696 if (getWaitDevnum() && getWaitOperands().empty())
3697 return emitError("wait_devnum cannot appear without waitOperands");
3698
3699 return success();
3700}
3701
3702#define GET_OP_CLASSES
3703#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3704
3705#define GET_ATTRDEF_CLASSES
3706#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3707
3708#define GET_TYPEDEF_CLASSES
3709#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3710
3711//===----------------------------------------------------------------------===//
3712// acc dialect utilities
3713//===----------------------------------------------------------------------===//
3714
3715mlir::TypedValue<mlir::acc::PointerLikeType>
3716mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) {
3717 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
3718 mlir::TypedValue<mlir::acc::PointerLikeType>>(
3719 accDataClauseOp)
3720 .Case<ACC_DATA_ENTRY_OPS>(
3721 [&](auto entry) { return entry.getVarPtr(); })
3722 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3723 [&](auto exit) { return exit.getVarPtr(); })
3724 .Default([&](mlir::Operation *) {
3725 return mlir::TypedValue<mlir::acc::PointerLikeType>();
3726 })};
3727 return varPtr;
3728}
3729
3730mlir::Value mlir::acc::getVar(mlir::Operation *accDataClauseOp) {
3731 auto varPtr{
3732 llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3733 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
3734 .Default([&](mlir::Operation *) { return mlir::Value(); })};
3735 return varPtr;
3736}
3737
3738mlir::Type mlir::acc::getVarType(mlir::Operation *accDataClauseOp) {
3739 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
3740 .Case<ACC_DATA_ENTRY_OPS>(
3741 [&](auto entry) { return entry.getVarType(); })
3742 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3743 [&](auto exit) { return exit.getVarType(); })
3744 .Default([&](mlir::Operation *) { return mlir::Type(); })};
3745 return varType;
3746}
3747
3748mlir::TypedValue<mlir::acc::PointerLikeType>
3749mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) {
3750 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
3751 mlir::TypedValue<mlir::acc::PointerLikeType>>(
3752 accDataClauseOp)
3753 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3754 [&](auto dataClause) { return dataClause.getAccPtr(); })
3755 .Default([&](mlir::Operation *) {
3756 return mlir::TypedValue<mlir::acc::PointerLikeType>();
3757 })};
3758 return accPtr;
3759}
3760
3761mlir::Value mlir::acc::getAccVar(mlir::Operation *accDataClauseOp) {
3762 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3763 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3764 [&](auto dataClause) { return dataClause.getAccVar(); })
3765 .Default([&](mlir::Operation *) { return mlir::Value(); })};
3766 return accPtr;
3767}
3768
3769mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) {
3770 auto varPtrPtr{
3771 llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3772 .Case<ACC_DATA_ENTRY_OPS>(
3773 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
3774 .Default([&](mlir::Operation *) { return mlir::Value(); })};
3775 return varPtrPtr;
3776}
3777
3778mlir::SmallVector<mlir::Value>
3779mlir::acc::getBounds(mlir::Operation *accDataClauseOp) {
3780 mlir::SmallVector<mlir::Value> bounds{
3781 llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
3782 accDataClauseOp)
3783 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3784 return mlir::SmallVector<mlir::Value>(
3785 dataClause.getBounds().begin(), dataClause.getBounds().end());
3786 })
3787 .Default([&](mlir::Operation *) {
3788 return mlir::SmallVector<mlir::Value, 0>();
3789 })};
3790 return bounds;
3791}
3792
3793mlir::SmallVector<mlir::Value>
3794mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) {
3795 return llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
3796 accDataClauseOp)
3797 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3798 return mlir::SmallVector<mlir::Value>(
3799 dataClause.getAsyncOperands().begin(),
3800 dataClause.getAsyncOperands().end());
3801 })
3802 .Default([&](mlir::Operation *) {
3803 return mlir::SmallVector<mlir::Value, 0>();
3804 });
3805}
3806
3807mlir::ArrayAttr
3808mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) {
3809 return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
3810 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
3811 return dataClause.getAsyncOperandsDeviceTypeAttr();
3812 })
3813 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3814}
3815
3816mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
3817 return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
3818 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3819 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
3820 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
3821}
3822
3823std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
3824 auto name{
3825 llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp)
3826 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
3827 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
3828 return {};
3829 })};
3830 return name;
3831}
3832
3833std::optional<mlir::acc::DataClause>
3834mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) {
3835 auto dataClause{
3836 llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>(
3837 accDataEntryOp)
3838 .Case<ACC_DATA_ENTRY_OPS>(
3839 [&](auto entry) { return entry.getDataClause(); })
3840 .Default([&](mlir::Operation *) { return std::nullopt; })};
3841 return dataClause;
3842}
3843
3844bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) {
3845 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
3846 .Case<ACC_DATA_ENTRY_OPS>(
3847 [&](auto entry) { return entry.getImplicit(); })
3848 .Default([&](mlir::Operation *) { return false; })};
3849 return implicit;
3850}
3851
3852mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) {
3853 auto dataOperands{
3854 llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp)
3855 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
3856 [&](auto entry) { return entry.getDataClauseOperands(); })
3857 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
3858 return dataOperands;
3859}
3860
3861mlir::MutableOperandRange
3862mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
3863 auto dataOperands{
3864 llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp)
3865 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
3866 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
3867 .Default([&](mlir::Operation *) { return nullptr; })};
3868 return dataOperands;
3869}
3870
3871mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
3872 mlir::Operation *parentOp = region.getParentOp();
3873 while (parentOp) {
3874 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
3875 return parentOp;
3876 }
3877 parentOp = parentOp->getParentOp();
3878 }
3879 return nullptr;
3880}
3881

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp