1 | //===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===// |
2 | // |
3 | // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | // ============================================================================= |
8 | |
9 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
10 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
11 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
12 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
13 | #include "mlir/IR/Builders.h" |
14 | #include "mlir/IR/BuiltinTypes.h" |
15 | #include "mlir/IR/DialectImplementation.h" |
16 | #include "mlir/IR/Matchers.h" |
17 | #include "mlir/IR/OpImplementation.h" |
18 | #include "mlir/Transforms/DialectConversion.h" |
19 | #include "llvm/ADT/SmallSet.h" |
20 | #include "llvm/ADT/TypeSwitch.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace acc; |
24 | |
25 | #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc" |
26 | #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc" |
27 | #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc" |
28 | #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc" |
29 | |
30 | namespace { |
31 | struct MemRefPointerLikeModel |
32 | : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, |
33 | MemRefType> { |
34 | Type getElementType(Type pointer) const { |
35 | return llvm::cast<MemRefType>(pointer).getElementType(); |
36 | } |
37 | }; |
38 | |
39 | struct LLVMPointerPointerLikeModel |
40 | : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel, |
41 | LLVM::LLVMPointerType> { |
42 | Type getElementType(Type pointer) const { return Type(); } |
43 | }; |
44 | } // namespace |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // OpenACC operations |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | void OpenACCDialect::initialize() { |
51 | addOperations< |
52 | #define GET_OP_LIST |
53 | #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" |
54 | >(); |
55 | addAttributes< |
56 | #define GET_ATTRDEF_LIST |
57 | #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" |
58 | >(); |
59 | addTypes< |
60 | #define GET_TYPEDEF_LIST |
61 | #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" |
62 | >(); |
63 | |
64 | // By attaching interfaces here, we make the OpenACC dialect dependent on |
65 | // the other dialects. This is probably better than having dialects like LLVM |
66 | // and memref be dependent on OpenACC. |
67 | MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext()); |
68 | LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>( |
69 | *getContext()); |
70 | } |
71 | |
72 | //===----------------------------------------------------------------------===// |
73 | // device_type support helpers |
74 | //===----------------------------------------------------------------------===// |
75 | |
76 | static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) { |
77 | if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) |
78 | return true; |
79 | return false; |
80 | } |
81 | |
82 | static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr, |
83 | mlir::acc::DeviceType deviceType) { |
84 | if (!hasDeviceTypeValues(arrayAttr)) |
85 | return false; |
86 | |
87 | for (auto attr : *arrayAttr) { |
88 | auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); |
89 | if (deviceTypeAttr.getValue() == deviceType) |
90 | return true; |
91 | } |
92 | |
93 | return false; |
94 | } |
95 | |
96 | static void printDeviceTypes(mlir::OpAsmPrinter &p, |
97 | std::optional<mlir::ArrayAttr> deviceTypes) { |
98 | if (!hasDeviceTypeValues(arrayAttr: deviceTypes)) |
99 | return; |
100 | |
101 | p << "[" ; |
102 | llvm::interleaveComma(*deviceTypes, p, |
103 | [&](mlir::Attribute attr) { p << attr; }); |
104 | p << "]" ; |
105 | } |
106 | |
107 | static std::optional<unsigned> findSegment(ArrayAttr segments, |
108 | mlir::acc::DeviceType deviceType) { |
109 | unsigned segmentIdx = 0; |
110 | for (auto attr : segments) { |
111 | auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); |
112 | if (deviceTypeAttr.getValue() == deviceType) |
113 | return std::make_optional(segmentIdx); |
114 | ++segmentIdx; |
115 | } |
116 | return std::nullopt; |
117 | } |
118 | |
119 | static mlir::Operation::operand_range |
120 | getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr, |
121 | mlir::Operation::operand_range range, |
122 | std::optional<llvm::ArrayRef<int32_t>> segments, |
123 | mlir::acc::DeviceType deviceType) { |
124 | if (!arrayAttr) |
125 | return range.take_front(n: 0); |
126 | if (auto pos = findSegment(*arrayAttr, deviceType)) { |
127 | int32_t nbOperandsBefore = 0; |
128 | for (unsigned i = 0; i < *pos; ++i) |
129 | nbOperandsBefore += (*segments)[i]; |
130 | return range.drop_front(n: nbOperandsBefore).take_front(n: (*segments)[*pos]); |
131 | } |
132 | return range.take_front(n: 0); |
133 | } |
134 | |
135 | static mlir::Value |
136 | getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr, |
137 | mlir::Operation::operand_range operands, |
138 | std::optional<llvm::ArrayRef<int32_t>> segments, |
139 | std::optional<mlir::ArrayAttr> hasWaitDevnum, |
140 | mlir::acc::DeviceType deviceType) { |
141 | if (!hasDeviceTypeValues(arrayAttr: deviceTypeAttr)) |
142 | return {}; |
143 | if (auto pos = findSegment(*deviceTypeAttr, deviceType)) |
144 | if (hasWaitDevnum->getValue()[*pos]) |
145 | return getValuesFromSegments(deviceTypeAttr, operands, segments, |
146 | deviceType) |
147 | .front(); |
148 | return {}; |
149 | } |
150 | |
151 | static mlir::Operation::operand_range |
152 | getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr, |
153 | mlir::Operation::operand_range operands, |
154 | std::optional<llvm::ArrayRef<int32_t>> segments, |
155 | std::optional<mlir::ArrayAttr> hasWaitDevnum, |
156 | mlir::acc::DeviceType deviceType) { |
157 | auto range = |
158 | getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType); |
159 | if (range.empty()) |
160 | return range; |
161 | if (auto pos = findSegment(*deviceTypeAttr, deviceType)) { |
162 | if (hasWaitDevnum && *hasWaitDevnum) { |
163 | auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]); |
164 | if (boolAttr.getValue()) |
165 | return range.drop_front(1); // first value is devnum |
166 | } |
167 | } |
168 | return range; |
169 | } |
170 | |
171 | template <typename Op> |
172 | static LogicalResult checkWaitAndAsyncConflict(Op op) { |
173 | for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); |
174 | ++dtypeInt) { |
175 | auto dtype = static_cast<acc::DeviceType>(dtypeInt); |
176 | |
177 | // The async attribute represent the async clause without value. Therefore |
178 | // the attribute and operand cannot appear at the same time. |
179 | if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) && |
180 | op.hasAsyncOnly(dtype)) |
181 | return op.emitError("async attribute cannot appear with asyncOperand" ); |
182 | |
183 | // The wait attribute represent the wait clause without values. Therefore |
184 | // the attribute and operands cannot appear at the same time. |
185 | if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) && |
186 | op.hasWaitOnly(dtype)) |
187 | return op.emitError("wait attribute cannot appear with waitOperands" ); |
188 | } |
189 | return success(); |
190 | } |
191 | |
192 | //===----------------------------------------------------------------------===// |
193 | // DataBoundsOp |
194 | //===----------------------------------------------------------------------===// |
195 | LogicalResult acc::DataBoundsOp::verify() { |
196 | auto extent = getExtent(); |
197 | auto upperbound = getUpperbound(); |
198 | if (!extent && !upperbound) |
199 | return emitError("expected extent or upperbound." ); |
200 | return success(); |
201 | } |
202 | |
203 | //===----------------------------------------------------------------------===// |
204 | // PrivateOp |
205 | //===----------------------------------------------------------------------===// |
206 | LogicalResult acc::PrivateOp::verify() { |
207 | if (getDataClause() != acc::DataClause::acc_private) |
208 | return emitError( |
209 | "data clause associated with private operation must match its intent" ); |
210 | return success(); |
211 | } |
212 | |
213 | //===----------------------------------------------------------------------===// |
214 | // FirstprivateOp |
215 | //===----------------------------------------------------------------------===// |
216 | LogicalResult acc::FirstprivateOp::verify() { |
217 | if (getDataClause() != acc::DataClause::acc_firstprivate) |
218 | return emitError("data clause associated with firstprivate operation must " |
219 | "match its intent" ); |
220 | return success(); |
221 | } |
222 | |
223 | //===----------------------------------------------------------------------===// |
224 | // ReductionOp |
225 | //===----------------------------------------------------------------------===// |
226 | LogicalResult acc::ReductionOp::verify() { |
227 | if (getDataClause() != acc::DataClause::acc_reduction) |
228 | return emitError("data clause associated with reduction operation must " |
229 | "match its intent" ); |
230 | return success(); |
231 | } |
232 | |
233 | //===----------------------------------------------------------------------===// |
234 | // DevicePtrOp |
235 | //===----------------------------------------------------------------------===// |
236 | LogicalResult acc::DevicePtrOp::verify() { |
237 | if (getDataClause() != acc::DataClause::acc_deviceptr) |
238 | return emitError("data clause associated with deviceptr operation must " |
239 | "match its intent" ); |
240 | return success(); |
241 | } |
242 | |
243 | //===----------------------------------------------------------------------===// |
244 | // PresentOp |
245 | //===----------------------------------------------------------------------===// |
246 | LogicalResult acc::PresentOp::verify() { |
247 | if (getDataClause() != acc::DataClause::acc_present) |
248 | return emitError( |
249 | "data clause associated with present operation must match its intent" ); |
250 | return success(); |
251 | } |
252 | |
253 | //===----------------------------------------------------------------------===// |
254 | // CopyinOp |
255 | //===----------------------------------------------------------------------===// |
256 | LogicalResult acc::CopyinOp::verify() { |
257 | // Test for all clauses this operation can be decomposed from: |
258 | if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin && |
259 | getDataClause() != acc::DataClause::acc_copyin_readonly && |
260 | getDataClause() != acc::DataClause::acc_copy && |
261 | getDataClause() != acc::DataClause::acc_reduction) |
262 | return emitError( |
263 | "data clause associated with copyin operation must match its intent" |
264 | " or specify original clause this operation was decomposed from" ); |
265 | return success(); |
266 | } |
267 | |
268 | bool acc::CopyinOp::isCopyinReadonly() { |
269 | return getDataClause() == acc::DataClause::acc_copyin_readonly; |
270 | } |
271 | |
272 | //===----------------------------------------------------------------------===// |
273 | // CreateOp |
274 | //===----------------------------------------------------------------------===// |
275 | LogicalResult acc::CreateOp::verify() { |
276 | // Test for all clauses this operation can be decomposed from: |
277 | if (getDataClause() != acc::DataClause::acc_create && |
278 | getDataClause() != acc::DataClause::acc_create_zero && |
279 | getDataClause() != acc::DataClause::acc_copyout && |
280 | getDataClause() != acc::DataClause::acc_copyout_zero) |
281 | return emitError( |
282 | "data clause associated with create operation must match its intent" |
283 | " or specify original clause this operation was decomposed from" ); |
284 | return success(); |
285 | } |
286 | |
287 | bool acc::CreateOp::isCreateZero() { |
288 | // The zero modifier is encoded in the data clause. |
289 | return getDataClause() == acc::DataClause::acc_create_zero || |
290 | getDataClause() == acc::DataClause::acc_copyout_zero; |
291 | } |
292 | |
293 | //===----------------------------------------------------------------------===// |
294 | // NoCreateOp |
295 | //===----------------------------------------------------------------------===// |
296 | LogicalResult acc::NoCreateOp::verify() { |
297 | if (getDataClause() != acc::DataClause::acc_no_create) |
298 | return emitError("data clause associated with no_create operation must " |
299 | "match its intent" ); |
300 | return success(); |
301 | } |
302 | |
303 | //===----------------------------------------------------------------------===// |
304 | // AttachOp |
305 | //===----------------------------------------------------------------------===// |
306 | LogicalResult acc::AttachOp::verify() { |
307 | if (getDataClause() != acc::DataClause::acc_attach) |
308 | return emitError( |
309 | "data clause associated with attach operation must match its intent" ); |
310 | return success(); |
311 | } |
312 | |
313 | //===----------------------------------------------------------------------===// |
314 | // DeclareDeviceResidentOp |
315 | //===----------------------------------------------------------------------===// |
316 | |
317 | LogicalResult acc::DeclareDeviceResidentOp::verify() { |
318 | if (getDataClause() != acc::DataClause::acc_declare_device_resident) |
319 | return emitError("data clause associated with device_resident operation " |
320 | "must match its intent" ); |
321 | return success(); |
322 | } |
323 | |
324 | //===----------------------------------------------------------------------===// |
325 | // DeclareLinkOp |
326 | //===----------------------------------------------------------------------===// |
327 | |
328 | LogicalResult acc::DeclareLinkOp::verify() { |
329 | if (getDataClause() != acc::DataClause::acc_declare_link) |
330 | return emitError( |
331 | "data clause associated with link operation must match its intent" ); |
332 | return success(); |
333 | } |
334 | |
335 | //===----------------------------------------------------------------------===// |
336 | // CopyoutOp |
337 | //===----------------------------------------------------------------------===// |
338 | LogicalResult acc::CopyoutOp::verify() { |
339 | // Test for all clauses this operation can be decomposed from: |
340 | if (getDataClause() != acc::DataClause::acc_copyout && |
341 | getDataClause() != acc::DataClause::acc_copyout_zero && |
342 | getDataClause() != acc::DataClause::acc_copy && |
343 | getDataClause() != acc::DataClause::acc_reduction) |
344 | return emitError( |
345 | "data clause associated with copyout operation must match its intent" |
346 | " or specify original clause this operation was decomposed from" ); |
347 | if (!getVarPtr() || !getAccPtr()) |
348 | return emitError("must have both host and device pointers" ); |
349 | return success(); |
350 | } |
351 | |
352 | bool acc::CopyoutOp::isCopyoutZero() { |
353 | return getDataClause() == acc::DataClause::acc_copyout_zero; |
354 | } |
355 | |
356 | //===----------------------------------------------------------------------===// |
357 | // DeleteOp |
358 | //===----------------------------------------------------------------------===// |
359 | LogicalResult acc::DeleteOp::verify() { |
360 | // Test for all clauses this operation can be decomposed from: |
361 | if (getDataClause() != acc::DataClause::acc_delete && |
362 | getDataClause() != acc::DataClause::acc_create && |
363 | getDataClause() != acc::DataClause::acc_create_zero && |
364 | getDataClause() != acc::DataClause::acc_copyin && |
365 | getDataClause() != acc::DataClause::acc_copyin_readonly && |
366 | getDataClause() != acc::DataClause::acc_present && |
367 | getDataClause() != acc::DataClause::acc_declare_device_resident && |
368 | getDataClause() != acc::DataClause::acc_declare_link) |
369 | return emitError( |
370 | "data clause associated with delete operation must match its intent" |
371 | " or specify original clause this operation was decomposed from" ); |
372 | if (!getAccPtr()) |
373 | return emitError("must have device pointer" ); |
374 | return success(); |
375 | } |
376 | |
377 | //===----------------------------------------------------------------------===// |
378 | // DetachOp |
379 | //===----------------------------------------------------------------------===// |
380 | LogicalResult acc::DetachOp::verify() { |
381 | // Test for all clauses this operation can be decomposed from: |
382 | if (getDataClause() != acc::DataClause::acc_detach && |
383 | getDataClause() != acc::DataClause::acc_attach) |
384 | return emitError( |
385 | "data clause associated with detach operation must match its intent" |
386 | " or specify original clause this operation was decomposed from" ); |
387 | if (!getAccPtr()) |
388 | return emitError("must have device pointer" ); |
389 | return success(); |
390 | } |
391 | |
392 | //===----------------------------------------------------------------------===// |
393 | // HostOp |
394 | //===----------------------------------------------------------------------===// |
395 | LogicalResult acc::UpdateHostOp::verify() { |
396 | // Test for all clauses this operation can be decomposed from: |
397 | if (getDataClause() != acc::DataClause::acc_update_host && |
398 | getDataClause() != acc::DataClause::acc_update_self) |
399 | return emitError( |
400 | "data clause associated with host operation must match its intent" |
401 | " or specify original clause this operation was decomposed from" ); |
402 | if (!getVarPtr() || !getAccPtr()) |
403 | return emitError("must have both host and device pointers" ); |
404 | return success(); |
405 | } |
406 | |
407 | //===----------------------------------------------------------------------===// |
408 | // DeviceOp |
409 | //===----------------------------------------------------------------------===// |
410 | LogicalResult acc::UpdateDeviceOp::verify() { |
411 | // Test for all clauses this operation can be decomposed from: |
412 | if (getDataClause() != acc::DataClause::acc_update_device) |
413 | return emitError( |
414 | "data clause associated with device operation must match its intent" |
415 | " or specify original clause this operation was decomposed from" ); |
416 | return success(); |
417 | } |
418 | |
419 | //===----------------------------------------------------------------------===// |
420 | // UseDeviceOp |
421 | //===----------------------------------------------------------------------===// |
422 | LogicalResult acc::UseDeviceOp::verify() { |
423 | // Test for all clauses this operation can be decomposed from: |
424 | if (getDataClause() != acc::DataClause::acc_use_device) |
425 | return emitError( |
426 | "data clause associated with use_device operation must match its intent" |
427 | " or specify original clause this operation was decomposed from" ); |
428 | return success(); |
429 | } |
430 | |
431 | //===----------------------------------------------------------------------===// |
432 | // CacheOp |
433 | //===----------------------------------------------------------------------===// |
434 | LogicalResult acc::CacheOp::verify() { |
435 | // Test for all clauses this operation can be decomposed from: |
436 | if (getDataClause() != acc::DataClause::acc_cache && |
437 | getDataClause() != acc::DataClause::acc_cache_readonly) |
438 | return emitError( |
439 | "data clause associated with cache operation must match its intent" |
440 | " or specify original clause this operation was decomposed from" ); |
441 | return success(); |
442 | } |
443 | |
444 | template <typename StructureOp> |
445 | static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, |
446 | unsigned nRegions = 1) { |
447 | |
448 | SmallVector<Region *, 2> regions; |
449 | for (unsigned i = 0; i < nRegions; ++i) |
450 | regions.push_back(Elt: state.addRegion()); |
451 | |
452 | for (Region *region : regions) |
453 | if (parser.parseRegion(region&: *region, /*arguments=*/{}, /*argTypes=*/enableNameShadowing: {})) |
454 | return failure(); |
455 | |
456 | return success(); |
457 | } |
458 | |
459 | static bool isComputeOperation(Operation *op) { |
460 | return isa<acc::ParallelOp, acc::LoopOp>(op); |
461 | } |
462 | |
463 | namespace { |
464 | /// Pattern to remove operation without region that have constant false `ifCond` |
465 | /// and remove the condition from the operation if the `ifCond` is a true |
466 | /// constant. |
467 | template <typename OpTy> |
468 | struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> { |
469 | using OpRewritePattern<OpTy>::OpRewritePattern; |
470 | |
471 | LogicalResult matchAndRewrite(OpTy op, |
472 | PatternRewriter &rewriter) const override { |
473 | // Early return if there is no condition. |
474 | Value ifCond = op.getIfCond(); |
475 | if (!ifCond) |
476 | return failure(); |
477 | |
478 | IntegerAttr constAttr; |
479 | if (!matchPattern(ifCond, m_Constant(&constAttr))) |
480 | return failure(); |
481 | if (constAttr.getInt()) |
482 | rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); |
483 | else |
484 | rewriter.eraseOp(op); |
485 | |
486 | return success(); |
487 | } |
488 | }; |
489 | |
490 | /// Replaces the given op with the contents of the given single-block region, |
491 | /// using the operands of the block terminator to replace operation results. |
492 | static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, |
493 | Region ®ion, ValueRange blockArgs = {}) { |
494 | assert(llvm::hasSingleElement(region) && "expected single-region block" ); |
495 | Block *block = ®ion.front(); |
496 | Operation *terminator = block->getTerminator(); |
497 | ValueRange results = terminator->getOperands(); |
498 | rewriter.inlineBlockBefore(source: block, op, argValues: blockArgs); |
499 | rewriter.replaceOp(op, newValues: results); |
500 | rewriter.eraseOp(op: terminator); |
501 | } |
502 | |
503 | /// Pattern to remove operation with region that have constant false `ifCond` |
504 | /// and remove the condition from the operation if the `ifCond` is constant |
505 | /// true. |
506 | template <typename OpTy> |
507 | struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> { |
508 | using OpRewritePattern<OpTy>::OpRewritePattern; |
509 | |
510 | LogicalResult matchAndRewrite(OpTy op, |
511 | PatternRewriter &rewriter) const override { |
512 | // Early return if there is no condition. |
513 | Value ifCond = op.getIfCond(); |
514 | if (!ifCond) |
515 | return failure(); |
516 | |
517 | IntegerAttr constAttr; |
518 | if (!matchPattern(ifCond, m_Constant(&constAttr))) |
519 | return failure(); |
520 | if (constAttr.getInt()) |
521 | rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); }); |
522 | else |
523 | replaceOpWithRegion(rewriter, op, op.getRegion()); |
524 | |
525 | return success(); |
526 | } |
527 | }; |
528 | |
529 | } // namespace |
530 | |
531 | //===----------------------------------------------------------------------===// |
532 | // PrivateRecipeOp |
533 | //===----------------------------------------------------------------------===// |
534 | |
535 | static LogicalResult verifyInitLikeSingleArgRegion( |
536 | Operation *op, Region ®ion, StringRef regionType, StringRef regionName, |
537 | Type type, bool verifyYield, bool optional = false) { |
538 | if (optional && region.empty()) |
539 | return success(); |
540 | |
541 | if (region.empty()) |
542 | return op->emitOpError() << "expects non-empty " << regionName << " region" ; |
543 | Block &firstBlock = region.front(); |
544 | if (firstBlock.getNumArguments() < 1 || |
545 | firstBlock.getArgument(i: 0).getType() != type) |
546 | return op->emitOpError() << "expects " << regionName |
547 | << " region first " |
548 | "argument of the " |
549 | << regionType << " type" ; |
550 | |
551 | if (verifyYield) { |
552 | for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) { |
553 | if (yieldOp.getOperands().size() != 1 || |
554 | yieldOp.getOperands().getTypes()[0] != type) |
555 | return op->emitOpError() << "expects " << regionName |
556 | << " region to " |
557 | "yield a value of the " |
558 | << regionType << " type" ; |
559 | } |
560 | } |
561 | return success(); |
562 | } |
563 | |
564 | LogicalResult acc::PrivateRecipeOp::verifyRegions() { |
565 | if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), |
566 | "privatization" , "init" , getType(), |
567 | /*verifyYield=*/false))) |
568 | return failure(); |
569 | if (failed(verifyInitLikeSingleArgRegion( |
570 | *this, getDestroyRegion(), "privatization" , "destroy" , getType(), |
571 | /*verifyYield=*/false, /*optional=*/true))) |
572 | return failure(); |
573 | return success(); |
574 | } |
575 | |
576 | //===----------------------------------------------------------------------===// |
577 | // FirstprivateRecipeOp |
578 | //===----------------------------------------------------------------------===// |
579 | |
580 | LogicalResult acc::FirstprivateRecipeOp::verifyRegions() { |
581 | if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), |
582 | "privatization" , "init" , getType(), |
583 | /*verifyYield=*/false))) |
584 | return failure(); |
585 | |
586 | if (getCopyRegion().empty()) |
587 | return emitOpError() << "expects non-empty copy region" ; |
588 | |
589 | Block &firstBlock = getCopyRegion().front(); |
590 | if (firstBlock.getNumArguments() < 2 || |
591 | firstBlock.getArgument(0).getType() != getType()) |
592 | return emitOpError() << "expects copy region with two arguments of the " |
593 | "privatization type" ; |
594 | |
595 | if (getDestroyRegion().empty()) |
596 | return success(); |
597 | |
598 | if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(), |
599 | "privatization" , "destroy" , |
600 | getType(), /*verifyYield=*/false))) |
601 | return failure(); |
602 | |
603 | return success(); |
604 | } |
605 | |
606 | //===----------------------------------------------------------------------===// |
607 | // ReductionRecipeOp |
608 | //===----------------------------------------------------------------------===// |
609 | |
610 | LogicalResult acc::ReductionRecipeOp::verifyRegions() { |
611 | if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction" , |
612 | "init" , getType(), |
613 | /*verifyYield=*/false))) |
614 | return failure(); |
615 | |
616 | if (getCombinerRegion().empty()) |
617 | return emitOpError() << "expects non-empty combiner region" ; |
618 | |
619 | Block &reductionBlock = getCombinerRegion().front(); |
620 | if (reductionBlock.getNumArguments() < 2 || |
621 | reductionBlock.getArgument(0).getType() != getType() || |
622 | reductionBlock.getArgument(1).getType() != getType()) |
623 | return emitOpError() << "expects combiner region with the first two " |
624 | << "arguments of the reduction type" ; |
625 | |
626 | for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) { |
627 | if (yieldOp.getOperands().size() != 1 || |
628 | yieldOp.getOperands().getTypes()[0] != getType()) |
629 | return emitOpError() << "expects combiner region to yield a value " |
630 | "of the reduction type" ; |
631 | } |
632 | |
633 | return success(); |
634 | } |
635 | |
636 | //===----------------------------------------------------------------------===// |
637 | // Custom parser and printer verifier for private clause |
638 | //===----------------------------------------------------------------------===// |
639 | |
640 | static ParseResult parseSymOperandList( |
641 | mlir::OpAsmParser &parser, |
642 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, |
643 | llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) { |
644 | llvm::SmallVector<SymbolRefAttr> attributes; |
645 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
646 | if (parser.parseAttribute(attributes.emplace_back()) || |
647 | parser.parseArrow() || |
648 | parser.parseOperand(result&: operands.emplace_back()) || |
649 | parser.parseColonType(result&: types.emplace_back())) |
650 | return failure(); |
651 | return success(); |
652 | }))) |
653 | return failure(); |
654 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
655 | attributes.end()); |
656 | symbols = ArrayAttr::get(parser.getContext(), arrayAttr); |
657 | return success(); |
658 | } |
659 | |
660 | static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, |
661 | mlir::OperandRange operands, |
662 | mlir::TypeRange types, |
663 | std::optional<mlir::ArrayAttr> attributes) { |
664 | llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](auto it) { |
665 | p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " |
666 | << std::get<1>(it).getType(); |
667 | }); |
668 | } |
669 | |
670 | //===----------------------------------------------------------------------===// |
671 | // ParallelOp |
672 | //===----------------------------------------------------------------------===// |
673 | |
674 | /// Check dataOperands for acc.parallel, acc.serial and acc.kernels. |
675 | template <typename Op> |
676 | static LogicalResult checkDataOperands(Op op, |
677 | const mlir::ValueRange &operands) { |
678 | for (mlir::Value operand : operands) |
679 | if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp, |
680 | acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp, |
681 | acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>( |
682 | operand.getDefiningOp())) |
683 | return op.emitError( |
684 | "expect data entry/exit operation or acc.getdeviceptr " |
685 | "as defining op" ); |
686 | return success(); |
687 | } |
688 | |
689 | template <typename Op> |
690 | static LogicalResult |
691 | checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes, |
692 | mlir::OperandRange operands, llvm::StringRef operandName, |
693 | llvm::StringRef symbolName, bool checkOperandType = true) { |
694 | if (!operands.empty()) { |
695 | if (!attributes || attributes->size() != operands.size()) |
696 | return op->emitOpError() |
697 | << "expected as many " << symbolName << " symbol reference as " |
698 | << operandName << " operands" ; |
699 | } else { |
700 | if (attributes) |
701 | return op->emitOpError() |
702 | << "unexpected " << symbolName << " symbol reference" ; |
703 | return success(); |
704 | } |
705 | |
706 | llvm::DenseSet<Value> set; |
707 | for (auto args : llvm::zip(operands, *attributes)) { |
708 | mlir::Value operand = std::get<0>(args); |
709 | |
710 | if (!set.insert(operand).second) |
711 | return op->emitOpError() |
712 | << operandName << " operand appears more than once" ; |
713 | |
714 | mlir::Type varType = operand.getType(); |
715 | auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args)); |
716 | auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef); |
717 | if (!decl) |
718 | return op->emitOpError() |
719 | << "expected symbol reference " << symbolRef << " to point to a " |
720 | << operandName << " declaration" ; |
721 | |
722 | if (checkOperandType && decl.getType() && decl.getType() != varType) |
723 | return op->emitOpError() << "expected " << operandName << " (" << varType |
724 | << ") to be the same type as " << operandName |
725 | << " declaration (" << decl.getType() << ")" ; |
726 | } |
727 | |
728 | return success(); |
729 | } |
730 | |
731 | unsigned ParallelOp::getNumDataOperands() { |
732 | return getReductionOperands().size() + getGangPrivateOperands().size() + |
733 | getGangFirstPrivateOperands().size() + getDataClauseOperands().size(); |
734 | } |
735 | |
736 | Value ParallelOp::getDataOperand(unsigned i) { |
737 | unsigned numOptional = getAsyncOperands().size(); |
738 | numOptional += getNumGangs().size(); |
739 | numOptional += getNumWorkers().size(); |
740 | numOptional += getVectorLength().size(); |
741 | numOptional += getIfCond() ? 1 : 0; |
742 | numOptional += getSelfCond() ? 1 : 0; |
743 | return getOperand(getWaitOperands().size() + numOptional + i); |
744 | } |
745 | |
746 | template <typename Op> |
747 | static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, |
748 | ArrayAttr deviceTypes, |
749 | llvm::StringRef keyword) { |
750 | if (!operands.empty() && deviceTypes.getValue().size() != operands.size()) |
751 | return op.emitOpError() << keyword << " operands count must match " |
752 | << keyword << " device_type count" ; |
753 | return success(); |
754 | } |
755 | |
756 | template <typename Op> |
757 | static LogicalResult verifyDeviceTypeAndSegmentCountMatch( |
758 | Op op, OperandRange operands, DenseI32ArrayAttr segments, |
759 | ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) { |
760 | std::size_t numOperandsInSegments = 0; |
761 | |
762 | if (!segments) |
763 | return success(); |
764 | |
765 | for (auto segCount : segments.asArrayRef()) { |
766 | if (maxInSegment != 0 && segCount > maxInSegment) |
767 | return op.emitOpError() << keyword << " expects a maximum of " |
768 | << maxInSegment << " values per segment" ; |
769 | numOperandsInSegments += segCount; |
770 | } |
771 | if (numOperandsInSegments != operands.size()) |
772 | return op.emitOpError() |
773 | << keyword << " operand count does not match count in segments" ; |
774 | if (deviceTypes.getValue().size() != (size_t)segments.size()) |
775 | return op.emitOpError() |
776 | << keyword << " segment count does not match device_type count" ; |
777 | return success(); |
778 | } |
779 | |
780 | LogicalResult acc::ParallelOp::verify() { |
781 | if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( |
782 | *this, getPrivatizations(), getGangPrivateOperands(), "private" , |
783 | "privatizations" , /*checkOperandType=*/false))) |
784 | return failure(); |
785 | if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( |
786 | *this, getReductionRecipes(), getReductionOperands(), "reduction" , |
787 | "reductions" , false))) |
788 | return failure(); |
789 | |
790 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
791 | *this, getNumGangs(), getNumGangsSegmentsAttr(), |
792 | getNumGangsDeviceTypeAttr(), "num_gangs" , 3))) |
793 | return failure(); |
794 | |
795 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
796 | *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), |
797 | getWaitOperandsDeviceTypeAttr(), "wait" ))) |
798 | return failure(); |
799 | |
800 | if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), |
801 | getNumWorkersDeviceTypeAttr(), |
802 | "num_workers" ))) |
803 | return failure(); |
804 | |
805 | if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), |
806 | getVectorLengthDeviceTypeAttr(), |
807 | "vector_length" ))) |
808 | return failure(); |
809 | |
810 | if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), |
811 | getAsyncOperandsDeviceTypeAttr(), |
812 | "async" ))) |
813 | return failure(); |
814 | |
815 | if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this))) |
816 | return failure(); |
817 | |
818 | return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands()); |
819 | } |
820 | |
821 | static mlir::Value |
822 | getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr, |
823 | mlir::Operation::operand_range range, |
824 | mlir::acc::DeviceType deviceType) { |
825 | if (!arrayAttr) |
826 | return {}; |
827 | if (auto pos = findSegment(*arrayAttr, deviceType)) |
828 | return range[*pos]; |
829 | return {}; |
830 | } |
831 | |
832 | bool acc::ParallelOp::hasAsyncOnly() { |
833 | return hasAsyncOnly(mlir::acc::DeviceType::None); |
834 | } |
835 | |
836 | bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { |
837 | return hasDeviceType(getAsyncOnly(), deviceType); |
838 | } |
839 | |
840 | mlir::Value acc::ParallelOp::getAsyncValue() { |
841 | return getAsyncValue(mlir::acc::DeviceType::None); |
842 | } |
843 | |
844 | mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) { |
845 | return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), |
846 | getAsyncOperands(), deviceType); |
847 | } |
848 | |
849 | mlir::Value acc::ParallelOp::getNumWorkersValue() { |
850 | return getNumWorkersValue(mlir::acc::DeviceType::None); |
851 | } |
852 | |
853 | mlir::Value |
854 | acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { |
855 | return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), |
856 | deviceType); |
857 | } |
858 | |
859 | mlir::Value acc::ParallelOp::getVectorLengthValue() { |
860 | return getVectorLengthValue(mlir::acc::DeviceType::None); |
861 | } |
862 | |
863 | mlir::Value |
864 | acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { |
865 | return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), |
866 | getVectorLength(), deviceType); |
867 | } |
868 | |
869 | mlir::Operation::operand_range ParallelOp::getNumGangsValues() { |
870 | return getNumGangsValues(mlir::acc::DeviceType::None); |
871 | } |
872 | |
873 | mlir::Operation::operand_range |
874 | ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { |
875 | return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), |
876 | getNumGangsSegments(), deviceType); |
877 | } |
878 | |
879 | bool acc::ParallelOp::hasWaitOnly() { |
880 | return hasWaitOnly(mlir::acc::DeviceType::None); |
881 | } |
882 | |
883 | bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { |
884 | return hasDeviceType(getWaitOnly(), deviceType); |
885 | } |
886 | |
887 | mlir::Operation::operand_range ParallelOp::getWaitValues() { |
888 | return getWaitValues(mlir::acc::DeviceType::None); |
889 | } |
890 | |
891 | mlir::Operation::operand_range |
892 | ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) { |
893 | return getWaitValuesWithoutDevnum( |
894 | getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), |
895 | getHasWaitDevnum(), deviceType); |
896 | } |
897 | |
898 | mlir::Value ParallelOp::getWaitDevnum() { |
899 | return getWaitDevnum(mlir::acc::DeviceType::None); |
900 | } |
901 | |
902 | mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { |
903 | return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), |
904 | getWaitOperandsSegments(), getHasWaitDevnum(), |
905 | deviceType); |
906 | } |
907 | |
908 | static ParseResult parseNumGangs( |
909 | mlir::OpAsmParser &parser, |
910 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, |
911 | llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, |
912 | mlir::DenseI32ArrayAttr &segments) { |
913 | llvm::SmallVector<DeviceTypeAttr> attributes; |
914 | llvm::SmallVector<int32_t> seg; |
915 | |
916 | do { |
917 | if (failed(result: parser.parseLBrace())) |
918 | return failure(); |
919 | |
920 | int32_t crtOperandsSize = operands.size(); |
921 | if (failed(result: parser.parseCommaSeparatedList( |
922 | delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() { |
923 | if (parser.parseOperand(result&: operands.emplace_back()) || |
924 | parser.parseColonType(result&: types.emplace_back())) |
925 | return failure(); |
926 | return success(); |
927 | }))) |
928 | return failure(); |
929 | seg.push_back(Elt: operands.size() - crtOperandsSize); |
930 | |
931 | if (failed(result: parser.parseRBrace())) |
932 | return failure(); |
933 | |
934 | if (succeeded(result: parser.parseOptionalLSquare())) { |
935 | if (parser.parseAttribute(attributes.emplace_back()) || |
936 | parser.parseRSquare()) |
937 | return failure(); |
938 | } else { |
939 | attributes.push_back(mlir::acc::DeviceTypeAttr::get( |
940 | parser.getContext(), mlir::acc::DeviceType::None)); |
941 | } |
942 | } while (succeeded(result: parser.parseOptionalComma())); |
943 | |
944 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
945 | attributes.end()); |
946 | deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); |
947 | segments = DenseI32ArrayAttr::get(parser.getContext(), seg); |
948 | |
949 | return success(); |
950 | } |
951 | |
952 | static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr) { |
953 | auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); |
954 | if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None) |
955 | p << " [" << attr << "]" ; |
956 | } |
957 | |
958 | static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, |
959 | mlir::OperandRange operands, mlir::TypeRange types, |
960 | std::optional<mlir::ArrayAttr> deviceTypes, |
961 | std::optional<mlir::DenseI32ArrayAttr> segments) { |
962 | unsigned opIdx = 0; |
963 | llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { |
964 | p << "{" ; |
965 | llvm::interleaveComma( |
966 | llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { |
967 | p << operands[opIdx] << " : " << operands[opIdx].getType(); |
968 | ++opIdx; |
969 | }); |
970 | p << "}" ; |
971 | printSingleDeviceType(p, it.value()); |
972 | }); |
973 | } |
974 | |
975 | static ParseResult parseDeviceTypeOperandsWithSegment( |
976 | mlir::OpAsmParser &parser, |
977 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, |
978 | llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, |
979 | mlir::DenseI32ArrayAttr &segments) { |
980 | llvm::SmallVector<DeviceTypeAttr> attributes; |
981 | llvm::SmallVector<int32_t> seg; |
982 | |
983 | do { |
984 | if (failed(result: parser.parseLBrace())) |
985 | return failure(); |
986 | |
987 | int32_t crtOperandsSize = operands.size(); |
988 | |
989 | if (failed(result: parser.parseCommaSeparatedList( |
990 | delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() { |
991 | if (parser.parseOperand(result&: operands.emplace_back()) || |
992 | parser.parseColonType(result&: types.emplace_back())) |
993 | return failure(); |
994 | return success(); |
995 | }))) |
996 | return failure(); |
997 | |
998 | seg.push_back(Elt: operands.size() - crtOperandsSize); |
999 | |
1000 | if (failed(result: parser.parseRBrace())) |
1001 | return failure(); |
1002 | |
1003 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1004 | if (parser.parseAttribute(attributes.emplace_back()) || |
1005 | parser.parseRSquare()) |
1006 | return failure(); |
1007 | } else { |
1008 | attributes.push_back(mlir::acc::DeviceTypeAttr::get( |
1009 | parser.getContext(), mlir::acc::DeviceType::None)); |
1010 | } |
1011 | } while (succeeded(result: parser.parseOptionalComma())); |
1012 | |
1013 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
1014 | attributes.end()); |
1015 | deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); |
1016 | segments = DenseI32ArrayAttr::get(parser.getContext(), seg); |
1017 | |
1018 | return success(); |
1019 | } |
1020 | |
1021 | static void printDeviceTypeOperandsWithSegment( |
1022 | mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, |
1023 | mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes, |
1024 | std::optional<mlir::DenseI32ArrayAttr> segments) { |
1025 | unsigned opIdx = 0; |
1026 | llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { |
1027 | p << "{" ; |
1028 | llvm::interleaveComma( |
1029 | llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { |
1030 | p << operands[opIdx] << " : " << operands[opIdx].getType(); |
1031 | ++opIdx; |
1032 | }); |
1033 | p << "}" ; |
1034 | printSingleDeviceType(p, it.value()); |
1035 | }); |
1036 | } |
1037 | |
1038 | static ParseResult parseWaitClause( |
1039 | mlir::OpAsmParser &parser, |
1040 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, |
1041 | llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, |
1042 | mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, |
1043 | mlir::ArrayAttr &keywordOnly) { |
1044 | llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum; |
1045 | llvm::SmallVector<int32_t> seg; |
1046 | |
1047 | bool needCommaBeforeOperands = false; |
1048 | |
1049 | // Keyword only |
1050 | if (failed(result: parser.parseOptionalLParen())) { |
1051 | keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get( |
1052 | parser.getContext(), mlir::acc::DeviceType::None)); |
1053 | keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs); |
1054 | return success(); |
1055 | } |
1056 | |
1057 | // Parse keyword only attributes |
1058 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1059 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
1060 | if (parser.parseAttribute(result&: keywordAttrs.emplace_back())) |
1061 | return failure(); |
1062 | return success(); |
1063 | }))) |
1064 | return failure(); |
1065 | if (parser.parseRSquare()) |
1066 | return failure(); |
1067 | needCommaBeforeOperands = true; |
1068 | } |
1069 | |
1070 | if (needCommaBeforeOperands && failed(result: parser.parseComma())) |
1071 | return failure(); |
1072 | |
1073 | do { |
1074 | if (failed(result: parser.parseLBrace())) |
1075 | return failure(); |
1076 | |
1077 | int32_t crtOperandsSize = operands.size(); |
1078 | |
1079 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "devnum" ))) { |
1080 | if (failed(result: parser.parseColon())) |
1081 | return failure(); |
1082 | devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: true)); |
1083 | } else { |
1084 | devnum.push_back(Elt: BoolAttr::get(context: parser.getContext(), value: false)); |
1085 | } |
1086 | |
1087 | if (failed(result: parser.parseCommaSeparatedList( |
1088 | delimiter: mlir::AsmParser::Delimiter::None, parseElementFn: [&]() { |
1089 | if (parser.parseOperand(result&: operands.emplace_back()) || |
1090 | parser.parseColonType(result&: types.emplace_back())) |
1091 | return failure(); |
1092 | return success(); |
1093 | }))) |
1094 | return failure(); |
1095 | |
1096 | seg.push_back(Elt: operands.size() - crtOperandsSize); |
1097 | |
1098 | if (failed(result: parser.parseRBrace())) |
1099 | return failure(); |
1100 | |
1101 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1102 | if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) || |
1103 | parser.parseRSquare()) |
1104 | return failure(); |
1105 | } else { |
1106 | deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( |
1107 | parser.getContext(), mlir::acc::DeviceType::None)); |
1108 | } |
1109 | } while (succeeded(result: parser.parseOptionalComma())); |
1110 | |
1111 | if (failed(result: parser.parseRParen())) |
1112 | return failure(); |
1113 | |
1114 | deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); |
1115 | keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs); |
1116 | segments = DenseI32ArrayAttr::get(parser.getContext(), seg); |
1117 | hasDevNum = ArrayAttr::get(parser.getContext(), devnum); |
1118 | |
1119 | return success(); |
1120 | } |
1121 | |
1122 | static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) { |
1123 | if (!hasDeviceTypeValues(arrayAttr: attrs)) |
1124 | return false; |
1125 | if (attrs->size() != 1) |
1126 | return false; |
1127 | if (auto deviceTypeAttr = |
1128 | mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0])) |
1129 | return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None; |
1130 | return false; |
1131 | } |
1132 | |
1133 | static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, |
1134 | mlir::OperandRange operands, mlir::TypeRange types, |
1135 | std::optional<mlir::ArrayAttr> deviceTypes, |
1136 | std::optional<mlir::DenseI32ArrayAttr> segments, |
1137 | std::optional<mlir::ArrayAttr> hasDevNum, |
1138 | std::optional<mlir::ArrayAttr> keywordOnly) { |
1139 | |
1140 | if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(attrs: keywordOnly)) |
1141 | return; |
1142 | |
1143 | p << "(" ; |
1144 | |
1145 | printDeviceTypes(p, deviceTypes: keywordOnly); |
1146 | if (hasDeviceTypeValues(arrayAttr: keywordOnly) && hasDeviceTypeValues(arrayAttr: deviceTypes)) |
1147 | p << ", " ; |
1148 | |
1149 | unsigned opIdx = 0; |
1150 | llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { |
1151 | p << "{" ; |
1152 | auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]); |
1153 | if (boolAttr && boolAttr.getValue()) |
1154 | p << "devnum: " ; |
1155 | llvm::interleaveComma( |
1156 | llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { |
1157 | p << operands[opIdx] << " : " << operands[opIdx].getType(); |
1158 | ++opIdx; |
1159 | }); |
1160 | p << "}" ; |
1161 | printSingleDeviceType(p, it.value()); |
1162 | }); |
1163 | |
1164 | p << ")" ; |
1165 | } |
1166 | |
1167 | static ParseResult parseDeviceTypeOperands( |
1168 | mlir::OpAsmParser &parser, |
1169 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, |
1170 | llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) { |
1171 | llvm::SmallVector<DeviceTypeAttr> attributes; |
1172 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
1173 | if (parser.parseOperand(result&: operands.emplace_back()) || |
1174 | parser.parseColonType(result&: types.emplace_back())) |
1175 | return failure(); |
1176 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1177 | if (parser.parseAttribute(attributes.emplace_back()) || |
1178 | parser.parseRSquare()) |
1179 | return failure(); |
1180 | } else { |
1181 | attributes.push_back(mlir::acc::DeviceTypeAttr::get( |
1182 | parser.getContext(), mlir::acc::DeviceType::None)); |
1183 | } |
1184 | return success(); |
1185 | }))) |
1186 | return failure(); |
1187 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
1188 | attributes.end()); |
1189 | deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); |
1190 | return success(); |
1191 | } |
1192 | |
1193 | static void |
1194 | printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, |
1195 | mlir::OperandRange operands, mlir::TypeRange types, |
1196 | std::optional<mlir::ArrayAttr> deviceTypes) { |
1197 | if (!hasDeviceTypeValues(arrayAttr: deviceTypes)) |
1198 | return; |
1199 | llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) { |
1200 | p << std::get<1>(it) << " : " << std::get<1>(it).getType(); |
1201 | printSingleDeviceType(p, std::get<0>(it)); |
1202 | }); |
1203 | } |
1204 | |
1205 | static ParseResult parseDeviceTypeOperandsWithKeywordOnly( |
1206 | mlir::OpAsmParser &parser, |
1207 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, |
1208 | llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes, |
1209 | mlir::ArrayAttr &keywordOnlyDeviceType) { |
1210 | |
1211 | llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes; |
1212 | bool needCommaBeforeOperands = false; |
1213 | |
1214 | if (failed(result: parser.parseOptionalLParen())) { |
1215 | // Keyword only |
1216 | keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( |
1217 | parser.getContext(), mlir::acc::DeviceType::None)); |
1218 | keywordOnlyDeviceType = |
1219 | ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes); |
1220 | return success(); |
1221 | } |
1222 | |
1223 | // Parse keyword only attributes |
1224 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1225 | // Parse keyword only attributes |
1226 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
1227 | if (parser.parseAttribute( |
1228 | result&: keywordOnlyDeviceTypeAttributes.emplace_back())) |
1229 | return failure(); |
1230 | return success(); |
1231 | }))) |
1232 | return failure(); |
1233 | if (parser.parseRSquare()) |
1234 | return failure(); |
1235 | needCommaBeforeOperands = true; |
1236 | } |
1237 | |
1238 | if (needCommaBeforeOperands && failed(result: parser.parseComma())) |
1239 | return failure(); |
1240 | |
1241 | llvm::SmallVector<DeviceTypeAttr> attributes; |
1242 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
1243 | if (parser.parseOperand(result&: operands.emplace_back()) || |
1244 | parser.parseColonType(result&: types.emplace_back())) |
1245 | return failure(); |
1246 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1247 | if (parser.parseAttribute(attributes.emplace_back()) || |
1248 | parser.parseRSquare()) |
1249 | return failure(); |
1250 | } else { |
1251 | attributes.push_back(mlir::acc::DeviceTypeAttr::get( |
1252 | parser.getContext(), mlir::acc::DeviceType::None)); |
1253 | } |
1254 | return success(); |
1255 | }))) |
1256 | return failure(); |
1257 | |
1258 | if (failed(result: parser.parseRParen())) |
1259 | return failure(); |
1260 | |
1261 | llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
1262 | attributes.end()); |
1263 | deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr); |
1264 | return success(); |
1265 | } |
1266 | |
1267 | static void printDeviceTypeOperandsWithKeywordOnly( |
1268 | mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, |
1269 | mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes, |
1270 | std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) { |
1271 | |
1272 | if (operands.begin() == operands.end() && |
1273 | hasOnlyDeviceTypeNone(attrs: keywordOnlyDeviceTypes)) { |
1274 | return; |
1275 | } |
1276 | |
1277 | p << "(" ; |
1278 | printDeviceTypes(p, deviceTypes: keywordOnlyDeviceTypes); |
1279 | if (hasDeviceTypeValues(arrayAttr: keywordOnlyDeviceTypes) && |
1280 | hasDeviceTypeValues(arrayAttr: deviceTypes)) |
1281 | p << ", " ; |
1282 | printDeviceTypeOperands(p, op, operands, types, deviceTypes); |
1283 | p << ")" ; |
1284 | } |
1285 | |
1286 | static ParseResult |
1287 | parseCombinedConstructsLoop(mlir::OpAsmParser &parser, |
1288 | mlir::acc::CombinedConstructsTypeAttr &attr) { |
1289 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "combined" ))) { |
1290 | if (parser.parseLParen()) |
1291 | return failure(); |
1292 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "kernels" ))) { |
1293 | attr = mlir::acc::CombinedConstructsTypeAttr::get( |
1294 | parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop); |
1295 | } else if (succeeded(result: parser.parseOptionalKeyword(keyword: "parallel" ))) { |
1296 | attr = mlir::acc::CombinedConstructsTypeAttr::get( |
1297 | parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop); |
1298 | } else if (succeeded(result: parser.parseOptionalKeyword(keyword: "serial" ))) { |
1299 | attr = mlir::acc::CombinedConstructsTypeAttr::get( |
1300 | parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop); |
1301 | } else { |
1302 | parser.emitError(loc: parser.getCurrentLocation(), |
1303 | message: "expected compute construct name" ); |
1304 | return failure(); |
1305 | } |
1306 | if (parser.parseRParen()) |
1307 | return failure(); |
1308 | } |
1309 | return success(); |
1310 | } |
1311 | |
1312 | static void |
1313 | printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, |
1314 | mlir::acc::CombinedConstructsTypeAttr attr) { |
1315 | if (attr) { |
1316 | switch (attr.getValue()) { |
1317 | case mlir::acc::CombinedConstructsType::KernelsLoop: |
1318 | p << "combined(kernels)" ; |
1319 | break; |
1320 | case mlir::acc::CombinedConstructsType::ParallelLoop: |
1321 | p << "combined(parallel)" ; |
1322 | break; |
1323 | case mlir::acc::CombinedConstructsType::SerialLoop: |
1324 | p << "combined(serial)" ; |
1325 | break; |
1326 | }; |
1327 | } |
1328 | } |
1329 | |
1330 | //===----------------------------------------------------------------------===// |
1331 | // SerialOp |
1332 | //===----------------------------------------------------------------------===// |
1333 | |
1334 | unsigned SerialOp::getNumDataOperands() { |
1335 | return getReductionOperands().size() + getGangPrivateOperands().size() + |
1336 | getGangFirstPrivateOperands().size() + getDataClauseOperands().size(); |
1337 | } |
1338 | |
1339 | Value SerialOp::getDataOperand(unsigned i) { |
1340 | unsigned numOptional = getAsyncOperands().size(); |
1341 | numOptional += getIfCond() ? 1 : 0; |
1342 | numOptional += getSelfCond() ? 1 : 0; |
1343 | return getOperand(getWaitOperands().size() + numOptional + i); |
1344 | } |
1345 | |
1346 | bool acc::SerialOp::hasAsyncOnly() { |
1347 | return hasAsyncOnly(mlir::acc::DeviceType::None); |
1348 | } |
1349 | |
1350 | bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { |
1351 | return hasDeviceType(getAsyncOnly(), deviceType); |
1352 | } |
1353 | |
1354 | mlir::Value acc::SerialOp::getAsyncValue() { |
1355 | return getAsyncValue(mlir::acc::DeviceType::None); |
1356 | } |
1357 | |
1358 | mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) { |
1359 | return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), |
1360 | getAsyncOperands(), deviceType); |
1361 | } |
1362 | |
1363 | bool acc::SerialOp::hasWaitOnly() { |
1364 | return hasWaitOnly(mlir::acc::DeviceType::None); |
1365 | } |
1366 | |
1367 | bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { |
1368 | return hasDeviceType(getWaitOnly(), deviceType); |
1369 | } |
1370 | |
1371 | mlir::Operation::operand_range SerialOp::getWaitValues() { |
1372 | return getWaitValues(mlir::acc::DeviceType::None); |
1373 | } |
1374 | |
1375 | mlir::Operation::operand_range |
1376 | SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) { |
1377 | return getWaitValuesWithoutDevnum( |
1378 | getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), |
1379 | getHasWaitDevnum(), deviceType); |
1380 | } |
1381 | |
1382 | mlir::Value SerialOp::getWaitDevnum() { |
1383 | return getWaitDevnum(mlir::acc::DeviceType::None); |
1384 | } |
1385 | |
1386 | mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { |
1387 | return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), |
1388 | getWaitOperandsSegments(), getHasWaitDevnum(), |
1389 | deviceType); |
1390 | } |
1391 | |
1392 | LogicalResult acc::SerialOp::verify() { |
1393 | if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( |
1394 | *this, getPrivatizations(), getGangPrivateOperands(), "private" , |
1395 | "privatizations" , /*checkOperandType=*/false))) |
1396 | return failure(); |
1397 | if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( |
1398 | *this, getReductionRecipes(), getReductionOperands(), "reduction" , |
1399 | "reductions" , false))) |
1400 | return failure(); |
1401 | |
1402 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
1403 | *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), |
1404 | getWaitOperandsDeviceTypeAttr(), "wait" ))) |
1405 | return failure(); |
1406 | |
1407 | if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), |
1408 | getAsyncOperandsDeviceTypeAttr(), |
1409 | "async" ))) |
1410 | return failure(); |
1411 | |
1412 | if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this))) |
1413 | return failure(); |
1414 | |
1415 | return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands()); |
1416 | } |
1417 | |
1418 | //===----------------------------------------------------------------------===// |
1419 | // KernelsOp |
1420 | //===----------------------------------------------------------------------===// |
1421 | |
1422 | unsigned KernelsOp::getNumDataOperands() { |
1423 | return getDataClauseOperands().size(); |
1424 | } |
1425 | |
1426 | Value KernelsOp::getDataOperand(unsigned i) { |
1427 | unsigned numOptional = getAsyncOperands().size(); |
1428 | numOptional += getWaitOperands().size(); |
1429 | numOptional += getNumGangs().size(); |
1430 | numOptional += getNumWorkers().size(); |
1431 | numOptional += getVectorLength().size(); |
1432 | numOptional += getIfCond() ? 1 : 0; |
1433 | numOptional += getSelfCond() ? 1 : 0; |
1434 | return getOperand(numOptional + i); |
1435 | } |
1436 | |
1437 | bool acc::KernelsOp::hasAsyncOnly() { |
1438 | return hasAsyncOnly(mlir::acc::DeviceType::None); |
1439 | } |
1440 | |
1441 | bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { |
1442 | return hasDeviceType(getAsyncOnly(), deviceType); |
1443 | } |
1444 | |
1445 | mlir::Value acc::KernelsOp::getAsyncValue() { |
1446 | return getAsyncValue(mlir::acc::DeviceType::None); |
1447 | } |
1448 | |
1449 | mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) { |
1450 | return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), |
1451 | getAsyncOperands(), deviceType); |
1452 | } |
1453 | |
1454 | mlir::Value acc::KernelsOp::getNumWorkersValue() { |
1455 | return getNumWorkersValue(mlir::acc::DeviceType::None); |
1456 | } |
1457 | |
1458 | mlir::Value |
1459 | acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) { |
1460 | return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(), |
1461 | deviceType); |
1462 | } |
1463 | |
1464 | mlir::Value acc::KernelsOp::getVectorLengthValue() { |
1465 | return getVectorLengthValue(mlir::acc::DeviceType::None); |
1466 | } |
1467 | |
1468 | mlir::Value |
1469 | acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) { |
1470 | return getValueInDeviceTypeSegment(getVectorLengthDeviceType(), |
1471 | getVectorLength(), deviceType); |
1472 | } |
1473 | |
1474 | mlir::Operation::operand_range KernelsOp::getNumGangsValues() { |
1475 | return getNumGangsValues(mlir::acc::DeviceType::None); |
1476 | } |
1477 | |
1478 | mlir::Operation::operand_range |
1479 | KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) { |
1480 | return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(), |
1481 | getNumGangsSegments(), deviceType); |
1482 | } |
1483 | |
1484 | bool acc::KernelsOp::hasWaitOnly() { |
1485 | return hasWaitOnly(mlir::acc::DeviceType::None); |
1486 | } |
1487 | |
1488 | bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { |
1489 | return hasDeviceType(getWaitOnly(), deviceType); |
1490 | } |
1491 | |
1492 | mlir::Operation::operand_range KernelsOp::getWaitValues() { |
1493 | return getWaitValues(mlir::acc::DeviceType::None); |
1494 | } |
1495 | |
1496 | mlir::Operation::operand_range |
1497 | KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) { |
1498 | return getWaitValuesWithoutDevnum( |
1499 | getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), |
1500 | getHasWaitDevnum(), deviceType); |
1501 | } |
1502 | |
1503 | mlir::Value KernelsOp::getWaitDevnum() { |
1504 | return getWaitDevnum(mlir::acc::DeviceType::None); |
1505 | } |
1506 | |
1507 | mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { |
1508 | return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), |
1509 | getWaitOperandsSegments(), getHasWaitDevnum(), |
1510 | deviceType); |
1511 | } |
1512 | |
1513 | LogicalResult acc::KernelsOp::verify() { |
1514 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
1515 | *this, getNumGangs(), getNumGangsSegmentsAttr(), |
1516 | getNumGangsDeviceTypeAttr(), "num_gangs" , 3))) |
1517 | return failure(); |
1518 | |
1519 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
1520 | *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), |
1521 | getWaitOperandsDeviceTypeAttr(), "wait" ))) |
1522 | return failure(); |
1523 | |
1524 | if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(), |
1525 | getNumWorkersDeviceTypeAttr(), |
1526 | "num_workers" ))) |
1527 | return failure(); |
1528 | |
1529 | if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(), |
1530 | getVectorLengthDeviceTypeAttr(), |
1531 | "vector_length" ))) |
1532 | return failure(); |
1533 | |
1534 | if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), |
1535 | getAsyncOperandsDeviceTypeAttr(), |
1536 | "async" ))) |
1537 | return failure(); |
1538 | |
1539 | if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this))) |
1540 | return failure(); |
1541 | |
1542 | return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands()); |
1543 | } |
1544 | |
1545 | //===----------------------------------------------------------------------===// |
1546 | // HostDataOp |
1547 | //===----------------------------------------------------------------------===// |
1548 | |
1549 | LogicalResult acc::HostDataOp::verify() { |
1550 | if (getDataClauseOperands().empty()) |
1551 | return emitError("at least one operand must appear on the host_data " |
1552 | "operation" ); |
1553 | |
1554 | for (mlir::Value operand : getDataClauseOperands()) |
1555 | if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp())) |
1556 | return emitError("expect data entry operation as defining op" ); |
1557 | return success(); |
1558 | } |
1559 | |
1560 | void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1561 | MLIRContext *context) { |
1562 | results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context); |
1563 | } |
1564 | |
1565 | //===----------------------------------------------------------------------===// |
1566 | // LoopOp |
1567 | //===----------------------------------------------------------------------===// |
1568 | |
1569 | static ParseResult parseGangValue( |
1570 | OpAsmParser &parser, llvm::StringRef keyword, |
1571 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, |
1572 | llvm::SmallVectorImpl<Type> &types, |
1573 | llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType, |
1574 | bool &needCommaBetweenValues, bool &newValue) { |
1575 | if (succeeded(result: parser.parseOptionalKeyword(keyword))) { |
1576 | if (parser.parseEqual()) |
1577 | return failure(); |
1578 | if (parser.parseOperand(result&: operands.emplace_back()) || |
1579 | parser.parseColonType(result&: types.emplace_back())) |
1580 | return failure(); |
1581 | attributes.push_back(gangArgType); |
1582 | needCommaBetweenValues = true; |
1583 | newValue = true; |
1584 | } |
1585 | return success(); |
1586 | } |
1587 | |
1588 | static ParseResult parseGangClause( |
1589 | OpAsmParser &parser, |
1590 | llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &gangOperands, |
1591 | llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType, |
1592 | mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, |
1593 | mlir::ArrayAttr &gangOnlyDeviceType) { |
1594 | llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes; |
1595 | llvm::SmallVector<mlir::Attribute> deviceTypeAttributes; |
1596 | llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes; |
1597 | llvm::SmallVector<int32_t> seg; |
1598 | bool needCommaBetweenValues = false; |
1599 | bool needCommaBeforeOperands = false; |
1600 | |
1601 | if (failed(result: parser.parseOptionalLParen())) { |
1602 | // Gang only keyword |
1603 | gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( |
1604 | parser.getContext(), mlir::acc::DeviceType::None)); |
1605 | gangOnlyDeviceType = |
1606 | ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes); |
1607 | return success(); |
1608 | } |
1609 | |
1610 | // Parse gang only attributes |
1611 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1612 | // Parse gang only attributes |
1613 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
1614 | if (parser.parseAttribute( |
1615 | result&: gangOnlyDeviceTypeAttributes.emplace_back())) |
1616 | return failure(); |
1617 | return success(); |
1618 | }))) |
1619 | return failure(); |
1620 | if (parser.parseRSquare()) |
1621 | return failure(); |
1622 | needCommaBeforeOperands = true; |
1623 | } |
1624 | |
1625 | auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(), |
1626 | mlir::acc::GangArgType::Num); |
1627 | auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(), |
1628 | mlir::acc::GangArgType::Dim); |
1629 | auto argStatic = mlir::acc::GangArgTypeAttr::get( |
1630 | parser.getContext(), mlir::acc::GangArgType::Static); |
1631 | |
1632 | do { |
1633 | if (needCommaBeforeOperands) { |
1634 | needCommaBeforeOperands = false; |
1635 | continue; |
1636 | } |
1637 | |
1638 | if (failed(result: parser.parseLBrace())) |
1639 | return failure(); |
1640 | |
1641 | int32_t crtOperandsSize = gangOperands.size(); |
1642 | while (true) { |
1643 | bool newValue = false; |
1644 | bool needValue = false; |
1645 | if (needCommaBetweenValues) { |
1646 | if (succeeded(result: parser.parseOptionalComma())) |
1647 | needValue = true; // expect a new value after comma. |
1648 | else |
1649 | break; |
1650 | } |
1651 | |
1652 | if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(), |
1653 | gangOperands, gangOperandsType, |
1654 | gangArgTypeAttributes, argNum, |
1655 | needCommaBetweenValues, newValue))) |
1656 | return failure(); |
1657 | if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(), |
1658 | gangOperands, gangOperandsType, |
1659 | gangArgTypeAttributes, argDim, |
1660 | needCommaBetweenValues, newValue))) |
1661 | return failure(); |
1662 | if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(), |
1663 | gangOperands, gangOperandsType, |
1664 | gangArgTypeAttributes, argStatic, |
1665 | needCommaBetweenValues, newValue))) |
1666 | return failure(); |
1667 | |
1668 | if (!newValue && needValue) { |
1669 | parser.emitError(loc: parser.getCurrentLocation(), |
1670 | message: "new value expected after comma" ); |
1671 | return failure(); |
1672 | } |
1673 | |
1674 | if (!newValue) |
1675 | break; |
1676 | } |
1677 | |
1678 | if (gangOperands.empty()) |
1679 | return parser.emitError( |
1680 | loc: parser.getCurrentLocation(), |
1681 | message: "expect at least one of num, dim or static values" ); |
1682 | |
1683 | if (failed(result: parser.parseRBrace())) |
1684 | return failure(); |
1685 | |
1686 | if (succeeded(result: parser.parseOptionalLSquare())) { |
1687 | if (parser.parseAttribute(result&: deviceTypeAttributes.emplace_back()) || |
1688 | parser.parseRSquare()) |
1689 | return failure(); |
1690 | } else { |
1691 | deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get( |
1692 | parser.getContext(), mlir::acc::DeviceType::None)); |
1693 | } |
1694 | |
1695 | seg.push_back(Elt: gangOperands.size() - crtOperandsSize); |
1696 | |
1697 | } while (succeeded(result: parser.parseOptionalComma())); |
1698 | |
1699 | if (failed(result: parser.parseRParen())) |
1700 | return failure(); |
1701 | |
1702 | llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(), |
1703 | gangArgTypeAttributes.end()); |
1704 | gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr); |
1705 | deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes); |
1706 | |
1707 | llvm::SmallVector<mlir::Attribute> gangOnlyAttr( |
1708 | gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end()); |
1709 | gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr); |
1710 | |
1711 | segments = DenseI32ArrayAttr::get(parser.getContext(), seg); |
1712 | return success(); |
1713 | } |
1714 | |
1715 | void printGangClause(OpAsmPrinter &p, Operation *op, |
1716 | mlir::OperandRange operands, mlir::TypeRange types, |
1717 | std::optional<mlir::ArrayAttr> gangArgTypes, |
1718 | std::optional<mlir::ArrayAttr> deviceTypes, |
1719 | std::optional<mlir::DenseI32ArrayAttr> segments, |
1720 | std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) { |
1721 | |
1722 | if (operands.begin() == operands.end() && |
1723 | hasOnlyDeviceTypeNone(attrs: gangOnlyDeviceTypes)) { |
1724 | return; |
1725 | } |
1726 | |
1727 | p << "(" ; |
1728 | |
1729 | printDeviceTypes(p, deviceTypes: gangOnlyDeviceTypes); |
1730 | |
1731 | if (hasDeviceTypeValues(arrayAttr: gangOnlyDeviceTypes) && |
1732 | hasDeviceTypeValues(arrayAttr: deviceTypes)) |
1733 | p << ", " ; |
1734 | |
1735 | if (hasDeviceTypeValues(arrayAttr: deviceTypes)) { |
1736 | unsigned opIdx = 0; |
1737 | llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) { |
1738 | p << "{" ; |
1739 | llvm::interleaveComma( |
1740 | llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) { |
1741 | auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>( |
1742 | (*gangArgTypes)[opIdx]); |
1743 | if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num) |
1744 | p << LoopOp::getGangNumKeyword(); |
1745 | else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim) |
1746 | p << LoopOp::getGangDimKeyword(); |
1747 | else if (gangArgTypeAttr.getValue() == |
1748 | mlir::acc::GangArgType::Static) |
1749 | p << LoopOp::getGangStaticKeyword(); |
1750 | p << "=" << operands[opIdx] << " : " << operands[opIdx].getType(); |
1751 | ++opIdx; |
1752 | }); |
1753 | p << "}" ; |
1754 | printSingleDeviceType(p, it.value()); |
1755 | }); |
1756 | } |
1757 | p << ")" ; |
1758 | } |
1759 | |
1760 | bool hasDuplicateDeviceTypes( |
1761 | std::optional<mlir::ArrayAttr> segments, |
1762 | llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) { |
1763 | if (!segments) |
1764 | return false; |
1765 | for (auto attr : *segments) { |
1766 | auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); |
1767 | if (deviceTypes.contains(deviceTypeAttr.getValue())) |
1768 | return true; |
1769 | deviceTypes.insert(deviceTypeAttr.getValue()); |
1770 | } |
1771 | return false; |
1772 | } |
1773 | |
1774 | /// Check for duplicates in the DeviceType array attribute. |
1775 | LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes) { |
1776 | llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes; |
1777 | if (!deviceTypes) |
1778 | return success(); |
1779 | for (auto attr : deviceTypes) { |
1780 | auto deviceTypeAttr = |
1781 | mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr); |
1782 | if (!deviceTypeAttr) |
1783 | return failure(); |
1784 | if (crtDeviceTypes.contains(deviceTypeAttr.getValue())) |
1785 | return failure(); |
1786 | crtDeviceTypes.insert(deviceTypeAttr.getValue()); |
1787 | } |
1788 | return success(); |
1789 | } |
1790 | |
1791 | LogicalResult acc::LoopOp::verify() { |
1792 | if (!getUpperbound().empty() && getInclusiveUpperbound() && |
1793 | (getUpperbound().size() != getInclusiveUpperbound()->size())) |
1794 | return emitError() << "inclusiveUpperbound size is expected to be the same" |
1795 | << " as upperbound size" ; |
1796 | |
1797 | // Check collapse |
1798 | if (getCollapseAttr() && !getCollapseDeviceTypeAttr()) |
1799 | return emitOpError() << "collapse device_type attr must be define when" |
1800 | << " collapse attr is present" ; |
1801 | |
1802 | if (getCollapseAttr() && getCollapseDeviceTypeAttr() && |
1803 | getCollapseAttr().getValue().size() != |
1804 | getCollapseDeviceTypeAttr().getValue().size()) |
1805 | return emitOpError() << "collapse attribute count must match collapse" |
1806 | << " device_type count" ; |
1807 | if (failed(checkDeviceTypes(getCollapseDeviceTypeAttr()))) |
1808 | return emitOpError() |
1809 | << "duplicate device_type found in collapseDeviceType attribute" ; |
1810 | |
1811 | // Check gang |
1812 | if (!getGangOperands().empty()) { |
1813 | if (!getGangOperandsArgType()) |
1814 | return emitOpError() << "gangOperandsArgType attribute must be defined" |
1815 | << " when gang operands are present" ; |
1816 | |
1817 | if (getGangOperands().size() != |
1818 | getGangOperandsArgTypeAttr().getValue().size()) |
1819 | return emitOpError() << "gangOperandsArgType attribute count must match" |
1820 | << " gangOperands count" ; |
1821 | } |
1822 | if (getGangAttr() && failed(checkDeviceTypes(getGangAttr()))) |
1823 | return emitOpError() << "duplicate device_type found in gang attribute" ; |
1824 | |
1825 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
1826 | *this, getGangOperands(), getGangOperandsSegmentsAttr(), |
1827 | getGangOperandsDeviceTypeAttr(), "gang" ))) |
1828 | return failure(); |
1829 | |
1830 | // Check worker |
1831 | if (failed(checkDeviceTypes(getWorkerAttr()))) |
1832 | return emitOpError() << "duplicate device_type found in worker attribute" ; |
1833 | if (failed(checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))) |
1834 | return emitOpError() << "duplicate device_type found in " |
1835 | "workerNumOperandsDeviceType attribute" ; |
1836 | if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(), |
1837 | getWorkerNumOperandsDeviceTypeAttr(), |
1838 | "worker" ))) |
1839 | return failure(); |
1840 | |
1841 | // Check vector |
1842 | if (failed(checkDeviceTypes(getVectorAttr()))) |
1843 | return emitOpError() << "duplicate device_type found in vector attribute" ; |
1844 | if (failed(checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))) |
1845 | return emitOpError() << "duplicate device_type found in " |
1846 | "vectorOperandsDeviceType attribute" ; |
1847 | if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(), |
1848 | getVectorOperandsDeviceTypeAttr(), |
1849 | "vector" ))) |
1850 | return failure(); |
1851 | |
1852 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
1853 | *this, getTileOperands(), getTileOperandsSegmentsAttr(), |
1854 | getTileOperandsDeviceTypeAttr(), "tile" ))) |
1855 | return failure(); |
1856 | |
1857 | // auto, independent and seq attribute are mutually exclusive. |
1858 | llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes; |
1859 | if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) || |
1860 | hasDuplicateDeviceTypes(getIndependent(), deviceTypes) || |
1861 | hasDuplicateDeviceTypes(getSeq(), deviceTypes)) { |
1862 | return emitError() << "only one of \"" << acc::LoopOp::getAutoAttrStrName() |
1863 | << "\", " << getIndependentAttrName() << ", " |
1864 | << getSeqAttrName() |
1865 | << " can be present at the same time" ; |
1866 | } |
1867 | |
1868 | // Gang, worker and vector are incompatible with seq. |
1869 | if (getSeqAttr()) { |
1870 | for (auto attr : getSeqAttr()) { |
1871 | auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); |
1872 | if (hasVector(deviceTypeAttr.getValue()) || |
1873 | getVectorValue(deviceTypeAttr.getValue()) || |
1874 | hasWorker(deviceTypeAttr.getValue()) || |
1875 | getWorkerValue(deviceTypeAttr.getValue()) || |
1876 | hasGang(deviceTypeAttr.getValue()) || |
1877 | getGangValue(mlir::acc::GangArgType::Num, |
1878 | deviceTypeAttr.getValue()) || |
1879 | getGangValue(mlir::acc::GangArgType::Dim, |
1880 | deviceTypeAttr.getValue()) || |
1881 | getGangValue(mlir::acc::GangArgType::Static, |
1882 | deviceTypeAttr.getValue())) |
1883 | return emitError() |
1884 | << "gang, worker or vector cannot appear with the seq attr" ; |
1885 | } |
1886 | } |
1887 | |
1888 | if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>( |
1889 | *this, getPrivatizations(), getPrivateOperands(), "private" , |
1890 | "privatizations" , false))) |
1891 | return failure(); |
1892 | |
1893 | if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>( |
1894 | *this, getReductionRecipes(), getReductionOperands(), "reduction" , |
1895 | "reductions" , false))) |
1896 | return failure(); |
1897 | |
1898 | if (getCombined().has_value() && |
1899 | (getCombined().value() != acc::CombinedConstructsType::ParallelLoop && |
1900 | getCombined().value() != acc::CombinedConstructsType::KernelsLoop && |
1901 | getCombined().value() != acc::CombinedConstructsType::SerialLoop)) { |
1902 | return emitError("unexpected combined constructs attribute" ); |
1903 | } |
1904 | |
1905 | // Check non-empty body(). |
1906 | if (getRegion().empty()) |
1907 | return emitError("expected non-empty body." ); |
1908 | |
1909 | return success(); |
1910 | } |
1911 | |
1912 | unsigned LoopOp::getNumDataOperands() { |
1913 | return getReductionOperands().size() + getPrivateOperands().size(); |
1914 | } |
1915 | |
1916 | Value LoopOp::getDataOperand(unsigned i) { |
1917 | unsigned numOptional = |
1918 | getLowerbound().size() + getUpperbound().size() + getStep().size(); |
1919 | numOptional += getGangOperands().size(); |
1920 | numOptional += getVectorOperands().size(); |
1921 | numOptional += getWorkerNumOperands().size(); |
1922 | numOptional += getTileOperands().size(); |
1923 | numOptional += getCacheOperands().size(); |
1924 | return getOperand(numOptional + i); |
1925 | } |
1926 | |
1927 | bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); } |
1928 | |
1929 | bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) { |
1930 | return hasDeviceType(getAuto_(), deviceType); |
1931 | } |
1932 | |
1933 | bool LoopOp::hasIndependent() { |
1934 | return hasIndependent(mlir::acc::DeviceType::None); |
1935 | } |
1936 | |
1937 | bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) { |
1938 | return hasDeviceType(getIndependent(), deviceType); |
1939 | } |
1940 | |
1941 | bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } |
1942 | |
1943 | bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) { |
1944 | return hasDeviceType(getSeq(), deviceType); |
1945 | } |
1946 | |
1947 | mlir::Value LoopOp::getVectorValue() { |
1948 | return getVectorValue(mlir::acc::DeviceType::None); |
1949 | } |
1950 | |
1951 | mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) { |
1952 | return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(), |
1953 | getVectorOperands(), deviceType); |
1954 | } |
1955 | |
1956 | bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } |
1957 | |
1958 | bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) { |
1959 | return hasDeviceType(getVector(), deviceType); |
1960 | } |
1961 | |
1962 | mlir::Value LoopOp::getWorkerValue() { |
1963 | return getWorkerValue(mlir::acc::DeviceType::None); |
1964 | } |
1965 | |
1966 | mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) { |
1967 | return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(), |
1968 | getWorkerNumOperands(), deviceType); |
1969 | } |
1970 | |
1971 | bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } |
1972 | |
1973 | bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) { |
1974 | return hasDeviceType(getWorker(), deviceType); |
1975 | } |
1976 | |
1977 | mlir::Operation::operand_range LoopOp::getTileValues() { |
1978 | return getTileValues(mlir::acc::DeviceType::None); |
1979 | } |
1980 | |
1981 | mlir::Operation::operand_range |
1982 | LoopOp::getTileValues(mlir::acc::DeviceType deviceType) { |
1983 | return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(), |
1984 | getTileOperandsSegments(), deviceType); |
1985 | } |
1986 | |
1987 | std::optional<int64_t> LoopOp::getCollapseValue() { |
1988 | return getCollapseValue(mlir::acc::DeviceType::None); |
1989 | } |
1990 | |
1991 | std::optional<int64_t> |
1992 | LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) { |
1993 | if (!getCollapseAttr()) |
1994 | return std::nullopt; |
1995 | if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) { |
1996 | auto intAttr = |
1997 | mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]); |
1998 | return intAttr.getValue().getZExtValue(); |
1999 | } |
2000 | return std::nullopt; |
2001 | } |
2002 | |
2003 | mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) { |
2004 | return getGangValue(gangArgType, mlir::acc::DeviceType::None); |
2005 | } |
2006 | |
2007 | mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType, |
2008 | mlir::acc::DeviceType deviceType) { |
2009 | if (getGangOperands().empty()) |
2010 | return {}; |
2011 | if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) { |
2012 | int32_t nbOperandsBefore = 0; |
2013 | for (unsigned i = 0; i < *pos; ++i) |
2014 | nbOperandsBefore += (*getGangOperandsSegments())[i]; |
2015 | mlir::Operation::operand_range values = |
2016 | getGangOperands() |
2017 | .drop_front(nbOperandsBefore) |
2018 | .take_front((*getGangOperandsSegments())[*pos]); |
2019 | |
2020 | int32_t argTypeIdx = nbOperandsBefore; |
2021 | for (auto value : values) { |
2022 | auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>( |
2023 | (*getGangOperandsArgType())[argTypeIdx]); |
2024 | if (gangArgTypeAttr.getValue() == gangArgType) |
2025 | return value; |
2026 | ++argTypeIdx; |
2027 | } |
2028 | } |
2029 | return {}; |
2030 | } |
2031 | |
2032 | bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } |
2033 | |
2034 | bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) { |
2035 | return hasDeviceType(getGang(), deviceType); |
2036 | } |
2037 | |
2038 | llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() { |
2039 | return {&getRegion()}; |
2040 | } |
2041 | |
2042 | /// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=` |
2043 | /// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step` |
2044 | /// `(` ssa-id-and-type-list `)` |
2045 | /// region |
2046 | ParseResult |
2047 | parseLoopControl(OpAsmParser &parser, Region ®ion, |
2048 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerbound, |
2049 | SmallVectorImpl<Type> &lowerboundType, |
2050 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperbound, |
2051 | SmallVectorImpl<Type> &upperboundType, |
2052 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &step, |
2053 | SmallVectorImpl<Type> &stepType) { |
2054 | |
2055 | SmallVector<OpAsmParser::Argument> inductionVars; |
2056 | if (succeeded( |
2057 | parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) { |
2058 | if (parser.parseLParen() || |
2059 | parser.parseArgumentList(result&: inductionVars, delimiter: OpAsmParser::Delimiter::None, |
2060 | /*allowType=*/true) || |
2061 | parser.parseRParen() || parser.parseEqual() || parser.parseLParen() || |
2062 | parser.parseOperandList(result&: lowerbound, requiredOperandCount: inductionVars.size(), |
2063 | delimiter: OpAsmParser::Delimiter::None) || |
2064 | parser.parseColonTypeList(result&: lowerboundType) || parser.parseRParen() || |
2065 | parser.parseKeyword(keyword: "to" ) || parser.parseLParen() || |
2066 | parser.parseOperandList(result&: upperbound, requiredOperandCount: inductionVars.size(), |
2067 | delimiter: OpAsmParser::Delimiter::None) || |
2068 | parser.parseColonTypeList(result&: upperboundType) || parser.parseRParen() || |
2069 | parser.parseKeyword(keyword: "step" ) || parser.parseLParen() || |
2070 | parser.parseOperandList(result&: step, requiredOperandCount: inductionVars.size(), |
2071 | delimiter: OpAsmParser::Delimiter::None) || |
2072 | parser.parseColonTypeList(result&: stepType) || parser.parseRParen()) |
2073 | return failure(); |
2074 | } |
2075 | return parser.parseRegion(region, arguments: inductionVars); |
2076 | } |
2077 | |
2078 | void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, |
2079 | ValueRange lowerbound, TypeRange lowerboundType, |
2080 | ValueRange upperbound, TypeRange upperboundType, |
2081 | ValueRange steps, TypeRange stepType) { |
2082 | ValueRange regionArgs = region.front().getArguments(); |
2083 | if (!regionArgs.empty()) { |
2084 | p << acc::LoopOp::getControlKeyword() << "(" ; |
2085 | llvm::interleaveComma(c: regionArgs, os&: p, |
2086 | each_fn: [&p](Value v) { p << v << " : " << v.getType(); }); |
2087 | p << ") = (" << lowerbound << " : " << lowerboundType << ") to (" |
2088 | << upperbound << " : " << upperboundType << ") " |
2089 | << " step (" << steps << " : " << stepType << ") " ; |
2090 | } |
2091 | p.printRegion(blocks&: region, /*printEntryBlockArgs=*/false); |
2092 | } |
2093 | |
2094 | //===----------------------------------------------------------------------===// |
2095 | // DataOp |
2096 | //===----------------------------------------------------------------------===// |
2097 | |
2098 | LogicalResult acc::DataOp::verify() { |
2099 | // 2.6.5. Data Construct restriction |
2100 | // At least one copy, copyin, copyout, create, no_create, present, deviceptr, |
2101 | // attach, or default clause must appear on a data construct. |
2102 | if (getOperands().empty() && !getDefaultAttr()) |
2103 | return emitError("at least one operand or the default attribute " |
2104 | "must appear on the data operation" ); |
2105 | |
2106 | for (mlir::Value operand : getDataClauseOperands()) |
2107 | if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp, |
2108 | acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp, |
2109 | acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>( |
2110 | operand.getDefiningOp())) |
2111 | return emitError("expect data entry/exit operation or acc.getdeviceptr " |
2112 | "as defining op" ); |
2113 | |
2114 | if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*this))) |
2115 | return failure(); |
2116 | |
2117 | return success(); |
2118 | } |
2119 | |
2120 | unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); } |
2121 | |
2122 | Value DataOp::getDataOperand(unsigned i) { |
2123 | unsigned numOptional = getIfCond() ? 1 : 0; |
2124 | numOptional += getAsyncOperands().size() ? 1 : 0; |
2125 | numOptional += getWaitOperands().size(); |
2126 | return getOperand(numOptional + i); |
2127 | } |
2128 | |
2129 | bool acc::DataOp::hasAsyncOnly() { |
2130 | return hasAsyncOnly(mlir::acc::DeviceType::None); |
2131 | } |
2132 | |
2133 | bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { |
2134 | return hasDeviceType(getAsyncOnly(), deviceType); |
2135 | } |
2136 | |
2137 | mlir::Value DataOp::getAsyncValue() { |
2138 | return getAsyncValue(mlir::acc::DeviceType::None); |
2139 | } |
2140 | |
2141 | mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) { |
2142 | return getValueInDeviceTypeSegment(getAsyncOperandsDeviceType(), |
2143 | getAsyncOperands(), deviceType); |
2144 | } |
2145 | |
2146 | bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); } |
2147 | |
2148 | bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { |
2149 | return hasDeviceType(getWaitOnly(), deviceType); |
2150 | } |
2151 | |
2152 | mlir::Operation::operand_range DataOp::getWaitValues() { |
2153 | return getWaitValues(mlir::acc::DeviceType::None); |
2154 | } |
2155 | |
2156 | mlir::Operation::operand_range |
2157 | DataOp::getWaitValues(mlir::acc::DeviceType deviceType) { |
2158 | return getWaitValuesWithoutDevnum( |
2159 | getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), |
2160 | getHasWaitDevnum(), deviceType); |
2161 | } |
2162 | |
2163 | mlir::Value DataOp::getWaitDevnum() { |
2164 | return getWaitDevnum(mlir::acc::DeviceType::None); |
2165 | } |
2166 | |
2167 | mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { |
2168 | return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), |
2169 | getWaitOperandsSegments(), getHasWaitDevnum(), |
2170 | deviceType); |
2171 | } |
2172 | |
2173 | //===----------------------------------------------------------------------===// |
2174 | // ExitDataOp |
2175 | //===----------------------------------------------------------------------===// |
2176 | |
2177 | LogicalResult acc::ExitDataOp::verify() { |
2178 | // 2.6.6. Data Exit Directive restriction |
2179 | // At least one copyout, delete, or detach clause must appear on an exit data |
2180 | // directive. |
2181 | if (getDataClauseOperands().empty()) |
2182 | return emitError("at least one operand must be present in dataOperands on " |
2183 | "the exit data operation" ); |
2184 | |
2185 | // The async attribute represent the async clause without value. Therefore the |
2186 | // attribute and operand cannot appear at the same time. |
2187 | if (getAsyncOperand() && getAsync()) |
2188 | return emitError("async attribute cannot appear with asyncOperand" ); |
2189 | |
2190 | // The wait attribute represent the wait clause without values. Therefore the |
2191 | // attribute and operands cannot appear at the same time. |
2192 | if (!getWaitOperands().empty() && getWait()) |
2193 | return emitError("wait attribute cannot appear with waitOperands" ); |
2194 | |
2195 | if (getWaitDevnum() && getWaitOperands().empty()) |
2196 | return emitError("wait_devnum cannot appear without waitOperands" ); |
2197 | |
2198 | return success(); |
2199 | } |
2200 | |
2201 | unsigned ExitDataOp::getNumDataOperands() { |
2202 | return getDataClauseOperands().size(); |
2203 | } |
2204 | |
2205 | Value ExitDataOp::getDataOperand(unsigned i) { |
2206 | unsigned numOptional = getIfCond() ? 1 : 0; |
2207 | numOptional += getAsyncOperand() ? 1 : 0; |
2208 | numOptional += getWaitDevnum() ? 1 : 0; |
2209 | return getOperand(getWaitOperands().size() + numOptional + i); |
2210 | } |
2211 | |
2212 | void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2213 | MLIRContext *context) { |
2214 | results.add<RemoveConstantIfCondition<ExitDataOp>>(context); |
2215 | } |
2216 | |
2217 | //===----------------------------------------------------------------------===// |
2218 | // EnterDataOp |
2219 | //===----------------------------------------------------------------------===// |
2220 | |
2221 | LogicalResult acc::EnterDataOp::verify() { |
2222 | // 2.6.6. Data Enter Directive restriction |
2223 | // At least one copyin, create, or attach clause must appear on an enter data |
2224 | // directive. |
2225 | if (getDataClauseOperands().empty()) |
2226 | return emitError("at least one operand must be present in dataOperands on " |
2227 | "the enter data operation" ); |
2228 | |
2229 | // The async attribute represent the async clause without value. Therefore the |
2230 | // attribute and operand cannot appear at the same time. |
2231 | if (getAsyncOperand() && getAsync()) |
2232 | return emitError("async attribute cannot appear with asyncOperand" ); |
2233 | |
2234 | // The wait attribute represent the wait clause without values. Therefore the |
2235 | // attribute and operands cannot appear at the same time. |
2236 | if (!getWaitOperands().empty() && getWait()) |
2237 | return emitError("wait attribute cannot appear with waitOperands" ); |
2238 | |
2239 | if (getWaitDevnum() && getWaitOperands().empty()) |
2240 | return emitError("wait_devnum cannot appear without waitOperands" ); |
2241 | |
2242 | for (mlir::Value operand : getDataClauseOperands()) |
2243 | if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>( |
2244 | operand.getDefiningOp())) |
2245 | return emitError("expect data entry operation as defining op" ); |
2246 | |
2247 | return success(); |
2248 | } |
2249 | |
2250 | unsigned EnterDataOp::getNumDataOperands() { |
2251 | return getDataClauseOperands().size(); |
2252 | } |
2253 | |
2254 | Value EnterDataOp::getDataOperand(unsigned i) { |
2255 | unsigned numOptional = getIfCond() ? 1 : 0; |
2256 | numOptional += getAsyncOperand() ? 1 : 0; |
2257 | numOptional += getWaitDevnum() ? 1 : 0; |
2258 | return getOperand(getWaitOperands().size() + numOptional + i); |
2259 | } |
2260 | |
2261 | void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2262 | MLIRContext *context) { |
2263 | results.add<RemoveConstantIfCondition<EnterDataOp>>(context); |
2264 | } |
2265 | |
2266 | //===----------------------------------------------------------------------===// |
2267 | // AtomicReadOp |
2268 | //===----------------------------------------------------------------------===// |
2269 | |
2270 | LogicalResult AtomicReadOp::verify() { return verifyCommon(); } |
2271 | |
2272 | //===----------------------------------------------------------------------===// |
2273 | // AtomicWriteOp |
2274 | //===----------------------------------------------------------------------===// |
2275 | |
2276 | LogicalResult AtomicWriteOp::verify() { return verifyCommon(); } |
2277 | |
2278 | //===----------------------------------------------------------------------===// |
2279 | // AtomicUpdateOp |
2280 | //===----------------------------------------------------------------------===// |
2281 | |
2282 | LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op, |
2283 | PatternRewriter &rewriter) { |
2284 | if (op.isNoOp()) { |
2285 | rewriter.eraseOp(op); |
2286 | return success(); |
2287 | } |
2288 | |
2289 | if (Value writeVal = op.getWriteOpVal()) { |
2290 | rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal); |
2291 | return success(); |
2292 | } |
2293 | |
2294 | return failure(); |
2295 | } |
2296 | |
2297 | LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); } |
2298 | |
2299 | LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); } |
2300 | |
2301 | //===----------------------------------------------------------------------===// |
2302 | // AtomicCaptureOp |
2303 | //===----------------------------------------------------------------------===// |
2304 | |
2305 | AtomicReadOp AtomicCaptureOp::getAtomicReadOp() { |
2306 | if (auto op = dyn_cast<AtomicReadOp>(getFirstOp())) |
2307 | return op; |
2308 | return dyn_cast<AtomicReadOp>(getSecondOp()); |
2309 | } |
2310 | |
2311 | AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() { |
2312 | if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp())) |
2313 | return op; |
2314 | return dyn_cast<AtomicWriteOp>(getSecondOp()); |
2315 | } |
2316 | |
2317 | AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() { |
2318 | if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp())) |
2319 | return op; |
2320 | return dyn_cast<AtomicUpdateOp>(getSecondOp()); |
2321 | } |
2322 | |
2323 | LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); } |
2324 | |
2325 | //===----------------------------------------------------------------------===// |
2326 | // DeclareEnterOp |
2327 | //===----------------------------------------------------------------------===// |
2328 | |
2329 | template <typename Op> |
2330 | static LogicalResult |
2331 | checkDeclareOperands(Op &op, const mlir::ValueRange &operands, |
2332 | bool requireAtLeastOneOperand = true) { |
2333 | if (operands.empty() && requireAtLeastOneOperand) |
2334 | return emitError( |
2335 | op->getLoc(), |
2336 | "at least one operand must appear on the declare operation" ); |
2337 | |
2338 | for (mlir::Value operand : operands) { |
2339 | if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp, |
2340 | acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp, |
2341 | acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>( |
2342 | operand.getDefiningOp())) |
2343 | return op.emitError( |
2344 | "expect valid declare data entry operation or acc.getdeviceptr " |
2345 | "as defining op" ); |
2346 | |
2347 | mlir::Value varPtr{getVarPtr(accDataClauseOp: operand.getDefiningOp())}; |
2348 | assert(varPtr && "declare operands can only be data entry operations which " |
2349 | "must have varPtr" ); |
2350 | std::optional<mlir::acc::DataClause> dataClauseOptional{ |
2351 | getDataClause(operand.getDefiningOp())}; |
2352 | assert(dataClauseOptional.has_value() && |
2353 | "declare operands can only be data entry operations which must have " |
2354 | "dataClause" ); |
2355 | |
2356 | // If varPtr has no defining op - there is nothing to check further. |
2357 | if (!varPtr.getDefiningOp()) |
2358 | continue; |
2359 | |
2360 | // Check that the varPtr has a declare attribute. |
2361 | auto declareAttribute{ |
2362 | varPtr.getDefiningOp()->getAttr(name: mlir::acc::getDeclareAttrName())}; |
2363 | if (!declareAttribute) |
2364 | return op.emitError( |
2365 | "expect declare attribute on variable in declare operation" ); |
2366 | |
2367 | auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute); |
2368 | if (declAttr.getDataClause().getValue() != dataClauseOptional.value()) |
2369 | return op.emitError( |
2370 | "expect matching declare attribute on variable in declare operation" ); |
2371 | |
2372 | // If the variable is marked with implicit attribute, the matching declare |
2373 | // data action must also be marked implicit. The reverse is not checked |
2374 | // since implicit data action may be inserted to do actions like updating |
2375 | // device copy, in which case the variable is not necessarily implicitly |
2376 | // declare'd. |
2377 | if (declAttr.getImplicit() && |
2378 | declAttr.getImplicit() != acc::getImplicitFlag(accDataEntryOp: operand.getDefiningOp())) |
2379 | return op.emitError( |
2380 | "implicitness must match between declare op and flag on variable" ); |
2381 | } |
2382 | |
2383 | return success(); |
2384 | } |
2385 | |
2386 | LogicalResult acc::DeclareEnterOp::verify() { |
2387 | return checkDeclareOperands(*this, this->getDataClauseOperands()); |
2388 | } |
2389 | |
2390 | //===----------------------------------------------------------------------===// |
2391 | // DeclareExitOp |
2392 | //===----------------------------------------------------------------------===// |
2393 | |
2394 | LogicalResult acc::DeclareExitOp::verify() { |
2395 | if (getToken()) |
2396 | return checkDeclareOperands(*this, this->getDataClauseOperands(), |
2397 | /*requireAtLeastOneOperand=*/false); |
2398 | return checkDeclareOperands(*this, this->getDataClauseOperands()); |
2399 | } |
2400 | |
2401 | //===----------------------------------------------------------------------===// |
2402 | // DeclareOp |
2403 | //===----------------------------------------------------------------------===// |
2404 | |
2405 | LogicalResult acc::DeclareOp::verify() { |
2406 | return checkDeclareOperands(*this, this->getDataClauseOperands()); |
2407 | } |
2408 | |
2409 | //===----------------------------------------------------------------------===// |
2410 | // RoutineOp |
2411 | //===----------------------------------------------------------------------===// |
2412 | |
2413 | static unsigned getParallelismForDeviceType(acc::RoutineOp op, |
2414 | acc::DeviceType dtype) { |
2415 | unsigned parallelism = 0; |
2416 | parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0; |
2417 | parallelism += op.hasWorker(dtype) ? 1 : 0; |
2418 | parallelism += op.hasVector(dtype) ? 1 : 0; |
2419 | parallelism += op.hasSeq(dtype) ? 1 : 0; |
2420 | return parallelism; |
2421 | } |
2422 | |
2423 | LogicalResult acc::RoutineOp::verify() { |
2424 | unsigned baseParallelism = |
2425 | getParallelismForDeviceType(*this, acc::DeviceType::None); |
2426 | |
2427 | if (baseParallelism > 1) |
2428 | return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " |
2429 | "be present at the same time" ; |
2430 | |
2431 | for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType(); |
2432 | ++dtypeInt) { |
2433 | auto dtype = static_cast<acc::DeviceType>(dtypeInt); |
2434 | if (dtype == acc::DeviceType::None) |
2435 | continue; |
2436 | unsigned parallelism = getParallelismForDeviceType(*this, dtype); |
2437 | |
2438 | if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1)) |
2439 | return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can " |
2440 | "be present at the same time" ; |
2441 | } |
2442 | |
2443 | return success(); |
2444 | } |
2445 | |
2446 | static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, |
2447 | mlir::ArrayAttr &deviceTypes) { |
2448 | llvm::SmallVector<mlir::Attribute> bindNameAttrs; |
2449 | llvm::SmallVector<mlir::Attribute> deviceTypeAttrs; |
2450 | |
2451 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
2452 | if (parser.parseAttribute(result&: bindNameAttrs.emplace_back())) |
2453 | return failure(); |
2454 | if (failed(result: parser.parseOptionalLSquare())) { |
2455 | deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( |
2456 | parser.getContext(), mlir::acc::DeviceType::None)); |
2457 | } else { |
2458 | if (parser.parseAttribute(result&: deviceTypeAttrs.emplace_back()) || |
2459 | parser.parseRSquare()) |
2460 | return failure(); |
2461 | } |
2462 | return success(); |
2463 | }))) |
2464 | return failure(); |
2465 | |
2466 | bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs); |
2467 | deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs); |
2468 | |
2469 | return success(); |
2470 | } |
2471 | |
2472 | static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, |
2473 | std::optional<mlir::ArrayAttr> bindName, |
2474 | std::optional<mlir::ArrayAttr> deviceTypes) { |
2475 | llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p, |
2476 | [&](const auto &pair) { |
2477 | p << std::get<0>(pair); |
2478 | printSingleDeviceType(p, std::get<1>(pair)); |
2479 | }); |
2480 | } |
2481 | |
2482 | static ParseResult parseRoutineGangClause(OpAsmParser &parser, |
2483 | mlir::ArrayAttr &gang, |
2484 | mlir::ArrayAttr &gangDim, |
2485 | mlir::ArrayAttr &gangDimDeviceTypes) { |
2486 | |
2487 | llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs, |
2488 | gangDimDeviceTypeAttrs; |
2489 | bool needCommaBeforeOperands = false; |
2490 | |
2491 | // Gang keyword only |
2492 | if (failed(result: parser.parseOptionalLParen())) { |
2493 | gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get( |
2494 | parser.getContext(), mlir::acc::DeviceType::None)); |
2495 | gang = ArrayAttr::get(parser.getContext(), gangAttrs); |
2496 | return success(); |
2497 | } |
2498 | |
2499 | // Parse keyword only attributes |
2500 | if (succeeded(result: parser.parseOptionalLSquare())) { |
2501 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
2502 | if (parser.parseAttribute(result&: gangAttrs.emplace_back())) |
2503 | return failure(); |
2504 | return success(); |
2505 | }))) |
2506 | return failure(); |
2507 | if (parser.parseRSquare()) |
2508 | return failure(); |
2509 | needCommaBeforeOperands = true; |
2510 | } |
2511 | |
2512 | if (needCommaBeforeOperands && failed(result: parser.parseComma())) |
2513 | return failure(); |
2514 | |
2515 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
2516 | if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) || |
2517 | parser.parseColon() || |
2518 | parser.parseAttribute(gangDimAttrs.emplace_back())) |
2519 | return failure(); |
2520 | if (succeeded(result: parser.parseOptionalLSquare())) { |
2521 | if (parser.parseAttribute(result&: gangDimDeviceTypeAttrs.emplace_back()) || |
2522 | parser.parseRSquare()) |
2523 | return failure(); |
2524 | } else { |
2525 | gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get( |
2526 | parser.getContext(), mlir::acc::DeviceType::None)); |
2527 | } |
2528 | return success(); |
2529 | }))) |
2530 | return failure(); |
2531 | |
2532 | if (failed(result: parser.parseRParen())) |
2533 | return failure(); |
2534 | |
2535 | gang = ArrayAttr::get(parser.getContext(), gangAttrs); |
2536 | gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs); |
2537 | gangDimDeviceTypes = |
2538 | ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs); |
2539 | |
2540 | return success(); |
2541 | } |
2542 | |
2543 | void printRoutineGangClause(OpAsmPrinter &p, Operation *op, |
2544 | std::optional<mlir::ArrayAttr> gang, |
2545 | std::optional<mlir::ArrayAttr> gangDim, |
2546 | std::optional<mlir::ArrayAttr> gangDimDeviceTypes) { |
2547 | |
2548 | if (!hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes) && hasDeviceTypeValues(arrayAttr: gang) && |
2549 | gang->size() == 1) { |
2550 | auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]); |
2551 | if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) |
2552 | return; |
2553 | } |
2554 | |
2555 | p << "(" ; |
2556 | |
2557 | printDeviceTypes(p, deviceTypes: gang); |
2558 | |
2559 | if (hasDeviceTypeValues(arrayAttr: gang) && hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes)) |
2560 | p << ", " ; |
2561 | |
2562 | if (hasDeviceTypeValues(arrayAttr: gangDimDeviceTypes)) |
2563 | llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p, |
2564 | [&](const auto &pair) { |
2565 | p << acc::RoutineOp::getGangDimKeyword() << ": " ; |
2566 | p << std::get<0>(pair); |
2567 | printSingleDeviceType(p, std::get<1>(pair)); |
2568 | }); |
2569 | |
2570 | p << ")" ; |
2571 | } |
2572 | |
2573 | static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, |
2574 | mlir::ArrayAttr &deviceTypes) { |
2575 | llvm::SmallVector<mlir::Attribute> attributes; |
2576 | // Keyword only |
2577 | if (failed(result: parser.parseOptionalLParen())) { |
2578 | attributes.push_back(mlir::acc::DeviceTypeAttr::get( |
2579 | parser.getContext(), mlir::acc::DeviceType::None)); |
2580 | deviceTypes = ArrayAttr::get(parser.getContext(), attributes); |
2581 | return success(); |
2582 | } |
2583 | |
2584 | // Parse device type attributes |
2585 | if (succeeded(result: parser.parseOptionalLSquare())) { |
2586 | if (failed(result: parser.parseCommaSeparatedList(parseElementFn: [&]() { |
2587 | if (parser.parseAttribute(result&: attributes.emplace_back())) |
2588 | return failure(); |
2589 | return success(); |
2590 | }))) |
2591 | return failure(); |
2592 | if (parser.parseRSquare() || parser.parseRParen()) |
2593 | return failure(); |
2594 | } |
2595 | deviceTypes = ArrayAttr::get(parser.getContext(), attributes); |
2596 | return success(); |
2597 | } |
2598 | |
2599 | static void |
2600 | printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, |
2601 | std::optional<mlir::ArrayAttr> deviceTypes) { |
2602 | |
2603 | if (hasDeviceTypeValues(arrayAttr: deviceTypes) && deviceTypes->size() == 1) { |
2604 | auto deviceTypeAttr = |
2605 | mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]); |
2606 | if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None) |
2607 | return; |
2608 | } |
2609 | |
2610 | if (!hasDeviceTypeValues(arrayAttr: deviceTypes)) |
2611 | return; |
2612 | |
2613 | p << "([" ; |
2614 | llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) { |
2615 | auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr); |
2616 | p << dTypeAttr; |
2617 | }); |
2618 | p << "])" ; |
2619 | } |
2620 | |
2621 | bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); } |
2622 | |
2623 | bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) { |
2624 | return hasDeviceType(getWorker(), deviceType); |
2625 | } |
2626 | |
2627 | bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); } |
2628 | |
2629 | bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) { |
2630 | return hasDeviceType(getVector(), deviceType); |
2631 | } |
2632 | |
2633 | bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); } |
2634 | |
2635 | bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) { |
2636 | return hasDeviceType(getSeq(), deviceType); |
2637 | } |
2638 | |
2639 | std::optional<llvm::StringRef> RoutineOp::getBindNameValue() { |
2640 | return getBindNameValue(mlir::acc::DeviceType::None); |
2641 | } |
2642 | |
2643 | std::optional<llvm::StringRef> |
2644 | RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) { |
2645 | if (!hasDeviceTypeValues(getBindNameDeviceType())) |
2646 | return std::nullopt; |
2647 | if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) { |
2648 | auto attr = (*getBindName())[*pos]; |
2649 | auto stringAttr = dyn_cast<mlir::StringAttr>(attr); |
2650 | return stringAttr.getValue(); |
2651 | } |
2652 | return std::nullopt; |
2653 | } |
2654 | |
2655 | bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); } |
2656 | |
2657 | bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) { |
2658 | return hasDeviceType(getGang(), deviceType); |
2659 | } |
2660 | |
2661 | std::optional<int64_t> RoutineOp::getGangDimValue() { |
2662 | return getGangDimValue(mlir::acc::DeviceType::None); |
2663 | } |
2664 | |
2665 | std::optional<int64_t> |
2666 | RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) { |
2667 | if (!hasDeviceTypeValues(getGangDimDeviceType())) |
2668 | return std::nullopt; |
2669 | if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) { |
2670 | auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]); |
2671 | return intAttr.getInt(); |
2672 | } |
2673 | return std::nullopt; |
2674 | } |
2675 | |
2676 | //===----------------------------------------------------------------------===// |
2677 | // InitOp |
2678 | //===----------------------------------------------------------------------===// |
2679 | |
2680 | LogicalResult acc::InitOp::verify() { |
2681 | Operation *currOp = *this; |
2682 | while ((currOp = currOp->getParentOp())) |
2683 | if (isComputeOperation(currOp)) |
2684 | return emitOpError("cannot be nested in a compute operation" ); |
2685 | return success(); |
2686 | } |
2687 | |
2688 | //===----------------------------------------------------------------------===// |
2689 | // ShutdownOp |
2690 | //===----------------------------------------------------------------------===// |
2691 | |
2692 | LogicalResult acc::ShutdownOp::verify() { |
2693 | Operation *currOp = *this; |
2694 | while ((currOp = currOp->getParentOp())) |
2695 | if (isComputeOperation(currOp)) |
2696 | return emitOpError("cannot be nested in a compute operation" ); |
2697 | return success(); |
2698 | } |
2699 | |
2700 | //===----------------------------------------------------------------------===// |
2701 | // SetOp |
2702 | //===----------------------------------------------------------------------===// |
2703 | |
2704 | LogicalResult acc::SetOp::verify() { |
2705 | Operation *currOp = *this; |
2706 | while ((currOp = currOp->getParentOp())) |
2707 | if (isComputeOperation(currOp)) |
2708 | return emitOpError("cannot be nested in a compute operation" ); |
2709 | if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum()) |
2710 | return emitOpError("at least one default_async, device_num, or device_type " |
2711 | "operand must appear" ); |
2712 | return success(); |
2713 | } |
2714 | |
2715 | //===----------------------------------------------------------------------===// |
2716 | // UpdateOp |
2717 | //===----------------------------------------------------------------------===// |
2718 | |
2719 | LogicalResult acc::UpdateOp::verify() { |
2720 | // At least one of host or device should have a value. |
2721 | if (getDataClauseOperands().empty()) |
2722 | return emitError("at least one value must be present in dataOperands" ); |
2723 | |
2724 | if (failed(verifyDeviceTypeCountMatch(*this, getAsyncOperands(), |
2725 | getAsyncOperandsDeviceTypeAttr(), |
2726 | "async" ))) |
2727 | return failure(); |
2728 | |
2729 | if (failed(verifyDeviceTypeAndSegmentCountMatch( |
2730 | *this, getWaitOperands(), getWaitOperandsSegmentsAttr(), |
2731 | getWaitOperandsDeviceTypeAttr(), "wait" ))) |
2732 | return failure(); |
2733 | |
2734 | if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*this))) |
2735 | return failure(); |
2736 | |
2737 | for (mlir::Value operand : getDataClauseOperands()) |
2738 | if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>( |
2739 | operand.getDefiningOp())) |
2740 | return emitError("expect data entry/exit operation or acc.getdeviceptr " |
2741 | "as defining op" ); |
2742 | |
2743 | return success(); |
2744 | } |
2745 | |
2746 | unsigned UpdateOp::getNumDataOperands() { |
2747 | return getDataClauseOperands().size(); |
2748 | } |
2749 | |
2750 | Value UpdateOp::getDataOperand(unsigned i) { |
2751 | unsigned numOptional = getAsyncOperands().size(); |
2752 | numOptional += getIfCond() ? 1 : 0; |
2753 | return getOperand(getWaitOperands().size() + numOptional + i); |
2754 | } |
2755 | |
2756 | void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2757 | MLIRContext *context) { |
2758 | results.add<RemoveConstantIfCondition<UpdateOp>>(context); |
2759 | } |
2760 | |
2761 | bool UpdateOp::hasAsyncOnly() { |
2762 | return hasAsyncOnly(mlir::acc::DeviceType::None); |
2763 | } |
2764 | |
2765 | bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) { |
2766 | return hasDeviceType(getAsync(), deviceType); |
2767 | } |
2768 | |
2769 | mlir::Value UpdateOp::getAsyncValue() { |
2770 | return getAsyncValue(mlir::acc::DeviceType::None); |
2771 | } |
2772 | |
2773 | mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) { |
2774 | if (!hasDeviceTypeValues(getAsyncOperandsDeviceType())) |
2775 | return {}; |
2776 | |
2777 | if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType)) |
2778 | return getAsyncOperands()[*pos]; |
2779 | |
2780 | return {}; |
2781 | } |
2782 | |
2783 | bool UpdateOp::hasWaitOnly() { |
2784 | return hasWaitOnly(mlir::acc::DeviceType::None); |
2785 | } |
2786 | |
2787 | bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) { |
2788 | return hasDeviceType(getWaitOnly(), deviceType); |
2789 | } |
2790 | |
2791 | mlir::Operation::operand_range UpdateOp::getWaitValues() { |
2792 | return getWaitValues(mlir::acc::DeviceType::None); |
2793 | } |
2794 | |
2795 | mlir::Operation::operand_range |
2796 | UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) { |
2797 | return getWaitValuesWithoutDevnum( |
2798 | getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(), |
2799 | getHasWaitDevnum(), deviceType); |
2800 | } |
2801 | |
2802 | mlir::Value UpdateOp::getWaitDevnum() { |
2803 | return getWaitDevnum(mlir::acc::DeviceType::None); |
2804 | } |
2805 | |
2806 | mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) { |
2807 | return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(), |
2808 | getWaitOperandsSegments(), getHasWaitDevnum(), |
2809 | deviceType); |
2810 | } |
2811 | |
2812 | //===----------------------------------------------------------------------===// |
2813 | // WaitOp |
2814 | //===----------------------------------------------------------------------===// |
2815 | |
2816 | LogicalResult acc::WaitOp::verify() { |
2817 | // The async attribute represent the async clause without value. Therefore the |
2818 | // attribute and operand cannot appear at the same time. |
2819 | if (getAsyncOperand() && getAsync()) |
2820 | return emitError("async attribute cannot appear with asyncOperand" ); |
2821 | |
2822 | if (getWaitDevnum() && getWaitOperands().empty()) |
2823 | return emitError("wait_devnum cannot appear without waitOperands" ); |
2824 | |
2825 | return success(); |
2826 | } |
2827 | |
2828 | #define GET_OP_CLASSES |
2829 | #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc" |
2830 | |
2831 | #define GET_ATTRDEF_CLASSES |
2832 | #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc" |
2833 | |
2834 | #define GET_TYPEDEF_CLASSES |
2835 | #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc" |
2836 | |
2837 | //===----------------------------------------------------------------------===// |
2838 | // acc dialect utilities |
2839 | //===----------------------------------------------------------------------===// |
2840 | |
2841 | mlir::Value mlir::acc::getVarPtr(mlir::Operation *accDataClauseOp) { |
2842 | auto varPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp) |
2843 | .Case<ACC_DATA_ENTRY_OPS>( |
2844 | [&](auto entry) { return entry.getVarPtr(); }) |
2845 | .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>( |
2846 | [&](auto exit) { return exit.getVarPtr(); }) |
2847 | .Default([&](mlir::Operation *) { return mlir::Value(); })}; |
2848 | return varPtr; |
2849 | } |
2850 | |
2851 | mlir::Value mlir::acc::getAccPtr(mlir::Operation *accDataClauseOp) { |
2852 | auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp) |
2853 | .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>( |
2854 | [&](auto dataClause) { return dataClause.getAccPtr(); }) |
2855 | .Default([&](mlir::Operation *) { return mlir::Value(); })}; |
2856 | return accPtr; |
2857 | } |
2858 | |
2859 | mlir::Value mlir::acc::getVarPtrPtr(mlir::Operation *accDataClauseOp) { |
2860 | auto varPtrPtr{ |
2861 | llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp) |
2862 | .Case<ACC_DATA_ENTRY_OPS>( |
2863 | [&](auto dataClause) { return dataClause.getVarPtrPtr(); }) |
2864 | .Default([&](mlir::Operation *) { return mlir::Value(); })}; |
2865 | return varPtrPtr; |
2866 | } |
2867 | |
2868 | mlir::SmallVector<mlir::Value> |
2869 | mlir::acc::getBounds(mlir::Operation *accDataClauseOp) { |
2870 | mlir::SmallVector<mlir::Value> bounds{ |
2871 | llvm::TypeSwitch<mlir::Operation *, mlir::SmallVector<mlir::Value>>( |
2872 | accDataClauseOp) |
2873 | .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) { |
2874 | return mlir::SmallVector<mlir::Value>( |
2875 | dataClause.getBounds().begin(), dataClause.getBounds().end()); |
2876 | }) |
2877 | .Default([&](mlir::Operation *) { |
2878 | return mlir::SmallVector<mlir::Value, 0>(); |
2879 | })}; |
2880 | return bounds; |
2881 | } |
2882 | |
2883 | std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) { |
2884 | auto name{ |
2885 | llvm::TypeSwitch<mlir::Operation *, std::optional<llvm::StringRef>>(accOp) |
2886 | .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); }) |
2887 | .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> { |
2888 | return {}; |
2889 | })}; |
2890 | return name; |
2891 | } |
2892 | |
2893 | std::optional<mlir::acc::DataClause> |
2894 | mlir::acc::getDataClause(mlir::Operation *accDataEntryOp) { |
2895 | auto dataClause{ |
2896 | llvm::TypeSwitch<mlir::Operation *, std::optional<mlir::acc::DataClause>>( |
2897 | accDataEntryOp) |
2898 | .Case<ACC_DATA_ENTRY_OPS>( |
2899 | [&](auto entry) { return entry.getDataClause(); }) |
2900 | .Default([&](mlir::Operation *) { return std::nullopt; })}; |
2901 | return dataClause; |
2902 | } |
2903 | |
2904 | bool mlir::acc::getImplicitFlag(mlir::Operation *accDataEntryOp) { |
2905 | auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp) |
2906 | .Case<ACC_DATA_ENTRY_OPS>( |
2907 | [&](auto entry) { return entry.getImplicit(); }) |
2908 | .Default([&](mlir::Operation *) { return false; })}; |
2909 | return implicit; |
2910 | } |
2911 | |
2912 | mlir::ValueRange mlir::acc::getDataOperands(mlir::Operation *accOp) { |
2913 | auto dataOperands{ |
2914 | llvm::TypeSwitch<mlir::Operation *, mlir::ValueRange>(accOp) |
2915 | .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>( |
2916 | [&](auto entry) { return entry.getDataClauseOperands(); }) |
2917 | .Default([&](mlir::Operation *) { return mlir::ValueRange(); })}; |
2918 | return dataOperands; |
2919 | } |
2920 | |
2921 | mlir::MutableOperandRange |
2922 | mlir::acc::getMutableDataOperands(mlir::Operation *accOp) { |
2923 | auto dataOperands{ |
2924 | llvm::TypeSwitch<mlir::Operation *, mlir::MutableOperandRange>(accOp) |
2925 | .Case<ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS>( |
2926 | [&](auto entry) { return entry.getDataClauseOperandsMutable(); }) |
2927 | .Default([&](mlir::Operation *) { return nullptr; })}; |
2928 | return dataOperands; |
2929 | } |
2930 | |