1//===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The goal of this pass is optimization (reducing runtime) by removing
10// unnecessary instructions. Unlike other passes that rely on local information
11// gathered from patterns to accomplish optimization, this pass uses a full
12// analysis of the IR, specifically, liveness analysis, and is thus more
13// powerful.
14//
15// Currently, this pass performs the following optimizations:
16// (A) Removes function arguments that are not live,
17// (B) Removes function return values that are not live across all callers of
18// the function,
19// (C) Removes unneccesary operands, results, region arguments, and region
20// terminator operands of region branch ops, and,
21// (D) Removes simple and region branch ops that have all non-live results and
22// don't affect memory in any way,
23//
24// iff
25//
26// the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27// branch ops.
28//
29// Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30// region branch op, branch op, region branch terminator op, or return-like.
31//
32//===----------------------------------------------------------------------===//
33
34#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
35#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
36#include "mlir/IR/Builders.h"
37#include "mlir/IR/BuiltinAttributes.h"
38#include "mlir/IR/Dialect.h"
39#include "mlir/IR/OperationSupport.h"
40#include "mlir/IR/SymbolTable.h"
41#include "mlir/IR/Value.h"
42#include "mlir/IR/ValueRange.h"
43#include "mlir/IR/Visitors.h"
44#include "mlir/Interfaces/CallInterfaces.h"
45#include "mlir/Interfaces/ControlFlowInterfaces.h"
46#include "mlir/Interfaces/FunctionInterfaces.h"
47#include "mlir/Interfaces/SideEffectInterfaces.h"
48#include "mlir/Pass/Pass.h"
49#include "mlir/Support/LLVM.h"
50#include "mlir/Transforms/FoldUtils.h"
51#include "mlir/Transforms/Passes.h"
52#include "llvm/ADT/STLExtras.h"
53#include "llvm/Support/Debug.h"
54#include <cassert>
55#include <cstddef>
56#include <memory>
57#include <optional>
58#include <vector>
59
60#define DEBUG_TYPE "remove-dead-values"
61#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
62#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
63
64namespace mlir {
65#define GEN_PASS_DEF_REMOVEDEADVALUES
66#include "mlir/Transforms/Passes.h.inc"
67} // namespace mlir
68
69using namespace mlir;
70using namespace mlir::dataflow;
71
72//===----------------------------------------------------------------------===//
73// RemoveDeadValues Pass
74//===----------------------------------------------------------------------===//
75
76namespace {
77
78// Set of structures below to be filled with operations and arguments to erase.
79// This is done to separate analysis and tree modification phases,
80// otherwise analysis is operating on half-deleted tree which is incorrect.
81
82struct FunctionToCleanUp {
83 FunctionOpInterface funcOp;
84 BitVector nonLiveArgs;
85 BitVector nonLiveRets;
86};
87
88struct OperationToCleanup {
89 Operation *op;
90 BitVector nonLive;
91};
92
93struct BlockArgsToCleanup {
94 Block *b;
95 BitVector nonLiveArgs;
96};
97
98struct SuccessorOperandsToCleanup {
99 BranchOpInterface branch;
100 unsigned successorIndex;
101 BitVector nonLiveOperands;
102};
103
104struct RDVFinalCleanupList {
105 SmallVector<Operation *> operations;
106 SmallVector<Value> values;
107 SmallVector<FunctionToCleanUp> functions;
108 SmallVector<OperationToCleanup> operands;
109 SmallVector<OperationToCleanup> results;
110 SmallVector<BlockArgsToCleanup> blocks;
111 SmallVector<SuccessorOperandsToCleanup> successorOperands;
112};
113
114// Some helper functions...
115
116/// Return true iff at least one value in `values` is live, given the liveness
117/// information in `la`.
118static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
119 RunLivenessAnalysis &la) {
120 for (Value value : values) {
121 if (nonLiveSet.contains(V: value)) {
122 LDBG("Value " << value << " is already marked non-live (dead)");
123 continue;
124 }
125
126 const Liveness *liveness = la.getLiveness(val: value);
127 if (!liveness) {
128 LDBG("Value " << value
129 << " has no liveness info, conservatively considered live");
130 return true;
131 }
132 if (liveness->isLive) {
133 LDBG("Value " << value << " is live according to liveness analysis");
134 return true;
135 } else {
136 LDBG("Value " << value << " is dead according to liveness analysis");
137 }
138 }
139 return false;
140}
141
142/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
143/// i-th value in `values` is live, given the liveness information in `la`.
144static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
145 RunLivenessAnalysis &la) {
146 BitVector lives(values.size(), true);
147
148 for (auto [index, value] : llvm::enumerate(First&: values)) {
149 if (nonLiveSet.contains(V: value)) {
150 lives.reset(Idx: index);
151 LDBG("Value " << value << " is already marked non-live (dead) at index "
152 << index);
153 continue;
154 }
155
156 const Liveness *liveness = la.getLiveness(val: value);
157 // It is important to note that when `liveness` is null, we can't tell if
158 // `value` is live or not. So, the safe option is to consider it live. Also,
159 // the execution of this pass might create new SSA values when erasing some
160 // of the results of an op and we know that these new values are live
161 // (because they weren't erased) and also their liveness is null because
162 // liveness analysis ran before their creation.
163 if (!liveness) {
164 LDBG("Value " << value << " at index " << index
165 << " has no liveness info, conservatively considered live");
166 continue;
167 }
168 if (!liveness->isLive) {
169 lives.reset(Idx: index);
170 LDBG("Value " << value << " at index " << index
171 << " is dead according to liveness analysis");
172 } else {
173 LDBG("Value " << value << " at index " << index
174 << " is live according to liveness analysis");
175 }
176 }
177
178 return lives;
179}
180
181/// Collects values marked as "non-live" in the provided range and inserts them
182/// into the nonLiveSet. A value is considered "non-live" if the corresponding
183/// index in the `nonLive` bit vector is set.
184static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
185 const BitVector &nonLive) {
186 for (auto [index, result] : llvm::enumerate(First&: range)) {
187 if (!nonLive[index])
188 continue;
189 nonLiveSet.insert(V: result);
190 LDBG("Marking value " << result << " as non-live (dead) at index "
191 << index);
192 }
193}
194
195/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
196/// is 1.
197static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
198 assert(op->getNumResults() == toErase.size() &&
199 "expected the number of results in `op` and the size of `toErase` to "
200 "be the same");
201
202 std::vector<Type> newResultTypes;
203 for (OpResult result : op->getResults())
204 if (!toErase[result.getResultNumber()])
205 newResultTypes.push_back(x: result.getType());
206 OpBuilder builder(op);
207 builder.setInsertionPointAfter(op);
208 OperationState state(op->getLoc(), op->getName().getStringRef(),
209 op->getOperands(), newResultTypes, op->getAttrs());
210 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
211 state.addRegion();
212 Operation *newOp = builder.create(state);
213 for (const auto &[index, region] : llvm::enumerate(First: op->getRegions())) {
214 Region &newRegion = newOp->getRegion(index);
215 // Move all blocks of `region` into `newRegion`.
216 Block *temp = new Block();
217 newRegion.push_back(block: temp);
218 while (!region.empty())
219 region.front().moveBefore(block: temp);
220 temp->erase();
221 }
222
223 unsigned indexOfNextNewCallOpResultToReplace = 0;
224 for (auto [index, result] : llvm::enumerate(First: op->getResults())) {
225 assert(result && "expected result to be non-null");
226 if (toErase[index]) {
227 result.dropAllUses();
228 } else {
229 result.replaceAllUsesWith(
230 newValue: newOp->getResult(idx: indexOfNextNewCallOpResultToReplace++));
231 }
232 }
233 op->erase();
234}
235
236/// Convert a list of `Operand`s to a list of `OpOperand`s.
237static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
238 OpOperand *values = operands.getBase();
239 SmallVector<OpOperand *> opOperands;
240 for (unsigned i = 0, e = operands.size(); i < e; i++)
241 opOperands.push_back(Elt: &values[i]);
242 return opOperands;
243}
244
245/// Process a simple operation `op` using the liveness analysis `la`.
246/// If the operation has no memory effects and none of its results are live:
247/// 1. Add the operation to a list for future removal, and
248/// 2. Mark all its results as non-live values
249///
250/// The operation `op` is assumed to be simple. A simple operation is one that
251/// is NOT:
252/// - Function-like
253/// - Call-like
254/// - A region branch operation
255/// - A branch operation
256/// - A region branch terminator
257/// - Return-like
258static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
259 DenseSet<Value> &nonLiveSet,
260 RDVFinalCleanupList &cl) {
261 if (!isMemoryEffectFree(op) || hasLive(values: op->getResults(), nonLiveSet, la)) {
262 LLVM_DEBUG({
263 llvm::dbgs()
264 << "Simple op is not memory effect free or has live results, "
265 "preserving it: "
266 << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
267 });
268 return;
269 }
270
271 LLVM_DEBUG({
272 llvm::dbgs() << "Simple op has all dead results and is memory effect free, "
273 "scheduling "
274 "for removal: "
275 << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
276 });
277 cl.operations.push_back(Elt: op);
278 collectNonLiveValues(nonLiveSet, range: op->getResults(),
279 nonLive: BitVector(op->getNumResults(), true));
280}
281
282/// Process a function-like operation `funcOp` using the liveness analysis `la`
283/// and the IR in `module`. If it is not public or external:
284/// (1) Adding its non-live arguments to a list for future removal.
285/// (2) Marking their corresponding operands in its callers for removal.
286/// (3) Identifying and enqueueing unnecessary terminator operands
287/// (return values that are non-live across all callers) for removal.
288/// (4) Enqueueing the non-live arguments and return values for removal.
289/// (5) Collecting the uses of these return values in its callers for future
290/// removal.
291/// (6) Marking all its results as non-live values.
292static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
293 RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
294 RDVFinalCleanupList &cl) {
295 LDBG("Processing function op: " << funcOp.getOperation()->getName());
296 if (funcOp.isPublic() || funcOp.isExternal()) {
297 LDBG("Function is public or external, skipping: "
298 << funcOp.getOperation()->getName());
299 return;
300 }
301
302 // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
303 SmallVector<Value> arguments(funcOp.getArguments());
304 BitVector nonLiveArgs = markLives(values: arguments, nonLiveSet, la);
305 nonLiveArgs = nonLiveArgs.flip();
306
307 // Do (1).
308 for (auto [index, arg] : llvm::enumerate(First&: arguments))
309 if (arg && nonLiveArgs[index]) {
310 cl.values.push_back(Elt: arg);
311 nonLiveSet.insert(V: arg);
312 }
313
314 // Do (2).
315 SymbolTable::UseRange uses = *funcOp.getSymbolUses(from: module);
316 for (SymbolTable::SymbolUse use : uses) {
317 Operation *callOp = use.getUser();
318 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
319 // The number of operands in the call op may not match the number of
320 // arguments in the func op.
321 BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
322 SmallVector<OpOperand *> callOpOperands =
323 operandsToOpOperands(operands: cast<CallOpInterface>(Val: callOp).getArgOperands());
324 for (int index : nonLiveArgs.set_bits())
325 nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
326 cl.operands.push_back(Elt: {.op: callOp, .nonLive: nonLiveCallOperands});
327 }
328
329 // Do (3).
330 // Get the list of unnecessary terminator operands (return values that are
331 // non-live across all callers) in `nonLiveRets`. There is a very important
332 // subtlety here. Unnecessary terminator operands are NOT the operands of the
333 // terminator that are non-live. Instead, these are the return values of the
334 // callers such that a given return value is non-live across all callers. Such
335 // corresponding operands in the terminator could be live. An example to
336 // demonstrate this:
337 // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
338 // %c0_i32 = arith.constant 0 : i32
339 // %0 = arith.addi %c0_i32, %c0_i32 : i32
340 // memref.store %0, %arg0[] : memref<i32>
341 // return %c0_i32, %0 : i32, i32
342 // }
343 // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
344 // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
345 // return %1#0 : i32
346 // }
347 // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
348 // need to return %0. But, %0 is live. And, still, we want to stop it from
349 // being returned, in order to optimize our IR. So, this demonstrates how we
350 // can make our optimization strong by even removing a live return value (%0),
351 // since it forwards only to non-live value(s) (%1#1).
352 Operation *lastReturnOp = funcOp.back().getTerminator();
353 size_t numReturns = lastReturnOp->getNumOperands();
354 BitVector nonLiveRets(numReturns, true);
355 for (SymbolTable::SymbolUse use : uses) {
356 Operation *callOp = use.getUser();
357 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
358 BitVector liveCallRets = markLives(values: callOp->getResults(), nonLiveSet, la);
359 nonLiveRets &= liveCallRets.flip();
360 }
361
362 // Note that in the absence of control flow ops forcing the control to go from
363 // the entry (first) block to the other blocks, the control never reaches any
364 // block other than the entry block, because every block has a terminator.
365 for (Block &block : funcOp.getBlocks()) {
366 Operation *returnOp = block.getTerminator();
367 if (returnOp && returnOp->getNumOperands() == numReturns)
368 cl.operands.push_back(Elt: {.op: returnOp, .nonLive: nonLiveRets});
369 }
370
371 // Do (4).
372 cl.functions.push_back(Elt: {.funcOp: funcOp, .nonLiveArgs: nonLiveArgs, .nonLiveRets: nonLiveRets});
373
374 // Do (5) and (6).
375 if (numReturns == 0)
376 return;
377 for (SymbolTable::SymbolUse use : uses) {
378 Operation *callOp = use.getUser();
379 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
380 cl.results.push_back(Elt: {.op: callOp, .nonLive: nonLiveRets});
381 collectNonLiveValues(nonLiveSet, range: callOp->getResults(), nonLive: nonLiveRets);
382 }
383}
384
385/// Process a region branch operation `regionBranchOp` using the liveness
386/// information in `la`. The processing involves two scenarios:
387///
388/// Scenario 1: If the operation has no memory effects and none of its results
389/// are live:
390/// (1') Enqueue all its uses for deletion.
391/// (2') Enqueue the branch itself for deletion.
392///
393/// Scenario 2: Otherwise:
394/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
395/// results or arguments).
396/// (2) Process each of its regions.
397/// (3) Collect the uses of its unnecessary results (results forwarded from
398/// unnecessary operands
399/// or terminator operands).
400/// (4) Add these results to the deletion list.
401///
402/// Processing a region includes:
403/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
404/// from unnecessary operands
405/// or terminator operands).
406/// (b) Collecting these unnecessary arguments.
407/// (c) Collecting its unnecessary terminator operands (terminator operands
408/// forwarded to unnecessary results
409/// or arguments).
410///
411/// Value Flow Note: In this operation, values flow as follows:
412/// - From operands and terminator operands (successor operands)
413/// - To arguments and results (successor inputs).
414static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
415 RunLivenessAnalysis &la,
416 DenseSet<Value> &nonLiveSet,
417 RDVFinalCleanupList &cl) {
418 LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print(
419 llvm::dbgs(), OpPrintingFlags().skipRegions());
420 llvm::dbgs() << "\n");
421 // Mark live results of `regionBranchOp` in `liveResults`.
422 auto markLiveResults = [&](BitVector &liveResults) {
423 liveResults = markLives(values: regionBranchOp->getResults(), nonLiveSet, la);
424 };
425
426 // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
427 auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
428 for (Region &region : regionBranchOp->getRegions()) {
429 if (region.empty())
430 continue;
431 SmallVector<Value> arguments(region.front().getArguments());
432 BitVector regionLiveArgs = markLives(values: arguments, nonLiveSet, la);
433 liveArgs[&region] = regionLiveArgs;
434 }
435 };
436
437 // Return the successors of `region` if the latter is not null. Else return
438 // the successors of `regionBranchOp`.
439 auto getSuccessors = [&](Region *region = nullptr) {
440 auto point = region ? region : RegionBranchPoint::parent();
441 SmallVector<RegionSuccessor> successors;
442 regionBranchOp.getSuccessorRegions(point, regions&: successors);
443 return successors;
444 };
445
446 // Return the operands of `terminator` that are forwarded to `successor` if
447 // the former is not null. Else return the operands of `regionBranchOp`
448 // forwarded to `successor`.
449 auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
450 Operation *terminator = nullptr) {
451 OperandRange operands =
452 terminator ? cast<RegionBranchTerminatorOpInterface>(Val: terminator)
453 .getSuccessorOperands(point: successor)
454 : regionBranchOp.getEntrySuccessorOperands(point: successor);
455 SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
456 return opOperands;
457 };
458
459 // Mark the non-forwarded operands of `regionBranchOp` in
460 // `nonForwardedOperands`.
461 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
462 nonForwardedOperands.resize(N: regionBranchOp->getNumOperands(), t: true);
463 for (const RegionSuccessor &successor : getSuccessors()) {
464 for (OpOperand *opOperand : getForwardedOpOperands(successor))
465 nonForwardedOperands.reset(Idx: opOperand->getOperandNumber());
466 }
467 };
468
469 // Mark the non-forwarded terminator operands of the various regions of
470 // `regionBranchOp` in `nonForwardedRets`.
471 auto markNonForwardedReturnValues =
472 [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
473 for (Region &region : regionBranchOp->getRegions()) {
474 if (region.empty())
475 continue;
476 Operation *terminator = region.front().getTerminator();
477 nonForwardedRets[terminator] =
478 BitVector(terminator->getNumOperands(), true);
479 for (const RegionSuccessor &successor : getSuccessors(&region)) {
480 for (OpOperand *opOperand :
481 getForwardedOpOperands(successor, terminator))
482 nonForwardedRets[terminator].reset(Idx: opOperand->getOperandNumber());
483 }
484 }
485 };
486
487 // Update `valuesToKeep` (which is expected to correspond to operands or
488 // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
489 // `region`. When `valuesToKeep` correspond to operands, `region` is null.
490 // Else, `region` is the parent region of the terminator.
491 auto updateOperandsOrTerminatorOperandsToKeep =
492 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
493 DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
494 Operation *terminator =
495 region ? region->front().getTerminator() : nullptr;
496
497 for (const RegionSuccessor &successor : getSuccessors(region)) {
498 Region *successorRegion = successor.getSuccessor();
499 for (auto [opOperand, input] :
500 llvm::zip(t: getForwardedOpOperands(successor, terminator),
501 u: successor.getSuccessorInputs())) {
502 size_t operandNum = opOperand->getOperandNumber();
503 bool updateBasedOn =
504 successorRegion
505 ? argsToKeep[successorRegion]
506 [cast<BlockArgument>(Val&: input).getArgNumber()]
507 : resultsToKeep[cast<OpResult>(Val&: input).getResultNumber()];
508 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
509 }
510 }
511 };
512
513 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
514 // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
515 // value is modified, else, false.
516 auto recomputeResultsAndArgsToKeep =
517 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
518 BitVector &operandsToKeep,
519 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
520 bool &resultsOrArgsToKeepChanged) {
521 resultsOrArgsToKeepChanged = false;
522
523 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
524 for (const RegionSuccessor &successor : getSuccessors()) {
525 Region *successorRegion = successor.getSuccessor();
526 for (auto [opOperand, input] :
527 llvm::zip(t: getForwardedOpOperands(successor),
528 u: successor.getSuccessorInputs())) {
529 bool recomputeBasedOn =
530 operandsToKeep[opOperand->getOperandNumber()];
531 bool toRecompute =
532 successorRegion
533 ? argsToKeep[successorRegion]
534 [cast<BlockArgument>(Val&: input).getArgNumber()]
535 : resultsToKeep[cast<OpResult>(Val&: input).getResultNumber()];
536 if (!toRecompute && recomputeBasedOn)
537 resultsOrArgsToKeepChanged = true;
538 if (successorRegion) {
539 argsToKeep[successorRegion][cast<BlockArgument>(Val&: input)
540 .getArgNumber()] =
541 argsToKeep[successorRegion]
542 [cast<BlockArgument>(Val&: input).getArgNumber()] |
543 recomputeBasedOn;
544 } else {
545 resultsToKeep[cast<OpResult>(Val&: input).getResultNumber()] =
546 resultsToKeep[cast<OpResult>(Val&: input).getResultNumber()] |
547 recomputeBasedOn;
548 }
549 }
550 }
551
552 // Recompute `resultsToKeep` and `argsToKeep` based on
553 // `terminatorOperandsToKeep`.
554 for (Region &region : regionBranchOp->getRegions()) {
555 if (region.empty())
556 continue;
557 Operation *terminator = region.front().getTerminator();
558 for (const RegionSuccessor &successor : getSuccessors(&region)) {
559 Region *successorRegion = successor.getSuccessor();
560 for (auto [opOperand, input] :
561 llvm::zip(t: getForwardedOpOperands(successor, terminator),
562 u: successor.getSuccessorInputs())) {
563 bool recomputeBasedOn =
564 terminatorOperandsToKeep[region.back().getTerminator()]
565 [opOperand->getOperandNumber()];
566 bool toRecompute =
567 successorRegion
568 ? argsToKeep[successorRegion]
569 [cast<BlockArgument>(Val&: input).getArgNumber()]
570 : resultsToKeep[cast<OpResult>(Val&: input).getResultNumber()];
571 if (!toRecompute && recomputeBasedOn)
572 resultsOrArgsToKeepChanged = true;
573 if (successorRegion) {
574 argsToKeep[successorRegion][cast<BlockArgument>(Val&: input)
575 .getArgNumber()] =
576 argsToKeep[successorRegion]
577 [cast<BlockArgument>(Val&: input).getArgNumber()] |
578 recomputeBasedOn;
579 } else {
580 resultsToKeep[cast<OpResult>(Val&: input).getResultNumber()] =
581 resultsToKeep[cast<OpResult>(Val&: input).getResultNumber()] |
582 recomputeBasedOn;
583 }
584 }
585 }
586 }
587 };
588
589 // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
590 // `operandsToKeep`, and `terminatorOperandsToKeep`.
591 auto markValuesToKeep =
592 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
593 BitVector &operandsToKeep,
594 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
595 bool resultsOrArgsToKeepChanged = true;
596 // We keep updating and recomputing the values until we reach a point
597 // where they stop changing.
598 while (resultsOrArgsToKeepChanged) {
599 // Update the operands that need to be kept.
600 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
601 resultsToKeep, argsToKeep);
602
603 // Update the terminator operands that need to be kept.
604 for (Region &region : regionBranchOp->getRegions()) {
605 if (region.empty())
606 continue;
607 updateOperandsOrTerminatorOperandsToKeep(
608 terminatorOperandsToKeep[region.back().getTerminator()],
609 resultsToKeep, argsToKeep, &region);
610 }
611
612 // Recompute the results and arguments that need to be kept.
613 recomputeResultsAndArgsToKeep(
614 resultsToKeep, argsToKeep, operandsToKeep,
615 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
616 }
617 };
618
619 // Scenario 1. This is the only case where the entire `regionBranchOp`
620 // is removed. It will not happen in any other scenario. Note that in this
621 // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
622 // It could never be live because of this op but its liveness could have been
623 // attributed to something else.
624 // Do (1') and (2').
625 if (isMemoryEffectFree(op: regionBranchOp.getOperation()) &&
626 !hasLive(values: regionBranchOp->getResults(), nonLiveSet, la)) {
627 cl.operations.push_back(Elt: regionBranchOp.getOperation());
628 return;
629 }
630
631 // Scenario 2.
632 // At this point, we know that every non-forwarded operand of `regionBranchOp`
633 // is live.
634
635 // Stores the results of `regionBranchOp` that we want to keep.
636 BitVector resultsToKeep;
637 // Stores the mapping from regions of `regionBranchOp` to their arguments that
638 // we want to keep.
639 DenseMap<Region *, BitVector> argsToKeep;
640 // Stores the operands of `regionBranchOp` that we want to keep.
641 BitVector operandsToKeep;
642 // Stores the mapping from region terminators in `regionBranchOp` to their
643 // operands that we want to keep.
644 DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
645
646 // Initializing the above variables...
647
648 // The live results of `regionBranchOp` definitely need to be kept.
649 markLiveResults(resultsToKeep);
650 // Similarly, the live arguments of the regions in `regionBranchOp` definitely
651 // need to be kept.
652 markLiveArgs(argsToKeep);
653 // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
654 // A live forwarded operand can be removed but no non-forwarded operand can be
655 // removed since it "controls" the flow of data in this control flow op.
656 markNonForwardedOperands(operandsToKeep);
657 // Similarly, the non-forwarded terminator operands of the regions in
658 // `regionBranchOp` definitely need to be kept.
659 markNonForwardedReturnValues(terminatorOperandsToKeep);
660
661 // Mark the values (results, arguments, operands, and terminator operands)
662 // that we want to keep.
663 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
664 terminatorOperandsToKeep);
665
666 // Do (1).
667 cl.operands.push_back(Elt: {.op: regionBranchOp, .nonLive: operandsToKeep.flip()});
668
669 // Do (2.a) and (2.b).
670 for (Region &region : regionBranchOp->getRegions()) {
671 if (region.empty())
672 continue;
673 BitVector argsToRemove = argsToKeep[&region].flip();
674 cl.blocks.push_back(Elt: {.b: &region.front(), .nonLiveArgs: argsToRemove});
675 collectNonLiveValues(nonLiveSet, range: region.front().getArguments(),
676 nonLive: argsToRemove);
677 }
678
679 // Do (2.c).
680 for (Region &region : regionBranchOp->getRegions()) {
681 if (region.empty())
682 continue;
683 Operation *terminator = region.front().getTerminator();
684 cl.operands.push_back(
685 Elt: {.op: terminator, .nonLive: terminatorOperandsToKeep[terminator].flip()});
686 }
687
688 // Do (3) and (4).
689 BitVector resultsToRemove = resultsToKeep.flip();
690 collectNonLiveValues(nonLiveSet, range: regionBranchOp.getOperation()->getResults(),
691 nonLive: resultsToRemove);
692 cl.results.push_back(Elt: {.op: regionBranchOp.getOperation(), .nonLive: resultsToRemove});
693}
694
695/// Steps to process a `BranchOpInterface` operation:
696/// Iterate through each successor block of `branchOp`.
697/// (1) For each successor block, gather all operands from all successors.
698/// (2) Fetch their associated liveness analysis data and collect for future
699/// removal.
700/// (3) Identify and collect the dead operands from the successor block
701/// as well as their corresponding arguments.
702
703static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
704 DenseSet<Value> &nonLiveSet,
705 RDVFinalCleanupList &cl) {
706 LDBG("Processing branch op: " << *branchOp);
707 unsigned numSuccessors = branchOp->getNumSuccessors();
708
709 for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
710 Block *successorBlock = branchOp->getSuccessor(index: succIdx);
711
712 // Do (1)
713 SuccessorOperands successorOperands =
714 branchOp.getSuccessorOperands(index: succIdx);
715 SmallVector<Value> operandValues;
716 for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
717 ++operandIdx) {
718 operandValues.push_back(Elt: successorOperands[operandIdx]);
719 }
720
721 // Do (2)
722 BitVector successorNonLive =
723 markLives(values: operandValues, nonLiveSet, la).flip();
724 collectNonLiveValues(nonLiveSet, range: successorBlock->getArguments(),
725 nonLive: successorNonLive);
726
727 // Do (3)
728 cl.blocks.push_back(Elt: {.b: successorBlock, .nonLiveArgs: successorNonLive});
729 cl.successorOperands.push_back(Elt: {.branch: branchOp, .successorIndex: succIdx, .nonLiveOperands: successorNonLive});
730 }
731}
732
733/// Removes dead values collected in RDVFinalCleanupList.
734/// To be run once when all dead values have been collected.
735static void cleanUpDeadVals(RDVFinalCleanupList &list) {
736 LLVM_DEBUG({ llvm::dbgs() << "Starting cleanup of dead values...\n"; });
737
738 // 1. Operations
739 LLVM_DEBUG({
740 llvm::dbgs() << "Cleaning up " << list.operations.size() << " operations"
741 << "\n";
742 });
743 for (auto &op : list.operations) {
744 LLVM_DEBUG({
745 llvm::dbgs() << "Erasing operation: "
746 << OpWithFlags(op, OpPrintingFlags().skipRegions()) << "\n";
747 });
748 op->dropAllUses();
749 op->erase();
750 }
751
752 // 2. Values
753 LLVM_DEBUG({
754 llvm::dbgs() << "Cleaning up " << list.values.size() << " values"
755 << "\n";
756 });
757 for (auto &v : list.values) {
758 LLVM_DEBUG(
759 { llvm::dbgs() << "Dropping all uses of value: " << v << "\n"; });
760 v.dropAllUses();
761 }
762
763 // 3. Functions
764 LLVM_DEBUG({
765 llvm::dbgs() << "Cleaning up " << list.functions.size() << " functions"
766 << "\n";
767 });
768 for (auto &f : list.functions) {
769 LLVM_DEBUG({
770 llvm::dbgs() << "Cleaning up function: "
771 << f.funcOp.getOperation()->getName() << "\n";
772 });
773 LLVM_DEBUG({
774 llvm::dbgs() << " Erasing " << f.nonLiveArgs.count()
775 << " non-live arguments"
776 << "\n";
777 });
778 LLVM_DEBUG({
779 llvm::dbgs() << " Erasing " << f.nonLiveRets.count()
780 << " non-live return values"
781 << "\n";
782 });
783 // Some functions may not allow erasing arguments or results. These calls
784 // return failure in such cases without modifying the function, so it's okay
785 // to proceed.
786 (void)f.funcOp.eraseArguments(argIndices: f.nonLiveArgs);
787 (void)f.funcOp.eraseResults(resultIndices: f.nonLiveRets);
788 }
789
790 // 4. Operands
791 LLVM_DEBUG({
792 llvm::dbgs() << "Cleaning up " << list.operands.size() << " operand lists"
793 << "\n";
794 });
795 for (OperationToCleanup &o : list.operands) {
796 if (o.op->getNumOperands() > 0) {
797 LLVM_DEBUG({
798 llvm::dbgs() << "Erasing " << o.nonLive.count()
799 << " non-live operands from operation: "
800 << OpWithFlags(o.op, OpPrintingFlags().skipRegions())
801 << "\n";
802 });
803 o.op->eraseOperands(eraseIndices: o.nonLive);
804 }
805 }
806
807 // 5. Results
808 LLVM_DEBUG({
809 llvm::dbgs() << "Cleaning up " << list.results.size() << " result lists"
810 << "\n";
811 });
812 for (auto &r : list.results) {
813 LLVM_DEBUG({
814 llvm::dbgs() << "Erasing " << r.nonLive.count()
815 << " non-live results from operation: "
816 << OpWithFlags(r.op, OpPrintingFlags().skipRegions())
817 << "\n";
818 });
819 dropUsesAndEraseResults(op: r.op, toErase: r.nonLive);
820 }
821
822 // 6. Blocks
823 LLVM_DEBUG({
824 llvm::dbgs() << "Cleaning up " << list.blocks.size()
825 << " block argument lists"
826 << "\n";
827 });
828 for (auto &b : list.blocks) {
829 // blocks that are accessed via multiple codepaths processed once
830 if (b.b->getNumArguments() != b.nonLiveArgs.size())
831 continue;
832 LLVM_DEBUG({
833 llvm::dbgs() << "Erasing " << b.nonLiveArgs.count()
834 << " non-live arguments from block: " << b.b << "\n";
835 });
836 // it iterates backwards because erase invalidates all successor indexes
837 for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
838 if (!b.nonLiveArgs[i])
839 continue;
840 LLVM_DEBUG({
841 llvm::dbgs() << " Erasing block argument " << i << ": "
842 << b.b->getArgument(i) << "\n";
843 });
844 b.b->getArgument(i).dropAllUses();
845 b.b->eraseArgument(index: i);
846 }
847 }
848
849 // 7. Successor Operands
850 LLVM_DEBUG({
851 llvm::dbgs() << "Cleaning up " << list.successorOperands.size()
852 << " successor operand lists"
853 << "\n";
854 });
855 for (auto &op : list.successorOperands) {
856 SuccessorOperands successorOperands =
857 op.branch.getSuccessorOperands(index: op.successorIndex);
858 // blocks that are accessed via multiple codepaths processed once
859 if (successorOperands.size() != op.nonLiveOperands.size())
860 continue;
861 LLVM_DEBUG({
862 llvm::dbgs() << "Erasing " << op.nonLiveOperands.count()
863 << " non-live successor operands from successor "
864 << op.successorIndex << " of branch: "
865 << OpWithFlags(op.branch, OpPrintingFlags().skipRegions())
866 << "\n";
867 });
868 // it iterates backwards because erase invalidates all successor indexes
869 for (int i = successorOperands.size() - 1; i >= 0; --i) {
870 if (!op.nonLiveOperands[i])
871 continue;
872 LLVM_DEBUG({
873 llvm::dbgs() << " Erasing successor operand " << i << ": "
874 << successorOperands[i] << "\n";
875 });
876 successorOperands.erase(subStart: i);
877 }
878 }
879
880 LLVM_DEBUG({
881 llvm::dbgs() << "Finished cleanup of dead values"
882 << "\n";
883 });
884}
885
886struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
887 void runOnOperation() override;
888};
889} // namespace
890
891void RemoveDeadValues::runOnOperation() {
892 auto &la = getAnalysis<RunLivenessAnalysis>();
893 Operation *module = getOperation();
894
895 // Tracks values eligible for erasure - complements liveness analysis to
896 // identify "droppable" values.
897 DenseSet<Value> deadVals;
898
899 // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
900 // end of this pass.
901 RDVFinalCleanupList finalCleanupList;
902
903 module->walk(callback: [&](Operation *op) {
904 if (auto funcOp = dyn_cast<FunctionOpInterface>(Val: op)) {
905 processFuncOp(funcOp, module, la, nonLiveSet&: deadVals, cl&: finalCleanupList);
906 } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(Val: op)) {
907 processRegionBranchOp(regionBranchOp, la, nonLiveSet&: deadVals, cl&: finalCleanupList);
908 } else if (auto branchOp = dyn_cast<BranchOpInterface>(Val: op)) {
909 processBranchOp(branchOp, la, nonLiveSet&: deadVals, cl&: finalCleanupList);
910 } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
911 // Nothing to do here because this is a terminator op and it should be
912 // honored with respect to its parent
913 } else if (isa<CallOpInterface>(Val: op)) {
914 // Nothing to do because this op is associated with a function op and gets
915 // cleaned when the latter is cleaned.
916 } else {
917 processSimpleOp(op, la, nonLiveSet&: deadVals, cl&: finalCleanupList);
918 }
919 });
920
921 cleanUpDeadVals(list&: finalCleanupList);
922}
923
924std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
925 return std::make_unique<RemoveDeadValues>();
926}
927

source code of mlir/lib/Transforms/RemoveDeadValues.cpp