1//===- GenericLoopConversion.cpp ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "flang/Support/OpenMP-utils.h"
10
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
13#include "mlir/IR/IRMapping.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/DialectConversion.h"
16
17#include <memory>
18#include <optional>
19#include <type_traits>
20
21namespace flangomp {
22#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
23#include "flang/Optimizer/OpenMP/Passes.h.inc"
24} // namespace flangomp
25
26namespace {
27
28/// A conversion pattern to handle various combined forms of `omp.loop`. For how
29/// combined/composite directive are handled see:
30/// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986.
31class GenericLoopConversionPattern
32 : public mlir::OpConversionPattern<mlir::omp::LoopOp> {
33public:
34 enum class GenericLoopCombinedInfo { Standalone, TeamsLoop, ParallelLoop };
35
36 using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
37
38 explicit GenericLoopConversionPattern(mlir::MLIRContext *ctx)
39 : mlir::OpConversionPattern<mlir::omp::LoopOp>{ctx} {
40 // Enable rewrite recursion to make sure nested `loop` directives are
41 // handled.
42 this->setHasBoundedRewriteRecursion(true);
43 }
44
45 mlir::LogicalResult
46 matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor,
47 mlir::ConversionPatternRewriter &rewriter) const override {
48 assert(mlir::succeeded(checkLoopConversionSupportStatus(loopOp)));
49
50 GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
51
52 switch (combinedInfo) {
53 case GenericLoopCombinedInfo::Standalone:
54 rewriteStandaloneLoop(loopOp, rewriter);
55 break;
56 case GenericLoopCombinedInfo::ParallelLoop:
57 rewriteToWsloop(loopOp, rewriter);
58 break;
59 case GenericLoopCombinedInfo::TeamsLoop:
60 if (teamsLoopCanBeParallelFor(loopOp)) {
61 rewriteToDistributeParallelDo(loopOp, rewriter);
62 } else {
63 auto teamsOp = llvm::cast<mlir::omp::TeamsOp>(loopOp->getParentOp());
64 auto teamsBlockArgIface =
65 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*teamsOp);
66 auto loopBlockArgIface =
67 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
68
69 for (unsigned i = 0; i < loopBlockArgIface.numReductionBlockArgs();
70 ++i) {
71 mlir::BlockArgument loopRedBlockArg =
72 loopBlockArgIface.getReductionBlockArgs()[i];
73 mlir::BlockArgument teamsRedBlockArg =
74 teamsBlockArgIface.getReductionBlockArgs()[i];
75 rewriter.replaceAllUsesWith(loopRedBlockArg, teamsRedBlockArg);
76 }
77
78 for (unsigned i = 0; i < loopBlockArgIface.numReductionBlockArgs();
79 ++i) {
80 loopOp.getRegion().eraseArgument(
81 loopBlockArgIface.getReductionBlockArgsStart());
82 }
83
84 loopOp.removeReductionModAttr();
85 loopOp.getReductionVarsMutable().clear();
86 loopOp.removeReductionByrefAttr();
87 loopOp.removeReductionSymsAttr();
88
89 rewriteToDistribute(loopOp, rewriter);
90 }
91
92 break;
93 }
94
95 rewriter.eraseOp(loopOp);
96 return mlir::success();
97 }
98
99 static mlir::LogicalResult
100 checkLoopConversionSupportStatus(mlir::omp::LoopOp loopOp) {
101 auto todo = [&loopOp](mlir::StringRef clauseName) {
102 return loopOp.emitError()
103 << "not yet implemented: Unhandled clause " << clauseName << " in "
104 << loopOp->getName() << " operation";
105 };
106
107 if (loopOp.getOrder())
108 return todo("order");
109
110 return mlir::success();
111 }
112
113private:
114 static GenericLoopCombinedInfo
115 findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) {
116 mlir::Operation *parentOp = loopOp->getParentOp();
117 GenericLoopCombinedInfo result = GenericLoopCombinedInfo::Standalone;
118
119 if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
120 result = GenericLoopCombinedInfo::TeamsLoop;
121
122 if (auto parallelOp =
123 mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
124 result = GenericLoopCombinedInfo::ParallelLoop;
125
126 return result;
127 }
128
129 /// Checks whether a `teams loop` construct can be rewriten to `teams
130 /// distribute parallel do` or it has to be converted to `teams distribute`.
131 ///
132 /// This checks similar constrains to what is checked by `TeamsLoopChecker` in
133 /// SemaOpenMP.cpp in clang.
134 static bool teamsLoopCanBeParallelFor(mlir::omp::LoopOp loopOp) {
135 bool canBeParallelFor =
136 !loopOp
137 .walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *nestedOp) {
138 if (nestedOp == loopOp)
139 return mlir::WalkResult::advance();
140
141 if (auto nestedLoopOp =
142 mlir::dyn_cast<mlir::omp::LoopOp>(nestedOp)) {
143 GenericLoopCombinedInfo combinedInfo =
144 findGenericLoopCombineInfo(nestedLoopOp);
145
146 // Worksharing loops cannot be nested inside each other.
147 // Therefore, if the current `loop` directive nests another
148 // `loop` whose `bind` modifier is `parallel`, this `loop`
149 // directive cannot be mapped to `distribute parallel for`
150 // but rather only to `distribute`.
151 if (combinedInfo == GenericLoopCombinedInfo::Standalone &&
152 nestedLoopOp.getBindKind() &&
153 *nestedLoopOp.getBindKind() ==
154 mlir::omp::ClauseBindKind::Parallel)
155 return mlir::WalkResult::interrupt();
156
157 if (combinedInfo == GenericLoopCombinedInfo::ParallelLoop)
158 return mlir::WalkResult::interrupt();
159
160 } else if (auto callOp =
161 mlir::dyn_cast<mlir::CallOpInterface>(nestedOp)) {
162 // Calls to non-OpenMP API runtime functions inhibits
163 // transformation to `teams distribute parallel do` since the
164 // called functions might have nested parallelism themselves.
165 bool isOpenMPAPI = false;
166 mlir::CallInterfaceCallable callable =
167 callOp.getCallableForCallee();
168
169 if (auto callableSymRef =
170 mlir::dyn_cast<mlir::SymbolRefAttr>(callable))
171 isOpenMPAPI =
172 callableSymRef.getRootReference().strref().starts_with(
173 "omp_");
174
175 if (!isOpenMPAPI)
176 return mlir::WalkResult::interrupt();
177 }
178
179 return mlir::WalkResult::advance();
180 })
181 .wasInterrupted();
182
183 return canBeParallelFor;
184 }
185
186 void rewriteStandaloneLoop(mlir::omp::LoopOp loopOp,
187 mlir::ConversionPatternRewriter &rewriter) const {
188 using namespace mlir::omp;
189 std::optional<ClauseBindKind> bindKind = loopOp.getBindKind();
190
191 if (!bindKind.has_value())
192 return rewriteToSimdLoop(loopOp, rewriter);
193
194 switch (*loopOp.getBindKind()) {
195 case ClauseBindKind::Parallel:
196 return rewriteToWsloop(loopOp, rewriter);
197 case ClauseBindKind::Teams:
198 return rewriteToDistribute(loopOp, rewriter);
199 case ClauseBindKind::Thread:
200 return rewriteToSimdLoop(loopOp, rewriter);
201 }
202 }
203
204 /// Rewrites standalone `loop` (without `bind` clause or with
205 /// `bind(parallel)`) directives to equivalent `simd` constructs.
206 ///
207 /// The reasoning behind this decision is that according to the spec (version
208 /// 5.2, section 11.7.1):
209 ///
210 /// "If the bind clause is not specified on a construct for which it may be
211 /// specified and the construct is closely nested inside a teams or parallel
212 /// construct, the effect is as if binding is teams or parallel. If none of
213 /// those conditions hold, the binding region is not defined."
214 ///
215 /// which means that standalone `loop` directives have undefined binding
216 /// region. Moreover, the spec says (in the next paragraph):
217 ///
218 /// "The specified binding region determines the binding thread set.
219 /// Specifically, if the binding region is a teams region, then the binding
220 /// thread set is the set of initial threads that are executing that region
221 /// while if the binding region is a parallel region, then the binding thread
222 /// set is the team of threads that are executing that region. If the binding
223 /// region is not defined, then the binding thread set is the encountering
224 /// thread."
225 ///
226 /// which means that the binding thread set for a standalone `loop` directive
227 /// is only the encountering thread.
228 ///
229 /// Since the encountering thread is the binding thread (set) for a
230 /// standalone `loop` directive, the best we can do in such case is to "simd"
231 /// the directive.
232 void rewriteToSimdLoop(mlir::omp::LoopOp loopOp,
233 mlir::ConversionPatternRewriter &rewriter) const {
234 loopOp.emitWarning(
235 "Detected standalone OpenMP `loop` directive with thread binding, "
236 "the associated loop will be rewritten to `simd`.");
237 rewriteToSingleWrapperOp<mlir::omp::SimdOp, mlir::omp::SimdOperands>(
238 loopOp, rewriter);
239 }
240
241 void rewriteToDistribute(mlir::omp::LoopOp loopOp,
242 mlir::ConversionPatternRewriter &rewriter) const {
243 assert(loopOp.getReductionVars().empty());
244 rewriteToSingleWrapperOp<mlir::omp::DistributeOp,
245 mlir::omp::DistributeOperands>(loopOp, rewriter);
246 }
247
248 void rewriteToWsloop(mlir::omp::LoopOp loopOp,
249 mlir::ConversionPatternRewriter &rewriter) const {
250 rewriteToSingleWrapperOp<mlir::omp::WsloopOp, mlir::omp::WsloopOperands>(
251 loopOp, rewriter);
252 }
253
254 // TODO Suggestion by Sergio: tag auto-generated operations for constructs
255 // that weren't part of the original program, that would be useful
256 // information for debugging purposes later on. This new attribute could be
257 // used for `omp.loop`, but also for `do concurrent` transformations,
258 // `workshare`, `workdistribute`, etc. The tag could be used for all kinds of
259 // auto-generated operations using a dialect attribute (named something like
260 // `omp.origin` or `omp.derived`) and perhaps hold the name of the operation
261 // it was derived from, the reason it was transformed or something like that
262 // we could use when emitting any messages related to it later on.
263 template <typename OpTy, typename OpOperandsTy>
264 void
265 rewriteToSingleWrapperOp(mlir::omp::LoopOp loopOp,
266 mlir::ConversionPatternRewriter &rewriter) const {
267 OpOperandsTy clauseOps;
268 clauseOps.privateVars = loopOp.getPrivateVars();
269
270 auto privateSyms = loopOp.getPrivateSyms();
271 if (privateSyms)
272 clauseOps.privateSyms.assign(privateSyms->begin(), privateSyms->end());
273
274 Fortran::common::openmp::EntryBlockArgs args;
275 args.priv.vars = clauseOps.privateVars;
276
277 if constexpr (!std::is_same_v<OpOperandsTy,
278 mlir::omp::DistributeOperands>) {
279 populateReductionClauseOps(loopOp, clauseOps);
280 args.reduction.vars = clauseOps.reductionVars;
281 }
282
283 auto wrapperOp = rewriter.create<OpTy>(loopOp.getLoc(), clauseOps);
284 mlir::Block *opBlock = genEntryBlock(rewriter, args, wrapperOp.getRegion());
285
286 mlir::IRMapping mapper;
287 mlir::Block &loopBlock = *loopOp.getRegion().begin();
288
289 for (auto [loopOpArg, opArg] :
290 llvm::zip_equal(loopBlock.getArguments(), opBlock->getArguments()))
291 mapper.map(loopOpArg, opArg);
292
293 rewriter.clone(*loopOp.begin(), mapper);
294 }
295
296 void rewriteToDistributeParallelDo(
297 mlir::omp::LoopOp loopOp,
298 mlir::ConversionPatternRewriter &rewriter) const {
299 mlir::omp::ParallelOperands parallelClauseOps;
300 parallelClauseOps.privateVars = loopOp.getPrivateVars();
301
302 auto privateSyms = loopOp.getPrivateSyms();
303 if (privateSyms)
304 parallelClauseOps.privateSyms.assign(privateSyms->begin(),
305 privateSyms->end());
306
307 Fortran::common::openmp::EntryBlockArgs parallelArgs;
308 parallelArgs.priv.vars = parallelClauseOps.privateVars;
309
310 auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
311 parallelClauseOps);
312 genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
313 parallelOp.setComposite(true);
314 rewriter.setInsertionPoint(
315 rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
316
317 mlir::omp::DistributeOperands distributeClauseOps;
318 auto distributeOp = rewriter.create<mlir::omp::DistributeOp>(
319 loopOp.getLoc(), distributeClauseOps);
320 distributeOp.setComposite(true);
321 rewriter.createBlock(&distributeOp.getRegion());
322
323 mlir::omp::WsloopOperands wsloopClauseOps;
324 populateReductionClauseOps(loopOp, wsloopClauseOps);
325 Fortran::common::openmp::EntryBlockArgs wsloopArgs;
326 wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
327
328 auto wsloopOp =
329 rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
330 wsloopOp.setComposite(true);
331 genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion());
332
333 mlir::IRMapping mapper;
334
335 auto loopBlockInterface =
336 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*loopOp);
337 auto parallelBlockInterface =
338 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*parallelOp);
339 auto wsloopBlockInterface =
340 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*wsloopOp);
341
342 for (auto [loopOpArg, parallelOpArg] :
343 llvm::zip_equal(loopBlockInterface.getPrivateBlockArgs(),
344 parallelBlockInterface.getPrivateBlockArgs()))
345 mapper.map(loopOpArg, parallelOpArg);
346
347 for (auto [loopOpArg, wsloopOpArg] :
348 llvm::zip_equal(loopBlockInterface.getReductionBlockArgs(),
349 wsloopBlockInterface.getReductionBlockArgs()))
350 mapper.map(loopOpArg, wsloopOpArg);
351
352 rewriter.clone(*loopOp.begin(), mapper);
353 }
354
355 void
356 populateReductionClauseOps(mlir::omp::LoopOp loopOp,
357 mlir::omp::ReductionClauseOps &clauseOps) const {
358 clauseOps.reductionMod = loopOp.getReductionModAttr();
359 clauseOps.reductionVars = loopOp.getReductionVars();
360
361 std::optional<mlir::ArrayAttr> reductionSyms = loopOp.getReductionSyms();
362 if (reductionSyms)
363 clauseOps.reductionSyms.assign(reductionSyms->begin(),
364 reductionSyms->end());
365
366 std::optional<llvm::ArrayRef<bool>> reductionByref =
367 loopOp.getReductionByref();
368 if (reductionByref)
369 clauseOps.reductionByref.assign(reductionByref->begin(),
370 reductionByref->end());
371 }
372};
373
374/// According to the spec (v5.2, p340, 36):
375///
376/// ```
377/// The effect of the reduction clause is as if it is applied to all leaf
378/// constructs that permit the clause, except for the following constructs:
379/// * ....
380/// * The teams construct, when combined with the loop construct.
381/// ```
382///
383/// Therefore, for a combined directive similar to: `!$omp teams loop
384/// reduction(...)`, the earlier stages of the compiler assign the `reduction`
385/// clauses only to the `loop` leaf and not to the `teams` leaf.
386///
387/// On the other hand, if we have a combined construct similar to: `!$omp teams
388/// distribute parallel do`, the `reduction` clauses are assigned both to the
389/// `teams` and the `do` leaves. We need to match this behavior when we convert
390/// `teams` op with a nested `loop` op since the target set of constructs/ops
391/// will be incorrect without moving the reductions up to the `teams` op as
392/// well.
393///
394/// This pattern does exactly this. Given the following input:
395/// ```
396/// omp.teams {
397/// omp.loop reduction(@red_sym %red_op -> %red_arg : !fir.ref<i32>) {
398/// omp.loop_nest ... {
399/// ...
400/// }
401/// }
402/// }
403/// ```
404/// this pattern updates the `omp.teams` op in-place to:
405/// ```
406/// omp.teams reduction(@red_sym %red_op -> %teams_red_arg : !fir.ref<i32>) {
407/// omp.loop reduction(@red_sym %teams_red_arg -> %red_arg : !fir.ref<i32>) {
408/// omp.loop_nest ... {
409/// ...
410/// }
411/// }
412/// }
413/// ```
414///
415/// Note the following:
416/// * The nested `omp.loop` is not rewritten by this pattern, this happens
417/// through `GenericLoopConversionPattern`.
418/// * The reduction info are cloned from the nested `omp.loop` op to the parent
419/// `omp.teams` op.
420/// * The reduction operand of the `omp.loop` op is updated to be the **new**
421/// reduction block argument of the `omp.teams` op.
422class ReductionsHoistingPattern
423 : public mlir::OpConversionPattern<mlir::omp::TeamsOp> {
424public:
425 using mlir::OpConversionPattern<mlir::omp::TeamsOp>::OpConversionPattern;
426
427 static mlir::omp::LoopOp
428 tryToFindNestedLoopWithReduction(mlir::omp::TeamsOp teamsOp) {
429 if (teamsOp.getRegion().getBlocks().size() != 1)
430 return nullptr;
431
432 mlir::Block &teamsBlock = *teamsOp.getRegion().begin();
433 auto loopOpIter = llvm::find_if(teamsBlock, [](mlir::Operation &op) {
434 auto nestedLoopOp = llvm::dyn_cast<mlir::omp::LoopOp>(&op);
435
436 if (!nestedLoopOp)
437 return false;
438
439 return !nestedLoopOp.getReductionVars().empty();
440 });
441
442 if (loopOpIter == teamsBlock.end())
443 return nullptr;
444
445 // TODO return error if more than one loop op is nested. We need to
446 // coalesce reductions in this case.
447 return llvm::cast<mlir::omp::LoopOp>(loopOpIter);
448 }
449
450 mlir::LogicalResult
451 matchAndRewrite(mlir::omp::TeamsOp teamsOp, OpAdaptor adaptor,
452 mlir::ConversionPatternRewriter &rewriter) const override {
453 mlir::omp::LoopOp nestedLoopOp = tryToFindNestedLoopWithReduction(teamsOp);
454
455 rewriter.modifyOpInPlace(teamsOp, [&]() {
456 teamsOp.setReductionMod(nestedLoopOp.getReductionMod());
457 teamsOp.getReductionVarsMutable().assign(nestedLoopOp.getReductionVars());
458 teamsOp.setReductionByref(nestedLoopOp.getReductionByref());
459 teamsOp.setReductionSymsAttr(nestedLoopOp.getReductionSymsAttr());
460
461 auto blockArgIface =
462 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*teamsOp);
463 unsigned reductionArgsStart = blockArgIface.getPrivateBlockArgsStart() +
464 blockArgIface.numPrivateBlockArgs();
465 llvm::SmallVector<mlir::Value> newLoopOpReductionOperands;
466
467 for (auto [idx, reductionVar] :
468 llvm::enumerate(nestedLoopOp.getReductionVars())) {
469 mlir::BlockArgument newTeamsOpReductionBlockArg =
470 teamsOp.getRegion().insertArgument(reductionArgsStart + idx,
471 reductionVar.getType(),
472 reductionVar.getLoc());
473 newLoopOpReductionOperands.push_back(newTeamsOpReductionBlockArg);
474 }
475
476 nestedLoopOp.getReductionVarsMutable().assign(newLoopOpReductionOperands);
477 });
478
479 return mlir::success();
480 }
481};
482
483class GenericLoopConversionPass
484 : public flangomp::impl::GenericLoopConversionPassBase<
485 GenericLoopConversionPass> {
486public:
487 GenericLoopConversionPass() = default;
488
489 void runOnOperation() override {
490 mlir::func::FuncOp func = getOperation();
491
492 if (func.isDeclaration())
493 return;
494
495 mlir::MLIRContext *context = &getContext();
496 mlir::RewritePatternSet patterns(context);
497 patterns.insert<ReductionsHoistingPattern, GenericLoopConversionPattern>(
498 context);
499 mlir::ConversionTarget target(*context);
500
501 target.markUnknownOpDynamicallyLegal(
502 [](mlir::Operation *) { return true; });
503
504 target.addDynamicallyLegalOp<mlir::omp::TeamsOp>(
505 [](mlir::omp::TeamsOp teamsOp) {
506 // If teamsOp's reductions are already populated, then the op is
507 // legal. Additionally, the op is legal if it does not nest a LoopOp
508 // with reductions.
509 return !teamsOp.getReductionVars().empty() ||
510 ReductionsHoistingPattern::tryToFindNestedLoopWithReduction(
511 teamsOp) == nullptr;
512 });
513
514 target.addDynamicallyLegalOp<mlir::omp::LoopOp>(
515 [](mlir::omp::LoopOp loopOp) {
516 return mlir::failed(
517 GenericLoopConversionPattern::checkLoopConversionSupportStatus(
518 loopOp));
519 });
520
521 if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
522 std::move(patterns)))) {
523 mlir::emitError(func.getLoc(), "error in converting `omp.loop` op");
524 signalPassFailure();
525 }
526 }
527};
528} // namespace
529

source code of flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp