1//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10#include "TransformsDetail.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/IR/BuiltinTypeInterfaces.h"
14#include "mlir/IR/ImplicitLocOpBuilder.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/IR/SymbolTable.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallVector.h"
19#include <numeric>
20
21namespace mlir {
22namespace mesh {
23
24void populateSimplificationPatterns(
25 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
26 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
27 patterns, reduction: ReductionKind::Sum);
28 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
29 patterns, reduction: ReductionKind::Sum);
30
31 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
32 patterns, reduction: ReductionKind::Min);
33 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
34 patterns, reduction: ReductionKind::Min);
35 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
36 patterns, reduction: ReductionKind::Min);
37
38 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
39 patterns, reduction: ReductionKind::Max);
40 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
41 patterns, reduction: ReductionKind::Max);
42 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
43 patterns, reduction: ReductionKind::Max);
44
45 // TODO: add simplifications for all-gather and other collectives.
46
47 populateFoldingPatterns(patterns, symbolTableCollection);
48}
49
50namespace {
51
52// This folding can not be done with an operation's fold method or
53// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
54// symbol tables.
55// We can't use DialectFoldInterface since the cache may be invalidated by some
56// pass changing the referenced MeshOp ops.
57struct MeshShapeFolder
58 : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
59 using OpRewritePatternWithSymbolTableCollection::
60 OpRewritePatternWithSymbolTableCollection;
61 LogicalResult matchAndRewrite(MeshShapeOp op,
62 PatternRewriter &rewriter) const override {
63 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
64 MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
65 from: op.getOperation(), symbol: op.getMeshAttr());
66 if (!mesh) {
67 return failure();
68 }
69 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
70 SmallVector<MeshAxis> opAxesIota;
71 if (opMeshAxes.empty()) {
72 opAxesIota.resize(N: mesh.getRank());
73 std::iota(first: opAxesIota.begin(), last: opAxesIota.end(), value: 0);
74 opMeshAxes = opAxesIota;
75 }
76 if (llvm::all_of(Range&: opMeshAxes, P: [&mesh](MeshAxis axis) {
77 return ShapedType::isDynamic(dValue: mesh.getShape()[axis]);
78 })) {
79 // All mesh dimensions are dynamic. Nothing to fold.
80 return failure();
81 }
82
83 SmallVector<Value> newResults(op->getResults().size());
84 SmallVector<MeshAxis> newShapeOpMeshAxes;
85 SmallVector<size_t> newToOldResultsIndexMap;
86
87 for (size_t i = 0; i < opMeshAxes.size(); ++i) {
88 auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
89 if (ShapedType::isDynamic(dValue: meshAxisSize)) {
90 newToOldResultsIndexMap.push_back(Elt: i);
91 newShapeOpMeshAxes.push_back(Elt: opMeshAxes[i]);
92 } else {
93 // Fold static mesh axes.
94 newResults[i] = builder.create<arith::ConstantOp>(
95 args: builder.getIndexAttr(value: meshAxisSize));
96 }
97 }
98
99 // Leave only the dynamic mesh axes to be queried.
100 if (!newShapeOpMeshAxes.empty()) {
101 MeshShapeOp newShapeOp =
102 builder.create<MeshShapeOp>(args: mesh.getSymName(), args&: newShapeOpMeshAxes);
103 for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
104 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
105 }
106 }
107 rewriter.replaceOp(op, newValues: newResults);
108
109 return success();
110 }
111};
112
113} // namespace
114
115void populateFoldingPatterns(RewritePatternSet &patterns,
116 SymbolTableCollection &symbolTableCollection) {
117 patterns.add<MeshShapeFolder>(arg&: symbolTableCollection, args: patterns.getContext());
118}
119
120} // namespace mesh
121} // namespace mlir
122

source code of mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp