1//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10#include "TransformsDetail.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/IR/BuiltinTypeInterfaces.h"
14#include "mlir/IR/ImplicitLocOpBuilder.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/IR/SymbolTable.h"
17#include "mlir/Support/LogicalResult.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVector.h"
20#include <numeric>
21#include <utility>
22
23namespace mlir {
24namespace mesh {
25
26void populateSimplificationPatterns(
27 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
28 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
29 patterns, ReductionKind::Sum);
30 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
31 patterns, ReductionKind::Sum);
32
33 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
34 patterns, ReductionKind::Min);
35 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
36 patterns, ReductionKind::Min);
37 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
38 patterns, ReductionKind::Min);
39
40 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
41 patterns, ReductionKind::Max);
42 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
43 patterns, ReductionKind::Max);
44 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
45 patterns, ReductionKind::Max);
46
47 // TODO: add simplifications for all-gather and other collectives.
48
49 populateFoldingPatterns(patterns, symbolTableCollection);
50}
51
52namespace {
53
54// This folding can not be done with an operation's fold method or
55// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
56// symbol tables.
57// We can't use DialectFoldInterface since the cache may be invalidated by some
58// pass changing the referenced MeshOp ops.
59struct MeshShapeFolder
60 : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
61 using OpRewritePatternWithSymbolTableCollection::
62 OpRewritePatternWithSymbolTableCollection;
63 LogicalResult matchAndRewrite(MeshShapeOp op,
64 PatternRewriter &rewriter) const override {
65 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
66 MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
67 op.getOperation(), op.getMeshAttr());
68 if (!mesh) {
69 return failure();
70 }
71 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
72 SmallVector<MeshAxis> opAxesIota;
73 if (opMeshAxes.empty()) {
74 opAxesIota.resize(mesh.getRank());
75 std::iota(first: opAxesIota.begin(), last: opAxesIota.end(), value: 0);
76 opMeshAxes = opAxesIota;
77 }
78 if (llvm::all_of(Range&: opMeshAxes, P: [&mesh](MeshAxis axis) {
79 return ShapedType::isDynamic(mesh.getShape()[axis]);
80 })) {
81 // All mesh dimensions are dynamic. Nothing to fold.
82 return failure();
83 }
84
85 SmallVector<Value> newResults(op->getResults().size());
86 SmallVector<MeshAxis> newShapeOpMeshAxes;
87 SmallVector<size_t> newToOldResultsIndexMap;
88
89 for (size_t i = 0; i < opMeshAxes.size(); ++i) {
90 auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
91 if (ShapedType::isDynamic(meshAxisSize)) {
92 newToOldResultsIndexMap.push_back(Elt: i);
93 newShapeOpMeshAxes.push_back(Elt: opMeshAxes[i]);
94 } else {
95 // Fold static mesh axes.
96 newResults[i] = builder.create<arith::ConstantOp>(
97 builder.getIndexAttr(meshAxisSize));
98 }
99 }
100
101 // Leave only the dynamic mesh axes to be queried.
102 if (!newShapeOpMeshAxes.empty()) {
103 MeshShapeOp newShapeOp =
104 builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
105 for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
106 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
107 }
108 }
109 rewriter.replaceOp(op, newResults);
110
111 return success();
112 }
113};
114
115} // namespace
116
117void populateFoldingPatterns(RewritePatternSet &patterns,
118 SymbolTableCollection &symbolTableCollection) {
119 patterns.add<MeshShapeFolder>(arg&: symbolTableCollection, args: patterns.getContext());
120}
121
122} // namespace mesh
123} // namespace mlir
124

source code of mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp