1//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Mesh/Transforms/Simplifications.h"
10#include "TransformsDetail.h"
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/Dialect/Mesh/IR/MeshOps.h"
13#include "mlir/IR/BuiltinTypeInterfaces.h"
14#include "mlir/IR/ImplicitLocOpBuilder.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/IR/SymbolTable.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SmallVector.h"
19#include <numeric>
20#include <utility>
21
22namespace mlir {
23namespace mesh {
24
25void populateSimplificationPatterns(
26 RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) {
27 populateAllReduceEndomorphismSimplificationPatterns<arith::AddFOp>(
28 patterns, ReductionKind::Sum);
29 populateAllReduceEndomorphismSimplificationPatterns<arith::AddIOp>(
30 patterns, ReductionKind::Sum);
31
32 populateAllReduceEndomorphismSimplificationPatterns<arith::MinimumFOp>(
33 patterns, ReductionKind::Min);
34 populateAllReduceEndomorphismSimplificationPatterns<arith::MinSIOp>(
35 patterns, ReductionKind::Min);
36 populateAllReduceEndomorphismSimplificationPatterns<arith::MinUIOp>(
37 patterns, ReductionKind::Min);
38
39 populateAllReduceEndomorphismSimplificationPatterns<arith::MaximumFOp>(
40 patterns, ReductionKind::Max);
41 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxSIOp>(
42 patterns, ReductionKind::Max);
43 populateAllReduceEndomorphismSimplificationPatterns<arith::MaxUIOp>(
44 patterns, ReductionKind::Max);
45
46 // TODO: add simplifications for all-gather and other collectives.
47
48 populateFoldingPatterns(patterns, symbolTableCollection);
49}
50
51namespace {
52
53// This folding can not be done with an operation's fold method or
54// DialectFoldInterface, because it needs a SymbolTableCollection to cache the
55// symbol tables.
56// We can't use DialectFoldInterface since the cache may be invalidated by some
57// pass changing the referenced MeshOp ops.
58struct MeshShapeFolder
59 : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> {
60 using OpRewritePatternWithSymbolTableCollection::
61 OpRewritePatternWithSymbolTableCollection;
62 LogicalResult matchAndRewrite(MeshShapeOp op,
63 PatternRewriter &rewriter) const override {
64 ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
65 MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>(
66 op.getOperation(), op.getMeshAttr());
67 if (!mesh) {
68 return failure();
69 }
70 ArrayRef<MeshAxis> opMeshAxes = op.getAxes();
71 SmallVector<MeshAxis> opAxesIota;
72 if (opMeshAxes.empty()) {
73 opAxesIota.resize(mesh.getRank());
74 std::iota(first: opAxesIota.begin(), last: opAxesIota.end(), value: 0);
75 opMeshAxes = opAxesIota;
76 }
77 if (llvm::all_of(Range&: opMeshAxes, P: [&mesh](MeshAxis axis) {
78 return ShapedType::isDynamic(mesh.getShape()[axis]);
79 })) {
80 // All mesh dimensions are dynamic. Nothing to fold.
81 return failure();
82 }
83
84 SmallVector<Value> newResults(op->getResults().size());
85 SmallVector<MeshAxis> newShapeOpMeshAxes;
86 SmallVector<size_t> newToOldResultsIndexMap;
87
88 for (size_t i = 0; i < opMeshAxes.size(); ++i) {
89 auto meshAxisSize = mesh.getShape()[opMeshAxes[i]];
90 if (ShapedType::isDynamic(meshAxisSize)) {
91 newToOldResultsIndexMap.push_back(Elt: i);
92 newShapeOpMeshAxes.push_back(Elt: opMeshAxes[i]);
93 } else {
94 // Fold static mesh axes.
95 newResults[i] = builder.create<arith::ConstantOp>(
96 builder.getIndexAttr(meshAxisSize));
97 }
98 }
99
100 // Leave only the dynamic mesh axes to be queried.
101 if (!newShapeOpMeshAxes.empty()) {
102 MeshShapeOp newShapeOp =
103 builder.create<MeshShapeOp>(mesh.getSymName(), newShapeOpMeshAxes);
104 for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) {
105 newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i];
106 }
107 }
108 rewriter.replaceOp(op, newResults);
109
110 return success();
111 }
112};
113
114} // namespace
115
116void populateFoldingPatterns(RewritePatternSet &patterns,
117 SymbolTableCollection &symbolTableCollection) {
118 patterns.add<MeshShapeFolder>(arg&: symbolTableCollection, args: patterns.getContext());
119}
120
121} // namespace mesh
122} // namespace mlir
123

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp