1//====----- OutlineShapeComputation.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
11#include "mlir/Dialect/Shape/IR/Shape.h"
12#include "mlir/Dialect/Shape/Transforms/Passes.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/IR/Matchers.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Transforms/DialectConversion.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/ADT/DenseSet.h"
20#include "llvm/Support/Debug.h"
21#include <queue>
22#include <unordered_set>
23#include <vector>
24
25namespace mlir {
26#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATION
27#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
28} // namespace mlir
29
30#define DEBUG_TYPE "outline-shape-computation"
31
32using namespace mlir;
33
34namespace {
35
36// A Value is an input of the cluster if it is an operand of an operation in the
37// cluster and its defining operation is not in the cluster.
38SmallVector<Value, 4>
39getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
40 SmallVector<Value, 4> inputs;
41 llvm::SmallDenseSet<Value> inputSet;
42 llvm::SmallDenseSet<Operation *> opSet;
43 for (Operation *op : cluster) {
44 bool inserted = opSet.insert(op).second;
45 (void)inserted;
46 assert(inserted && "cluster contains duplicate operations");
47 }
48
49 for (Operation *op : cluster) {
50 for (Value operand : op->getOperands()) {
51 Operation *operandOp = operand.getDefiningOp();
52 if (opSet.contains(operandOp)) {
53 // Skip if defining op is in the cluster.
54 continue;
55 }
56 if (inputSet.insert(operand).second)
57 inputs.push_back(operand);
58 }
59 }
60 return inputs;
61}
62
63// Create a shape.func representing the shape computation for `shape`.
64std::pair<shape::FuncOp, SmallVector<Value>>
65createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
66 Value shape, StringRef fnName, Location loc) {
67 SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
68 auto fnType =
69 cluster.empty()
70 ? b.getFunctionType(shape.getType(), shape.getType())
71 : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType());
72 shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType);
73 Block *block = fnOp.addEntryBlock();
74 b.setInsertionPoint(block, block->end());
75 IRMapping bvm;
76 if (cluster.empty()) {
77 bvm.map(shape, fnOp.getArgument(0));
78 } else {
79 for (auto inputAndArg : llvm::zip(inputs, fnOp.getArguments()))
80 bvm.map(std::get<0>(inputAndArg), std::get<1>(inputAndArg));
81 }
82
83 for (Operation *op : cluster)
84 b.clone(*op, bvm);
85 llvm::SmallVector<Value, 4> fnReturns;
86 fnReturns.push_back(bvm.lookupOrDefault(shape));
87
88 b.create<shape::ReturnOp>(loc, fnReturns);
89 fnOp.setPrivate();
90 return std::make_pair(fnOp, inputs);
91}
92
93// The operations in the cluster might be unsorted, which could be inconvenient
94// when creating shape.func op.
95DenseMap<Value, SmallVector<Operation *, 8>>
96getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
97 func::FuncOp funcOp) {
98 // Compute all clusters that each operation is in
99 DenseMap<Operation *, SmallVector<Value>> op2Shapes;
100 for (const auto &it : clusters) {
101 Value shape = it.first;
102 const DenseSet<Operation *> &cluster = it.second;
103 for (Operation *cOp : cluster)
104 op2Shapes[cOp].push_back(shape);
105 }
106
107 // Iterate through all operations in order. Get all the clusters `cOp` belongs
108 // to and construct the new ordered cluster as it traverses.
109 DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
110 funcOp.walk([&](Operation *op) {
111 auto it = op2Shapes.find(op);
112 if (it != op2Shapes.end()) {
113 Operation *cOp = it->first;
114 for (Value shape : it->second)
115 orderedClusters[shape].push_back(cOp);
116 }
117 });
118
119 return orderedClusters;
120}
121
122void constructShapeFunc(
123 const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
124 DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
125 SymbolTable &symbolTable,
126 DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
127 func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
128 std::string shapeCalculationNamePrefix = "shape_cal_";
129 int shapeCalculationNameIdx = 0;
130 OpBuilder builder(context);
131
132 // Construct a shape function
133 for (shape::WithOp withOp : allWithOps) {
134 Value value = withOp.getOperand();
135 Value shape = withOp.getShape();
136 RankedTensorType rankedType = dyn_cast<RankedTensorType>(value.getType());
137 if (rankedType == nullptr)
138 continue;
139
140 const SmallVector<Operation *, 8> &cluster = clusters[shape];
141 shape::ShapeMappingValue shapeMappingValue;
142 auto it = dynShape2ShapeFunc.find(shape);
143 if (it == dynShape2ShapeFunc.end()) {
144 std::string name = shapeCalculationNamePrefix +
145 std::to_string(shapeCalculationNameIdx++);
146 Location loc = value.getLoc();
147 builder.setInsertionPointAfter(funcOp);
148 auto pair = createFuncFromCluster(builder, cluster, shape, name, loc);
149 const SmallVector<Value> &inputs = pair.second;
150 shape::FuncOp shapeFuncOp = pair.first;
151 StringAttr insertedName = symbolTable.insert(shapeFuncOp);
152 auto symbol = FlatSymbolRefAttr::get(context, insertedName);
153
154 shapeMappingValue.funcSymbol = symbol;
155 shapeMappingValue.inputs = inputs;
156 } else {
157 shapeMappingValue = it->second;
158 }
159 dynShape2ShapeFunc[shape] = shapeMappingValue;
160 shapeMappingAnalysis.shapeMapping.insert(
161 std::make_pair(value, shapeMappingValue));
162 }
163}
164
165struct OutlineShapeComputationPass
166 : public impl::OutlineShapeComputationBase<OutlineShapeComputationPass> {
167
168 void runOnOperation() override;
169
170private:
171 bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
172
173 void getClusterFromValue(Value shape,
174 DenseMap<Value, DenseSet<Operation *>> &clusters);
175
176 DenseMap<Value, SmallVector<Operation *, 8>>
177 constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
178 func::FuncOp funcOp);
179
180 DenseSet<Operation *> onlyUsedByWithShapes;
181};
182
183class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
184 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
185
186 LogicalResult matchAndRewrite(tensor::DimOp op,
187 PatternRewriter &rewriter) const override {
188 auto shapeOf =
189 rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource());
190 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf,
191 op.getIndex());
192 return success();
193 }
194};
195
196void OutlineShapeComputationPass::runOnOperation() {
197 ModuleOp moduleOp = getOperation();
198 SymbolTable symbolTable(moduleOp);
199 DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
200 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
201 // TODO: This is as we populate this analysis during a pass that mutates. This
202 // pass currently requires 1 single module being compiled.
203 shapeMappingAnalysis.shapeMapping.clear();
204 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
205
206 moduleOp.walk([&](func::FuncOp funcOp) {
207 MLIRContext *context = funcOp.getContext();
208 RewritePatternSet prevPatterns(context);
209 prevPatterns.insert<TensorDimOpRewriter>(context);
210 if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
211 return signalPassFailure();
212
213 // initialize class member `onlyUsedByWithShapes`
214 onlyUsedByWithShapes.clear();
215 funcOp.walk([&](Operation *op) {
216 calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
217 });
218 LLVM_DEBUG({
219 llvm::dbgs() << "onlyUsedByWithShapes table: \n";
220 for (auto it : onlyUsedByWithShapes)
221 llvm::dbgs() << *it << "\n";
222 });
223
224 // collect all the shape.with_shape ops.
225 std::vector<shape::WithOp> allWithOps;
226 funcOp.walk([&](shape::WithOp withOp) { allWithOps.push_back(withOp); });
227
228 DenseMap<Value, SmallVector<Operation *, 8>> clusters =
229 constructClustersForEachShape(allWithOps, funcOp);
230 constructShapeFunc(allWithOps, context, clusters, symbolTable,
231 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
232
233 for (shape::WithOp withOp : allWithOps) {
234 Value value = withOp.getOperand();
235 for (Operation *user :
236 llvm::make_early_inc_range(withOp.getResult().getUsers())) {
237 if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(user)) {
238 // For pattern like
239 // %1 = shape.with_shape %arg1, %0
240 // %2 = shape.value_of %1
241 // because shape.value doesn't care the shape, the shape.with_shape is
242 // redundant.
243 // If type of %arg1 and %2 has same type, just
244 // replaced %2 with %arg1.
245 // If type of %arg1 has different type like !shape.value_shape,
246 // transform into
247 // %2 = shape.value_of %arg1
248 if (valueOf.getType() == value.getType())
249 valueOf.replaceAllUsesWith(value);
250 else
251 valueOf.setOperand(value);
252 }
253 }
254 }
255
256 // Apply patterns, note this also performs DCE.
257 if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
258 return signalPassFailure();
259 });
260}
261
262DenseMap<Value, SmallVector<Operation *, 8>>
263OutlineShapeComputationPass::constructClustersForEachShape(
264 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
265 DenseMap<Value, DenseSet<Operation *>> clusters;
266 for (shape::WithOp withOp : allWithOps) {
267 Value shape = withOp.getShape();
268 if (clusters.count(shape) == 0)
269 getClusterFromValue(shape, clusters);
270 }
271 return getOrderedClusters(clusters, funcOp);
272}
273
274// The output of a cluster is the `shape`, and the inputs are the outputs of
275// operations who are not in `onlyUsedByWithShapes`
276void OutlineShapeComputationPass::getClusterFromValue(
277 Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
278 DenseSet<Operation *> cluster;
279
280 DenseSet<Operation *> visited;
281 std::queue<Operation *> queue;
282
283 // defOp == nullptr means shape is the argument of the func op
284 if (Operation *defOp = shape.getDefiningOp()) {
285 visited.insert(defOp);
286 queue.push(defOp);
287 }
288 while (!queue.empty()) {
289 Operation *op = queue.front();
290 queue.pop();
291 if (onlyUsedByWithShapes.contains(op)) {
292 cluster.insert(op);
293 for (Value inp : op->getOperands()) {
294 Operation *inpDefOp = inp.getDefiningOp();
295 if (nullptr != inpDefOp && !visited.contains(inpDefOp)) {
296 visited.insert(inpDefOp);
297 queue.push(inpDefOp);
298 }
299 }
300 }
301 }
302
303 clusters[shape] = std::move(cluster);
304}
305
306// Returns whether `op` is a shape.with_shape, or all the users' of `op`
307// eventually point to the shape operand of shape.with_shape ops
308bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
309 Operation *op, Value prevOutput) {
310 if (onlyUsedByWithShapes.contains(op))
311 return true;
312
313 if (auto withOp = llvm::dyn_cast<shape::WithOp>(op))
314 return withOp.getShape() == prevOutput;
315
316 if (op->use_empty())
317 return false;
318
319 for (Value oup : op->getResults())
320 for (Operation *user : oup.getUsers())
321 if (!calOnlyUsedByWithShapesRecursively(user, oup))
322 return false;
323
324 onlyUsedByWithShapes.insert(op);
325 return true;
326}
327
328} // namespace
329
330std::unique_ptr<OperationPass<ModuleOp>>
331mlir::createOutlineShapeComputationPass() {
332 return std::make_unique<OutlineShapeComputationPass>();
333}
334

source code of mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp