1//===- MPIOps.cpp - MPI dialect ops implementation ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/DLTI/DLTI.h"
10#include "mlir/Dialect/MPI/IR/MPI.h"
11#include "mlir/Dialect/MemRef/IR/MemRef.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/BuiltinAttributes.h"
14#include "mlir/IR/PatternMatch.h"
15
16using namespace mlir;
17using namespace mlir::mpi;
18
19namespace {
20
21// If input memref has dynamic shape and is a cast and if the cast's input has
22// static shape, fold the cast's static input into the given operation.
23template <typename OpT>
24struct FoldCast final : public mlir::OpRewritePattern<OpT> {
25 using mlir::OpRewritePattern<OpT>::OpRewritePattern;
26
27 LogicalResult matchAndRewrite(OpT op,
28 mlir::PatternRewriter &b) const override {
29 auto mRef = op.getRef();
30 if (mRef.getType().hasStaticShape()) {
31 return mlir::failure();
32 }
33 auto defOp = mRef.getDefiningOp();
34 if (!defOp || !mlir::isa<mlir::memref::CastOp>(defOp)) {
35 return mlir::failure();
36 }
37 auto src = mlir::cast<mlir::memref::CastOp>(defOp).getSource();
38 if (!src.getType().hasStaticShape()) {
39 return mlir::failure();
40 }
41 op.getRefMutable().assign(src);
42 return mlir::success();
43 }
44};
45
46struct FoldRank final : public mlir::OpRewritePattern<mlir::mpi::CommRankOp> {
47 using mlir::OpRewritePattern<mlir::mpi::CommRankOp>::OpRewritePattern;
48
49 LogicalResult matchAndRewrite(mlir::mpi::CommRankOp op,
50 mlir::PatternRewriter &b) const override {
51 auto comm = op.getComm();
52 if (!comm.getDefiningOp<mlir::mpi::CommWorldOp>())
53 return mlir::failure();
54
55 // Try to get DLTI attribute for MPI:comm_world_rank
56 // If found, set worldRank to the value of the attribute.
57 auto dltiAttr = dlti::query(op, keys: {"MPI:comm_world_rank"}, emitError: false);
58 if (failed(Result: dltiAttr))
59 return mlir::failure();
60 if (!isa<IntegerAttr>(Val: dltiAttr.value()))
61 return op->emitError()
62 << "Expected an integer attribute for MPI:comm_world_rank";
63 Value res = b.create<arith::ConstantIndexOp>(
64 location: op.getLoc(), args: cast<IntegerAttr>(Val&: dltiAttr.value()).getInt());
65 if (Value retVal = op.getRetval())
66 b.replaceOp(op, newValues: {retVal, res});
67 else
68 b.replaceOp(op, newValues: res);
69 return mlir::success();
70 }
71};
72
73} // namespace
74
75void mlir::mpi::SendOp::getCanonicalizationPatterns(
76 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
77 results.add<FoldCast<mlir::mpi::SendOp>>(arg&: context);
78}
79
80void mlir::mpi::RecvOp::getCanonicalizationPatterns(
81 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
82 results.add<FoldCast<mlir::mpi::RecvOp>>(arg&: context);
83}
84
85void mlir::mpi::ISendOp::getCanonicalizationPatterns(
86 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
87 results.add<FoldCast<mlir::mpi::ISendOp>>(arg&: context);
88}
89
90void mlir::mpi::IRecvOp::getCanonicalizationPatterns(
91 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
92 results.add<FoldCast<mlir::mpi::IRecvOp>>(arg&: context);
93}
94
95void mlir::mpi::CommRankOp::getCanonicalizationPatterns(
96 mlir::RewritePatternSet &results, mlir::MLIRContext *context) {
97 results.add<FoldRank>(arg&: context);
98}
99
100//===----------------------------------------------------------------------===//
101// TableGen'd op method definitions
102//===----------------------------------------------------------------------===//
103
104#define GET_OP_CLASSES
105#include "mlir/Dialect/MPI/IR/MPIOps.cpp.inc"
106

source code of mlir/lib/Dialect/MPI/IR/MPIOps.cpp