1 | //===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
10 | |
11 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
12 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
13 | #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
14 | #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" |
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
16 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace mlir::bufferization; |
20 | using namespace mlir::linalg; |
21 | |
22 | /// Get an output operand that matches the given input operand and can be used |
23 | /// to eliminate a tensor.empty op. |
24 | static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) { |
25 | for (OpOperand &operand : op.getDpsInitsMutable()) { |
26 | // Operand must be unused. |
27 | if (op.payloadUsesValueFromOperand(&operand)) |
28 | continue; |
29 | // Types must match. |
30 | if (operand.get().getType() != in->get().getType()) |
31 | continue; |
32 | // Indexing maps must match. |
33 | if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in)) |
34 | continue; |
35 | return &operand; |
36 | } |
37 | return nullptr; |
38 | } |
39 | |
40 | LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( |
41 | RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { |
42 | OpBuilder::InsertionGuard g(rewriter); |
43 | DominanceInfo domInfo; |
44 | |
45 | op->walk([&](LinalgOp op) { |
46 | // Only ops with all "parallel" iterator types are supported. |
47 | if (op.getNumParallelLoops() != op.getNumLoops()) |
48 | return WalkResult::skip(); |
49 | |
50 | for (OpOperand *in : op.getDpsInputOperands()) { |
51 | // Skip non-tensor operands. |
52 | if (!isa<RankedTensorType>(in->get().getType())) |
53 | continue; |
54 | |
55 | // Find tensor.empty ops on the reverse SSA use-def chain. Only follow |
56 | // equivalent tensors. I.e., stop when there are ops such as extract_slice |
57 | // on the path. |
58 | TraversalConfig config; |
59 | config.followEquivalentOnly = true; |
60 | config.alwaysIncludeLeaves = false; |
61 | SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( |
62 | in->get(), /*condition=*/ |
63 | [&](Value val) { |
64 | return val.getDefiningOp<tensor::EmptyOp>() && |
65 | val.getType() == in->get().getType(); |
66 | }, |
67 | config); |
68 | if (emptyTensors.empty()) |
69 | continue; |
70 | |
71 | // Find matching out operand. |
72 | OpOperand *out = getUnusedOutOperand(op, in); |
73 | if (!out) |
74 | continue; |
75 | |
76 | // Check if this transform would violate dominance. |
77 | if (!llvm::all_of(emptyTensors, [&](Value v) { |
78 | return domInfo.properlyDominates(out->get(), v.getDefiningOp()); |
79 | })) |
80 | continue; |
81 | |
82 | // Replace all uses of the tensor.empty, but do not delete it yet. It will |
83 | // fold away later (to not invalidate DominanceInfo). |
84 | for (Value v : emptyTensors) { |
85 | assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty" ); |
86 | rewriter.replaceAllUsesWith(v, out->get()); |
87 | } |
88 | |
89 | // Turn the "in" into an "out". |
90 | rewriter.modifyOpInPlace(op, [&]() { |
91 | out->set(in->get()); |
92 | // The original "in" could be removed entirely here (because it will no |
93 | // longer have any uses in the payload), but we delegate this to |
94 | // existing cleanup patterns that remove unused operands. |
95 | in->set(emptyTensors.front()); |
96 | BlockArgument outArg = op.getMatchingBlockArgument(out); |
97 | assert(outArg.getUses().empty() && "expected that out has no uses" ); |
98 | BlockArgument inArg = op.getMatchingBlockArgument(in); |
99 | rewriter.replaceAllUsesWith(inArg, outArg); |
100 | assert(!op.payloadUsesValueFromOperand(in) && |
101 | "expected that the in operand is now unused" ); |
102 | }); |
103 | |
104 | state.resetCache(); |
105 | } |
106 | |
107 | return WalkResult::advance(); |
108 | }); |
109 | return success(); |
110 | } |
111 | |