1//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
10
11#include "mlir/Dialect/OpenACC/OpenACC.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/Transforms/DialectConversion.h"
15
16namespace mlir {
17#define GEN_PASS_DEF_CONVERTOPENACCTOSCFPASS
18#include "mlir/Conversion/Passes.h.inc"
19} // namespace mlir
20
21using namespace mlir;
22
23//===----------------------------------------------------------------------===//
24// Conversion patterns
25//===----------------------------------------------------------------------===//
26
27namespace {
28/// Pattern to transform the `getIfCond` on operation without region into a
29/// scf.if and move the operation into the `then` region.
30template <typename OpTy>
31class ExpandIfCondition : public OpRewritePattern<OpTy> {
32 using OpRewritePattern<OpTy>::OpRewritePattern;
33
34 LogicalResult matchAndRewrite(OpTy op,
35 PatternRewriter &rewriter) const override {
36 // Early exit if there is no condition.
37 if (!op.getIfCond())
38 return failure();
39
40 IntegerAttr constAttr;
41 if (!matchPattern(op.getIfCond(), m_Constant(bind_value: &constAttr))) {
42 auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
43 op.getIfCond(), false);
44 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
45 auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
46 thenBodyBuilder.clone(*op.getOperation());
47 rewriter.eraseOp(op);
48 } else {
49 if (constAttr.getInt())
50 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
51 else
52 rewriter.eraseOp(op);
53 }
54 return success();
55 }
56};
57} // namespace
58
59void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) {
60 patterns.add<ExpandIfCondition<acc::EnterDataOp>>(arg: patterns.getContext());
61 patterns.add<ExpandIfCondition<acc::ExitDataOp>>(arg: patterns.getContext());
62 patterns.add<ExpandIfCondition<acc::UpdateOp>>(arg: patterns.getContext());
63}
64
65namespace {
66struct ConvertOpenACCToSCFPass
67 : public impl::ConvertOpenACCToSCFPassBase<ConvertOpenACCToSCFPass> {
68 void runOnOperation() override;
69};
70} // namespace
71
72void ConvertOpenACCToSCFPass::runOnOperation() {
73 auto op = getOperation();
74 auto *context = op.getContext();
75
76 RewritePatternSet patterns(context);
77 ConversionTarget target(*context);
78 populateOpenACCToSCFConversionPatterns(patterns);
79
80 target.addLegalDialect<scf::SCFDialect>();
81 target.addLegalDialect<acc::OpenACCDialect>();
82
83 target.addDynamicallyLegalOp<acc::EnterDataOp>(
84 callback: [](acc::EnterDataOp op) { return !op.getIfCond(); });
85
86 target.addDynamicallyLegalOp<acc::ExitDataOp>(
87 callback: [](acc::ExitDataOp op) { return !op.getIfCond(); });
88
89 target.addDynamicallyLegalOp<acc::UpdateOp>(
90 callback: [](acc::UpdateOp op) { return !op.getIfCond(); });
91
92 if (failed(Result: applyPartialConversion(op, target, patterns: std::move(patterns))))
93 signalPassFailure();
94}
95

source code of mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp