1//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/MPI/IR/MPI.h"
10#include "mlir/Dialect/MemRef/IR/MemRef.h"
11#include "mlir/IR/Builders.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/PatternMatch.h"
14
15using namespace mlir;
16using namespace mlir::mpi;
17
18namespace {
19
20// If input memref has dynamic shape and is a cast and if the cast's input has
21// static shape, fold the cast's static input into the given operation.
22template <typename OpT>
23struct FoldCast final : public mlir::OpRewritePattern<OpT> {
24 using mlir::OpRewritePattern<OpT>::OpRewritePattern;
25
26 LogicalResult matchAndRewrite(OpT op,
27 mlir::PatternRewriter &b) const override {
28 auto mRef = op.getRef();
29 if (mRef.getType().hasStaticShape()) {
30 return mlir::failure();
31 }
32 auto defOp = mRef.getDefiningOp();
33 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
34 return mlir::failure();
35 }
36 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
37 if (!src.getType().hasStaticShape()) {
38 return mlir::failure();
39 }
40 op.getRefMutable().assign(src);
41 return mlir::success();
42 }
43};
44} // namespace
45
46void mlir::mpi::SendOp::getCanonicalizationPatterns(
47 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
48 results.add<FoldCast<mlir::mpi::SendOp>>(context);
49}
50
51void mlir::mpi::RecvOp::getCanonicalizationPatterns(
52 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
53 results.add<FoldCast<mlir::mpi::RecvOp>>(context);
54}
55
56void mlir::mpi::ISendOp::getCanonicalizationPatterns(
57 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
58 results.add<FoldCast<mlir::mpi::ISendOp>>(context);
59}
60
61void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
62 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
63 results.add<FoldCast<mlir::mpi::IRecvOp>>(context);
64}
65
66//===----------------------------------------------------------------------===//
67// TableGen'd op method definitions
68//===----------------------------------------------------------------------===//
69
70#define GET_OP_CLASSES
71#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
72

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/MPI/IR/MPIOps.cpp