1//===- OpenACCToSCF.cpp - OpenACC condition to SCF if conversion ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h"
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/OpenACC/OpenACC.h"
13#include "mlir/Dialect/SCF/IR/SCF.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_CONVERTOPENACCTOSCF
20#include "mlir/Conversion/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24
25//===----------------------------------------------------------------------===//
26// Conversion patterns
27//===----------------------------------------------------------------------===//
28
29namespace {
30/// Pattern to transform the `getIfCond` on operation without region into a
31/// scf.if and move the operation into the `then` region.
32template <typename OpTy>
33class ExpandIfCondition : public OpRewritePattern<OpTy> {
34 using OpRewritePattern<OpTy>::OpRewritePattern;
35
36 LogicalResult matchAndRewrite(OpTy op,
37 PatternRewriter &rewriter) const override {
38 // Early exit if there is no condition.
39 if (!op.getIfCond())
40 return failure();
41
42 IntegerAttr constAttr;
43 if (!matchPattern(op.getIfCond(), m_Constant(&constAttr))) {
44 auto ifOp = rewriter.create<scf::IfOp>(op.getLoc(), TypeRange(),
45 op.getIfCond(), false);
46 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
47 auto thenBodyBuilder = ifOp.getThenBodyBuilder(rewriter.getListener());
48 thenBodyBuilder.clone(*op.getOperation());
49 rewriter.eraseOp(op);
50 } else {
51 if (constAttr.getInt())
52 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
53 else
54 rewriter.eraseOp(op);
55 }
56 return success();
57 }
58};
59} // namespace
60
61void mlir::populateOpenACCToSCFConversionPatterns(RewritePatternSet &patterns) {
62 patterns.add<ExpandIfCondition<acc::EnterDataOp>>(patterns.getContext());
63 patterns.add<ExpandIfCondition<acc::ExitDataOp>>(patterns.getContext());
64 patterns.add<ExpandIfCondition<acc::UpdateOp>>(patterns.getContext());
65}
66
67namespace {
68struct ConvertOpenACCToSCFPass
69 : public impl::ConvertOpenACCToSCFBase<ConvertOpenACCToSCFPass> {
70 void runOnOperation() override;
71};
72} // namespace
73
74void ConvertOpenACCToSCFPass::runOnOperation() {
75 auto op = getOperation();
76 auto *context = op.getContext();
77
78 RewritePatternSet patterns(context);
79 ConversionTarget target(*context);
80 populateOpenACCToSCFConversionPatterns(patterns);
81
82 target.addLegalDialect<scf::SCFDialect>();
83 target.addLegalDialect<acc::OpenACCDialect>();
84
85 target.addDynamicallyLegalOp<acc::EnterDataOp>(
86 [](acc::EnterDataOp op) { return !op.getIfCond(); });
87
88 target.addDynamicallyLegalOp<acc::ExitDataOp>(
89 [](acc::ExitDataOp op) { return !op.getIfCond(); });
90
91 target.addDynamicallyLegalOp<acc::UpdateOp>(
92 [](acc::UpdateOp op) { return !op.getIfCond(); });
93
94 if (failed(applyPartialConversion(op, target, std::move(patterns))))
95 signalPassFailure();
96}
97
98std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertOpenACCToSCFPass() {
99 return std::make_unique<ConvertOpenACCToSCFPass>();
100}
101

source code of mlir/lib/Conversion/OpenACCToSCF/OpenACCToSCF.cpp