1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
9#include "mlir/Dialect/OpenACC/OpenACC.h"
10#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
12#include "mlir/Dialect/MemRef/IR/MemRef.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/BuiltinAttributes.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/DialectImplementation.h"
17#include "mlir/IR/Matchers.h"
18#include "mlir/IR/OpImplementation.h"
19#include "mlir/Support/LLVM.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/ADT/SmallSet.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/LogicalResult.h"
24
25using namespace mlir;
26using namespace acc;
27
28#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
33
34namespace {
35
36static bool isScalarLikeType(Type type) {
37 return type.isIntOrIndexOrFloat() || isa<ComplexType>(Val: type);
38}
39
40struct MemRefPointerLikeModel
41 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
42 MemRefType> {
43 Type getElementType(Type pointer) const {
44 return cast<MemRefType>(Val&: pointer).getElementType();
45 }
46 mlir::acc::VariableTypeCategory
47 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
48 Type varType) const {
49 if (auto mappableTy = dyn_cast<MappableType>(Val&: varType)) {
50 return mappableTy.getTypeCategory(var: varPtr);
51 }
52 auto memrefTy = cast<MemRefType>(Val&: pointer);
53 if (!memrefTy.hasRank()) {
54 // This memref is unranked - aka it could have any rank, including a
55 // rank of 0 which could mean scalar. For now, return uncategorized.
56 return mlir::acc::VariableTypeCategory::uncategorized;
57 }
58
59 if (memrefTy.getRank() == 0) {
60 if (isScalarLikeType(type: memrefTy.getElementType())) {
61 return mlir::acc::VariableTypeCategory::scalar;
62 }
63 // Zero-rank non-scalar - need further analysis to determine the type
64 // category. For now, return uncategorized.
65 return mlir::acc::VariableTypeCategory::uncategorized;
66 }
67
68 // It has a rank - must be an array.
69 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
70 return mlir::acc::VariableTypeCategory::array;
71 }
72};
73
74struct LLVMPointerPointerLikeModel
75 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
77 Type getElementType(Type pointer) const { return Type(); }
78};
79
80/// Helper function for any of the times we need to modify an ArrayAttr based on
81/// a device type list. Returns a new ArrayAttr with all of the
82/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
83/// list is empty).
84mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
85 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
86 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
87 llvm::SmallVector<mlir::Attribute> deviceTypes;
88 if (existingDeviceTypes)
89 llvm::copy(Range&: existingDeviceTypes, Out: std::back_inserter(x&: deviceTypes));
90
91 if (newDeviceTypes.empty())
92 deviceTypes.push_back(
93 Elt: acc::DeviceTypeAttr::get(context, value: acc::DeviceType::None));
94
95 for (DeviceType DT : newDeviceTypes)
96 deviceTypes.push_back(Elt: acc::DeviceTypeAttr::get(context, value: DT));
97
98 return mlir::ArrayAttr::get(context, value: deviceTypes);
99}
100
101/// Helper function for any of the times we need to add operands that are
102/// affected by a device type list. Returns a new ArrayAttr with all of the
103/// existingDeviceTypes, plus the effective new ones (or an added none, if the
104/// new list is empty). Additionally, adds the arguments to the argCollection
105/// the correct number of times. This will also update a 'segments' array, even
106/// if it won't be used.
107mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
108 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
109 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
110 mlir::MutableOperandRange argCollection,
111 llvm::SmallVector<int32_t> &segments) {
112 llvm::SmallVector<mlir::Attribute> deviceTypes;
113 if (existingDeviceTypes)
114 llvm::copy(Range&: existingDeviceTypes, Out: std::back_inserter(x&: deviceTypes));
115
116 if (newDeviceTypes.empty()) {
117 argCollection.append(values: arguments);
118 segments.push_back(Elt: arguments.size());
119 deviceTypes.push_back(
120 Elt: acc::DeviceTypeAttr::get(context, value: acc::DeviceType::None));
121 }
122
123 for (DeviceType DT : newDeviceTypes) {
124 argCollection.append(values: arguments);
125 segments.push_back(Elt: arguments.size());
126 deviceTypes.push_back(Elt: acc::DeviceTypeAttr::get(context, value: DT));
127 }
128
129 return mlir::ArrayAttr::get(context, value: deviceTypes);
130}
131
132/// Overload for when the 'segments' aren't needed.
133mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
134 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
135 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
136 mlir::MutableOperandRange argCollection) {
137 llvm::SmallVector<int32_t> segments;
138 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
139 newDeviceTypes, arguments,
140 argCollection, segments);
141}
142} // namespace
143
144//===----------------------------------------------------------------------===//
145// OpenACC operations
146//===----------------------------------------------------------------------===//
147
148void OpenACCDialect::initialize() {
149 addOperations<
150#define GET_OP_LIST
151#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
152 >();
153 addAttributes<
154#define GET_ATTRDEF_LIST
155#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
156 >();
157 addTypes<
158#define GET_TYPEDEF_LIST
159#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
160 >();
161
162 // By attaching interfaces here, we make the OpenACC dialect dependent on
163 // the other dialects. This is probably better than having dialects like LLVM
164 // and memref be dependent on OpenACC.
165 MemRefType::attachInterface<MemRefPointerLikeModel>(context&: *getContext());
166 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
167 context&: *getContext());
168}
169
170//===----------------------------------------------------------------------===//
171// device_type support helpers
172//===----------------------------------------------------------------------===//
173
174static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
175 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
176 return true;
177 return false;
178}
179
180static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
181 mlir::acc::DeviceType deviceType) {
182 if (!hasDeviceTypeValues(arrayAttr))
183 return false;
184
185 for (auto attr : *arrayAttr) {
186 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val&: attr);
187 if (deviceTypeAttr.getValue() == deviceType)
188 return true;
189 }
190
191 return false;
192}
193
194static void printDeviceTypes(mlir::OpAsmPrinter &p,
195 std::optional<mlir::ArrayAttr> deviceTypes) {
196 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
197 return;
198
199 p << "[";
200 llvm::interleaveComma(c: *deviceTypes, os&: p,
201 each_fn: [&](mlir::Attribute attr) { p << attr; });
202 p << "]";
203}
204
205static std::optional<unsigned> findSegment(ArrayAttr segments,
206 mlir::acc::DeviceType deviceType) {
207 unsigned segmentIdx = 0;
208 for (auto attr : segments) {
209 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val&: attr);
210 if (deviceTypeAttr.getValue() == deviceType)
211 return std::make_optional(t&: segmentIdx);
212 ++segmentIdx;
213 }
214 return std::nullopt;
215}
216
217static mlir::Operation::operand_range
218getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
219 mlir::Operation::operand_range range,
220 std::optional<llvm::ArrayRef<int32_t>> segments,
221 mlir::acc::DeviceType deviceType) {
222 if (!arrayAttr)
223 return range.take_front(n: 0);
224 if (auto pos = findSegment(segments: *arrayAttr, deviceType)) {
225 int32_t nbOperandsBefore = 0;
226 for (unsigned i = 0; i < *pos; ++i)
227 nbOperandsBefore += (*segments)[i];
228 return range.drop_front(n: nbOperandsBefore).take_front(n: (*segments)[*pos]);
229 }
230 return range.take_front(n: 0);
231}
232
233static mlir::Value
234getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
235 mlir::Operation::operand_range operands,
236 std::optional<llvm::ArrayRef<int32_t>> segments,
237 std::optional<mlir::ArrayAttr> hasWaitDevnum,
238 mlir::acc::DeviceType deviceType) {
239 if (!hasDeviceTypeValues(arrayAttr: deviceTypeAttr))
240 return {};
241 if (auto pos = findSegment(segments: *deviceTypeAttr, deviceType))
242 if (hasWaitDevnum->getValue()[*pos])
243 return getValuesFromSegments(arrayAttr: deviceTypeAttr, range: operands, segments,
244 deviceType)
245 .front();
246 return {};
247}
248
249static mlir::Operation::operand_range
250getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
251 mlir::Operation::operand_range operands,
252 std::optional<llvm::ArrayRef<int32_t>> segments,
253 std::optional<mlir::ArrayAttr> hasWaitDevnum,
254 mlir::acc::DeviceType deviceType) {
255 auto range =
256 getValuesFromSegments(arrayAttr: deviceTypeAttr, range: operands, segments, deviceType);
257 if (range.empty())
258 return range;
259 if (auto pos = findSegment(segments: *deviceTypeAttr, deviceType)) {
260 if (hasWaitDevnum && *hasWaitDevnum) {
261 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(Val: (*hasWaitDevnum)[*pos]);
262 if (boolAttr.getValue())
263 return range.drop_front(n: 1); // first value is devnum
264 }
265 }
266 return range;
267}
268
269template <typename Op>
270static LogicalResult checkWaitAndAsyncConflict(Op op) {
271 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
272 ++dtypeInt) {
273 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
274
275 // The asyncOnly attribute represent the async clause without value.
276 // Therefore the attribute and operand cannot appear at the same time.
277 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
278 op.hasAsyncOnly(dtype))
279 return op.emitError(
280 "asyncOnly attribute cannot appear with asyncOperand");
281
282 // The wait attribute represent the wait clause without values. Therefore
283 // the attribute and operands cannot appear at the same time.
284 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
285 op.hasWaitOnly(dtype))
286 return op.emitError("wait attribute cannot appear with waitOperands");
287 }
288 return success();
289}
290
291template <typename Op>
292static LogicalResult checkVarAndVarType(Op op) {
293 if (!op.getVar())
294 return op.emitError("must have var operand");
295
296 // A variable must have a type that is either pointer-like or mappable.
297 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
298 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
299 return op.emitError("var must be mappable or pointer-like");
300
301 // When it is a pointer-like type, the varType must capture the target type.
302 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
303 op.getVarType() == op.getVar().getType())
304 return op.emitError("varType must capture the element type of var");
305
306 return success();
307}
308
309template <typename Op>
310static LogicalResult checkVarAndAccVar(Op op) {
311 if (op.getVar().getType() != op.getAccVar().getType())
312 return op.emitError("input and output types must match");
313
314 return success();
315}
316
317template <typename Op>
318static LogicalResult checkNoModifier(Op op) {
319 if (op.getModifiers() != acc::DataClauseModifier::none)
320 return op.emitError("no data clause modifiers are allowed");
321 return success();
322}
323
324template <typename Op>
325static LogicalResult
326checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
327 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
328 return op.emitError(
329 "invalid data clause modifiers: " +
330 acc::stringifyDataClauseModifier(symbol: op.getModifiers() & ~validModifiers));
331
332 return success();
333}
334
335static ParseResult parseVar(mlir::OpAsmParser &parser,
336 OpAsmParser::UnresolvedOperand &var) {
337 // Either `var` or `varPtr` keyword is required.
338 if (failed(Result: parser.parseOptionalKeyword(keyword: "varPtr"))) {
339 if (failed(Result: parser.parseKeyword(keyword: "var")))
340 return failure();
341 }
342 if (failed(Result: parser.parseLParen()))
343 return failure();
344 if (failed(Result: parser.parseOperand(result&: var)))
345 return failure();
346
347 return success();
348}
349
350static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
351 mlir::Value var) {
352 if (mlir::isa<mlir::acc::PointerLikeType>(Val: var.getType()))
353 p << "varPtr(";
354 else
355 p << "var(";
356 p.printOperand(value: var);
357}
358
359static ParseResult parseAccVar(mlir::OpAsmParser &parser,
360 OpAsmParser::UnresolvedOperand &var,
361 mlir::Type &accVarType) {
362 // Either `accVar` or `accPtr` keyword is required.
363 if (failed(Result: parser.parseOptionalKeyword(keyword: "accPtr"))) {
364 if (failed(Result: parser.parseKeyword(keyword: "accVar")))
365 return failure();
366 }
367 if (failed(Result: parser.parseLParen()))
368 return failure();
369 if (failed(Result: parser.parseOperand(result&: var)))
370 return failure();
371 if (failed(Result: parser.parseColon()))
372 return failure();
373 if (failed(Result: parser.parseType(result&: accVarType)))
374 return failure();
375 if (failed(Result: parser.parseRParen()))
376 return failure();
377
378 return success();
379}
380
381static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op,
382 mlir::Value accVar, mlir::Type accVarType) {
383 if (mlir::isa<mlir::acc::PointerLikeType>(Val: accVar.getType()))
384 p << "accPtr(";
385 else
386 p << "accVar(";
387 p.printOperand(value: accVar);
388 p << " : ";
389 p.printType(type: accVarType);
390 p << ")";
391}
392
393static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
394 mlir::Type &varPtrType,
395 mlir::TypeAttr &varTypeAttr) {
396 if (failed(Result: parser.parseType(result&: varPtrType)))
397 return failure();
398 if (failed(Result: parser.parseRParen()))
399 return failure();
400
401 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "varType"))) {
402 if (failed(Result: parser.parseLParen()))
403 return failure();
404 mlir::Type varType;
405 if (failed(Result: parser.parseType(result&: varType)))
406 return failure();
407 varTypeAttr = mlir::TypeAttr::get(type: varType);
408 if (failed(Result: parser.parseRParen()))
409 return failure();
410 } else {
411 // Set `varType` from the element type of the type of `varPtr`.
412 if (mlir::isa<mlir::acc::PointerLikeType>(Val: varPtrType))
413 varTypeAttr = mlir::TypeAttr::get(
414 type: mlir::cast<mlir::acc::PointerLikeType>(Val&: varPtrType).getElementType());
415 else
416 varTypeAttr = mlir::TypeAttr::get(type: varPtrType);
417 }
418
419 return success();
420}
421
422static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op,
423 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
424 p.printType(type: varPtrType);
425 p << ")";
426
427 // Print the `varType` only if it differs from the element type of
428 // `varPtr`'s type.
429 mlir::Type varType = varTypeAttr.getValue();
430 mlir::Type typeToCheckAgainst =
431 mlir::isa<mlir::acc::PointerLikeType>(Val: varPtrType)
432 ? mlir::cast<mlir::acc::PointerLikeType>(Val&: varPtrType).getElementType()
433 : varPtrType;
434 if (typeToCheckAgainst != varType) {
435 p << " varType(";
436 p.printType(type: varType);
437 p << ")";
438 }
439}
440
441//===----------------------------------------------------------------------===//
442// DataBoundsOp
443//===----------------------------------------------------------------------===//
444LogicalResult acc::DataBoundsOp::verify() {
445 auto extent = getExtent();
446 auto upperbound = getUpperbound();
447 if (!extent && !upperbound)
448 return emitError(message: "expected extent or upperbound.");
449 return success();
450}
451
452//===----------------------------------------------------------------------===//
453// PrivateOp
454//===----------------------------------------------------------------------===//
455LogicalResult acc::PrivateOp::verify() {
456 if (getDataClause() != acc::DataClause::acc_private)
457 return emitError(
458 message: "data clause associated with private operation must match its intent");
459 if (failed(Result: checkVarAndVarType(op: *this)))
460 return failure();
461 if (failed(Result: checkNoModifier(op: *this)))
462 return failure();
463 return success();
464}
465
466//===----------------------------------------------------------------------===//
467// FirstprivateOp
468//===----------------------------------------------------------------------===//
469LogicalResult acc::FirstprivateOp::verify() {
470 if (getDataClause() != acc::DataClause::acc_firstprivate)
471 return emitError(message: "data clause associated with firstprivate operation must "
472 "match its intent");
473 if (failed(Result: checkVarAndVarType(op: *this)))
474 return failure();
475 if (failed(Result: checkNoModifier(op: *this)))
476 return failure();
477 return success();
478}
479
480//===----------------------------------------------------------------------===//
481// ReductionOp
482//===----------------------------------------------------------------------===//
483LogicalResult acc::ReductionOp::verify() {
484 if (getDataClause() != acc::DataClause::acc_reduction)
485 return emitError(message: "data clause associated with reduction operation must "
486 "match its intent");
487 if (failed(Result: checkVarAndVarType(op: *this)))
488 return failure();
489 if (failed(Result: checkNoModifier(op: *this)))
490 return failure();
491 return success();
492}
493
494//===----------------------------------------------------------------------===//
495// DevicePtrOp
496//===----------------------------------------------------------------------===//
497LogicalResult acc::DevicePtrOp::verify() {
498 if (getDataClause() != acc::DataClause::acc_deviceptr)
499 return emitError(message: "data clause associated with deviceptr operation must "
500 "match its intent");
501 if (failed(Result: checkVarAndVarType(op: *this)))
502 return failure();
503 if (failed(Result: checkVarAndAccVar(op: *this)))
504 return failure();
505 if (failed(Result: checkNoModifier(op: *this)))
506 return failure();
507 return success();
508}
509
510//===----------------------------------------------------------------------===//
511// PresentOp
512//===----------------------------------------------------------------------===//
513LogicalResult acc::PresentOp::verify() {
514 if (getDataClause() != acc::DataClause::acc_present)
515 return emitError(
516 message: "data clause associated with present operation must match its intent");
517 if (failed(Result: checkVarAndVarType(op: *this)))
518 return failure();
519 if (failed(Result: checkVarAndAccVar(op: *this)))
520 return failure();
521 if (failed(Result: checkNoModifier(op: *this)))
522 return failure();
523 return success();
524}
525
526//===----------------------------------------------------------------------===//
527// CopyinOp
528//===----------------------------------------------------------------------===//
529LogicalResult acc::CopyinOp::verify() {
530 // Test for all clauses this operation can be decomposed from:
531 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
532 getDataClause() != acc::DataClause::acc_copyin_readonly &&
533 getDataClause() != acc::DataClause::acc_copy &&
534 getDataClause() != acc::DataClause::acc_reduction)
535 return emitError(
536 message: "data clause associated with copyin operation must match its intent"
537 " or specify original clause this operation was decomposed from");
538 if (failed(Result: checkVarAndVarType(op: *this)))
539 return failure();
540 if (failed(Result: checkVarAndAccVar(op: *this)))
541 return failure();
542 if (failed(Result: checkValidModifier(op: *this, validModifiers: acc::DataClauseModifier::readonly |
543 acc::DataClauseModifier::always |
544 acc::DataClauseModifier::capture)))
545 return failure();
546 return success();
547}
548
549bool acc::CopyinOp::isCopyinReadonly() {
550 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
551 acc::bitEnumContainsAny(bits: getModifiers(),
552 bit: acc::DataClauseModifier::readonly);
553}
554
555//===----------------------------------------------------------------------===//
556// CreateOp
557//===----------------------------------------------------------------------===//
558LogicalResult acc::CreateOp::verify() {
559 // Test for all clauses this operation can be decomposed from:
560 if (getDataClause() != acc::DataClause::acc_create &&
561 getDataClause() != acc::DataClause::acc_create_zero &&
562 getDataClause() != acc::DataClause::acc_copyout &&
563 getDataClause() != acc::DataClause::acc_copyout_zero)
564 return emitError(
565 message: "data clause associated with create operation must match its intent"
566 " or specify original clause this operation was decomposed from");
567 if (failed(Result: checkVarAndVarType(op: *this)))
568 return failure();
569 if (failed(Result: checkVarAndAccVar(op: *this)))
570 return failure();
571 // this op is the entry part of copyout, so it also needs to allow all
572 // modifiers allowed on copyout.
573 if (failed(Result: checkValidModifier(op: *this, validModifiers: acc::DataClauseModifier::zero |
574 acc::DataClauseModifier::always |
575 acc::DataClauseModifier::capture)))
576 return failure();
577 return success();
578}
579
580bool acc::CreateOp::isCreateZero() {
581 // The zero modifier is encoded in the data clause.
582 return getDataClause() == acc::DataClause::acc_create_zero ||
583 getDataClause() == acc::DataClause::acc_copyout_zero ||
584 acc::bitEnumContainsAny(bits: getModifiers(), bit: acc::DataClauseModifier::zero);
585}
586
587//===----------------------------------------------------------------------===//
588// NoCreateOp
589//===----------------------------------------------------------------------===//
590LogicalResult acc::NoCreateOp::verify() {
591 if (getDataClause() != acc::DataClause::acc_no_create)
592 return emitError(message: "data clause associated with no_create operation must "
593 "match its intent");
594 if (failed(Result: checkVarAndVarType(op: *this)))
595 return failure();
596 if (failed(Result: checkVarAndAccVar(op: *this)))
597 return failure();
598 if (failed(Result: checkNoModifier(op: *this)))
599 return failure();
600 return success();
601}
602
603//===----------------------------------------------------------------------===//
604// AttachOp
605//===----------------------------------------------------------------------===//
606LogicalResult acc::AttachOp::verify() {
607 if (getDataClause() != acc::DataClause::acc_attach)
608 return emitError(
609 message: "data clause associated with attach operation must match its intent");
610 if (failed(Result: checkVarAndVarType(op: *this)))
611 return failure();
612 if (failed(Result: checkVarAndAccVar(op: *this)))
613 return failure();
614 if (failed(Result: checkNoModifier(op: *this)))
615 return failure();
616 return success();
617}
618
619//===----------------------------------------------------------------------===//
620// DeclareDeviceResidentOp
621//===----------------------------------------------------------------------===//
622
623LogicalResult acc::DeclareDeviceResidentOp::verify() {
624 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
625 return emitError(message: "data clause associated with device_resident operation "
626 "must match its intent");
627 if (failed(Result: checkVarAndVarType(op: *this)))
628 return failure();
629 if (failed(Result: checkVarAndAccVar(op: *this)))
630 return failure();
631 if (failed(Result: checkNoModifier(op: *this)))
632 return failure();
633 return success();
634}
635
636//===----------------------------------------------------------------------===//
637// DeclareLinkOp
638//===----------------------------------------------------------------------===//
639
640LogicalResult acc::DeclareLinkOp::verify() {
641 if (getDataClause() != acc::DataClause::acc_declare_link)
642 return emitError(
643 message: "data clause associated with link operation must match its intent");
644 if (failed(Result: checkVarAndVarType(op: *this)))
645 return failure();
646 if (failed(Result: checkVarAndAccVar(op: *this)))
647 return failure();
648 if (failed(Result: checkNoModifier(op: *this)))
649 return failure();
650 return success();
651}
652
653//===----------------------------------------------------------------------===//
654// CopyoutOp
655//===----------------------------------------------------------------------===//
656LogicalResult acc::CopyoutOp::verify() {
657 // Test for all clauses this operation can be decomposed from:
658 if (getDataClause() != acc::DataClause::acc_copyout &&
659 getDataClause() != acc::DataClause::acc_copyout_zero &&
660 getDataClause() != acc::DataClause::acc_copy &&
661 getDataClause() != acc::DataClause::acc_reduction)
662 return emitError(
663 message: "data clause associated with copyout operation must match its intent"
664 " or specify original clause this operation was decomposed from");
665 if (!getVar() || !getAccVar())
666 return emitError(message: "must have both host and device pointers");
667 if (failed(Result: checkVarAndVarType(op: *this)))
668 return failure();
669 if (failed(Result: checkVarAndAccVar(op: *this)))
670 return failure();
671 if (failed(Result: checkValidModifier(op: *this, validModifiers: acc::DataClauseModifier::zero |
672 acc::DataClauseModifier::always |
673 acc::DataClauseModifier::capture)))
674 return failure();
675 return success();
676}
677
678bool acc::CopyoutOp::isCopyoutZero() {
679 return getDataClause() == acc::DataClause::acc_copyout_zero ||
680 acc::bitEnumContainsAny(bits: getModifiers(), bit: acc::DataClauseModifier::zero);
681}
682
683//===----------------------------------------------------------------------===//
684// DeleteOp
685//===----------------------------------------------------------------------===//
686LogicalResult acc::DeleteOp::verify() {
687 // Test for all clauses this operation can be decomposed from:
688 if (getDataClause() != acc::DataClause::acc_delete &&
689 getDataClause() != acc::DataClause::acc_create &&
690 getDataClause() != acc::DataClause::acc_create_zero &&
691 getDataClause() != acc::DataClause::acc_copyin &&
692 getDataClause() != acc::DataClause::acc_copyin_readonly &&
693 getDataClause() != acc::DataClause::acc_present &&
694 getDataClause() != acc::DataClause::acc_no_create &&
695 getDataClause() != acc::DataClause::acc_declare_device_resident &&
696 getDataClause() != acc::DataClause::acc_declare_link)
697 return emitError(
698 message: "data clause associated with delete operation must match its intent"
699 " or specify original clause this operation was decomposed from");
700 if (!getAccVar())
701 return emitError(message: "must have device pointer");
702 // This op is the exit part of copyin and create - thus allow all modifiers
703 // allowed on either case.
704 if (failed(Result: checkValidModifier(op: *this, validModifiers: acc::DataClauseModifier::zero |
705 acc::DataClauseModifier::readonly |
706 acc::DataClauseModifier::always |
707 acc::DataClauseModifier::capture)))
708 return failure();
709 return success();
710}
711
712//===----------------------------------------------------------------------===//
713// DetachOp
714//===----------------------------------------------------------------------===//
715LogicalResult acc::DetachOp::verify() {
716 // Test for all clauses this operation can be decomposed from:
717 if (getDataClause() != acc::DataClause::acc_detach &&
718 getDataClause() != acc::DataClause::acc_attach)
719 return emitError(
720 message: "data clause associated with detach operation must match its intent"
721 " or specify original clause this operation was decomposed from");
722 if (!getAccVar())
723 return emitError(message: "must have device pointer");
724 if (failed(Result: checkNoModifier(op: *this)))
725 return failure();
726 return success();
727}
728
729//===----------------------------------------------------------------------===//
730// HostOp
731//===----------------------------------------------------------------------===//
732LogicalResult acc::UpdateHostOp::verify() {
733 // Test for all clauses this operation can be decomposed from:
734 if (getDataClause() != acc::DataClause::acc_update_host &&
735 getDataClause() != acc::DataClause::acc_update_self)
736 return emitError(
737 message: "data clause associated with host operation must match its intent"
738 " or specify original clause this operation was decomposed from");
739 if (!getVar() || !getAccVar())
740 return emitError(message: "must have both host and device pointers");
741 if (failed(Result: checkVarAndVarType(op: *this)))
742 return failure();
743 if (failed(Result: checkVarAndAccVar(op: *this)))
744 return failure();
745 if (failed(Result: checkNoModifier(op: *this)))
746 return failure();
747 return success();
748}
749
750//===----------------------------------------------------------------------===//
751// DeviceOp
752//===----------------------------------------------------------------------===//
753LogicalResult acc::UpdateDeviceOp::verify() {
754 // Test for all clauses this operation can be decomposed from:
755 if (getDataClause() != acc::DataClause::acc_update_device)
756 return emitError(
757 message: "data clause associated with device operation must match its intent"
758 " or specify original clause this operation was decomposed from");
759 if (failed(Result: checkVarAndVarType(op: *this)))
760 return failure();
761 if (failed(Result: checkVarAndAccVar(op: *this)))
762 return failure();
763 if (failed(Result: checkNoModifier(op: *this)))
764 return failure();
765 return success();
766}
767
768//===----------------------------------------------------------------------===//
769// UseDeviceOp
770//===----------------------------------------------------------------------===//
771LogicalResult acc::UseDeviceOp::verify() {
772 // Test for all clauses this operation can be decomposed from:
773 if (getDataClause() != acc::DataClause::acc_use_device)
774 return emitError(
775 message: "data clause associated with use_device operation must match its intent"
776 " or specify original clause this operation was decomposed from");
777 if (failed(Result: checkVarAndVarType(op: *this)))
778 return failure();
779 if (failed(Result: checkVarAndAccVar(op: *this)))
780 return failure();
781 if (failed(Result: checkNoModifier(op: *this)))
782 return failure();
783 return success();
784}
785
786//===----------------------------------------------------------------------===//
787// CacheOp
788//===----------------------------------------------------------------------===//
789LogicalResult acc::CacheOp::verify() {
790 // Test for all clauses this operation can be decomposed from:
791 if (getDataClause() != acc::DataClause::acc_cache &&
792 getDataClause() != acc::DataClause::acc_cache_readonly)
793 return emitError(
794 message: "data clause associated with cache operation must match its intent"
795 " or specify original clause this operation was decomposed from");
796 if (failed(Result: checkVarAndVarType(op: *this)))
797 return failure();
798 if (failed(Result: checkVarAndAccVar(op: *this)))
799 return failure();
800 if (failed(Result: checkValidModifier(op: *this, validModifiers: acc::DataClauseModifier::readonly)))
801 return failure();
802 return success();
803}
804
805bool acc::CacheOp::isCacheReadonly() {
806 return getDataClause() == acc::DataClause::acc_cache_readonly ||
807 acc::bitEnumContainsAny(bits: getModifiers(),
808 bit: acc::DataClauseModifier::readonly);
809}
810
811template <typename StructureOp>
812static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
813 unsigned nRegions = 1) {
814
815 SmallVector<Region *, 2> regions;
816 for (unsigned i = 0; i < nRegions; ++i)
817 regions.push_back(Elt: state.addRegion());
818
819 for (Region *region : regions)
820 if (parser.parseRegion(region&: *region, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {}))
821 return failure();
822
823 return success();
824}
825
826static bool isComputeOperation(Operation *op) {
827 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(Val: op);
828}
829
830namespace {
831/// Pattern to remove operation without region that have constant false `ifCond`
832/// and remove the condition from the operation if the `ifCond` is a true
833/// constant.
834template <typename OpTy>
835struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
836 using OpRewritePattern<OpTy>::OpRewritePattern;
837
838 LogicalResult matchAndRewrite(OpTy op,
839 PatternRewriter &rewriter) const override {
840 // Early return if there is no condition.
841 Value ifCond = op.getIfCond();
842 if (!ifCond)
843 return failure();
844
845 IntegerAttr constAttr;
846 if (!matchPattern(value: ifCond, pattern: m_Constant(bind_value: &constAttr)))
847 return failure();
848 if (constAttr.getInt())
849 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
850 else
851 rewriter.eraseOp(op);
852
853 return success();
854 }
855};
856
857/// Replaces the given op with the contents of the given single-block region,
858/// using the operands of the block terminator to replace operation results.
859static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
860 Region &region, ValueRange blockArgs = {}) {
861 assert(llvm::hasSingleElement(region) && "expected single-region block");
862 Block *block = &region.front();
863 Operation *terminator = block->getTerminator();
864 ValueRange results = terminator->getOperands();
865 rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs);
866 rewriter.replaceOp(op, newValues: results);
867 rewriter.eraseOp(op: terminator);
868}
869
870/// Pattern to remove operation with region that have constant false `ifCond`
871/// and remove the condition from the operation if the `ifCond` is constant
872/// true.
873template <typename OpTy>
874struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
875 using OpRewritePattern<OpTy>::OpRewritePattern;
876
877 LogicalResult matchAndRewrite(OpTy op,
878 PatternRewriter &rewriter) const override {
879 // Early return if there is no condition.
880 Value ifCond = op.getIfCond();
881 if (!ifCond)
882 return failure();
883
884 IntegerAttr constAttr;
885 if (!matchPattern(value: ifCond, pattern: m_Constant(bind_value: &constAttr)))
886 return failure();
887 if (constAttr.getInt())
888 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
889 else
890 replaceOpWithRegion(rewriter, op, op.getRegion());
891
892 return success();
893 }
894};
895
896} // namespace
897
898//===----------------------------------------------------------------------===//
899// PrivateRecipeOp
900//===----------------------------------------------------------------------===//
901
902static LogicalResult verifyInitLikeSingleArgRegion(
903 Operation *op, Region &region, StringRef regionType, StringRef regionName,
904 Type type, bool verifyYield, bool optional = false) {
905 if (optional && region.empty())
906 return success();
907
908 if (region.empty())
909 return op->emitOpError() << "expects non-empty " << regionName << " region";
910 Block &firstBlock = region.front();
911 if (firstBlock.getNumArguments() < 1 ||
912 firstBlock.getArgument(i: 0).getType() != type)
913 return op->emitOpError() << "expects " << regionName
914 << " region first "
915 "argument of the "
916 << regionType << " type";
917
918 if (verifyYield) {
919 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
920 if (yieldOp.getOperands().size() != 1 ||
921 yieldOp.getOperands().getTypes()[0] != type)
922 return op->emitOpError() << "expects " << regionName
923 << " region to "
924 "yield a value of the "
925 << regionType << " type";
926 }
927 }
928 return success();
929}
930
931LogicalResult acc::PrivateRecipeOp::verifyRegions() {
932 if (failed(Result: verifyInitLikeSingleArgRegion(op: *this, region&: getInitRegion(),
933 regionType: "privatization", regionName: "init", type: getType(),
934 /*verifyYield=*/false)))
935 return failure();
936 if (failed(Result: verifyInitLikeSingleArgRegion(
937 op: *this, region&: getDestroyRegion(), regionType: "privatization", regionName: "destroy", type: getType(),
938 /*verifyYield=*/false, /*optional=*/true)))
939 return failure();
940 return success();
941}
942
943//===----------------------------------------------------------------------===//
944// FirstprivateRecipeOp
945//===----------------------------------------------------------------------===//
946
947LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
948 if (failed(Result: verifyInitLikeSingleArgRegion(op: *this, region&: getInitRegion(),
949 regionType: "privatization", regionName: "init", type: getType(),
950 /*verifyYield=*/false)))
951 return failure();
952
953 if (getCopyRegion().empty())
954 return emitOpError() << "expects non-empty copy region";
955
956 Block &firstBlock = getCopyRegion().front();
957 if (firstBlock.getNumArguments() < 2 ||
958 firstBlock.getArgument(i: 0).getType() != getType())
959 return emitOpError() << "expects copy region with two arguments of the "
960 "privatization type";
961
962 if (getDestroyRegion().empty())
963 return success();
964
965 if (failed(Result: verifyInitLikeSingleArgRegion(op: *this, region&: getDestroyRegion(),
966 regionType: "privatization", regionName: "destroy",
967 type: getType(), /*verifyYield=*/false)))
968 return failure();
969
970 return success();
971}
972
973//===----------------------------------------------------------------------===//
974// ReductionRecipeOp
975//===----------------------------------------------------------------------===//
976
977LogicalResult acc::ReductionRecipeOp::verifyRegions() {
978 if (failed(Result: verifyInitLikeSingleArgRegion(op: *this, region&: getInitRegion(), regionType: "reduction",
979 regionName: "init", type: getType(),
980 /*verifyYield=*/false)))
981 return failure();
982
983 if (getCombinerRegion().empty())
984 return emitOpError() << "expects non-empty combiner region";
985
986 Block &reductionBlock = getCombinerRegion().front();
987 if (reductionBlock.getNumArguments() < 2 ||
988 reductionBlock.getArgument(i: 0).getType() != getType() ||
989 reductionBlock.getArgument(i: 1).getType() != getType())
990 return emitOpError() << "expects combiner region with the first two "
991 << "arguments of the reduction type";
992
993 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
994 if (yieldOp.getOperands().size() != 1 ||
995 yieldOp.getOperands().getTypes()[0] != getType())
996 return emitOpError() << "expects combiner region to yield a value "
997 "of the reduction type";
998 }
999
1000 return success();
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// Custom parser and printer verifier for private clause
1005//===----------------------------------------------------------------------===//
1006
1007static ParseResult parseSymOperandList(
1008 mlir::OpAsmParser &parser,
1009 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1010 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
1011 llvm::SmallVector<SymbolRefAttr> attributes;
1012 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1013 if (parser.parseAttribute(result&: attributes.emplace_back()) ||
1014 parser.parseArrow() ||
1015 parser.parseOperand(result&: operands.emplace_back()) ||
1016 parser.parseColonType(result&: types.emplace_back()))
1017 return failure();
1018 return success();
1019 })))
1020 return failure();
1021 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1022 attributes.end());
1023 symbols = ArrayAttr::get(context: parser.getContext(), value: arrayAttr);
1024 return success();
1025}
1026
1027static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
1028 mlir::OperandRange operands,
1029 mlir::TypeRange types,
1030 std::optional<mlir::ArrayAttr> attributes) {
1031 llvm::interleaveComma(c: llvm::zip(t&: *attributes, u&: operands), os&: p, each_fn: [&](auto it) {
1032 p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
1033 << std::get<1>(it).getType();
1034 });
1035}
1036
1037//===----------------------------------------------------------------------===//
1038// ParallelOp
1039//===----------------------------------------------------------------------===//
1040
1041/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1042template <typename Op>
1043static LogicalResult checkDataOperands(Op op,
1044 const mlir::ValueRange &operands) {
1045 for (mlir::Value operand : operands)
1046 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1047 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1048 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1049 Val: operand.getDefiningOp()))
1050 return op.emitError(
1051 "expect data entry/exit operation or acc.getdeviceptr "
1052 "as defining op");
1053 return success();
1054}
1055
1056template <typename Op>
1057static LogicalResult
1058checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
1059 mlir::OperandRange operands, llvm::StringRef operandName,
1060 llvm::StringRef symbolName, bool checkOperandType = true) {
1061 if (!operands.empty()) {
1062 if (!attributes || attributes->size() != operands.size())
1063 return op->emitOpError()
1064 << "expected as many " << symbolName << " symbol reference as "
1065 << operandName << " operands";
1066 } else {
1067 if (attributes)
1068 return op->emitOpError()
1069 << "unexpected " << symbolName << " symbol reference";
1070 return success();
1071 }
1072
1073 llvm::DenseSet<Value> set;
1074 for (auto args : llvm::zip(t&: operands, u&: *attributes)) {
1075 mlir::Value operand = std::get<0>(t&: args);
1076
1077 if (!set.insert(V: operand).second)
1078 return op->emitOpError()
1079 << operandName << " operand appears more than once";
1080
1081 mlir::Type varType = operand.getType();
1082 auto symbolRef = llvm::cast<SymbolRefAttr>(Val: std::get<1>(t&: args));
1083 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1084 if (!decl)
1085 return op->emitOpError()
1086 << "expected symbol reference " << symbolRef << " to point to a "
1087 << operandName << " declaration";
1088
1089 if (checkOperandType && decl.getType() && decl.getType() != varType)
1090 return op->emitOpError() << "expected " << operandName << " (" << varType
1091 << ") to be the same type as " << operandName
1092 << " declaration (" << decl.getType() << ")";
1093 }
1094
1095 return success();
1096}
1097
1098unsigned ParallelOp::getNumDataOperands() {
1099 return getReductionOperands().size() + getPrivateOperands().size() +
1100 getFirstprivateOperands().size() + getDataClauseOperands().size();
1101}
1102
1103Value ParallelOp::getDataOperand(unsigned i) {
1104 unsigned numOptional = getAsyncOperands().size();
1105 numOptional += getNumGangs().size();
1106 numOptional += getNumWorkers().size();
1107 numOptional += getVectorLength().size();
1108 numOptional += getIfCond() ? 1 : 0;
1109 numOptional += getSelfCond() ? 1 : 0;
1110 return getOperand(i: getWaitOperands().size() + numOptional + i);
1111}
1112
1113template <typename Op>
1114static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1115 ArrayAttr deviceTypes,
1116 llvm::StringRef keyword) {
1117 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1118 return op.emitOpError() << keyword << " operands count must match "
1119 << keyword << " device_type count";
1120 return success();
1121}
1122
1123template <typename Op>
1124static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
1125 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1126 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1127 std::size_t numOperandsInSegments = 0;
1128 std::size_t nbOfSegments = 0;
1129
1130 if (segments) {
1131 for (auto segCount : segments.asArrayRef()) {
1132 if (maxInSegment != 0 && segCount > maxInSegment)
1133 return op.emitOpError() << keyword << " expects a maximum of "
1134 << maxInSegment << " values per segment";
1135 numOperandsInSegments += segCount;
1136 ++nbOfSegments;
1137 }
1138 }
1139
1140 if ((numOperandsInSegments != operands.size()) ||
1141 (!deviceTypes && !operands.empty()))
1142 return op.emitOpError()
1143 << keyword << " operand count does not match count in segments";
1144 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1145 return op.emitOpError()
1146 << keyword << " segment count does not match device_type count";
1147 return success();
1148}
1149
1150LogicalResult acc::ParallelOp::verify() {
1151 if (failed(Result: checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1152 op: *this, attributes: getPrivatizationRecipes(), operands: getPrivateOperands(), operandName: "private",
1153 symbolName: "privatizations", /*checkOperandType=*/false)))
1154 return failure();
1155 if (failed(Result: checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1156 op: *this, attributes: getFirstprivatizationRecipes(), operands: getFirstprivateOperands(),
1157 operandName: "firstprivate", symbolName: "firstprivatizations", /*checkOperandType=*/false)))
1158 return failure();
1159 if (failed(Result: checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1160 op: *this, attributes: getReductionRecipes(), operands: getReductionOperands(), operandName: "reduction",
1161 symbolName: "reductions", checkOperandType: false)))
1162 return failure();
1163
1164 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
1165 op: *this, operands: getNumGangs(), segments: getNumGangsSegmentsAttr(),
1166 deviceTypes: getNumGangsDeviceTypeAttr(), keyword: "num_gangs", maxInSegment: 3)))
1167 return failure();
1168
1169 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
1170 op: *this, operands: getWaitOperands(), segments: getWaitOperandsSegmentsAttr(),
1171 deviceTypes: getWaitOperandsDeviceTypeAttr(), keyword: "wait")))
1172 return failure();
1173
1174 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getNumWorkers(),
1175 deviceTypes: getNumWorkersDeviceTypeAttr(),
1176 keyword: "num_workers")))
1177 return failure();
1178
1179 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getVectorLength(),
1180 deviceTypes: getVectorLengthDeviceTypeAttr(),
1181 keyword: "vector_length")))
1182 return failure();
1183
1184 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getAsyncOperands(),
1185 deviceTypes: getAsyncOperandsDeviceTypeAttr(),
1186 keyword: "async")))
1187 return failure();
1188
1189 if (failed(Result: checkWaitAndAsyncConflict<acc::ParallelOp>(op: *this)))
1190 return failure();
1191
1192 return checkDataOperands<acc::ParallelOp>(op: *this, operands: getDataClauseOperands());
1193}
1194
1195static mlir::Value
1196getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1197 mlir::Operation::operand_range range,
1198 mlir::acc::DeviceType deviceType) {
1199 if (!arrayAttr)
1200 return {};
1201 if (auto pos = findSegment(segments: *arrayAttr, deviceType))
1202 return range[*pos];
1203 return {};
1204}
1205
1206bool acc::ParallelOp::hasAsyncOnly() {
1207 return hasAsyncOnly(deviceType: mlir::acc::DeviceType::None);
1208}
1209
1210bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1211 return hasDeviceType(arrayAttr: getAsyncOnly(), deviceType);
1212}
1213
1214mlir::Value acc::ParallelOp::getAsyncValue() {
1215 return getAsyncValue(deviceType: mlir::acc::DeviceType::None);
1216}
1217
1218mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1219 return getValueInDeviceTypeSegment(arrayAttr: getAsyncOperandsDeviceType(),
1220 range: getAsyncOperands(), deviceType);
1221}
1222
1223mlir::Value acc::ParallelOp::getNumWorkersValue() {
1224 return getNumWorkersValue(deviceType: mlir::acc::DeviceType::None);
1225}
1226
1227mlir::Value
1228acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1229 return getValueInDeviceTypeSegment(arrayAttr: getNumWorkersDeviceType(), range: getNumWorkers(),
1230 deviceType);
1231}
1232
1233mlir::Value acc::ParallelOp::getVectorLengthValue() {
1234 return getVectorLengthValue(deviceType: mlir::acc::DeviceType::None);
1235}
1236
1237mlir::Value
1238acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1239 return getValueInDeviceTypeSegment(arrayAttr: getVectorLengthDeviceType(),
1240 range: getVectorLength(), deviceType);
1241}
1242
1243mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1244 return getNumGangsValues(deviceType: mlir::acc::DeviceType::None);
1245}
1246
1247mlir::Operation::operand_range
1248ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1249 return getValuesFromSegments(arrayAttr: getNumGangsDeviceType(), range: getNumGangs(),
1250 segments: getNumGangsSegments(), deviceType);
1251}
1252
1253bool acc::ParallelOp::hasWaitOnly() {
1254 return hasWaitOnly(deviceType: mlir::acc::DeviceType::None);
1255}
1256
1257bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1258 return hasDeviceType(arrayAttr: getWaitOnly(), deviceType);
1259}
1260
1261mlir::Operation::operand_range ParallelOp::getWaitValues() {
1262 return getWaitValues(deviceType: mlir::acc::DeviceType::None);
1263}
1264
1265mlir::Operation::operand_range
1266ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1267 return getWaitValuesWithoutDevnum(
1268 deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(), segments: getWaitOperandsSegments(),
1269 hasWaitDevnum: getHasWaitDevnum(), deviceType);
1270}
1271
1272mlir::Value ParallelOp::getWaitDevnum() {
1273 return getWaitDevnum(deviceType: mlir::acc::DeviceType::None);
1274}
1275
1276mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1277 return getWaitDevnumValue(deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(),
1278 segments: getWaitOperandsSegments(), hasWaitDevnum: getHasWaitDevnum(),
1279 deviceType);
1280}
1281
1282void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1283 mlir::OperationState &odsState,
1284 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1285 mlir::ValueRange vectorLength,
1286 mlir::ValueRange asyncOperands,
1287 mlir::ValueRange waitOperands, mlir::Value ifCond,
1288 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1289 mlir::ValueRange gangPrivateOperands,
1290 mlir::ValueRange gangFirstPrivateOperands,
1291 mlir::ValueRange dataClauseOperands) {
1292
1293 ParallelOp::build(
1294 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1295 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1296 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1297 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1298 /*numGangsDeviceType=*/nullptr, numWorkers,
1299 /*numWorkersDeviceType=*/nullptr, vectorLength,
1300 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1301 /*selfAttr=*/nullptr, reductionOperands, /*reductionRecipes=*/nullptr,
1302 privateOperands: gangPrivateOperands, /*privatizations=*/privatizationRecipes: nullptr, firstprivateOperands: gangFirstPrivateOperands,
1303 /*firstprivatizations=*/firstprivatizationRecipes: nullptr, dataClauseOperands,
1304 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1305}
1306
1307void acc::ParallelOp::addNumWorkersOperand(
1308 MLIRContext *context, mlir::Value newValue,
1309 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1310 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1311 context, existingDeviceTypes: getNumWorkersDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
1312 argCollection: getNumWorkersMutable()));
1313}
1314void acc::ParallelOp::addVectorLengthOperand(
1315 MLIRContext *context, mlir::Value newValue,
1316 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1317 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1318 context, existingDeviceTypes: getVectorLengthDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
1319 argCollection: getVectorLengthMutable()));
1320}
1321
1322void acc::ParallelOp::addAsyncOnly(
1323 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1324 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1325 context, existingDeviceTypes: getAsyncOnlyAttr(), newDeviceTypes: effectiveDeviceTypes));
1326}
1327
1328void acc::ParallelOp::addAsyncOperand(
1329 MLIRContext *context, mlir::Value newValue,
1330 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1331 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1332 context, existingDeviceTypes: getAsyncOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
1333 argCollection: getAsyncOperandsMutable()));
1334}
1335
1336void acc::ParallelOp::addNumGangsOperands(
1337 MLIRContext *context, mlir::ValueRange newValues,
1338 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1339 llvm::SmallVector<int32_t> segments;
1340 if (getNumGangsSegments())
1341 llvm::copy(Range: *getNumGangsSegments(), Out: std::back_inserter(x&: segments));
1342
1343 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1344 context, existingDeviceTypes: getNumGangsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValues,
1345 argCollection: getNumGangsMutable(), segments));
1346
1347 setNumGangsSegments(segments);
1348}
1349void acc::ParallelOp::addWaitOnly(
1350 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1351 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getWaitOnlyAttr(),
1352 newDeviceTypes: effectiveDeviceTypes));
1353}
1354void acc::ParallelOp::addWaitOperands(
1355 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1356 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1357
1358 llvm::SmallVector<int32_t> segments;
1359 if (getWaitOperandsSegments())
1360 llvm::copy(Range: *getWaitOperandsSegments(), Out: std::back_inserter(x&: segments));
1361
1362 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1363 context, existingDeviceTypes: getWaitOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValues,
1364 argCollection: getWaitOperandsMutable(), segments));
1365 setWaitOperandsSegments(segments);
1366
1367 llvm::SmallVector<mlir::Attribute> hasDevnums;
1368 if (getHasWaitDevnumAttr())
1369 llvm::copy(Range: getHasWaitDevnumAttr(), Out: std::back_inserter(x&: hasDevnums));
1370 hasDevnums.insert(
1371 I: hasDevnums.end(),
1372 NumToInsert: std::max(a: effectiveDeviceTypes.size(), b: static_cast<size_t>(1)),
1373 Elt: mlir::BoolAttr::get(context, value: hasDevnum));
1374 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, value: hasDevnums));
1375}
1376
1377static ParseResult parseNumGangs(
1378 mlir::OpAsmParser &parser,
1379 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1380 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1381 mlir::DenseI32ArrayAttr &segments) {
1382 llvm::SmallVector<DeviceTypeAttr> attributes;
1383 llvm::SmallVector<int32_t> seg;
1384
1385 do {
1386 if (failed(Result: parser.parseLBrace()))
1387 return failure();
1388
1389 int32_t crtOperandsSize = operands.size();
1390 if (failed(Result: parser.parseCommaSeparatedList(
1391 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
1392 if (parser.parseOperand(result&: operands.emplace_back()) ||
1393 parser.parseColonType(result&: types.emplace_back()))
1394 return failure();
1395 return success();
1396 })))
1397 return failure();
1398 seg.push_back(Elt: operands.size() - crtOperandsSize);
1399
1400 if (failed(Result: parser.parseRBrace()))
1401 return failure();
1402
1403 if (succeeded(Result: parser.parseOptionalLSquare())) {
1404 if (parser.parseAttribute(result&: attributes.emplace_back()) ||
1405 parser.parseRSquare())
1406 return failure();
1407 } else {
1408 attributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
1409 context: parser.getContext(), value: mlir::acc::DeviceType::None));
1410 }
1411 } while (succeeded(Result: parser.parseOptionalComma()));
1412
1413 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1414 attributes.end());
1415 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: arrayAttr);
1416 segments = DenseI32ArrayAttr::get(context: parser.getContext(), content: seg);
1417
1418 return success();
1419}
1420
1421static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) {
1422 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val&: attr);
1423 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1424 p << " [" << attr << "]";
1425}
1426
1427static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op,
1428 mlir::OperandRange operands, mlir::TypeRange types,
1429 std::optional<mlir::ArrayAttr> deviceTypes,
1430 std::optional<mlir::DenseI32ArrayAttr> segments) {
1431 unsigned opIdx = 0;
1432 llvm::interleaveComma(c: llvm::enumerate(First&: *deviceTypes), os&: p, each_fn: [&](auto it) {
1433 p << "{";
1434 llvm::interleaveComma(
1435 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1436 p << operands[opIdx] << " : " << operands[opIdx].getType();
1437 ++opIdx;
1438 });
1439 p << "}";
1440 printSingleDeviceType(p, it.value());
1441 });
1442}
1443
1444static ParseResult parseDeviceTypeOperandsWithSegment(
1445 mlir::OpAsmParser &parser,
1446 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1447 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1448 mlir::DenseI32ArrayAttr &segments) {
1449 llvm::SmallVector<DeviceTypeAttr> attributes;
1450 llvm::SmallVector<int32_t> seg;
1451
1452 do {
1453 if (failed(Result: parser.parseLBrace()))
1454 return failure();
1455
1456 int32_t crtOperandsSize = operands.size();
1457
1458 if (failed(Result: parser.parseCommaSeparatedList(
1459 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
1460 if (parser.parseOperand(result&: operands.emplace_back()) ||
1461 parser.parseColonType(result&: types.emplace_back()))
1462 return failure();
1463 return success();
1464 })))
1465 return failure();
1466
1467 seg.push_back(Elt: operands.size() - crtOperandsSize);
1468
1469 if (failed(Result: parser.parseRBrace()))
1470 return failure();
1471
1472 if (succeeded(Result: parser.parseOptionalLSquare())) {
1473 if (parser.parseAttribute(result&: attributes.emplace_back()) ||
1474 parser.parseRSquare())
1475 return failure();
1476 } else {
1477 attributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
1478 context: parser.getContext(), value: mlir::acc::DeviceType::None));
1479 }
1480 } while (succeeded(Result: parser.parseOptionalComma()));
1481
1482 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1483 attributes.end());
1484 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: arrayAttr);
1485 segments = DenseI32ArrayAttr::get(context: parser.getContext(), content: seg);
1486
1487 return success();
1488}
1489
1490static void printDeviceTypeOperandsWithSegment(
1491 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1492 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1493 std::optional<mlir::DenseI32ArrayAttr> segments) {
1494 unsigned opIdx = 0;
1495 llvm::interleaveComma(c: llvm::enumerate(First&: *deviceTypes), os&: p, each_fn: [&](auto it) {
1496 p << "{";
1497 llvm::interleaveComma(
1498 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1499 p << operands[opIdx] << " : " << operands[opIdx].getType();
1500 ++opIdx;
1501 });
1502 p << "}";
1503 printSingleDeviceType(p, it.value());
1504 });
1505}
1506
1507static ParseResult parseWaitClause(
1508 mlir::OpAsmParser &parser,
1509 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1510 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1511 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
1512 mlir::ArrayAttr &keywordOnly) {
1513 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
1514 llvm::SmallVector<int32_t> seg;
1515
1516 bool needCommaBeforeOperands = false;
1517
1518 // Keyword only
1519 if (failed(Result: parser.parseOptionalLParen())) {
1520 keywordAttrs.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
1521 context: parser.getContext(), value: mlir::acc::DeviceType::None));
1522 keywordOnly = ArrayAttr::get(context: parser.getContext(), value: keywordAttrs);
1523 return success();
1524 }
1525
1526 // Parse keyword only attributes
1527 if (succeeded(Result: parser.parseOptionalLSquare())) {
1528 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1529 if (parser.parseAttribute(result&: keywordAttrs.emplace_back()))
1530 return failure();
1531 return success();
1532 })))
1533 return failure();
1534 if (parser.parseRSquare())
1535 return failure();
1536 needCommaBeforeOperands = true;
1537 }
1538
1539 if (needCommaBeforeOperands && failed(Result: parser.parseComma()))
1540 return failure();
1541
1542 do {
1543 if (failed(Result: parser.parseLBrace()))
1544 return failure();
1545
1546 int32_t crtOperandsSize = operands.size();
1547
1548 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "devnum"))) {
1549 if (failed(Result: parser.parseColon()))
1550 return failure();
1551 devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: true));
1552 } else {
1553 devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: false));
1554 }
1555
1556 if (failed(Result: parser.parseCommaSeparatedList(
1557 delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() {
1558 if (parser.parseOperand(result&: operands.emplace_back()) ||
1559 parser.parseColonType(result&: types.emplace_back()))
1560 return failure();
1561 return success();
1562 })))
1563 return failure();
1564
1565 seg.push_back(Elt: operands.size() - crtOperandsSize);
1566
1567 if (failed(Result: parser.parseRBrace()))
1568 return failure();
1569
1570 if (succeeded(Result: parser.parseOptionalLSquare())) {
1571 if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) ||
1572 parser.parseRSquare())
1573 return failure();
1574 } else {
1575 deviceTypeAttrs.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
1576 context: parser.getContext(), value: mlir::acc::DeviceType::None));
1577 }
1578 } while (succeeded(Result: parser.parseOptionalComma()));
1579
1580 if (failed(Result: parser.parseRParen()))
1581 return failure();
1582
1583 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: deviceTypeAttrs);
1584 keywordOnly = ArrayAttr::get(context: parser.getContext(), value: keywordAttrs);
1585 segments = DenseI32ArrayAttr::get(context: parser.getContext(), content: seg);
1586 hasDevNum = ArrayAttr::get(context: parser.getContext(), value: devnum);
1587
1588 return success();
1589}
1590
1591static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
1592 if (!hasDeviceTypeValues(arrayAttr: attrs))
1593 return false;
1594 if (attrs->size() != 1)
1595 return false;
1596 if (auto deviceTypeAttr =
1597 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val: (*attrs)[0]))
1598 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
1599 return false;
1600}
1601
1602static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
1603 mlir::OperandRange operands, mlir::TypeRange types,
1604 std::optional<mlir::ArrayAttr> deviceTypes,
1605 std::optional<mlir::DenseI32ArrayAttr> segments,
1606 std::optional<mlir::ArrayAttr> hasDevNum,
1607 std::optional<mlir::ArrayAttr> keywordOnly) {
1608
1609 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(attrs: keywordOnly))
1610 return;
1611
1612 p << "(";
1613
1614 printDeviceTypes(p, deviceTypes: keywordOnly);
1615 if (hasDeviceTypeValues(arrayAttr: keywordOnly) && hasDeviceTypeValues(arrayAttr: deviceTypes))
1616 p << ", ";
1617
1618 if (hasDeviceTypeValues(arrayAttr: deviceTypes)) {
1619 unsigned opIdx = 0;
1620 llvm::interleaveComma(c: llvm::enumerate(First&: *deviceTypes), os&: p, each_fn: [&](auto it) {
1621 p << "{";
1622 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1623 if (boolAttr && boolAttr.getValue())
1624 p << "devnum: ";
1625 llvm::interleaveComma(
1626 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1627 p << operands[opIdx] << " : " << operands[opIdx].getType();
1628 ++opIdx;
1629 });
1630 p << "}";
1631 printSingleDeviceType(p, it.value());
1632 });
1633 }
1634
1635 p << ")";
1636}
1637
1638static ParseResult parseDeviceTypeOperands(
1639 mlir::OpAsmParser &parser,
1640 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1641 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
1642 llvm::SmallVector<DeviceTypeAttr> attributes;
1643 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1644 if (parser.parseOperand(result&: operands.emplace_back()) ||
1645 parser.parseColonType(result&: types.emplace_back()))
1646 return failure();
1647 if (succeeded(Result: parser.parseOptionalLSquare())) {
1648 if (parser.parseAttribute(result&: attributes.emplace_back()) ||
1649 parser.parseRSquare())
1650 return failure();
1651 } else {
1652 attributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
1653 context: parser.getContext(), value: mlir::acc::DeviceType::None));
1654 }
1655 return success();
1656 })))
1657 return failure();
1658 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1659 attributes.end());
1660 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: arrayAttr);
1661 return success();
1662}
1663
1664static void
1665printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op,
1666 mlir::OperandRange operands, mlir::TypeRange types,
1667 std::optional<mlir::ArrayAttr> deviceTypes) {
1668 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
1669 return;
1670 llvm::interleaveComma(c: llvm::zip(t&: *deviceTypes, u&: operands), os&: p, each_fn: [&](auto it) {
1671 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
1672 printSingleDeviceType(p, std::get<0>(it));
1673 });
1674}
1675
1676static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
1677 mlir::OpAsmParser &parser,
1678 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1679 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1680 mlir::ArrayAttr &keywordOnlyDeviceType) {
1681
1682 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
1683 bool needCommaBeforeOperands = false;
1684
1685 if (failed(Result: parser.parseOptionalLParen())) {
1686 // Keyword only
1687 keywordOnlyDeviceTypeAttributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
1688 context: parser.getContext(), value: mlir::acc::DeviceType::None));
1689 keywordOnlyDeviceType =
1690 ArrayAttr::get(context: parser.getContext(), value: keywordOnlyDeviceTypeAttributes);
1691 return success();
1692 }
1693
1694 // Parse keyword only attributes
1695 if (succeeded(Result: parser.parseOptionalLSquare())) {
1696 // Parse keyword only attributes
1697 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1698 if (parser.parseAttribute(
1699 result&: keywordOnlyDeviceTypeAttributes.emplace_back()))
1700 return failure();
1701 return success();
1702 })))
1703 return failure();
1704 if (parser.parseRSquare())
1705 return failure();
1706 needCommaBeforeOperands = true;
1707 }
1708
1709 if (needCommaBeforeOperands && failed(Result: parser.parseComma()))
1710 return failure();
1711
1712 llvm::SmallVector<DeviceTypeAttr> attributes;
1713 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1714 if (parser.parseOperand(result&: operands.emplace_back()) ||
1715 parser.parseColonType(result&: types.emplace_back()))
1716 return failure();
1717 if (succeeded(Result: parser.parseOptionalLSquare())) {
1718 if (parser.parseAttribute(result&: attributes.emplace_back()) ||
1719 parser.parseRSquare())
1720 return failure();
1721 } else {
1722 attributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
1723 context: parser.getContext(), value: mlir::acc::DeviceType::None));
1724 }
1725 return success();
1726 })))
1727 return failure();
1728
1729 if (failed(Result: parser.parseRParen()))
1730 return failure();
1731
1732 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
1733 attributes.end());
1734 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: arrayAttr);
1735 return success();
1736}
1737
1738static void printDeviceTypeOperandsWithKeywordOnly(
1739 mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
1740 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
1741 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1742
1743 if (operands.begin() == operands.end() &&
1744 hasOnlyDeviceTypeNone(attrs: keywordOnlyDeviceTypes)) {
1745 return;
1746 }
1747
1748 p << "(";
1749 printDeviceTypes(p, deviceTypes: keywordOnlyDeviceTypes);
1750 if (hasDeviceTypeValues(arrayAttr: keywordOnlyDeviceTypes) &&
1751 hasDeviceTypeValues(arrayAttr: deviceTypes))
1752 p << ", ";
1753 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
1754 p << ")";
1755}
1756
1757static ParseResult parseOperandWithKeywordOnly(
1758 mlir::OpAsmParser &parser,
1759 std::optional<OpAsmParser::UnresolvedOperand> &operand,
1760 mlir::Type &operandType, mlir::UnitAttr &attr) {
1761 // Keyword only
1762 if (failed(Result: parser.parseOptionalLParen())) {
1763 attr = mlir::UnitAttr::get(context: parser.getContext());
1764 return success();
1765 }
1766
1767 OpAsmParser::UnresolvedOperand op;
1768 if (failed(Result: parser.parseOperand(result&: op)))
1769 return failure();
1770 operand = op;
1771 if (failed(Result: parser.parseColon()))
1772 return failure();
1773 if (failed(Result: parser.parseType(result&: operandType)))
1774 return failure();
1775 if (failed(Result: parser.parseRParen()))
1776 return failure();
1777
1778 return success();
1779}
1780
1781static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p,
1782 mlir::Operation *op,
1783 std::optional<mlir::Value> operand,
1784 mlir::Type operandType,
1785 mlir::UnitAttr attr) {
1786 if (attr)
1787 return;
1788
1789 p << "(";
1790 p.printOperand(value: *operand);
1791 p << " : ";
1792 p.printType(type: operandType);
1793 p << ")";
1794}
1795
1796static ParseResult parseOperandsWithKeywordOnly(
1797 mlir::OpAsmParser &parser,
1798 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
1799 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
1800 // Keyword only
1801 if (failed(Result: parser.parseOptionalLParen())) {
1802 attr = mlir::UnitAttr::get(context: parser.getContext());
1803 return success();
1804 }
1805
1806 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1807 if (parser.parseOperand(result&: operands.emplace_back()))
1808 return failure();
1809 return success();
1810 })))
1811 return failure();
1812 if (failed(Result: parser.parseColon()))
1813 return failure();
1814 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
1815 if (parser.parseType(result&: types.emplace_back()))
1816 return failure();
1817 return success();
1818 })))
1819 return failure();
1820 if (failed(Result: parser.parseRParen()))
1821 return failure();
1822
1823 return success();
1824}
1825
1826static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p,
1827 mlir::Operation *op,
1828 mlir::OperandRange operands,
1829 mlir::TypeRange types,
1830 mlir::UnitAttr attr) {
1831 if (attr)
1832 return;
1833
1834 p << "(";
1835 llvm::interleaveComma(c: operands, os&: p, each_fn: [&](auto it) { p << it; });
1836 p << " : ";
1837 llvm::interleaveComma(c: types, os&: p, each_fn: [&](auto it) { p << it; });
1838 p << ")";
1839}
1840
1841static ParseResult
1842parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
1843 mlir::acc::CombinedConstructsTypeAttr &attr) {
1844 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "kernels"))) {
1845 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1846 context: parser.getContext(), value: mlir::acc::CombinedConstructsType::KernelsLoop);
1847 } else if (succeeded(Result: parser.parseOptionalKeyword(keyword: "parallel"))) {
1848 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1849 context: parser.getContext(), value: mlir::acc::CombinedConstructsType::ParallelLoop);
1850 } else if (succeeded(Result: parser.parseOptionalKeyword(keyword: "serial"))) {
1851 attr = mlir::acc::CombinedConstructsTypeAttr::get(
1852 context: parser.getContext(), value: mlir::acc::CombinedConstructsType::SerialLoop);
1853 } else {
1854 parser.emitError(loc: parser.getCurrentLocation(),
1855 message: "expected compute construct name");
1856 return failure();
1857 }
1858 return success();
1859}
1860
1861static void
1862printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
1863 mlir::acc::CombinedConstructsTypeAttr attr) {
1864 if (attr) {
1865 switch (attr.getValue()) {
1866 case mlir::acc::CombinedConstructsType::KernelsLoop:
1867 p << "kernels";
1868 break;
1869 case mlir::acc::CombinedConstructsType::ParallelLoop:
1870 p << "parallel";
1871 break;
1872 case mlir::acc::CombinedConstructsType::SerialLoop:
1873 p << "serial";
1874 break;
1875 };
1876 }
1877}
1878
1879//===----------------------------------------------------------------------===//
1880// SerialOp
1881//===----------------------------------------------------------------------===//
1882
1883unsigned SerialOp::getNumDataOperands() {
1884 return getReductionOperands().size() + getPrivateOperands().size() +
1885 getFirstprivateOperands().size() + getDataClauseOperands().size();
1886}
1887
1888Value SerialOp::getDataOperand(unsigned i) {
1889 unsigned numOptional = getAsyncOperands().size();
1890 numOptional += getIfCond() ? 1 : 0;
1891 numOptional += getSelfCond() ? 1 : 0;
1892 return getOperand(i: getWaitOperands().size() + numOptional + i);
1893}
1894
1895bool acc::SerialOp::hasAsyncOnly() {
1896 return hasAsyncOnly(deviceType: mlir::acc::DeviceType::None);
1897}
1898
1899bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1900 return hasDeviceType(arrayAttr: getAsyncOnly(), deviceType);
1901}
1902
1903mlir::Value acc::SerialOp::getAsyncValue() {
1904 return getAsyncValue(deviceType: mlir::acc::DeviceType::None);
1905}
1906
1907mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1908 return getValueInDeviceTypeSegment(arrayAttr: getAsyncOperandsDeviceType(),
1909 range: getAsyncOperands(), deviceType);
1910}
1911
1912bool acc::SerialOp::hasWaitOnly() {
1913 return hasWaitOnly(deviceType: mlir::acc::DeviceType::None);
1914}
1915
1916bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1917 return hasDeviceType(arrayAttr: getWaitOnly(), deviceType);
1918}
1919
1920mlir::Operation::operand_range SerialOp::getWaitValues() {
1921 return getWaitValues(deviceType: mlir::acc::DeviceType::None);
1922}
1923
1924mlir::Operation::operand_range
1925SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1926 return getWaitValuesWithoutDevnum(
1927 deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(), segments: getWaitOperandsSegments(),
1928 hasWaitDevnum: getHasWaitDevnum(), deviceType);
1929}
1930
1931mlir::Value SerialOp::getWaitDevnum() {
1932 return getWaitDevnum(deviceType: mlir::acc::DeviceType::None);
1933}
1934
1935mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1936 return getWaitDevnumValue(deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(),
1937 segments: getWaitOperandsSegments(), hasWaitDevnum: getHasWaitDevnum(),
1938 deviceType);
1939}
1940
1941LogicalResult acc::SerialOp::verify() {
1942 if (failed(Result: checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1943 op: *this, attributes: getPrivatizationRecipes(), operands: getPrivateOperands(), operandName: "private",
1944 symbolName: "privatizations", /*checkOperandType=*/false)))
1945 return failure();
1946 if (failed(Result: checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1947 op: *this, attributes: getFirstprivatizationRecipes(), operands: getFirstprivateOperands(),
1948 operandName: "firstprivate", symbolName: "firstprivatizations", /*checkOperandType=*/false)))
1949 return failure();
1950 if (failed(Result: checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1951 op: *this, attributes: getReductionRecipes(), operands: getReductionOperands(), operandName: "reduction",
1952 symbolName: "reductions", checkOperandType: false)))
1953 return failure();
1954
1955 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
1956 op: *this, operands: getWaitOperands(), segments: getWaitOperandsSegmentsAttr(),
1957 deviceTypes: getWaitOperandsDeviceTypeAttr(), keyword: "wait")))
1958 return failure();
1959
1960 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getAsyncOperands(),
1961 deviceTypes: getAsyncOperandsDeviceTypeAttr(),
1962 keyword: "async")))
1963 return failure();
1964
1965 if (failed(Result: checkWaitAndAsyncConflict<acc::SerialOp>(op: *this)))
1966 return failure();
1967
1968 return checkDataOperands<acc::SerialOp>(op: *this, operands: getDataClauseOperands());
1969}
1970
1971void acc::SerialOp::addAsyncOnly(
1972 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1973 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1974 context, existingDeviceTypes: getAsyncOnlyAttr(), newDeviceTypes: effectiveDeviceTypes));
1975}
1976
1977void acc::SerialOp::addAsyncOperand(
1978 MLIRContext *context, mlir::Value newValue,
1979 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1980 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1981 context, existingDeviceTypes: getAsyncOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
1982 argCollection: getAsyncOperandsMutable()));
1983}
1984
1985void acc::SerialOp::addWaitOnly(
1986 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1987 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getWaitOnlyAttr(),
1988 newDeviceTypes: effectiveDeviceTypes));
1989}
1990void acc::SerialOp::addWaitOperands(
1991 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1992 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1993
1994 llvm::SmallVector<int32_t> segments;
1995 if (getWaitOperandsSegments())
1996 llvm::copy(Range: *getWaitOperandsSegments(), Out: std::back_inserter(x&: segments));
1997
1998 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1999 context, existingDeviceTypes: getWaitOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValues,
2000 argCollection: getWaitOperandsMutable(), segments));
2001 setWaitOperandsSegments(segments);
2002
2003 llvm::SmallVector<mlir::Attribute> hasDevnums;
2004 if (getHasWaitDevnumAttr())
2005 llvm::copy(Range: getHasWaitDevnumAttr(), Out: std::back_inserter(x&: hasDevnums));
2006 hasDevnums.insert(
2007 I: hasDevnums.end(),
2008 NumToInsert: std::max(a: effectiveDeviceTypes.size(), b: static_cast<size_t>(1)),
2009 Elt: mlir::BoolAttr::get(context, value: hasDevnum));
2010 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, value: hasDevnums));
2011}
2012
2013//===----------------------------------------------------------------------===//
2014// KernelsOp
2015//===----------------------------------------------------------------------===//
2016
2017unsigned KernelsOp::getNumDataOperands() {
2018 return getDataClauseOperands().size();
2019}
2020
2021Value KernelsOp::getDataOperand(unsigned i) {
2022 unsigned numOptional = getAsyncOperands().size();
2023 numOptional += getWaitOperands().size();
2024 numOptional += getNumGangs().size();
2025 numOptional += getNumWorkers().size();
2026 numOptional += getVectorLength().size();
2027 numOptional += getIfCond() ? 1 : 0;
2028 numOptional += getSelfCond() ? 1 : 0;
2029 return getOperand(i: numOptional + i);
2030}
2031
2032bool acc::KernelsOp::hasAsyncOnly() {
2033 return hasAsyncOnly(deviceType: mlir::acc::DeviceType::None);
2034}
2035
2036bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2037 return hasDeviceType(arrayAttr: getAsyncOnly(), deviceType);
2038}
2039
2040mlir::Value acc::KernelsOp::getAsyncValue() {
2041 return getAsyncValue(deviceType: mlir::acc::DeviceType::None);
2042}
2043
2044mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2045 return getValueInDeviceTypeSegment(arrayAttr: getAsyncOperandsDeviceType(),
2046 range: getAsyncOperands(), deviceType);
2047}
2048
2049mlir::Value acc::KernelsOp::getNumWorkersValue() {
2050 return getNumWorkersValue(deviceType: mlir::acc::DeviceType::None);
2051}
2052
2053mlir::Value
2054acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2055 return getValueInDeviceTypeSegment(arrayAttr: getNumWorkersDeviceType(), range: getNumWorkers(),
2056 deviceType);
2057}
2058
2059mlir::Value acc::KernelsOp::getVectorLengthValue() {
2060 return getVectorLengthValue(deviceType: mlir::acc::DeviceType::None);
2061}
2062
2063mlir::Value
2064acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2065 return getValueInDeviceTypeSegment(arrayAttr: getVectorLengthDeviceType(),
2066 range: getVectorLength(), deviceType);
2067}
2068
2069mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2070 return getNumGangsValues(deviceType: mlir::acc::DeviceType::None);
2071}
2072
2073mlir::Operation::operand_range
2074KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2075 return getValuesFromSegments(arrayAttr: getNumGangsDeviceType(), range: getNumGangs(),
2076 segments: getNumGangsSegments(), deviceType);
2077}
2078
2079bool acc::KernelsOp::hasWaitOnly() {
2080 return hasWaitOnly(deviceType: mlir::acc::DeviceType::None);
2081}
2082
2083bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2084 return hasDeviceType(arrayAttr: getWaitOnly(), deviceType);
2085}
2086
2087mlir::Operation::operand_range KernelsOp::getWaitValues() {
2088 return getWaitValues(deviceType: mlir::acc::DeviceType::None);
2089}
2090
2091mlir::Operation::operand_range
2092KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2093 return getWaitValuesWithoutDevnum(
2094 deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(), segments: getWaitOperandsSegments(),
2095 hasWaitDevnum: getHasWaitDevnum(), deviceType);
2096}
2097
2098mlir::Value KernelsOp::getWaitDevnum() {
2099 return getWaitDevnum(deviceType: mlir::acc::DeviceType::None);
2100}
2101
2102mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2103 return getWaitDevnumValue(deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(),
2104 segments: getWaitOperandsSegments(), hasWaitDevnum: getHasWaitDevnum(),
2105 deviceType);
2106}
2107
2108LogicalResult acc::KernelsOp::verify() {
2109 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
2110 op: *this, operands: getNumGangs(), segments: getNumGangsSegmentsAttr(),
2111 deviceTypes: getNumGangsDeviceTypeAttr(), keyword: "num_gangs", maxInSegment: 3)))
2112 return failure();
2113
2114 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
2115 op: *this, operands: getWaitOperands(), segments: getWaitOperandsSegmentsAttr(),
2116 deviceTypes: getWaitOperandsDeviceTypeAttr(), keyword: "wait")))
2117 return failure();
2118
2119 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getNumWorkers(),
2120 deviceTypes: getNumWorkersDeviceTypeAttr(),
2121 keyword: "num_workers")))
2122 return failure();
2123
2124 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getVectorLength(),
2125 deviceTypes: getVectorLengthDeviceTypeAttr(),
2126 keyword: "vector_length")))
2127 return failure();
2128
2129 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getAsyncOperands(),
2130 deviceTypes: getAsyncOperandsDeviceTypeAttr(),
2131 keyword: "async")))
2132 return failure();
2133
2134 if (failed(Result: checkWaitAndAsyncConflict<acc::KernelsOp>(op: *this)))
2135 return failure();
2136
2137 return checkDataOperands<acc::KernelsOp>(op: *this, operands: getDataClauseOperands());
2138}
2139
2140void acc::KernelsOp::addNumWorkersOperand(
2141 MLIRContext *context, mlir::Value newValue,
2142 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2143 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2144 context, existingDeviceTypes: getNumWorkersDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
2145 argCollection: getNumWorkersMutable()));
2146}
2147
2148void acc::KernelsOp::addVectorLengthOperand(
2149 MLIRContext *context, mlir::Value newValue,
2150 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2151 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2152 context, existingDeviceTypes: getVectorLengthDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
2153 argCollection: getVectorLengthMutable()));
2154}
2155void acc::KernelsOp::addAsyncOnly(
2156 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2157 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2158 context, existingDeviceTypes: getAsyncOnlyAttr(), newDeviceTypes: effectiveDeviceTypes));
2159}
2160
2161void acc::KernelsOp::addAsyncOperand(
2162 MLIRContext *context, mlir::Value newValue,
2163 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2164 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2165 context, existingDeviceTypes: getAsyncOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
2166 argCollection: getAsyncOperandsMutable()));
2167}
2168
2169void acc::KernelsOp::addNumGangsOperands(
2170 MLIRContext *context, mlir::ValueRange newValues,
2171 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2172 llvm::SmallVector<int32_t> segments;
2173 if (getNumGangsSegmentsAttr())
2174 llvm::copy(Range: *getNumGangsSegments(), Out: std::back_inserter(x&: segments));
2175
2176 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2177 context, existingDeviceTypes: getNumGangsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValues,
2178 argCollection: getNumGangsMutable(), segments));
2179
2180 setNumGangsSegments(segments);
2181}
2182
2183void acc::KernelsOp::addWaitOnly(
2184 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2185 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getWaitOnlyAttr(),
2186 newDeviceTypes: effectiveDeviceTypes));
2187}
2188void acc::KernelsOp::addWaitOperands(
2189 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2190 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2191
2192 llvm::SmallVector<int32_t> segments;
2193 if (getWaitOperandsSegments())
2194 llvm::copy(Range: *getWaitOperandsSegments(), Out: std::back_inserter(x&: segments));
2195
2196 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2197 context, existingDeviceTypes: getWaitOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValues,
2198 argCollection: getWaitOperandsMutable(), segments));
2199 setWaitOperandsSegments(segments);
2200
2201 llvm::SmallVector<mlir::Attribute> hasDevnums;
2202 if (getHasWaitDevnumAttr())
2203 llvm::copy(Range: getHasWaitDevnumAttr(), Out: std::back_inserter(x&: hasDevnums));
2204 hasDevnums.insert(
2205 I: hasDevnums.end(),
2206 NumToInsert: std::max(a: effectiveDeviceTypes.size(), b: static_cast<size_t>(1)),
2207 Elt: mlir::BoolAttr::get(context, value: hasDevnum));
2208 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, value: hasDevnums));
2209}
2210
2211//===----------------------------------------------------------------------===//
2212// HostDataOp
2213//===----------------------------------------------------------------------===//
2214
2215LogicalResult acc::HostDataOp::verify() {
2216 if (getDataClauseOperands().empty())
2217 return emitError(message: "at least one operand must appear on the host_data "
2218 "operation");
2219
2220 for (mlir::Value operand : getDataClauseOperands())
2221 if (!mlir::isa<acc::UseDeviceOp>(Val: operand.getDefiningOp()))
2222 return emitError(message: "expect data entry operation as defining op");
2223 return success();
2224}
2225
2226void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2227 MLIRContext *context) {
2228 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(arg&: context);
2229}
2230
2231//===----------------------------------------------------------------------===//
2232// LoopOp
2233//===----------------------------------------------------------------------===//
2234
2235static ParseResult parseGangValue(
2236 OpAsmParser &parser, llvm::StringRef keyword,
2237 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
2238 llvm::SmallVectorImpl<Type> &types,
2239 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2240 bool &needCommaBetweenValues, bool &newValue) {
2241 if (succeeded(Result: parser.parseOptionalKeyword(keyword))) {
2242 if (parser.parseEqual())
2243 return failure();
2244 if (parser.parseOperand(result&: operands.emplace_back()) ||
2245 parser.parseColonType(result&: types.emplace_back()))
2246 return failure();
2247 attributes.push_back(Elt: gangArgType);
2248 needCommaBetweenValues = true;
2249 newValue = true;
2250 }
2251 return success();
2252}
2253
2254static ParseResult parseGangClause(
2255 OpAsmParser &parser,
2256 llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands,
2257 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2258 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2259 mlir::ArrayAttr &gangOnlyDeviceType) {
2260 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2261 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2262 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2263 llvm::SmallVector<int32_t> seg;
2264 bool needCommaBetweenValues = false;
2265 bool needCommaBeforeOperands = false;
2266
2267 if (failed(Result: parser.parseOptionalLParen())) {
2268 // Gang only keyword
2269 gangOnlyDeviceTypeAttributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
2270 context: parser.getContext(), value: mlir::acc::DeviceType::None));
2271 gangOnlyDeviceType =
2272 ArrayAttr::get(context: parser.getContext(), value: gangOnlyDeviceTypeAttributes);
2273 return success();
2274 }
2275
2276 // Parse gang only attributes
2277 if (succeeded(Result: parser.parseOptionalLSquare())) {
2278 // Parse gang only attributes
2279 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
2280 if (parser.parseAttribute(
2281 result&: gangOnlyDeviceTypeAttributes.emplace_back()))
2282 return failure();
2283 return success();
2284 })))
2285 return failure();
2286 if (parser.parseRSquare())
2287 return failure();
2288 needCommaBeforeOperands = true;
2289 }
2290
2291 auto argNum = mlir::acc::GangArgTypeAttr::get(context: parser.getContext(),
2292 value: mlir::acc::GangArgType::Num);
2293 auto argDim = mlir::acc::GangArgTypeAttr::get(context: parser.getContext(),
2294 value: mlir::acc::GangArgType::Dim);
2295 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2296 context: parser.getContext(), value: mlir::acc::GangArgType::Static);
2297
2298 do {
2299 if (needCommaBeforeOperands) {
2300 needCommaBeforeOperands = false;
2301 continue;
2302 }
2303
2304 if (failed(Result: parser.parseLBrace()))
2305 return failure();
2306
2307 int32_t crtOperandsSize = gangOperands.size();
2308 while (true) {
2309 bool newValue = false;
2310 bool needValue = false;
2311 if (needCommaBetweenValues) {
2312 if (succeeded(Result: parser.parseOptionalComma()))
2313 needValue = true; // expect a new value after comma.
2314 else
2315 break;
2316 }
2317
2318 if (failed(Result: parseGangValue(parser, keyword: LoopOp::getGangNumKeyword(),
2319 operands&: gangOperands, types&: gangOperandsType,
2320 attributes&: gangArgTypeAttributes, gangArgType: argNum,
2321 needCommaBetweenValues, newValue)))
2322 return failure();
2323 if (failed(Result: parseGangValue(parser, keyword: LoopOp::getGangDimKeyword(),
2324 operands&: gangOperands, types&: gangOperandsType,
2325 attributes&: gangArgTypeAttributes, gangArgType: argDim,
2326 needCommaBetweenValues, newValue)))
2327 return failure();
2328 if (failed(Result: parseGangValue(parser, keyword: LoopOp::getGangStaticKeyword(),
2329 operands&: gangOperands, types&: gangOperandsType,
2330 attributes&: gangArgTypeAttributes, gangArgType: argStatic,
2331 needCommaBetweenValues, newValue)))
2332 return failure();
2333
2334 if (!newValue && needValue) {
2335 parser.emitError(loc: parser.getCurrentLocation(),
2336 message: "new value expected after comma");
2337 return failure();
2338 }
2339
2340 if (!newValue)
2341 break;
2342 }
2343
2344 if (gangOperands.empty())
2345 return parser.emitError(
2346 loc: parser.getCurrentLocation(),
2347 message: "expect at least one of num, dim or static values");
2348
2349 if (failed(Result: parser.parseRBrace()))
2350 return failure();
2351
2352 if (succeeded(Result: parser.parseOptionalLSquare())) {
2353 if (parser.parseAttribute(result&: deviceTypeAttributes.emplace_back()) ||
2354 parser.parseRSquare())
2355 return failure();
2356 } else {
2357 deviceTypeAttributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
2358 context: parser.getContext(), value: mlir::acc::DeviceType::None));
2359 }
2360
2361 seg.push_back(Elt: gangOperands.size() - crtOperandsSize);
2362
2363 } while (succeeded(Result: parser.parseOptionalComma()));
2364
2365 if (failed(Result: parser.parseRParen()))
2366 return failure();
2367
2368 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
2369 gangArgTypeAttributes.end());
2370 gangArgType = ArrayAttr::get(context: parser.getContext(), value: arrayAttr);
2371 deviceType = ArrayAttr::get(context: parser.getContext(), value: deviceTypeAttributes);
2372
2373 llvm::SmallVector<mlir::Attribute> gangOnlyAttr(
2374 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2375 gangOnlyDeviceType = ArrayAttr::get(context: parser.getContext(), value: gangOnlyAttr);
2376
2377 segments = DenseI32ArrayAttr::get(context: parser.getContext(), content: seg);
2378 return success();
2379}
2380
2381void printGangClause(OpAsmPrinter &p, Operation *op,
2382 mlir::OperandRange operands, mlir::TypeRange types,
2383 std::optional<mlir::ArrayAttr> gangArgTypes,
2384 std::optional<mlir::ArrayAttr> deviceTypes,
2385 std::optional<mlir::DenseI32ArrayAttr> segments,
2386 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2387
2388 if (operands.begin() == operands.end() &&
2389 hasOnlyDeviceTypeNone(attrs: gangOnlyDeviceTypes)) {
2390 return;
2391 }
2392
2393 p << "(";
2394
2395 printDeviceTypes(p, deviceTypes: gangOnlyDeviceTypes);
2396
2397 if (hasDeviceTypeValues(arrayAttr: gangOnlyDeviceTypes) &&
2398 hasDeviceTypeValues(arrayAttr: deviceTypes))
2399 p << ", ";
2400
2401 if (hasDeviceTypeValues(arrayAttr: deviceTypes)) {
2402 unsigned opIdx = 0;
2403 llvm::interleaveComma(c: llvm::enumerate(First&: *deviceTypes), os&: p, each_fn: [&](auto it) {
2404 p << "{";
2405 llvm::interleaveComma(
2406 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2407 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2408 Val: (*gangArgTypes)[opIdx]);
2409 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2410 p << LoopOp::getGangNumKeyword();
2411 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2412 p << LoopOp::getGangDimKeyword();
2413 else if (gangArgTypeAttr.getValue() ==
2414 mlir::acc::GangArgType::Static)
2415 p << LoopOp::getGangStaticKeyword();
2416 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
2417 ++opIdx;
2418 });
2419 p << "}";
2420 printSingleDeviceType(p, it.value());
2421 });
2422 }
2423 p << ")";
2424}
2425
2426bool hasDuplicateDeviceTypes(
2427 std::optional<mlir::ArrayAttr> segments,
2428 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2429 if (!segments)
2430 return false;
2431 for (auto attr : *segments) {
2432 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val&: attr);
2433 if (!deviceTypes.insert(V: deviceTypeAttr.getValue()).second)
2434 return true;
2435 }
2436 return false;
2437}
2438
2439/// Check for duplicates in the DeviceType array attribute.
2440LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
2441 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2442 if (!deviceTypes)
2443 return success();
2444 for (auto attr : deviceTypes) {
2445 auto deviceTypeAttr =
2446 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(Val&: attr);
2447 if (!deviceTypeAttr)
2448 return failure();
2449 if (!crtDeviceTypes.insert(V: deviceTypeAttr.getValue()).second)
2450 return failure();
2451 }
2452 return success();
2453}
2454
2455LogicalResult acc::LoopOp::verify() {
2456 if (getUpperbound().size() != getStep().size())
2457 return emitError() << "number of upperbounds expected to be the same as "
2458 "number of steps";
2459
2460 if (getUpperbound().size() != getLowerbound().size())
2461 return emitError() << "number of upperbounds expected to be the same as "
2462 "number of lowerbounds";
2463
2464 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2465 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2466 return emitError() << "inclusiveUpperbound size is expected to be the same"
2467 << " as upperbound size";
2468
2469 // Check collapse
2470 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2471 return emitOpError() << "collapse device_type attr must be define when"
2472 << " collapse attr is present";
2473
2474 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2475 getCollapseAttr().getValue().size() !=
2476 getCollapseDeviceTypeAttr().getValue().size())
2477 return emitOpError() << "collapse attribute count must match collapse"
2478 << " device_type count";
2479 if (failed(Result: checkDeviceTypes(deviceTypes: getCollapseDeviceTypeAttr())))
2480 return emitOpError()
2481 << "duplicate device_type found in collapseDeviceType attribute";
2482
2483 // Check gang
2484 if (!getGangOperands().empty()) {
2485 if (!getGangOperandsArgType())
2486 return emitOpError() << "gangOperandsArgType attribute must be defined"
2487 << " when gang operands are present";
2488
2489 if (getGangOperands().size() !=
2490 getGangOperandsArgTypeAttr().getValue().size())
2491 return emitOpError() << "gangOperandsArgType attribute count must match"
2492 << " gangOperands count";
2493 }
2494 if (getGangAttr() && failed(Result: checkDeviceTypes(deviceTypes: getGangAttr())))
2495 return emitOpError() << "duplicate device_type found in gang attribute";
2496
2497 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
2498 op: *this, operands: getGangOperands(), segments: getGangOperandsSegmentsAttr(),
2499 deviceTypes: getGangOperandsDeviceTypeAttr(), keyword: "gang")))
2500 return failure();
2501
2502 // Check worker
2503 if (failed(Result: checkDeviceTypes(deviceTypes: getWorkerAttr())))
2504 return emitOpError() << "duplicate device_type found in worker attribute";
2505 if (failed(Result: checkDeviceTypes(deviceTypes: getWorkerNumOperandsDeviceTypeAttr())))
2506 return emitOpError() << "duplicate device_type found in "
2507 "workerNumOperandsDeviceType attribute";
2508 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getWorkerNumOperands(),
2509 deviceTypes: getWorkerNumOperandsDeviceTypeAttr(),
2510 keyword: "worker")))
2511 return failure();
2512
2513 // Check vector
2514 if (failed(Result: checkDeviceTypes(deviceTypes: getVectorAttr())))
2515 return emitOpError() << "duplicate device_type found in vector attribute";
2516 if (failed(Result: checkDeviceTypes(deviceTypes: getVectorOperandsDeviceTypeAttr())))
2517 return emitOpError() << "duplicate device_type found in "
2518 "vectorOperandsDeviceType attribute";
2519 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getVectorOperands(),
2520 deviceTypes: getVectorOperandsDeviceTypeAttr(),
2521 keyword: "vector")))
2522 return failure();
2523
2524 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
2525 op: *this, operands: getTileOperands(), segments: getTileOperandsSegmentsAttr(),
2526 deviceTypes: getTileOperandsDeviceTypeAttr(), keyword: "tile")))
2527 return failure();
2528
2529 // auto, independent and seq attribute are mutually exclusive.
2530 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2531 if (hasDuplicateDeviceTypes(segments: getAuto_(), deviceTypes) ||
2532 hasDuplicateDeviceTypes(segments: getIndependent(), deviceTypes) ||
2533 hasDuplicateDeviceTypes(segments: getSeq(), deviceTypes)) {
2534 return emitError() << "only one of auto, independent, seq can be present "
2535 "at the same time";
2536 }
2537
2538 // Check that at least one of auto, independent, or seq is present
2539 // for the device-independent default clauses.
2540 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
2541 return attr.getValue() == mlir::acc::DeviceType::None;
2542 };
2543 bool hasDefaultSeq =
2544 getSeqAttr()
2545 ? llvm::any_of(Range: getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2546 P: hasDeviceNone)
2547 : false;
2548 bool hasDefaultIndependent =
2549 getIndependentAttr()
2550 ? llvm::any_of(
2551 Range: getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2552 P: hasDeviceNone)
2553 : false;
2554 bool hasDefaultAuto =
2555 getAuto_Attr()
2556 ? llvm::any_of(Range: getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2557 P: hasDeviceNone)
2558 : false;
2559 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
2560 return emitError()
2561 << "at least one of auto, independent, seq must be present";
2562 }
2563
2564 // Gang, worker and vector are incompatible with seq.
2565 if (getSeqAttr()) {
2566 for (auto attr : getSeqAttr()) {
2567 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val&: attr);
2568 if (hasVector(deviceType: deviceTypeAttr.getValue()) ||
2569 getVectorValue(deviceType: deviceTypeAttr.getValue()) ||
2570 hasWorker(deviceType: deviceTypeAttr.getValue()) ||
2571 getWorkerValue(deviceType: deviceTypeAttr.getValue()) ||
2572 hasGang(deviceType: deviceTypeAttr.getValue()) ||
2573 getGangValue(gangArgType: mlir::acc::GangArgType::Num,
2574 deviceType: deviceTypeAttr.getValue()) ||
2575 getGangValue(gangArgType: mlir::acc::GangArgType::Dim,
2576 deviceType: deviceTypeAttr.getValue()) ||
2577 getGangValue(gangArgType: mlir::acc::GangArgType::Static,
2578 deviceType: deviceTypeAttr.getValue()))
2579 return emitError() << "gang, worker or vector cannot appear with seq";
2580 }
2581 }
2582
2583 if (failed(Result: checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2584 op: *this, attributes: getPrivatizationRecipes(), operands: getPrivateOperands(), operandName: "private",
2585 symbolName: "privatizations", checkOperandType: false)))
2586 return failure();
2587
2588 if (failed(Result: checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2589 op: *this, attributes: getReductionRecipes(), operands: getReductionOperands(), operandName: "reduction",
2590 symbolName: "reductions", checkOperandType: false)))
2591 return failure();
2592
2593 if (getCombined().has_value() &&
2594 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2595 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2596 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2597 return emitError(message: "unexpected combined constructs attribute");
2598 }
2599
2600 // Check non-empty body().
2601 if (getRegion().empty())
2602 return emitError(message: "expected non-empty body.");
2603
2604 // When it is container-like - it is expected to hold a loop-like operation.
2605 if (isContainerLike()) {
2606 // Obtain the maximum collapse count - we use this to check that there
2607 // are enough loops contained.
2608 uint64_t collapseCount = getCollapseValue().value_or(u: 1);
2609 if (getCollapseAttr()) {
2610 for (auto collapseEntry : getCollapseAttr()) {
2611 auto intAttr = mlir::dyn_cast<IntegerAttr>(Val&: collapseEntry);
2612 if (intAttr.getValue().getZExtValue() > collapseCount)
2613 collapseCount = intAttr.getValue().getZExtValue();
2614 }
2615 }
2616
2617 // We want to check that we find enough loop-like operations inside.
2618 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
2619 // level.
2620 mlir::Operation *expectedParent = this->getOperation();
2621 bool foundSibling = false;
2622 getRegion().walk<WalkOrder::PreOrder>(callback: [&](mlir::Operation *op) {
2623 if (mlir::isa<mlir::LoopLikeOpInterface>(Val: op)) {
2624 // This effectively checks that we are not looking at a sibling loop.
2625 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2626 expectedParent) {
2627 foundSibling = true;
2628 return mlir::WalkResult::interrupt();
2629 }
2630
2631 collapseCount--;
2632 expectedParent = op;
2633 }
2634 // We found enough contained loops.
2635 if (collapseCount == 0)
2636 return mlir::WalkResult::interrupt();
2637 return mlir::WalkResult::advance();
2638 });
2639
2640 if (foundSibling)
2641 return emitError(message: "found sibling loops inside container-like acc.loop");
2642 if (collapseCount != 0)
2643 return emitError(message: "failed to find enough loop-like operations inside "
2644 "container-like acc.loop");
2645 }
2646
2647 return success();
2648}
2649
2650unsigned LoopOp::getNumDataOperands() {
2651 return getReductionOperands().size() + getPrivateOperands().size();
2652}
2653
2654Value LoopOp::getDataOperand(unsigned i) {
2655 unsigned numOptional =
2656 getLowerbound().size() + getUpperbound().size() + getStep().size();
2657 numOptional += getGangOperands().size();
2658 numOptional += getVectorOperands().size();
2659 numOptional += getWorkerNumOperands().size();
2660 numOptional += getTileOperands().size();
2661 numOptional += getCacheOperands().size();
2662 return getOperand(i: numOptional + i);
2663}
2664
2665bool LoopOp::hasAuto() { return hasAuto(deviceType: mlir::acc::DeviceType::None); }
2666
2667bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2668 return hasDeviceType(arrayAttr: getAuto_(), deviceType);
2669}
2670
2671bool LoopOp::hasIndependent() {
2672 return hasIndependent(deviceType: mlir::acc::DeviceType::None);
2673}
2674
2675bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2676 return hasDeviceType(arrayAttr: getIndependent(), deviceType);
2677}
2678
2679bool LoopOp::hasSeq() { return hasSeq(deviceType: mlir::acc::DeviceType::None); }
2680
2681bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2682 return hasDeviceType(arrayAttr: getSeq(), deviceType);
2683}
2684
2685mlir::Value LoopOp::getVectorValue() {
2686 return getVectorValue(deviceType: mlir::acc::DeviceType::None);
2687}
2688
2689mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2690 return getValueInDeviceTypeSegment(arrayAttr: getVectorOperandsDeviceType(),
2691 range: getVectorOperands(), deviceType);
2692}
2693
2694bool LoopOp::hasVector() { return hasVector(deviceType: mlir::acc::DeviceType::None); }
2695
2696bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2697 return hasDeviceType(arrayAttr: getVector(), deviceType);
2698}
2699
2700mlir::Value LoopOp::getWorkerValue() {
2701 return getWorkerValue(deviceType: mlir::acc::DeviceType::None);
2702}
2703
2704mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2705 return getValueInDeviceTypeSegment(arrayAttr: getWorkerNumOperandsDeviceType(),
2706 range: getWorkerNumOperands(), deviceType);
2707}
2708
2709bool LoopOp::hasWorker() { return hasWorker(deviceType: mlir::acc::DeviceType::None); }
2710
2711bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2712 return hasDeviceType(arrayAttr: getWorker(), deviceType);
2713}
2714
2715mlir::Operation::operand_range LoopOp::getTileValues() {
2716 return getTileValues(deviceType: mlir::acc::DeviceType::None);
2717}
2718
2719mlir::Operation::operand_range
2720LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2721 return getValuesFromSegments(arrayAttr: getTileOperandsDeviceType(), range: getTileOperands(),
2722 segments: getTileOperandsSegments(), deviceType);
2723}
2724
2725std::optional<int64_t> LoopOp::getCollapseValue() {
2726 return getCollapseValue(deviceType: mlir::acc::DeviceType::None);
2727}
2728
2729std::optional<int64_t>
2730LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2731 if (!getCollapseAttr())
2732 return std::nullopt;
2733 if (auto pos = findSegment(segments: getCollapseDeviceTypeAttr(), deviceType)) {
2734 auto intAttr =
2735 mlir::dyn_cast<IntegerAttr>(Val: getCollapseAttr().getValue()[*pos]);
2736 return intAttr.getValue().getZExtValue();
2737 }
2738 return std::nullopt;
2739}
2740
2741mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2742 return getGangValue(gangArgType, deviceType: mlir::acc::DeviceType::None);
2743}
2744
2745mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2746 mlir::acc::DeviceType deviceType) {
2747 if (getGangOperands().empty())
2748 return {};
2749 if (auto pos = findSegment(segments: *getGangOperandsDeviceType(), deviceType)) {
2750 int32_t nbOperandsBefore = 0;
2751 for (unsigned i = 0; i < *pos; ++i)
2752 nbOperandsBefore += (*getGangOperandsSegments())[i];
2753 mlir::Operation::operand_range values =
2754 getGangOperands()
2755 .drop_front(n: nbOperandsBefore)
2756 .take_front(n: (*getGangOperandsSegments())[*pos]);
2757
2758 int32_t argTypeIdx = nbOperandsBefore;
2759 for (auto value : values) {
2760 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2761 Val: (*getGangOperandsArgType())[argTypeIdx]);
2762 if (gangArgTypeAttr.getValue() == gangArgType)
2763 return value;
2764 ++argTypeIdx;
2765 }
2766 }
2767 return {};
2768}
2769
2770bool LoopOp::hasGang() { return hasGang(deviceType: mlir::acc::DeviceType::None); }
2771
2772bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2773 return hasDeviceType(arrayAttr: getGang(), deviceType);
2774}
2775
2776llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
2777 return {&getRegion()};
2778}
2779
2780/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
2781/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
2782/// `(` ssa-id-and-type-list `)`
2783/// region
2784ParseResult
2785parseLoopControl(OpAsmParser &parser, Region &region,
2786 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound,
2787 SmallVectorImpl<Type> &lowerboundType,
2788 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound,
2789 SmallVectorImpl<Type> &upperboundType,
2790 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step,
2791 SmallVectorImpl<Type> &stepType) {
2792
2793 SmallVector<OpAsmParser::Argument> inductionVars;
2794 if (succeeded(
2795 Result: parser.parseOptionalKeyword(keyword: acc::LoopOp::getControlKeyword()))) {
2796 if (parser.parseLParen() ||
2797 parser.parseArgumentList(result&: inductionVars, delimiter: OpAsmParser::Delimiter::None,
2798 /*allowType=*/true) ||
2799 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
2800 parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(),
2801 delimiter: OpAsmParser::Delimiter::None) ||
2802 parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() ||
2803 parser.parseKeyword(keyword: "to") || parser.parseLParen() ||
2804 parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(),
2805 delimiter: OpAsmParser::Delimiter::None) ||
2806 parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() ||
2807 parser.parseKeyword(keyword: "step") || parser.parseLParen() ||
2808 parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(),
2809 delimiter: OpAsmParser::Delimiter::None) ||
2810 parser.parseColonTypeList(result&: stepType) || parser.parseRParen())
2811 return failure();
2812 }
2813 return parser.parseRegion(region, arguments: inductionVars);
2814}
2815
2816void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
2817 ValueRange lowerbound, TypeRange lowerboundType,
2818 ValueRange upperbound, TypeRange upperboundType,
2819 ValueRange steps, TypeRange stepType) {
2820 ValueRange regionArgs = region.front().getArguments();
2821 if (!regionArgs.empty()) {
2822 p << acc::LoopOp::getControlKeyword() << "(";
2823 llvm::interleaveComma(c: regionArgs, os&: p,
2824 each_fn: [&p](Value v) { p << v << " : " << v.getType(); });
2825 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
2826 << upperbound << " : " << upperboundType << ") " << " step (" << steps
2827 << " : " << stepType << ") ";
2828 }
2829 p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false);
2830}
2831
2832void acc::LoopOp::addSeq(MLIRContext *context,
2833 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2834 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getSeqAttr(),
2835 newDeviceTypes: effectiveDeviceTypes));
2836}
2837
2838void acc::LoopOp::addIndependent(
2839 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2840 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2841 context, existingDeviceTypes: getIndependentAttr(), newDeviceTypes: effectiveDeviceTypes));
2842}
2843
2844void acc::LoopOp::addAuto(MLIRContext *context,
2845 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2846 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getAuto_Attr(),
2847 newDeviceTypes: effectiveDeviceTypes));
2848}
2849
2850void acc::LoopOp::setCollapseForDeviceTypes(
2851 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2852 llvm::APInt value) {
2853 llvm::SmallVector<mlir::Attribute> newValues;
2854 llvm::SmallVector<mlir::Attribute> newDeviceTypes;
2855
2856 assert((getCollapseAttr() == nullptr) ==
2857 (getCollapseDeviceTypeAttr() == nullptr));
2858 assert(value.getBitWidth() == 64);
2859
2860 if (getCollapseAttr()) {
2861 for (const auto &existing :
2862 llvm::zip_equal(t: getCollapseAttr(), u: getCollapseDeviceTypeAttr())) {
2863 newValues.push_back(Elt: std::get<0>(t: existing));
2864 newDeviceTypes.push_back(Elt: std::get<1>(t: existing));
2865 }
2866 }
2867
2868 if (effectiveDeviceTypes.empty()) {
2869 // If the effective device-types list is empty, this is before there are any
2870 // being applied by device_type, so this should be added as a 'none'.
2871 newValues.push_back(
2872 Elt: mlir::IntegerAttr::get(type: mlir::IntegerType::get(context, width: 64), value));
2873 newDeviceTypes.push_back(
2874 Elt: acc::DeviceTypeAttr::get(context, value: DeviceType::None));
2875 } else {
2876 for (DeviceType DT : effectiveDeviceTypes) {
2877 newValues.push_back(
2878 Elt: mlir::IntegerAttr::get(type: mlir::IntegerType::get(context, width: 64), value));
2879 newDeviceTypes.push_back(Elt: acc::DeviceTypeAttr::get(context, value: DT));
2880 }
2881 }
2882
2883 setCollapseAttr(ArrayAttr::get(context, value: newValues));
2884 setCollapseDeviceTypeAttr(ArrayAttr::get(context, value: newDeviceTypes));
2885}
2886
2887void acc::LoopOp::setTileForDeviceTypes(
2888 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2889 ValueRange values) {
2890 llvm::SmallVector<int32_t> segments;
2891 if (getTileOperandsSegments())
2892 llvm::copy(Range: *getTileOperandsSegments(), Out: std::back_inserter(x&: segments));
2893
2894 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2895 context, existingDeviceTypes: getTileOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: values,
2896 argCollection: getTileOperandsMutable(), segments));
2897
2898 setTileOperandsSegments(segments);
2899}
2900
2901void acc::LoopOp::addVectorOperand(
2902 MLIRContext *context, mlir::Value newValue,
2903 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2904 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2905 context, existingDeviceTypes: getVectorOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes,
2906 arguments: newValue, argCollection: getVectorOperandsMutable()));
2907}
2908
2909void acc::LoopOp::addEmptyVector(
2910 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2911 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getVectorAttr(),
2912 newDeviceTypes: effectiveDeviceTypes));
2913}
2914
2915void acc::LoopOp::addWorkerNumOperand(
2916 MLIRContext *context, mlir::Value newValue,
2917 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2918 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2919 context, existingDeviceTypes: getWorkerNumOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes,
2920 arguments: newValue, argCollection: getWorkerNumOperandsMutable()));
2921}
2922
2923void acc::LoopOp::addEmptyWorker(
2924 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2925 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getWorkerAttr(),
2926 newDeviceTypes: effectiveDeviceTypes));
2927}
2928
2929void acc::LoopOp::addEmptyGang(
2930 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2931 setGangAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getGangAttr(),
2932 newDeviceTypes: effectiveDeviceTypes));
2933}
2934
2935bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
2936 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
2937 return attr.getValue() == dt;
2938 };
2939 auto testFromArr = [=](ArrayAttr arr) -> bool {
2940 return llvm::any_of(Range: arr.getAsRange<DeviceTypeAttr>(), P: hasDevice);
2941 };
2942
2943 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
2944 return true;
2945 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
2946 return true;
2947 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
2948 return true;
2949
2950 return false;
2951}
2952
2953bool acc::LoopOp::hasDefaultGangWorkerVector() {
2954 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
2955 hasGang() || getGangValue(gangArgType: GangArgType::Num) ||
2956 getGangValue(gangArgType: GangArgType::Dim) || getGangValue(gangArgType: GangArgType::Static);
2957}
2958
2959void acc::LoopOp::addGangOperands(
2960 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
2961 llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
2962 llvm::SmallVector<int32_t> segments;
2963 if (std::optional<ArrayRef<int32_t>> existingSegments =
2964 getGangOperandsSegments())
2965 llvm::copy(Range&: *existingSegments, Out: std::back_inserter(x&: segments));
2966
2967 unsigned beforeCount = segments.size();
2968
2969 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2970 context, existingDeviceTypes: getGangOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: values,
2971 argCollection: getGangOperandsMutable(), segments));
2972
2973 setGangOperandsSegments(segments);
2974
2975 // This is a bit of extra work to make sure we update the 'types' correctly by
2976 // adding to the types collection the correct number of times. We could
2977 // potentially add something similar to the
2978 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
2979 // excessive for a one-off case.
2980 unsigned numAdded = segments.size() - beforeCount;
2981
2982 if (numAdded > 0) {
2983 llvm::SmallVector<mlir::Attribute> gangTypes;
2984 if (getGangOperandsArgTypeAttr())
2985 llvm::copy(Range: getGangOperandsArgTypeAttr(), Out: std::back_inserter(x&: gangTypes));
2986
2987 for (auto i : llvm::index_range(0u, numAdded)) {
2988 llvm::transform(Range&: argTypes, d_first: std::back_inserter(x&: gangTypes),
2989 F: [=](mlir::acc::GangArgType gangTy) {
2990 return mlir::acc::GangArgTypeAttr::get(context, value: gangTy);
2991 });
2992 (void)i;
2993 }
2994
2995 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, value: gangTypes));
2996 }
2997}
2998
2999//===----------------------------------------------------------------------===//
3000// DataOp
3001//===----------------------------------------------------------------------===//
3002
3003LogicalResult acc::DataOp::verify() {
3004 // 2.6.5. Data Construct restriction
3005 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3006 // attach, or default clause must appear on a data construct.
3007 if (getOperands().empty() && !getDefaultAttr())
3008 return emitError(message: "at least one operand or the default attribute "
3009 "must appear on the data operation");
3010
3011 for (mlir::Value operand : getDataClauseOperands())
3012 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3013 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3014 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3015 Val: operand.getDefiningOp()))
3016 return emitError(message: "expect data entry/exit operation or acc.getdeviceptr "
3017 "as defining op");
3018
3019 if (failed(Result: checkWaitAndAsyncConflict<acc::DataOp>(op: *this)))
3020 return failure();
3021
3022 return success();
3023}
3024
3025unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3026
3027Value DataOp::getDataOperand(unsigned i) {
3028 unsigned numOptional = getIfCond() ? 1 : 0;
3029 numOptional += getAsyncOperands().size() ? 1 : 0;
3030 numOptional += getWaitOperands().size();
3031 return getOperand(i: numOptional + i);
3032}
3033
3034bool acc::DataOp::hasAsyncOnly() {
3035 return hasAsyncOnly(deviceType: mlir::acc::DeviceType::None);
3036}
3037
3038bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3039 return hasDeviceType(arrayAttr: getAsyncOnly(), deviceType);
3040}
3041
3042mlir::Value DataOp::getAsyncValue() {
3043 return getAsyncValue(deviceType: mlir::acc::DeviceType::None);
3044}
3045
3046mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3047 return getValueInDeviceTypeSegment(arrayAttr: getAsyncOperandsDeviceType(),
3048 range: getAsyncOperands(), deviceType);
3049}
3050
3051bool DataOp::hasWaitOnly() { return hasWaitOnly(deviceType: mlir::acc::DeviceType::None); }
3052
3053bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3054 return hasDeviceType(arrayAttr: getWaitOnly(), deviceType);
3055}
3056
3057mlir::Operation::operand_range DataOp::getWaitValues() {
3058 return getWaitValues(deviceType: mlir::acc::DeviceType::None);
3059}
3060
3061mlir::Operation::operand_range
3062DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3063 return getWaitValuesWithoutDevnum(
3064 deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(), segments: getWaitOperandsSegments(),
3065 hasWaitDevnum: getHasWaitDevnum(), deviceType);
3066}
3067
3068mlir::Value DataOp::getWaitDevnum() {
3069 return getWaitDevnum(deviceType: mlir::acc::DeviceType::None);
3070}
3071
3072mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3073 return getWaitDevnumValue(deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(),
3074 segments: getWaitOperandsSegments(), hasWaitDevnum: getHasWaitDevnum(),
3075 deviceType);
3076}
3077
3078void acc::DataOp::addAsyncOnly(
3079 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3080 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3081 context, existingDeviceTypes: getAsyncOnlyAttr(), newDeviceTypes: effectiveDeviceTypes));
3082}
3083
3084void acc::DataOp::addAsyncOperand(
3085 MLIRContext *context, mlir::Value newValue,
3086 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3087 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3088 context, existingDeviceTypes: getAsyncOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
3089 argCollection: getAsyncOperandsMutable()));
3090}
3091
3092void acc::DataOp::addWaitOnly(MLIRContext *context,
3093 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3094 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getWaitOnlyAttr(),
3095 newDeviceTypes: effectiveDeviceTypes));
3096}
3097
3098void acc::DataOp::addWaitOperands(
3099 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3100 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3101
3102 llvm::SmallVector<int32_t> segments;
3103 if (getWaitOperandsSegments())
3104 llvm::copy(Range: *getWaitOperandsSegments(), Out: std::back_inserter(x&: segments));
3105
3106 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3107 context, existingDeviceTypes: getWaitOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValues,
3108 argCollection: getWaitOperandsMutable(), segments));
3109 setWaitOperandsSegments(segments);
3110
3111 llvm::SmallVector<mlir::Attribute> hasDevnums;
3112 if (getHasWaitDevnumAttr())
3113 llvm::copy(Range: getHasWaitDevnumAttr(), Out: std::back_inserter(x&: hasDevnums));
3114 hasDevnums.insert(
3115 I: hasDevnums.end(),
3116 NumToInsert: std::max(a: effectiveDeviceTypes.size(), b: static_cast<size_t>(1)),
3117 Elt: mlir::BoolAttr::get(context, value: hasDevnum));
3118 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, value: hasDevnums));
3119}
3120
3121//===----------------------------------------------------------------------===//
3122// ExitDataOp
3123//===----------------------------------------------------------------------===//
3124
3125LogicalResult acc::ExitDataOp::verify() {
3126 // 2.6.6. Data Exit Directive restriction
3127 // At least one copyout, delete, or detach clause must appear on an exit data
3128 // directive.
3129 if (getDataClauseOperands().empty())
3130 return emitError(message: "at least one operand must be present in dataOperands on "
3131 "the exit data operation");
3132
3133 // The async attribute represent the async clause without value. Therefore the
3134 // attribute and operand cannot appear at the same time.
3135 if (getAsyncOperand() && getAsync())
3136 return emitError(message: "async attribute cannot appear with asyncOperand");
3137
3138 // The wait attribute represent the wait clause without values. Therefore the
3139 // attribute and operands cannot appear at the same time.
3140 if (!getWaitOperands().empty() && getWait())
3141 return emitError(message: "wait attribute cannot appear with waitOperands");
3142
3143 if (getWaitDevnum() && getWaitOperands().empty())
3144 return emitError(message: "wait_devnum cannot appear without waitOperands");
3145
3146 return success();
3147}
3148
3149unsigned ExitDataOp::getNumDataOperands() {
3150 return getDataClauseOperands().size();
3151}
3152
3153Value ExitDataOp::getDataOperand(unsigned i) {
3154 unsigned numOptional = getIfCond() ? 1 : 0;
3155 numOptional += getAsyncOperand() ? 1 : 0;
3156 numOptional += getWaitDevnum() ? 1 : 0;
3157 return getOperand(i: getWaitOperands().size() + numOptional + i);
3158}
3159
3160void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3161 MLIRContext *context) {
3162 results.add<RemoveConstantIfCondition<ExitDataOp>>(arg&: context);
3163}
3164
3165void ExitDataOp::addAsyncOnly(MLIRContext *context,
3166 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3167 assert(effectiveDeviceTypes.empty());
3168 assert(!getAsyncAttr());
3169 assert(!getAsyncOperand());
3170
3171 setAsyncAttr(mlir::UnitAttr::get(context));
3172}
3173
3174void ExitDataOp::addAsyncOperand(
3175 MLIRContext *context, mlir::Value newValue,
3176 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3177 assert(effectiveDeviceTypes.empty());
3178 assert(!getAsyncAttr());
3179 assert(!getAsyncOperand());
3180
3181 getAsyncOperandMutable().append(values: newValue);
3182}
3183
3184void ExitDataOp::addWaitOnly(MLIRContext *context,
3185 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3186 assert(effectiveDeviceTypes.empty());
3187 assert(!getWaitAttr());
3188 assert(getWaitOperands().empty());
3189 assert(!getWaitDevnum());
3190
3191 setWaitAttr(mlir::UnitAttr::get(context));
3192}
3193
3194void ExitDataOp::addWaitOperands(
3195 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3196 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3197 assert(effectiveDeviceTypes.empty());
3198 assert(!getWaitAttr());
3199 assert(getWaitOperands().empty());
3200 assert(!getWaitDevnum());
3201
3202 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3203 // operands list.
3204 if (hasDevnum) {
3205 getWaitDevnumMutable().append(values: newValues.front());
3206 newValues = newValues.drop_front();
3207 }
3208
3209 getWaitOperandsMutable().append(values: newValues);
3210}
3211
3212//===----------------------------------------------------------------------===//
3213// EnterDataOp
3214//===----------------------------------------------------------------------===//
3215
3216LogicalResult acc::EnterDataOp::verify() {
3217 // 2.6.6. Data Enter Directive restriction
3218 // At least one copyin, create, or attach clause must appear on an enter data
3219 // directive.
3220 if (getDataClauseOperands().empty())
3221 return emitError(message: "at least one operand must be present in dataOperands on "
3222 "the enter data operation");
3223
3224 // The async attribute represent the async clause without value. Therefore the
3225 // attribute and operand cannot appear at the same time.
3226 if (getAsyncOperand() && getAsync())
3227 return emitError(message: "async attribute cannot appear with asyncOperand");
3228
3229 // The wait attribute represent the wait clause without values. Therefore the
3230 // attribute and operands cannot appear at the same time.
3231 if (!getWaitOperands().empty() && getWait())
3232 return emitError(message: "wait attribute cannot appear with waitOperands");
3233
3234 if (getWaitDevnum() && getWaitOperands().empty())
3235 return emitError(message: "wait_devnum cannot appear without waitOperands");
3236
3237 for (mlir::Value operand : getDataClauseOperands())
3238 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3239 Val: operand.getDefiningOp()))
3240 return emitError(message: "expect data entry operation as defining op");
3241
3242 return success();
3243}
3244
3245unsigned EnterDataOp::getNumDataOperands() {
3246 return getDataClauseOperands().size();
3247}
3248
3249Value EnterDataOp::getDataOperand(unsigned i) {
3250 unsigned numOptional = getIfCond() ? 1 : 0;
3251 numOptional += getAsyncOperand() ? 1 : 0;
3252 numOptional += getWaitDevnum() ? 1 : 0;
3253 return getOperand(i: getWaitOperands().size() + numOptional + i);
3254}
3255
3256void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3257 MLIRContext *context) {
3258 results.add<RemoveConstantIfCondition<EnterDataOp>>(arg&: context);
3259}
3260
3261void EnterDataOp::addAsyncOnly(
3262 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3263 assert(effectiveDeviceTypes.empty());
3264 assert(!getAsyncAttr());
3265 assert(!getAsyncOperand());
3266
3267 setAsyncAttr(mlir::UnitAttr::get(context));
3268}
3269
3270void EnterDataOp::addAsyncOperand(
3271 MLIRContext *context, mlir::Value newValue,
3272 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3273 assert(effectiveDeviceTypes.empty());
3274 assert(!getAsyncAttr());
3275 assert(!getAsyncOperand());
3276
3277 getAsyncOperandMutable().append(values: newValue);
3278}
3279
3280void EnterDataOp::addWaitOnly(MLIRContext *context,
3281 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3282 assert(effectiveDeviceTypes.empty());
3283 assert(!getWaitAttr());
3284 assert(getWaitOperands().empty());
3285 assert(!getWaitDevnum());
3286
3287 setWaitAttr(mlir::UnitAttr::get(context));
3288}
3289
3290void EnterDataOp::addWaitOperands(
3291 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3292 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3293 assert(effectiveDeviceTypes.empty());
3294 assert(!getWaitAttr());
3295 assert(getWaitOperands().empty());
3296 assert(!getWaitDevnum());
3297
3298 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3299 // operands list.
3300 if (hasDevnum) {
3301 getWaitDevnumMutable().append(values: newValues.front());
3302 newValues = newValues.drop_front();
3303 }
3304
3305 getWaitOperandsMutable().append(values: newValues);
3306}
3307
3308//===----------------------------------------------------------------------===//
3309// AtomicReadOp
3310//===----------------------------------------------------------------------===//
3311
3312LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
3313
3314//===----------------------------------------------------------------------===//
3315// AtomicWriteOp
3316//===----------------------------------------------------------------------===//
3317
3318LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
3319
3320//===----------------------------------------------------------------------===//
3321// AtomicUpdateOp
3322//===----------------------------------------------------------------------===//
3323
3324LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3325 PatternRewriter &rewriter) {
3326 if (op.isNoOp()) {
3327 rewriter.eraseOp(op);
3328 return success();
3329 }
3330
3331 if (Value writeVal = op.getWriteOpVal()) {
3332 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, args: op.getX(), args&: writeVal);
3333 return success();
3334 }
3335
3336 return failure();
3337}
3338
3339LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
3340
3341LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3342
3343//===----------------------------------------------------------------------===//
3344// AtomicCaptureOp
3345//===----------------------------------------------------------------------===//
3346
3347AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3348 if (auto op = dyn_cast<AtomicReadOp>(Val: getFirstOp()))
3349 return op;
3350 return dyn_cast<AtomicReadOp>(Val: getSecondOp());
3351}
3352
3353AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3354 if (auto op = dyn_cast<AtomicWriteOp>(Val: getFirstOp()))
3355 return op;
3356 return dyn_cast<AtomicWriteOp>(Val: getSecondOp());
3357}
3358
3359AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3360 if (auto op = dyn_cast<AtomicUpdateOp>(Val: getFirstOp()))
3361 return op;
3362 return dyn_cast<AtomicUpdateOp>(Val: getSecondOp());
3363}
3364
3365LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
3366
3367//===----------------------------------------------------------------------===//
3368// DeclareEnterOp
3369//===----------------------------------------------------------------------===//
3370
3371template <typename Op>
3372static LogicalResult
3373checkDeclareOperands(Op &op, const mlir::ValueRange &operands,
3374 bool requireAtLeastOneOperand = true) {
3375 if (operands.empty() && requireAtLeastOneOperand)
3376 return emitError(
3377 op->getLoc(),
3378 "at least one operand must appear on the declare operation");
3379
3380 for (mlir::Value operand : operands) {
3381 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3382 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3383 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3384 Val: operand.getDefiningOp()))
3385 return op.emitError(
3386 "expect valid declare data entry operation or acc.getdeviceptr "
3387 "as defining op");
3388
3389 mlir::Value var{getVar(accDataClauseOp: operand.getDefiningOp())};
3390 assert(var && "declare operands can only be data entry operations which "
3391 "must have var");
3392 (void)var;
3393 std::optional<mlir::acc::DataClause> dataClauseOptional{
3394 getDataClause(accDataEntryOp: operand.getDefiningOp())};
3395 assert(dataClauseOptional.has_value() &&
3396 "declare operands can only be data entry operations which must have "
3397 "dataClause");
3398 (void)dataClauseOptional;
3399 }
3400
3401 return success();
3402}
3403
3404LogicalResult acc::DeclareEnterOp::verify() {
3405 return checkDeclareOperands(op&: *this, operands: this->getDataClauseOperands());
3406}
3407
3408//===----------------------------------------------------------------------===//
3409// DeclareExitOp
3410//===----------------------------------------------------------------------===//
3411
3412LogicalResult acc::DeclareExitOp::verify() {
3413 if (getToken())
3414 return checkDeclareOperands(op&: *this, operands: this->getDataClauseOperands(),
3415 /*requireAtLeastOneOperand=*/false);
3416 return checkDeclareOperands(op&: *this, operands: this->getDataClauseOperands());
3417}
3418
3419//===----------------------------------------------------------------------===//
3420// DeclareOp
3421//===----------------------------------------------------------------------===//
3422
3423LogicalResult acc::DeclareOp::verify() {
3424 return checkDeclareOperands(op&: *this, operands: this->getDataClauseOperands());
3425}
3426
3427//===----------------------------------------------------------------------===//
3428// RoutineOp
3429//===----------------------------------------------------------------------===//
3430
3431static unsigned getParallelismForDeviceType(acc::RoutineOp op,
3432 acc::DeviceType dtype) {
3433 unsigned parallelism = 0;
3434 parallelism += (op.hasGang(deviceType: dtype) || op.getGangDimValue(deviceType: dtype)) ? 1 : 0;
3435 parallelism += op.hasWorker(deviceType: dtype) ? 1 : 0;
3436 parallelism += op.hasVector(deviceType: dtype) ? 1 : 0;
3437 parallelism += op.hasSeq(deviceType: dtype) ? 1 : 0;
3438 return parallelism;
3439}
3440
3441LogicalResult acc::RoutineOp::verify() {
3442 unsigned baseParallelism =
3443 getParallelismForDeviceType(op: *this, dtype: acc::DeviceType::None);
3444
3445 if (baseParallelism > 1)
3446 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3447 "be present at the same time";
3448
3449 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3450 ++dtypeInt) {
3451 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
3452 if (dtype == acc::DeviceType::None)
3453 continue;
3454 unsigned parallelism = getParallelismForDeviceType(op: *this, dtype);
3455
3456 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3457 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
3458 "be present at the same time";
3459 }
3460
3461 return success();
3462}
3463
3464static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
3465 mlir::ArrayAttr &deviceTypes) {
3466 llvm::SmallVector<mlir::Attribute> bindNameAttrs;
3467 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
3468
3469 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3470 if (parser.parseAttribute(result&: bindNameAttrs.emplace_back()))
3471 return failure();
3472 if (failed(Result: parser.parseOptionalLSquare())) {
3473 deviceTypeAttrs.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
3474 context: parser.getContext(), value: mlir::acc::DeviceType::None));
3475 } else {
3476 if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) ||
3477 parser.parseRSquare())
3478 return failure();
3479 }
3480 return success();
3481 })))
3482 return failure();
3483
3484 bindName = ArrayAttr::get(context: parser.getContext(), value: bindNameAttrs);
3485 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: deviceTypeAttrs);
3486
3487 return success();
3488}
3489
3490static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
3491 std::optional<mlir::ArrayAttr> bindName,
3492 std::optional<mlir::ArrayAttr> deviceTypes) {
3493 llvm::interleaveComma(c: llvm::zip(t&: *bindName, u&: *deviceTypes), os&: p,
3494 each_fn: [&](const auto &pair) {
3495 p << std::get<0>(pair);
3496 printSingleDeviceType(p, std::get<1>(pair));
3497 });
3498}
3499
3500static ParseResult parseRoutineGangClause(OpAsmParser &parser,
3501 mlir::ArrayAttr &gang,
3502 mlir::ArrayAttr &gangDim,
3503 mlir::ArrayAttr &gangDimDeviceTypes) {
3504
3505 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
3506 gangDimDeviceTypeAttrs;
3507 bool needCommaBeforeOperands = false;
3508
3509 // Gang keyword only
3510 if (failed(Result: parser.parseOptionalLParen())) {
3511 gangAttrs.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
3512 context: parser.getContext(), value: mlir::acc::DeviceType::None));
3513 gang = ArrayAttr::get(context: parser.getContext(), value: gangAttrs);
3514 return success();
3515 }
3516
3517 // Parse keyword only attributes
3518 if (succeeded(Result: parser.parseOptionalLSquare())) {
3519 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3520 if (parser.parseAttribute(result&: gangAttrs.emplace_back()))
3521 return failure();
3522 return success();
3523 })))
3524 return failure();
3525 if (parser.parseRSquare())
3526 return failure();
3527 needCommaBeforeOperands = true;
3528 }
3529
3530 if (needCommaBeforeOperands && failed(Result: parser.parseComma()))
3531 return failure();
3532
3533 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3534 if (parser.parseKeyword(keyword: acc::RoutineOp::getGangDimKeyword()) ||
3535 parser.parseColon() ||
3536 parser.parseAttribute(result&: gangDimAttrs.emplace_back()))
3537 return failure();
3538 if (succeeded(Result: parser.parseOptionalLSquare())) {
3539 if (parser.parseAttribute(result&: gangDimDeviceTypeAttrs.emplace_back()) ||
3540 parser.parseRSquare())
3541 return failure();
3542 } else {
3543 gangDimDeviceTypeAttrs.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
3544 context: parser.getContext(), value: mlir::acc::DeviceType::None));
3545 }
3546 return success();
3547 })))
3548 return failure();
3549
3550 if (failed(Result: parser.parseRParen()))
3551 return failure();
3552
3553 gang = ArrayAttr::get(context: parser.getContext(), value: gangAttrs);
3554 gangDim = ArrayAttr::get(context: parser.getContext(), value: gangDimAttrs);
3555 gangDimDeviceTypes =
3556 ArrayAttr::get(context: parser.getContext(), value: gangDimDeviceTypeAttrs);
3557
3558 return success();
3559}
3560
3561void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
3562 std::optional<mlir::ArrayAttr> gang,
3563 std::optional<mlir::ArrayAttr> gangDim,
3564 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3565
3566 if (!hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes) && hasDeviceTypeValues(arrayAttr: gang) &&
3567 gang->size() == 1) {
3568 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val: (*gang)[0]);
3569 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3570 return;
3571 }
3572
3573 p << "(";
3574
3575 printDeviceTypes(p, deviceTypes: gang);
3576
3577 if (hasDeviceTypeValues(arrayAttr: gang) && hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes))
3578 p << ", ";
3579
3580 if (hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes))
3581 llvm::interleaveComma(c: llvm::zip(t&: *gangDim, u&: *gangDimDeviceTypes), os&: p,
3582 each_fn: [&](const auto &pair) {
3583 p << acc::RoutineOp::getGangDimKeyword() << ": ";
3584 p << std::get<0>(pair);
3585 printSingleDeviceType(p, std::get<1>(pair));
3586 });
3587
3588 p << ")";
3589}
3590
3591static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
3592 mlir::ArrayAttr &deviceTypes) {
3593 llvm::SmallVector<mlir::Attribute> attributes;
3594 // Keyword only
3595 if (failed(Result: parser.parseOptionalLParen())) {
3596 attributes.push_back(Elt: mlir::acc::DeviceTypeAttr::get(
3597 context: parser.getContext(), value: mlir::acc::DeviceType::None));
3598 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: attributes);
3599 return success();
3600 }
3601
3602 // Parse device type attributes
3603 if (succeeded(Result: parser.parseOptionalLSquare())) {
3604 if (failed(Result: parser.parseCommaSeparatedList(parseElementFn: [&]() {
3605 if (parser.parseAttribute(result&: attributes.emplace_back()))
3606 return failure();
3607 return success();
3608 })))
3609 return failure();
3610 if (parser.parseRSquare() || parser.parseRParen())
3611 return failure();
3612 }
3613 deviceTypes = ArrayAttr::get(context: parser.getContext(), value: attributes);
3614 return success();
3615}
3616
3617static void
3618printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
3619 std::optional<mlir::ArrayAttr> deviceTypes) {
3620
3621 if (hasDeviceTypeValues(arrayAttr: deviceTypes) && deviceTypes->size() == 1) {
3622 auto deviceTypeAttr =
3623 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val: (*deviceTypes)[0]);
3624 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
3625 return;
3626 }
3627
3628 if (!hasDeviceTypeValues(arrayAttr: deviceTypes))
3629 return;
3630
3631 p << "([";
3632 llvm::interleaveComma(c: *deviceTypes, os&: p, each_fn: [&](mlir::Attribute attr) {
3633 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(Val&: attr);
3634 p << dTypeAttr;
3635 });
3636 p << "])";
3637}
3638
3639bool RoutineOp::hasWorker() { return hasWorker(deviceType: mlir::acc::DeviceType::None); }
3640
3641bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3642 return hasDeviceType(arrayAttr: getWorker(), deviceType);
3643}
3644
3645bool RoutineOp::hasVector() { return hasVector(deviceType: mlir::acc::DeviceType::None); }
3646
3647bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3648 return hasDeviceType(arrayAttr: getVector(), deviceType);
3649}
3650
3651bool RoutineOp::hasSeq() { return hasSeq(deviceType: mlir::acc::DeviceType::None); }
3652
3653bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3654 return hasDeviceType(arrayAttr: getSeq(), deviceType);
3655}
3656
3657std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3658 return getBindNameValue(deviceType: mlir::acc::DeviceType::None);
3659}
3660
3661std::optional<llvm::StringRef>
3662RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3663 if (!hasDeviceTypeValues(arrayAttr: getBindNameDeviceType()))
3664 return std::nullopt;
3665 if (auto pos = findSegment(segments: *getBindNameDeviceType(), deviceType)) {
3666 auto attr = (*getBindName())[*pos];
3667 auto stringAttr = dyn_cast<mlir::StringAttr>(Val&: attr);
3668 return stringAttr.getValue();
3669 }
3670 return std::nullopt;
3671}
3672
3673bool RoutineOp::hasGang() { return hasGang(deviceType: mlir::acc::DeviceType::None); }
3674
3675bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3676 return hasDeviceType(arrayAttr: getGang(), deviceType);
3677}
3678
3679std::optional<int64_t> RoutineOp::getGangDimValue() {
3680 return getGangDimValue(deviceType: mlir::acc::DeviceType::None);
3681}
3682
3683std::optional<int64_t>
3684RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3685 if (!hasDeviceTypeValues(arrayAttr: getGangDimDeviceType()))
3686 return std::nullopt;
3687 if (auto pos = findSegment(segments: *getGangDimDeviceType(), deviceType)) {
3688 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(Val: (*getGangDim())[*pos]);
3689 return intAttr.getInt();
3690 }
3691 return std::nullopt;
3692}
3693
3694//===----------------------------------------------------------------------===//
3695// InitOp
3696//===----------------------------------------------------------------------===//
3697
3698LogicalResult acc::InitOp::verify() {
3699 Operation *currOp = *this;
3700 while ((currOp = currOp->getParentOp()))
3701 if (isComputeOperation(op: currOp))
3702 return emitOpError(message: "cannot be nested in a compute operation");
3703 return success();
3704}
3705
3706void acc::InitOp::addDeviceType(MLIRContext *context,
3707 mlir::acc::DeviceType deviceType) {
3708 llvm::SmallVector<mlir::Attribute> deviceTypes;
3709 if (getDeviceTypesAttr())
3710 llvm::copy(Range: getDeviceTypesAttr(), Out: std::back_inserter(x&: deviceTypes));
3711
3712 deviceTypes.push_back(Elt: acc::DeviceTypeAttr::get(context, value: deviceType));
3713 setDeviceTypesAttr(mlir::ArrayAttr::get(context, value: deviceTypes));
3714}
3715
3716//===----------------------------------------------------------------------===//
3717// ShutdownOp
3718//===----------------------------------------------------------------------===//
3719
3720LogicalResult acc::ShutdownOp::verify() {
3721 Operation *currOp = *this;
3722 while ((currOp = currOp->getParentOp()))
3723 if (isComputeOperation(op: currOp))
3724 return emitOpError(message: "cannot be nested in a compute operation");
3725 return success();
3726}
3727
3728void acc::ShutdownOp::addDeviceType(MLIRContext *context,
3729 mlir::acc::DeviceType deviceType) {
3730 llvm::SmallVector<mlir::Attribute> deviceTypes;
3731 if (getDeviceTypesAttr())
3732 llvm::copy(Range: getDeviceTypesAttr(), Out: std::back_inserter(x&: deviceTypes));
3733
3734 deviceTypes.push_back(Elt: acc::DeviceTypeAttr::get(context, value: deviceType));
3735 setDeviceTypesAttr(mlir::ArrayAttr::get(context, value: deviceTypes));
3736}
3737
3738//===----------------------------------------------------------------------===//
3739// SetOp
3740//===----------------------------------------------------------------------===//
3741
3742LogicalResult acc::SetOp::verify() {
3743 Operation *currOp = *this;
3744 while ((currOp = currOp->getParentOp()))
3745 if (isComputeOperation(op: currOp))
3746 return emitOpError(message: "cannot be nested in a compute operation");
3747 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3748 return emitOpError(message: "at least one default_async, device_num, or device_type "
3749 "operand must appear");
3750 return success();
3751}
3752
3753//===----------------------------------------------------------------------===//
3754// UpdateOp
3755//===----------------------------------------------------------------------===//
3756
3757LogicalResult acc::UpdateOp::verify() {
3758 // At least one of host or device should have a value.
3759 if (getDataClauseOperands().empty())
3760 return emitError(message: "at least one value must be present in dataOperands");
3761
3762 if (failed(Result: verifyDeviceTypeCountMatch(op: *this, operands: getAsyncOperands(),
3763 deviceTypes: getAsyncOperandsDeviceTypeAttr(),
3764 keyword: "async")))
3765 return failure();
3766
3767 if (failed(Result: verifyDeviceTypeAndSegmentCountMatch(
3768 op: *this, operands: getWaitOperands(), segments: getWaitOperandsSegmentsAttr(),
3769 deviceTypes: getWaitOperandsDeviceTypeAttr(), keyword: "wait")))
3770 return failure();
3771
3772 if (failed(Result: checkWaitAndAsyncConflict<acc::UpdateOp>(op: *this)))
3773 return failure();
3774
3775 for (mlir::Value operand : getDataClauseOperands())
3776 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3777 Val: operand.getDefiningOp()))
3778 return emitError(message: "expect data entry/exit operation or acc.getdeviceptr "
3779 "as defining op");
3780
3781 return success();
3782}
3783
3784unsigned UpdateOp::getNumDataOperands() {
3785 return getDataClauseOperands().size();
3786}
3787
3788Value UpdateOp::getDataOperand(unsigned i) {
3789 unsigned numOptional = getAsyncOperands().size();
3790 numOptional += getIfCond() ? 1 : 0;
3791 return getOperand(i: getWaitOperands().size() + numOptional + i);
3792}
3793
3794void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
3795 MLIRContext *context) {
3796 results.add<RemoveConstantIfCondition<UpdateOp>>(arg&: context);
3797}
3798
3799bool UpdateOp::hasAsyncOnly() {
3800 return hasAsyncOnly(deviceType: mlir::acc::DeviceType::None);
3801}
3802
3803bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3804 return hasDeviceType(arrayAttr: getAsyncOnly(), deviceType);
3805}
3806
3807mlir::Value UpdateOp::getAsyncValue() {
3808 return getAsyncValue(deviceType: mlir::acc::DeviceType::None);
3809}
3810
3811mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3812 if (!hasDeviceTypeValues(arrayAttr: getAsyncOperandsDeviceType()))
3813 return {};
3814
3815 if (auto pos = findSegment(segments: *getAsyncOperandsDeviceType(), deviceType))
3816 return getAsyncOperands()[*pos];
3817
3818 return {};
3819}
3820
3821bool UpdateOp::hasWaitOnly() {
3822 return hasWaitOnly(deviceType: mlir::acc::DeviceType::None);
3823}
3824
3825bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3826 return hasDeviceType(arrayAttr: getWaitOnly(), deviceType);
3827}
3828
3829mlir::Operation::operand_range UpdateOp::getWaitValues() {
3830 return getWaitValues(deviceType: mlir::acc::DeviceType::None);
3831}
3832
3833mlir::Operation::operand_range
3834UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3835 return getWaitValuesWithoutDevnum(
3836 deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(), segments: getWaitOperandsSegments(),
3837 hasWaitDevnum: getHasWaitDevnum(), deviceType);
3838}
3839
3840mlir::Value UpdateOp::getWaitDevnum() {
3841 return getWaitDevnum(deviceType: mlir::acc::DeviceType::None);
3842}
3843
3844mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3845 return getWaitDevnumValue(deviceTypeAttr: getWaitOperandsDeviceType(), operands: getWaitOperands(),
3846 segments: getWaitOperandsSegments(), hasWaitDevnum: getHasWaitDevnum(),
3847 deviceType);
3848}
3849
3850void UpdateOp::addAsyncOnly(MLIRContext *context,
3851 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3852 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3853 context, existingDeviceTypes: getAsyncOnlyAttr(), newDeviceTypes: effectiveDeviceTypes));
3854}
3855
3856void UpdateOp::addAsyncOperand(
3857 MLIRContext *context, mlir::Value newValue,
3858 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3859 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3860 context, existingDeviceTypes: getAsyncOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValue,
3861 argCollection: getAsyncOperandsMutable()));
3862}
3863
3864void UpdateOp::addWaitOnly(MLIRContext *context,
3865 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3866 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes: getWaitOnlyAttr(),
3867 newDeviceTypes: effectiveDeviceTypes));
3868}
3869
3870void UpdateOp::addWaitOperands(
3871 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3872 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3873
3874 llvm::SmallVector<int32_t> segments;
3875 if (getWaitOperandsSegments())
3876 llvm::copy(Range: *getWaitOperandsSegments(), Out: std::back_inserter(x&: segments));
3877
3878 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3879 context, existingDeviceTypes: getWaitOperandsDeviceTypeAttr(), newDeviceTypes: effectiveDeviceTypes, arguments: newValues,
3880 argCollection: getWaitOperandsMutable(), segments));
3881 setWaitOperandsSegments(segments);
3882
3883 llvm::SmallVector<mlir::Attribute> hasDevnums;
3884 if (getHasWaitDevnumAttr())
3885 llvm::copy(Range: getHasWaitDevnumAttr(), Out: std::back_inserter(x&: hasDevnums));
3886 hasDevnums.insert(
3887 I: hasDevnums.end(),
3888 NumToInsert: std::max(a: effectiveDeviceTypes.size(), b: static_cast<size_t>(1)),
3889 Elt: mlir::BoolAttr::get(context, value: hasDevnum));
3890 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, value: hasDevnums));
3891}
3892
3893//===----------------------------------------------------------------------===//
3894// WaitOp
3895//===----------------------------------------------------------------------===//
3896
3897LogicalResult acc::WaitOp::verify() {
3898 // The async attribute represent the async clause without value. Therefore the
3899 // attribute and operand cannot appear at the same time.
3900 if (getAsyncOperand() && getAsync())
3901 return emitError(message: "async attribute cannot appear with asyncOperand");
3902
3903 if (getWaitDevnum() && getWaitOperands().empty())
3904 return emitError(message: "wait_devnum cannot appear without waitOperands");
3905
3906 return success();
3907}
3908
3909#define GET_OP_CLASSES
3910#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3911
3912#define GET_ATTRDEF_CLASSES
3913#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3914
3915#define GET_TYPEDEF_CLASSES
3916#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3917
3918//===----------------------------------------------------------------------===//
3919// acc dialect utilities
3920//===----------------------------------------------------------------------===//
3921
3922mlir::TypedValue<mlir::acc::PointerLikeType>
3923mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) {
3924 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
3925 mlir::TypedValue<mlir::acc::PointerLikeType>>(
3926 accDataClauseOp)
3927 .Case<ACC_DATA_ENTRY_OPS>(
3928 caseFn: [&](auto entry) { return entry.getVarPtr(); })
3929 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3930 caseFn: [&](auto exit) { return exit.getVarPtr(); })
3931 .Default(defaultFn: [&](mlir::Operation *) {
3932 return mlir::TypedValue<mlir::acc::PointerLikeType>();
3933 })};
3934 return varPtr;
3935}
3936
3937mlir::Value mlir::acc::getVar(mlir::Operation *accDataClauseOp) {
3938 auto varPtr{
3939 llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3940 .Case<ACC_DATA_ENTRY_OPS>(caseFn: [&](auto entry) { return entry.getVar(); })
3941 .Default(defaultFn: [&](mlir::Operation *) { return mlir::Value(); })};
3942 return varPtr;
3943}
3944
3945mlir::Type mlir::acc::getVarType(mlir::Operation *accDataClauseOp) {
3946 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
3947 .Case<ACC_DATA_ENTRY_OPS>(
3948 caseFn: [&](auto entry) { return entry.getVarType(); })
3949 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3950 caseFn: [&](auto exit) { return exit.getVarType(); })
3951 .Default(defaultFn: [&](mlir::Operation *) { return mlir::Type(); })};
3952 return varType;
3953}
3954
3955mlir::TypedValue<mlir::acc::PointerLikeType>
3956mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) {
3957 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
3958 mlir::TypedValue<mlir::acc::PointerLikeType>>(
3959 accDataClauseOp)
3960 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3961 caseFn: [&](auto dataClause) { return dataClause.getAccPtr(); })
3962 .Default(defaultFn: [&](mlir::Operation *) {
3963 return mlir::TypedValue<mlir::acc::PointerLikeType>();
3964 })};
3965 return accPtr;
3966}
3967
3968mlir::Value mlir::acc::getAccVar(mlir::Operation *accDataClauseOp) {
3969 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3970 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3971 caseFn: [&](auto dataClause) { return dataClause.getAccVar(); })
3972 .Default(defaultFn: [&](mlir::Operation *) { return mlir::Value(); })};
3973 return accPtr;
3974}
3975
3976mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) {
3977 auto varPtrPtr{
3978 llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
3979 .Case<ACC_DATA_ENTRY_OPS>(
3980 caseFn: [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
3981 .Default(defaultFn: [&](mlir::Operation *) { return mlir::Value(); })};
3982 return varPtrPtr;
3983}
3984
3985mlir::SmallVector<mlir::Value>
3986mlir::acc::getBounds(mlir::Operation *accDataClauseOp) {
3987 mlir::SmallVector<mlir::Value> bounds{
3988 llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
3989 accDataClauseOp)
3990 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(caseFn: [&](auto dataClause) {
3991 return mlir::SmallVector<mlir::Value>(
3992 dataClause.getBounds().begin(), dataClause.getBounds().end());
3993 })
3994 .Default(defaultFn: [&](mlir::Operation *) {
3995 return mlir::SmallVector<mlir::Value, 0>();
3996 })};
3997 return bounds;
3998}
3999
4000mlir::SmallVector<mlir::Value>
4001mlir::acc::getAsyncOperands(mlir::Operation *accDataClauseOp) {
4002 return llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>(
4003 accDataClauseOp)
4004 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(caseFn: [&](auto dataClause) {
4005 return mlir::SmallVector<mlir::Value>(
4006 dataClause.getAsyncOperands().begin(),
4007 dataClause.getAsyncOperands().end());
4008 })
4009 .Default(defaultFn: [&](mlir::Operation *) {
4010 return mlir::SmallVector<mlir::Value, 0>();
4011 });
4012}
4013
4014mlir::ArrayAttr
4015mlir::acc::getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp) {
4016 return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
4017 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(caseFn: [&](auto dataClause) {
4018 return dataClause.getAsyncOperandsDeviceTypeAttr();
4019 })
4020 .Default(defaultFn: [&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4021}
4022
4023mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4024 return llvm::TypeSwitch<mlir::Operation *, mlir::ArrayAttr>(accDataClauseOp)
4025 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4026 caseFn: [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4027 .Default(defaultFn: [&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4028}
4029
4030std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4031 auto name{
4032 llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp)
4033 .Case<ACC_DATA_ENTRY_OPS>(caseFn: [&](auto entry) { return entry.getName(); })
4034 .Default(defaultFn: [&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4035 return {};
4036 })};
4037 return name;
4038}
4039
4040std::optional<mlir::acc::DataClause>
4041mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) {
4042 auto dataClause{
4043 llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>(
4044 accDataEntryOp)
4045 .Case<ACC_DATA_ENTRY_OPS>(
4046 caseFn: [&](auto entry) { return entry.getDataClause(); })
4047 .Default(defaultFn: [&](mlir::Operation *) { return std::nullopt; })};
4048 return dataClause;
4049}
4050
4051bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) {
4052 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4053 .Case<ACC_DATA_ENTRY_OPS>(
4054 caseFn: [&](auto entry) { return entry.getImplicit(); })
4055 .Default(defaultFn: [&](mlir::Operation *) { return false; })};
4056 return implicit;
4057}
4058
4059mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) {
4060 auto dataOperands{
4061 llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp)
4062 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
4063 caseFn: [&](auto entry) { return entry.getDataClauseOperands(); })
4064 .Default(defaultFn: [&](mlir::Operation *) { return mlir::ValueRange(); })};
4065 return dataOperands;
4066}
4067
4068mlir::MutableOperandRange
4069mlir::acc::getMutableDataOperands(mlir::Operation *accOp) {
4070 auto dataOperands{
4071 llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp)
4072 .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>(
4073 caseFn: [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4074 .Default(defaultFn: [&](mlir::Operation *) { return nullptr; })};
4075 return dataOperands;
4076}
4077
4078mlir::Operation *mlir::acc::getEnclosingComputeOp(mlir::Region &region) {
4079 mlir::Operation *parentOp = region.getParentOp();
4080 while (parentOp) {
4081 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(Val: parentOp)) {
4082 return parentOp;
4083 }
4084 parentOp = parentOp->getParentOp();
4085 }
4086 return nullptr;
4087}
4088

source code of mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp