1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Emit OpenACC Stmt nodes as CIR code.
10//
11//===----------------------------------------------------------------------===//
12
13#include "CIRGenBuilder.h"
14#include "CIRGenFunction.h"
15#include "mlir/Dialect/OpenACC/OpenACC.h"
16#include "clang/AST/OpenACCClause.h"
17#include "clang/AST/StmtOpenACC.h"
18
19using namespace clang;
20using namespace clang::CIRGen;
21using namespace cir;
22using namespace mlir::acc;
23
24template <typename Op, typename TermOp>
25mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(
26 mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
27 SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
28 const Stmt *associatedStmt) {
29 mlir::LogicalResult res = mlir::success();
30
31 llvm::SmallVector<mlir::Type> retTy;
32 llvm::SmallVector<mlir::Value> operands;
33 auto op = builder.create<Op>(start, retTy, operands);
34
35 emitOpenACCClauses(op, dirKind, dirLoc, clauses);
36
37 {
38 mlir::Block &block = op.getRegion().emplaceBlock();
39 mlir::OpBuilder::InsertionGuard guardCase(builder);
40 builder.setInsertionPointToEnd(&block);
41
42 LexicalScope ls{*this, start, builder.getInsertionBlock()};
43 res = emitStmt(associatedStmt, /*useCurrentScope=*/true);
44
45 builder.create<TermOp>(end);
46 }
47 return res;
48}
49
50namespace {
51template <typename Op> struct CombinedType;
52template <> struct CombinedType<ParallelOp> {
53 static constexpr mlir::acc::CombinedConstructsType value =
54 mlir::acc::CombinedConstructsType::ParallelLoop;
55};
56template <> struct CombinedType<SerialOp> {
57 static constexpr mlir::acc::CombinedConstructsType value =
58 mlir::acc::CombinedConstructsType::SerialLoop;
59};
60template <> struct CombinedType<KernelsOp> {
61 static constexpr mlir::acc::CombinedConstructsType value =
62 mlir::acc::CombinedConstructsType::KernelsLoop;
63};
64} // namespace
65
66template <typename Op, typename TermOp>
67mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct(
68 mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,
69 SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,
70 const Stmt *loopStmt) {
71 mlir::LogicalResult res = mlir::success();
72
73 llvm::SmallVector<mlir::Type> retTy;
74 llvm::SmallVector<mlir::Value> operands;
75
76 auto computeOp = builder.create<Op>(start, retTy, operands);
77 computeOp.setCombinedAttr(builder.getUnitAttr());
78 mlir::acc::LoopOp loopOp;
79
80 // First, emit the bodies of both operations, with the loop inside the body of
81 // the combined construct.
82 {
83 mlir::Block &block = computeOp.getRegion().emplaceBlock();
84 mlir::OpBuilder::InsertionGuard guardCase(builder);
85 builder.setInsertionPointToEnd(&block);
86
87 LexicalScope ls{*this, start, builder.getInsertionBlock()};
88 auto loopOp = builder.create<LoopOp>(start, retTy, operands);
89 loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(
90 builder.getContext(), CombinedType<Op>::value));
91
92 {
93 mlir::Block &innerBlock = loopOp.getRegion().emplaceBlock();
94 mlir::OpBuilder::InsertionGuard guardCase(builder);
95 builder.setInsertionPointToEnd(&innerBlock);
96
97 LexicalScope ls{*this, start, builder.getInsertionBlock()};
98 res = emitStmt(loopStmt, /*useCurrentScope=*/true);
99
100 builder.create<mlir::acc::YieldOp>(end);
101 }
102
103 emitOpenACCClauses(computeOp, loopOp, dirKind, dirLoc, clauses);
104
105 updateLoopOpParallelism(op&: loopOp, /*isOrphan=*/false, dk: dirKind);
106
107 builder.create<TermOp>(end);
108 }
109
110 return res;
111}
112
113template <typename Op>
114Op CIRGenFunction::emitOpenACCOp(
115 mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,
116 llvm::ArrayRef<const OpenACCClause *> clauses) {
117 llvm::SmallVector<mlir::Type> retTy;
118 llvm::SmallVector<mlir::Value> operands;
119 auto op = builder.create<Op>(start, retTy, operands);
120
121 emitOpenACCClauses(op, dirKind, dirLoc, clauses);
122 return op;
123}
124
125mlir::LogicalResult
126CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {
127 mlir::Location start = getLoc(s.getSourceRange().getBegin());
128 mlir::Location end = getLoc(s.getSourceRange().getEnd());
129
130 switch (s.getDirectiveKind()) {
131 case OpenACCDirectiveKind::Parallel:
132 return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(
133 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
134 s.getStructuredBlock());
135 case OpenACCDirectiveKind::Serial:
136 return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(
137 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
138 s.getStructuredBlock());
139 case OpenACCDirectiveKind::Kernels:
140 return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(
141 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
142 s.getStructuredBlock());
143 default:
144 llvm_unreachable("invalid compute construct kind");
145 }
146}
147
148mlir::LogicalResult
149CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {
150 mlir::Location start = getLoc(s.getSourceRange().getBegin());
151 mlir::Location end = getLoc(s.getSourceRange().getEnd());
152
153 return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(
154 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
155 s.getStructuredBlock());
156}
157
158mlir::LogicalResult
159CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {
160 mlir::Location start = getLoc(s.getSourceRange().getBegin());
161 emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
162 s.clauses());
163 return mlir::success();
164}
165
166mlir::LogicalResult
167CIRGenFunction::emitOpenACCSetConstruct(const OpenACCSetConstruct &s) {
168 mlir::Location start = getLoc(s.getSourceRange().getBegin());
169 emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),
170 s.clauses());
171 return mlir::success();
172}
173
174mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(
175 const OpenACCShutdownConstruct &s) {
176 mlir::Location start = getLoc(s.getSourceRange().getBegin());
177 emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),
178 s.getDirectiveLoc(), s.clauses());
179 return mlir::success();
180}
181
182mlir::LogicalResult
183CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {
184 mlir::Location start = getLoc(s.getSourceRange().getBegin());
185 auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind(),
186 s.getDirectiveLoc(), s.clauses());
187
188 auto createIntExpr = [this](const Expr *intExpr) {
189 mlir::Value expr = emitScalarExpr(intExpr);
190 mlir::Location exprLoc = cgm.getLoc(intExpr->getBeginLoc());
191
192 mlir::IntegerType targetType = mlir::IntegerType::get(
193 &getMLIRContext(), getContext().getIntWidth(intExpr->getType()),
194 intExpr->getType()->isSignedIntegerOrEnumerationType()
195 ? mlir::IntegerType::SignednessSemantics::Signed
196 : mlir::IntegerType::SignednessSemantics::Unsigned);
197
198 auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(
199 exprLoc, targetType, expr);
200 return conversionOp.getResult(0);
201 };
202
203 // Emit the correct 'wait' clauses.
204 {
205 mlir::OpBuilder::InsertionGuard guardCase(builder);
206 builder.setInsertionPoint(waitOp);
207
208 if (s.hasDevNumExpr())
209 waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));
210
211 for (Expr *QueueExpr : s.getQueueIdExprs())
212 waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
213 }
214
215 return mlir::success();
216}
217
218mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct(
219 const OpenACCCombinedConstruct &s) {
220 mlir::Location start = getLoc(s.getSourceRange().getBegin());
221 mlir::Location end = getLoc(s.getSourceRange().getEnd());
222
223 switch (s.getDirectiveKind()) {
224 case OpenACCDirectiveKind::ParallelLoop:
225 return emitOpenACCOpCombinedConstruct<ParallelOp, mlir::acc::YieldOp>(
226 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
227 s.getLoop());
228 case OpenACCDirectiveKind::SerialLoop:
229 return emitOpenACCOpCombinedConstruct<SerialOp, mlir::acc::YieldOp>(
230 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
231 s.getLoop());
232 case OpenACCDirectiveKind::KernelsLoop:
233 return emitOpenACCOpCombinedConstruct<KernelsOp, mlir::acc::TerminatorOp>(
234 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
235 s.getLoop());
236 default:
237 llvm_unreachable("invalid compute construct kind");
238 }
239}
240
241mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(
242 const OpenACCHostDataConstruct &s) {
243 mlir::Location start = getLoc(s.getSourceRange().getBegin());
244 mlir::Location end = getLoc(s.getSourceRange().getEnd());
245
246 return emitOpenACCOpAssociatedStmt<HostDataOp, mlir::acc::TerminatorOp>(
247 start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),
248 s.getStructuredBlock());
249}
250
251mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct(
252 const OpenACCEnterDataConstruct &s) {
253 cgm.errorNYI(s.getSourceRange(), "OpenACC EnterData Construct");
254 return mlir::failure();
255}
256mlir::LogicalResult CIRGenFunction::emitOpenACCExitDataConstruct(
257 const OpenACCExitDataConstruct &s) {
258 cgm.errorNYI(s.getSourceRange(), "OpenACC ExitData Construct");
259 return mlir::failure();
260}
261mlir::LogicalResult
262CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) {
263 cgm.errorNYI(s.getSourceRange(), "OpenACC Update Construct");
264 return mlir::failure();
265}
266mlir::LogicalResult
267CIRGenFunction::emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct &s) {
268 cgm.errorNYI(s.getSourceRange(), "OpenACC Atomic Construct");
269 return mlir::failure();
270}
271mlir::LogicalResult
272CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) {
273 cgm.errorNYI(s.getSourceRange(), "OpenACC Cache Construct");
274 return mlir::failure();
275}
276

source code of clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp