| 1 | //===----------------------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Emit OpenACC Loop Stmt node as CIR code. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #include "CIRGenBuilder.h" |
| 14 | #include "CIRGenFunction.h" |
| 15 | |
| 16 | #include "clang/AST/StmtOpenACC.h" |
| 17 | |
| 18 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
| 19 | |
| 20 | using namespace clang; |
| 21 | using namespace clang::CIRGen; |
| 22 | using namespace cir; |
| 23 | using namespace mlir::acc; |
| 24 | |
| 25 | void CIRGenFunction::updateLoopOpParallelism(mlir::acc::LoopOp &op, |
| 26 | bool isOrphan, |
| 27 | OpenACCDirectiveKind dk) { |
| 28 | // Check that at least one of auto, independent, or seq is present |
| 29 | // for the device-independent default clauses. |
| 30 | if (op.hasParallelismFlag(mlir::acc::DeviceType::None)) |
| 31 | return; |
| 32 | |
| 33 | switch (dk) { |
| 34 | default: |
| 35 | llvm_unreachable("Invalid parent directive kind" ); |
| 36 | case OpenACCDirectiveKind::Invalid: |
| 37 | case OpenACCDirectiveKind::Parallel: |
| 38 | case OpenACCDirectiveKind::ParallelLoop: |
| 39 | op.addIndependent(builder.getContext(), {}); |
| 40 | return; |
| 41 | case OpenACCDirectiveKind::Kernels: |
| 42 | case OpenACCDirectiveKind::KernelsLoop: |
| 43 | op.addAuto(builder.getContext(), {}); |
| 44 | return; |
| 45 | case OpenACCDirectiveKind::Serial: |
| 46 | case OpenACCDirectiveKind::SerialLoop: |
| 47 | if (op.hasDefaultGangWorkerVector()) |
| 48 | op.addAuto(builder.getContext(), {}); |
| 49 | else |
| 50 | op.addSeq(builder.getContext(), {}); |
| 51 | return; |
| 52 | }; |
| 53 | } |
| 54 | |
| 55 | mlir::LogicalResult |
| 56 | CIRGenFunction::emitOpenACCLoopConstruct(const OpenACCLoopConstruct &s) { |
| 57 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
| 58 | mlir::Location end = getLoc(s.getSourceRange().getEnd()); |
| 59 | llvm::SmallVector<mlir::Type> retTy; |
| 60 | llvm::SmallVector<mlir::Value> operands; |
| 61 | auto op = builder.create<LoopOp>(start, retTy, operands); |
| 62 | |
| 63 | // TODO(OpenACC): In the future we are going to need to come up with a |
| 64 | // transformation here that can teach the acc.loop how to figure out the |
| 65 | // 'lowerbound', 'upperbound', and 'step'. |
| 66 | // |
| 67 | // -'upperbound' should fortunately be pretty easy as it should be |
| 68 | // in the initialization section of the cir.for loop. In Sema, we limit to |
| 69 | // just the forms 'Var = init', `Type Var = init`, or `Var = init` (where it |
| 70 | // is an operator= call)`. However, as those are all necessary to emit for |
| 71 | // the init section of the for loop, they should be inside the initial |
| 72 | // cir.scope. |
| 73 | // |
| 74 | // -'upperbound' should be somewhat easy to determine. Sema is limiting this |
| 75 | // to: ==, <, >, !=, <=, >= builtin operators, the overloaded 'comparison' |
| 76 | // operations, and member-call expressions. |
| 77 | // |
| 78 | // For the builtin comparison operators, we can pretty well deduce based on |
| 79 | // the comparison what the 'end' object is going to be, and the inclusive |
| 80 | // nature of it. |
| 81 | // |
| 82 | // For the overloaded operators, Sema will ensure that at least one side of |
| 83 | // the operator is the init variable, so we can deduce the comparison there |
| 84 | // too. The standard places no real bounds on WHAT the comparison operators do |
| 85 | // for a `RandomAccessIterator` however, so we'll have to just 'assume' they |
| 86 | // do the right thing? Note that this might be incrementing by a different |
| 87 | // 'object', not an integral, so it isn't really clear to me what we can do to |
| 88 | // determine the other side. |
| 89 | // |
| 90 | // Member-call expressions are the difficult ones. I don't think there is |
| 91 | // anything we can deduce from this to determine the 'end', so we might end up |
| 92 | // having to go back to Sema and make this ill-formed. |
| 93 | // |
| 94 | // HOWEVER: What ACC dialect REALLY cares about is the tripcount, which you |
| 95 | // cannot get (in the case of `RandomAccessIterator`) from JUST 'upperbound' |
| 96 | // and 'lowerbound'. We will likely have to provide a 'recipe' equivalent to |
| 97 | // `std::distance` instead. In the case of integer/pointers, it is fairly |
| 98 | // simple to find: it is just the mathematical subtraction. Howver, in the |
| 99 | // case of `RandomAccessIterator`, we have to enable the use of `operator-`. |
| 100 | // FORTUNATELY the standard requires this to work correctly for |
| 101 | // `RandomAccessIterator`, so we don't have to implement a `std::distance` |
| 102 | // that loops through, like we would for a forward/etc iterator. |
| 103 | // |
| 104 | // 'step': Sema is currently allowing builtin ++,--, +=, -=, *=, /=, and = |
| 105 | // operators. Additionally, it allows the equivalent for the operator-call, as |
| 106 | // well as member-call. |
| 107 | // |
| 108 | // For builtin operators, we perhaps should refine the assignment here. It |
| 109 | // doesn't really help us know the 'step' count at all, but we could perhaps |
| 110 | // do one more step of analysis in Sema to allow something like Var = Var + 1. |
| 111 | // For the others, this should get us the step reasonably well. |
| 112 | // |
| 113 | // For the overloaded operators, we have the same problems as for |
| 114 | // 'upperbound', plus not really knowing what they do. Member-call expressions |
| 115 | // are again difficult, and we might want to reconsider allowing these in |
| 116 | // Sema. |
| 117 | // |
| 118 | |
| 119 | // Emit all clauses. |
| 120 | emitOpenACCClauses(op, s.getDirectiveKind(), s.getDirectiveLoc(), |
| 121 | s.clauses()); |
| 122 | |
| 123 | updateLoopOpParallelism(op&: op, isOrphan: s.isOrphanedLoopConstruct(), |
| 124 | dk: s.getParentComputeConstructKind()); |
| 125 | |
| 126 | mlir::LogicalResult stmtRes = mlir::success(); |
| 127 | // Emit body. |
| 128 | { |
| 129 | mlir::Block &block = op.getRegion().emplaceBlock(); |
| 130 | mlir::OpBuilder::InsertionGuard guardCase(builder); |
| 131 | builder.setInsertionPointToEnd(&block); |
| 132 | LexicalScope ls{*this, start, builder.getInsertionBlock()}; |
| 133 | |
| 134 | stmtRes = emitStmt(s.getLoop(), /*useCurrentScope=*/true); |
| 135 | builder.create<mlir::acc::YieldOp>(end); |
| 136 | } |
| 137 | |
| 138 | return stmtRes; |
| 139 | } |
| 140 | |