1//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
10
11#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
13#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17
18using namespace mlir;
19using namespace mlir::bufferization;
20using namespace mlir::linalg;
21
22/// Get an output operand that matches the given input operand and can be used
23/// to eliminate a tensor.empty op.
24static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) {
25 for (OpOperand &operand : op.getDpsInitsMutable()) {
26 // Operand must be unused.
27 if (op.payloadUsesValueFromOperand(&operand))
28 continue;
29 // Types must match.
30 if (operand.get().getType() != in->get().getType())
31 continue;
32 // Indexing maps must match.
33 if (op.getMatchingIndexingMap(&operand) != op.getMatchingIndexingMap(in))
34 continue;
35 return &operand;
36 }
37 return nullptr;
38}
39
40LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
41 RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
42 OpBuilder::InsertionGuard g(rewriter);
43 DominanceInfo domInfo;
44
45 op->walk([&](LinalgOp op) {
46 // Only ops with all "parallel" iterator types are supported.
47 if (op.getNumParallelLoops() != op.getNumLoops())
48 return WalkResult::skip();
49
50 for (OpOperand *in : op.getDpsInputOperands()) {
51 // Skip non-tensor operands.
52 if (!isa<RankedTensorType>(in->get().getType()))
53 continue;
54
55 // Find tensor.empty ops on the reverse SSA use-def chain. Only follow
56 // equivalent tensors. I.e., stop when there are ops such as extract_slice
57 // on the path.
58 TraversalConfig config;
59 config.followEquivalentOnly = true;
60 config.alwaysIncludeLeaves = false;
61 SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
62 in->get(), /*condition=*/
63 [&](Value val) {
64 return val.getDefiningOp<tensor::EmptyOp>() &&
65 val.getType() == in->get().getType();
66 },
67 config);
68 if (emptyTensors.empty())
69 continue;
70
71 // Find matching out operand.
72 OpOperand *out = getUnusedOutOperand(op, in);
73 if (!out)
74 continue;
75
76 // Check if this transform would violate dominance.
77 if (!llvm::all_of(emptyTensors, [&](Value v) {
78 return domInfo.properlyDominates(out->get(), v.getDefiningOp());
79 }))
80 continue;
81
82 // Replace all uses of the tensor.empty, but do not delete it yet. It will
83 // fold away later (to not invalidate DominanceInfo).
84 for (Value v : emptyTensors) {
85 assert(v.getDefiningOp<tensor::EmptyOp>() && "expected tensor.empty");
86 rewriter.replaceAllUsesWith(v, out->get());
87 }
88
89 // Turn the "in" into an "out".
90 rewriter.modifyOpInPlace(op, [&]() {
91 out->set(in->get());
92 // The original "in" could be removed entirely here (because it will no
93 // longer have any uses in the payload), but we delegate this to
94 // existing cleanup patterns that remove unused operands.
95 in->set(emptyTensors.front());
96 BlockArgument outArg = op.getMatchingBlockArgument(out);
97 assert(outArg.getUses().empty() && "expected that out has no uses");
98 BlockArgument inArg = op.getMatchingBlockArgument(in);
99 rewriter.replaceAllUsesWith(inArg, outArg);
100 assert(!op.payloadUsesValueFromOperand(in) &&
101 "expected that the in operand is now unused");
102 });
103
104 state.resetCache();
105 }
106
107 return WalkResult::advance();
108 });
109 return success();
110}
111

source code of mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp