1//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
14#include "mlir/Dialect/Linalg/IR/Linalg.h"
15#include "mlir/Dialect/Tensor/IR/Tensor.h"
16
17using namespace mlir;
18using namespace mlir::bufferization;
19using namespace mlir::linalg;
20
21/// Get an output operand that matches the given input operand and can be used
22/// to eliminate a tensor.empty op.
23static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
24 for (OpOperand &operand : op.getDpsInitsMutable()) {
25 // Operand must be unused.
26 if (op.payloadUsesValueFromOperand(opOperand: &operand))
27 continue;
28 // Types must match.
29 if (operand.get().getType() != in->get().getType())
30 continue;
31 // Indexing maps must match.
32 if (op.getMatchingIndexingMap(opOperand: &operand) != op.getMatchingIndexingMap(opOperand: in))
33 continue;
34 return &operand;
35 }
36 return nullptr;
37}
38
39LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
40 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
41 OpBuilder::InsertionGuard g(rewriter);
42 DominanceInfo domInfo;
43
44 op->walk(callback: [&](LinalgOp op) {
45 // Only ops with all "parallel" iterator types are supported.
46 if (op.getNumParallelLoops() != op.getNumLoops())
47 return WalkResult::skip();
48
49 for (OpOperand *in : op.getDpsInputOperands()) {
50 // Skip non-tensor operands.
51 if (!isa<RankedTensorType>(Val: in->get().getType()))
52 continue;
53
54 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
55 // equivalent tensors. I.e., stop when there are ops such as extract_slice
56 // on the path.
57 TraversalConfig config;
58 config.followEquivalentOnly = true;
59 config.alwaysIncludeLeaves = false;
60 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
61 opOperand: in, /*condition=*/
62 [&](Value val) {
63 return val.getDefiningOp<tensor::EmptyOp>() &&
64 val.getType() == in->get().getType();
65 },
66 config);
67 if (emptyTensors.empty())
68 continue;
69
70 // Find matching out operand.
71 OpOperand *out = getUnusedOutOperand(op, in);
72 if (!out)
73 continue;
74
75 // Check if this transform would violate dominance.
76 if (!llvm::all_of(Range&: emptyTensors, P: [&](Value v) {
77 return domInfo.properlyDominates(a: out->get(), b: v.getDefiningOp());
78 }))
79 continue;
80
81 // Replace all uses of the tensor.empty, but do not delete it yet. It will
82 // fold away later (to not invalidate DominanceInfo).
83 for (Value v : emptyTensors) {
84 assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
85 rewriter.replaceAllUsesWith(from: v, to: out->get());
86 }
87
88 // Turn the "in" into an "out".
89 rewriter.modifyOpInPlace(root: op, callable: [&]() {
90 out->set(in->get());
91 // The original "in" could be removed entirely here (because it will no
92 // longer have any uses in the payload), but we delegate this to
93 // existing cleanup patterns that remove unused operands.
94 in->set(emptyTensors.front());
95 BlockArgument outArg = op.getMatchingBlockArgument(opOperand: out);
96 assert(outArg.getUses().empty() && "expected that out has no uses");
97 BlockArgument inArg = op.getMatchingBlockArgument(opOperand: in);
98 rewriter.replaceAllUsesWith(from: inArg, to: outArg);
99 assert(!op.payloadUsesValueFromOperand(in) &&
100 "expected that the in operand is now unused");
101 });
102
103 state.resetCache();
104 }
105
106 return WalkResult::advance();
107 });
108 return success();
109}
110

source code of mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp