1 | //===----------------------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Emit OpenACC Stmt nodes as CIR code. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "CIRGenBuilder.h" |
14 | #include "CIRGenFunction.h" |
15 | #include "mlir/Dialect/OpenACC/OpenACC.h" |
16 | #include "clang/AST/OpenACCClause.h" |
17 | #include "clang/AST/StmtOpenACC.h" |
18 | |
19 | using namespace clang; |
20 | using namespace clang::CIRGen; |
21 | using namespace cir; |
22 | using namespace mlir::acc; |
23 | |
24 | template <typename Op, typename TermOp> |
25 | mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt( |
26 | mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind, |
27 | SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses, |
28 | const Stmt *associatedStmt) { |
29 | mlir::LogicalResult res = mlir::success(); |
30 | |
31 | llvm::SmallVector<mlir::Type> retTy; |
32 | llvm::SmallVector<mlir::Value> operands; |
33 | auto op = builder.create<Op>(start, retTy, operands); |
34 | |
35 | emitOpenACCClauses(op, dirKind, dirLoc, clauses); |
36 | |
37 | { |
38 | mlir::Block &block = op.getRegion().emplaceBlock(); |
39 | mlir::OpBuilder::InsertionGuard guardCase(builder); |
40 | builder.setInsertionPointToEnd(&block); |
41 | |
42 | LexicalScope ls{*this, start, builder.getInsertionBlock()}; |
43 | res = emitStmt(associatedStmt, /*useCurrentScope=*/true); |
44 | |
45 | builder.create<TermOp>(end); |
46 | } |
47 | return res; |
48 | } |
49 | |
50 | namespace { |
51 | template <typename Op> struct CombinedType; |
52 | template <> struct CombinedType<ParallelOp> { |
53 | static constexpr mlir::acc::CombinedConstructsType value = |
54 | mlir::acc::CombinedConstructsType::ParallelLoop; |
55 | }; |
56 | template <> struct CombinedType<SerialOp> { |
57 | static constexpr mlir::acc::CombinedConstructsType value = |
58 | mlir::acc::CombinedConstructsType::SerialLoop; |
59 | }; |
60 | template <> struct CombinedType<KernelsOp> { |
61 | static constexpr mlir::acc::CombinedConstructsType value = |
62 | mlir::acc::CombinedConstructsType::KernelsLoop; |
63 | }; |
64 | } // namespace |
65 | |
66 | template <typename Op, typename TermOp> |
67 | mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct( |
68 | mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind, |
69 | SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses, |
70 | const Stmt *loopStmt) { |
71 | mlir::LogicalResult res = mlir::success(); |
72 | |
73 | llvm::SmallVector<mlir::Type> retTy; |
74 | llvm::SmallVector<mlir::Value> operands; |
75 | |
76 | auto computeOp = builder.create<Op>(start, retTy, operands); |
77 | computeOp.setCombinedAttr(builder.getUnitAttr()); |
78 | mlir::acc::LoopOp loopOp; |
79 | |
80 | // First, emit the bodies of both operations, with the loop inside the body of |
81 | // the combined construct. |
82 | { |
83 | mlir::Block &block = computeOp.getRegion().emplaceBlock(); |
84 | mlir::OpBuilder::InsertionGuard guardCase(builder); |
85 | builder.setInsertionPointToEnd(&block); |
86 | |
87 | LexicalScope ls{*this, start, builder.getInsertionBlock()}; |
88 | auto loopOp = builder.create<LoopOp>(start, retTy, operands); |
89 | loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get( |
90 | builder.getContext(), CombinedType<Op>::value)); |
91 | |
92 | { |
93 | mlir::Block &innerBlock = loopOp.getRegion().emplaceBlock(); |
94 | mlir::OpBuilder::InsertionGuard guardCase(builder); |
95 | builder.setInsertionPointToEnd(&innerBlock); |
96 | |
97 | LexicalScope ls{*this, start, builder.getInsertionBlock()}; |
98 | res = emitStmt(loopStmt, /*useCurrentScope=*/true); |
99 | |
100 | builder.create<mlir::acc::YieldOp>(end); |
101 | } |
102 | |
103 | emitOpenACCClauses(computeOp, loopOp, dirKind, dirLoc, clauses); |
104 | |
105 | updateLoopOpParallelism(op&: loopOp, /*isOrphan=*/false, dk: dirKind); |
106 | |
107 | builder.create<TermOp>(end); |
108 | } |
109 | |
110 | return res; |
111 | } |
112 | |
113 | template <typename Op> |
114 | Op CIRGenFunction::emitOpenACCOp( |
115 | mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc, |
116 | llvm::ArrayRef<const OpenACCClause *> clauses) { |
117 | llvm::SmallVector<mlir::Type> retTy; |
118 | llvm::SmallVector<mlir::Value> operands; |
119 | auto op = builder.create<Op>(start, retTy, operands); |
120 | |
121 | emitOpenACCClauses(op, dirKind, dirLoc, clauses); |
122 | return op; |
123 | } |
124 | |
125 | mlir::LogicalResult |
126 | CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) { |
127 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
128 | mlir::Location end = getLoc(s.getSourceRange().getEnd()); |
129 | |
130 | switch (s.getDirectiveKind()) { |
131 | case OpenACCDirectiveKind::Parallel: |
132 | return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>( |
133 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
134 | s.getStructuredBlock()); |
135 | case OpenACCDirectiveKind::Serial: |
136 | return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>( |
137 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
138 | s.getStructuredBlock()); |
139 | case OpenACCDirectiveKind::Kernels: |
140 | return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>( |
141 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
142 | s.getStructuredBlock()); |
143 | default: |
144 | llvm_unreachable("invalid compute construct kind" ); |
145 | } |
146 | } |
147 | |
148 | mlir::LogicalResult |
149 | CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) { |
150 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
151 | mlir::Location end = getLoc(s.getSourceRange().getEnd()); |
152 | |
153 | return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>( |
154 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
155 | s.getStructuredBlock()); |
156 | } |
157 | |
158 | mlir::LogicalResult |
159 | CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) { |
160 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
161 | emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), |
162 | s.clauses()); |
163 | return mlir::success(); |
164 | } |
165 | |
166 | mlir::LogicalResult |
167 | CIRGenFunction::emitOpenACCSetConstruct(const OpenACCSetConstruct &s) { |
168 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
169 | emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(), |
170 | s.clauses()); |
171 | return mlir::success(); |
172 | } |
173 | |
174 | mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct( |
175 | const OpenACCShutdownConstruct &s) { |
176 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
177 | emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(), |
178 | s.getDirectiveLoc(), s.clauses()); |
179 | return mlir::success(); |
180 | } |
181 | |
182 | mlir::LogicalResult |
183 | CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) { |
184 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
185 | auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind(), |
186 | s.getDirectiveLoc(), s.clauses()); |
187 | |
188 | auto createIntExpr = [this](const Expr *intExpr) { |
189 | mlir::Value expr = emitScalarExpr(intExpr); |
190 | mlir::Location exprLoc = cgm.getLoc(intExpr->getBeginLoc()); |
191 | |
192 | mlir::IntegerType targetType = mlir::IntegerType::get( |
193 | &getMLIRContext(), getContext().getIntWidth(intExpr->getType()), |
194 | intExpr->getType()->isSignedIntegerOrEnumerationType() |
195 | ? mlir::IntegerType::SignednessSemantics::Signed |
196 | : mlir::IntegerType::SignednessSemantics::Unsigned); |
197 | |
198 | auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>( |
199 | exprLoc, targetType, expr); |
200 | return conversionOp.getResult(0); |
201 | }; |
202 | |
203 | // Emit the correct 'wait' clauses. |
204 | { |
205 | mlir::OpBuilder::InsertionGuard guardCase(builder); |
206 | builder.setInsertionPoint(waitOp); |
207 | |
208 | if (s.hasDevNumExpr()) |
209 | waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr())); |
210 | |
211 | for (Expr *QueueExpr : s.getQueueIdExprs()) |
212 | waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr)); |
213 | } |
214 | |
215 | return mlir::success(); |
216 | } |
217 | |
218 | mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct( |
219 | const OpenACCCombinedConstruct &s) { |
220 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
221 | mlir::Location end = getLoc(s.getSourceRange().getEnd()); |
222 | |
223 | switch (s.getDirectiveKind()) { |
224 | case OpenACCDirectiveKind::ParallelLoop: |
225 | return emitOpenACCOpCombinedConstruct<ParallelOp, mlir::acc::YieldOp>( |
226 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
227 | s.getLoop()); |
228 | case OpenACCDirectiveKind::SerialLoop: |
229 | return emitOpenACCOpCombinedConstruct<SerialOp, mlir::acc::YieldOp>( |
230 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
231 | s.getLoop()); |
232 | case OpenACCDirectiveKind::KernelsLoop: |
233 | return emitOpenACCOpCombinedConstruct<KernelsOp, mlir::acc::TerminatorOp>( |
234 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
235 | s.getLoop()); |
236 | default: |
237 | llvm_unreachable("invalid compute construct kind" ); |
238 | } |
239 | } |
240 | |
241 | mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct( |
242 | const OpenACCHostDataConstruct &s) { |
243 | mlir::Location start = getLoc(s.getSourceRange().getBegin()); |
244 | mlir::Location end = getLoc(s.getSourceRange().getEnd()); |
245 | |
246 | return emitOpenACCOpAssociatedStmt<HostDataOp, mlir::acc::TerminatorOp>( |
247 | start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(), |
248 | s.getStructuredBlock()); |
249 | } |
250 | |
251 | mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct( |
252 | const OpenACCEnterDataConstruct &s) { |
253 | cgm.errorNYI(s.getSourceRange(), "OpenACC EnterData Construct" ); |
254 | return mlir::failure(); |
255 | } |
256 | mlir::LogicalResult CIRGenFunction::emitOpenACCExitDataConstruct( |
257 | const OpenACCExitDataConstruct &s) { |
258 | cgm.errorNYI(s.getSourceRange(), "OpenACC ExitData Construct" ); |
259 | return mlir::failure(); |
260 | } |
261 | mlir::LogicalResult |
262 | CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) { |
263 | cgm.errorNYI(s.getSourceRange(), "OpenACC Update Construct" ); |
264 | return mlir::failure(); |
265 | } |
266 | mlir::LogicalResult |
267 | CIRGenFunction::emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct &s) { |
268 | cgm.errorNYI(s.getSourceRange(), "OpenACC Atomic Construct" ); |
269 | return mlir::failure(); |
270 | } |
271 | mlir::LogicalResult |
272 | CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) { |
273 | cgm.errorNYI(s.getSourceRange(), "OpenACC Cache Construct" ); |
274 | return mlir::failure(); |
275 | } |
276 | |