1//====----- OutlineShapeComputation.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Func/IR/FuncOps.h"
10#include "mlir/Dialect/Shape/Analysis/ShapeMappingAnalysis.h"
11#include "mlir/Dialect/Shape/IR/Shape.h"
12#include "mlir/Dialect/Shape/Transforms/Passes.h"
13#include "mlir/Dialect/Tensor/IR/Tensor.h"
14#include "mlir/IR/IRMapping.h"
15#include "mlir/Transforms/DialectConversion.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17#include "llvm/ADT/DenseSet.h"
18#include "llvm/Support/Debug.h"
19#include <queue>
20#include <vector>
21
22namespace mlir {
23#define GEN_PASS_DEF_OUTLINESHAPECOMPUTATIONPASS
24#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
25} // namespace mlir
26
27#define DEBUG_TYPE "outline-shape-computation"
28
29using namespace mlir;
30
31namespace {
32
33// A Value is an input of the cluster if it is an operand of an operation in the
34// cluster and its defining operation is not in the cluster.
35SmallVector<Value, 4>
36getInputsOfCluster(const llvm::SmallVector<Operation *, 8> &cluster) {
37 SmallVector<Value, 4> inputs;
38 llvm::SmallDenseSet<Value> inputSet;
39 llvm::SmallDenseSet<Operation *> opSet;
40 for (Operation *op : cluster) {
41 bool inserted = opSet.insert(V: op).second;
42 (void)inserted;
43 assert(inserted && "cluster contains duplicate operations");
44 }
45
46 for (Operation *op : cluster) {
47 for (Value operand : op->getOperands()) {
48 Operation *operandOp = operand.getDefiningOp();
49 if (opSet.contains(V: operandOp)) {
50 // Skip if defining op is in the cluster.
51 continue;
52 }
53 if (inputSet.insert(V: operand).second)
54 inputs.push_back(Elt: operand);
55 }
56 }
57 return inputs;
58}
59
60// Create a shape.func representing the shape computation for `shape`.
61std::pair<shape::FuncOp, SmallVector<Value>>
62createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster,
63 Value shape, StringRef fnName, Location loc) {
64 SmallVector<Value, 4> inputs = getInputsOfCluster(cluster);
65 auto fnType =
66 cluster.empty()
67 ? b.getFunctionType(inputs: shape.getType(), results: shape.getType())
68 : b.getFunctionType(inputs: ValueRange(inputs).getTypes(), results: shape.getType());
69 shape::FuncOp fnOp = b.create<shape::FuncOp>(location: loc, args&: fnName, args&: fnType);
70 Block *block = fnOp.addEntryBlock();
71 b.setInsertionPointToEnd(block);
72 IRMapping bvm;
73 if (cluster.empty()) {
74 bvm.map(from: shape, to: fnOp.getArgument(idx: 0));
75 } else {
76 for (auto inputAndArg : llvm::zip(t&: inputs, u: fnOp.getArguments()))
77 bvm.map(from: std::get<0>(t&: inputAndArg), to: std::get<1>(t&: inputAndArg));
78 }
79
80 for (Operation *op : cluster)
81 b.clone(op&: *op, mapper&: bvm);
82 llvm::SmallVector<Value, 4> fnReturns;
83 fnReturns.push_back(Elt: bvm.lookupOrDefault(from: shape));
84
85 b.create<shape::ReturnOp>(location: loc, args&: fnReturns);
86 fnOp.setPrivate();
87 return std::make_pair(x&: fnOp, y&: inputs);
88}
89
90// The operations in the cluster might be unsorted, which could be inconvenient
91// when creating shape.func op.
92DenseMap<Value, SmallVector<Operation *, 8>>
93getOrderedClusters(const DenseMap<Value, DenseSet<Operation *>> &clusters,
94 func::FuncOp funcOp) {
95 // Compute all clusters that each operation is in
96 DenseMap<Operation *, SmallVector<Value>> op2Shapes;
97 for (const auto &it : clusters) {
98 Value shape = it.first;
99 const DenseSet<Operation *> &cluster = it.second;
100 for (Operation *cOp : cluster)
101 op2Shapes[cOp].push_back(Elt: shape);
102 }
103
104 // Iterate through all operations in order. Get all the clusters `cOp` belongs
105 // to and construct the new ordered cluster as it traverses.
106 DenseMap<Value, SmallVector<Operation *, 8>> orderedClusters;
107 funcOp.walk(callback: [&](Operation *op) {
108 auto it = op2Shapes.find(Val: op);
109 if (it != op2Shapes.end()) {
110 Operation *cOp = it->first;
111 for (Value shape : it->second)
112 orderedClusters[shape].push_back(Elt: cOp);
113 }
114 });
115
116 return orderedClusters;
117}
118
119void constructShapeFunc(
120 const std::vector<shape::WithOp> &allWithOps, MLIRContext *context,
121 DenseMap<Value, SmallVector<Operation *, 8>> &clusters,
122 SymbolTable &symbolTable,
123 DenseMap<Value, shape::ShapeMappingValue> &dynShape2ShapeFunc,
124 func::FuncOp funcOp, shape::ShapeMappingAnalysis &shapeMappingAnalysis) {
125 std::string shapeCalculationNamePrefix = "shape_cal_";
126 int shapeCalculationNameIdx = 0;
127 OpBuilder builder(context);
128
129 // Construct a shape function
130 for (shape::WithOp withOp : allWithOps) {
131 Value value = withOp.getOperand();
132 Value shape = withOp.getShape();
133 RankedTensorType rankedType = dyn_cast<RankedTensorType>(Val: value.getType());
134 if (rankedType == nullptr)
135 continue;
136
137 const SmallVector<Operation *, 8> &cluster = clusters[shape];
138 shape::ShapeMappingValue shapeMappingValue;
139 auto it = dynShape2ShapeFunc.find(Val: shape);
140 if (it == dynShape2ShapeFunc.end()) {
141 std::string name = shapeCalculationNamePrefix +
142 std::to_string(val: shapeCalculationNameIdx++);
143 Location loc = value.getLoc();
144 builder.setInsertionPointAfter(funcOp);
145 auto pair = createFuncFromCluster(b&: builder, cluster, shape, fnName: name, loc);
146 const SmallVector<Value> &inputs = pair.second;
147 shape::FuncOp shapeFuncOp = pair.first;
148 StringAttr insertedName = symbolTable.insert(symbol: shapeFuncOp);
149 auto symbol = FlatSymbolRefAttr::get(ctx: context, value: insertedName);
150
151 shapeMappingValue.funcSymbol = symbol;
152 shapeMappingValue.inputs = inputs;
153 } else {
154 shapeMappingValue = it->second;
155 }
156 dynShape2ShapeFunc[shape] = shapeMappingValue;
157 shapeMappingAnalysis.shapeMapping.insert(
158 KV: std::make_pair(x&: value, y&: shapeMappingValue));
159 }
160}
161
162struct OutlineShapeComputationPass
163 : public impl::OutlineShapeComputationPassBase<
164 OutlineShapeComputationPass> {
165
166 void runOnOperation() override;
167
168private:
169 bool calOnlyUsedByWithShapesRecursively(Operation *op, Value prevOutput);
170
171 void getClusterFromValue(Value shape,
172 DenseMap<Value, DenseSet<Operation *>> &clusters);
173
174 DenseMap<Value, SmallVector<Operation *, 8>>
175 constructClustersForEachShape(const std::vector<shape::WithOp> &allWithOps,
176 func::FuncOp funcOp);
177
178 DenseSet<Operation *> onlyUsedByWithShapes;
179};
180
181class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> {
182 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
183
184 LogicalResult matchAndRewrite(tensor::DimOp op,
185 PatternRewriter &rewriter) const override {
186 auto shapeOf =
187 rewriter.create<shape::ShapeOfOp>(location: op.getLoc(), args: op.getSource());
188 rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, args: op.getType(), args&: shapeOf,
189 args: op.getIndex());
190 return success();
191 }
192};
193
194void OutlineShapeComputationPass::runOnOperation() {
195 ModuleOp moduleOp = getOperation();
196 SymbolTable symbolTable(moduleOp);
197 DenseMap<Value, shape::ShapeMappingValue> dynShape2ShapeFunc;
198 auto &shapeMappingAnalysis = getAnalysis<shape::ShapeMappingAnalysis>();
199 // TODO: This is as we populate this analysis during a pass that mutates. This
200 // pass currently requires 1 single module being compiled.
201 shapeMappingAnalysis.shapeMapping.clear();
202 markAnalysesPreserved<shape::ShapeMappingAnalysis>();
203
204 moduleOp.walk(callback: [&](func::FuncOp funcOp) {
205 MLIRContext *context = funcOp.getContext();
206 RewritePatternSet prevPatterns(context);
207 prevPatterns.insert<TensorDimOpRewriter>(arg&: context);
208 if (failed(Result: applyPatternsGreedily(op: funcOp, patterns: std::move(prevPatterns))))
209 return signalPassFailure();
210
211 // initialize class member `onlyUsedByWithShapes`
212 onlyUsedByWithShapes.clear();
213 funcOp.walk(callback: [&](Operation *op) {
214 calOnlyUsedByWithShapesRecursively(op, /*prevOutput=*/nullptr);
215 });
216 LLVM_DEBUG({
217 llvm::dbgs() << "onlyUsedByWithShapes table: \n";
218 for (auto it : onlyUsedByWithShapes)
219 llvm::dbgs() << *it << "\n";
220 });
221
222 // collect all the shape.with_shape ops.
223 std::vector<shape::WithOp> allWithOps;
224 funcOp.walk(callback: [&](shape::WithOp withOp) { allWithOps.push_back(x: withOp); });
225
226 DenseMap<Value, SmallVector<Operation *, 8>> clusters =
227 constructClustersForEachShape(allWithOps, funcOp);
228 constructShapeFunc(allWithOps, context, clusters, symbolTable,
229 dynShape2ShapeFunc, funcOp, shapeMappingAnalysis);
230
231 for (shape::WithOp withOp : allWithOps) {
232 Value value = withOp.getOperand();
233 for (Operation *user :
234 llvm::make_early_inc_range(Range: withOp.getResult().getUsers())) {
235 if (auto valueOf = llvm::dyn_cast<shape::ValueOfOp>(Val: user)) {
236 // For pattern like
237 // %1 = shape.with_shape %arg1, %0
238 // %2 = shape.value_of %1
239 // because shape.value doesn't care the shape, the shape.with_shape is
240 // redundant.
241 // If type of %arg1 and %2 has same type, just
242 // replaced %2 with %arg1.
243 // If type of %arg1 has different type like !shape.value_shape,
244 // transform into
245 // %2 = shape.value_of %arg1
246 if (valueOf.getType() == value.getType())
247 valueOf.replaceAllUsesWith(newValue: value);
248 else
249 valueOf.setOperand(value);
250 }
251 }
252 }
253
254 // Apply patterns, note this also performs DCE.
255 if (failed(Result: applyPatternsGreedily(op: funcOp, patterns: {})))
256 return signalPassFailure();
257 });
258}
259
260DenseMap<Value, SmallVector<Operation *, 8>>
261OutlineShapeComputationPass::constructClustersForEachShape(
262 const std::vector<shape::WithOp> &allWithOps, func::FuncOp funcOp) {
263 DenseMap<Value, DenseSet<Operation *>> clusters;
264 for (shape::WithOp withOp : allWithOps) {
265 Value shape = withOp.getShape();
266 if (clusters.count(Val: shape) == 0)
267 getClusterFromValue(shape, clusters);
268 }
269 return getOrderedClusters(clusters, funcOp);
270}
271
272// The output of a cluster is the `shape`, and the inputs are the outputs of
273// operations who are not in `onlyUsedByWithShapes`
274void OutlineShapeComputationPass::getClusterFromValue(
275 Value shape, DenseMap<Value, DenseSet<Operation *>> &clusters) {
276 DenseSet<Operation *> cluster;
277
278 DenseSet<Operation *> visited;
279 std::queue<Operation *> queue;
280
281 // defOp == nullptr means shape is the argument of the func op
282 if (Operation *defOp = shape.getDefiningOp()) {
283 visited.insert(V: defOp);
284 queue.push(x: defOp);
285 }
286 while (!queue.empty()) {
287 Operation *op = queue.front();
288 queue.pop();
289 if (onlyUsedByWithShapes.contains(V: op)) {
290 cluster.insert(V: op);
291 for (Value inp : op->getOperands()) {
292 Operation *inpDefOp = inp.getDefiningOp();
293 if (nullptr != inpDefOp && visited.insert(V: inpDefOp).second)
294 queue.push(x: inpDefOp);
295 }
296 }
297 }
298
299 clusters[shape] = std::move(cluster);
300}
301
302// Returns whether `op` is a shape.with_shape, or all the users' of `op`
303// eventually point to the shape operand of shape.with_shape ops
304bool OutlineShapeComputationPass::calOnlyUsedByWithShapesRecursively(
305 Operation *op, Value prevOutput) {
306 if (onlyUsedByWithShapes.contains(V: op))
307 return true;
308
309 if (auto withOp = llvm::dyn_cast<shape::WithOp>(Val: op))
310 return withOp.getShape() == prevOutput;
311
312 if (op->use_empty())
313 return false;
314
315 for (Value oup : op->getResults())
316 for (Operation *user : oup.getUsers())
317 if (!calOnlyUsedByWithShapesRecursively(op: user, prevOutput: oup))
318 return false;
319
320 onlyUsedByWithShapes.insert(V: op);
321 return true;
322}
323
324} // namespace
325

source code of mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp