1//===- RemoveDeadValues.cpp - Remove Dead Values --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// The goal of this pass is optimization (reducing runtime) by removing
10// unnecessary instructions. Unlike other passes that rely on local information
11// gathered from patterns to accomplish optimization, this pass uses a full
12// analysis of the IR, specifically, liveness analysis, and is thus more
13// powerful.
14//
15// Currently, this pass performs the following optimizations:
16// (A) Removes function arguments that are not live,
17// (B) Removes function return values that are not live across all callers of
18// the function,
19// (C) Removes unneccesary operands, results, region arguments, and region
20// terminator operands of region branch ops, and,
21// (D) Removes simple and region branch ops that have all non-live results and
22// don't affect memory in any way,
23//
24// iff
25//
26// the IR doesn't have any non-function symbol ops, non-call symbol user ops and
27// branch ops.
28//
29// Here, a "simple op" refers to an op that isn't a symbol op, symbol-user op,
30// region branch op, branch op, region branch terminator op, or return-like.
31//
32//===----------------------------------------------------------------------===//
33
34#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
35#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
36#include "mlir/IR/Attributes.h"
37#include "mlir/IR/Builders.h"
38#include "mlir/IR/BuiltinAttributes.h"
39#include "mlir/IR/Dialect.h"
40#include "mlir/IR/IRMapping.h"
41#include "mlir/IR/OperationSupport.h"
42#include "mlir/IR/SymbolTable.h"
43#include "mlir/IR/Value.h"
44#include "mlir/IR/ValueRange.h"
45#include "mlir/IR/Visitors.h"
46#include "mlir/Interfaces/CallInterfaces.h"
47#include "mlir/Interfaces/ControlFlowInterfaces.h"
48#include "mlir/Interfaces/FunctionInterfaces.h"
49#include "mlir/Interfaces/SideEffectInterfaces.h"
50#include "mlir/Pass/Pass.h"
51#include "mlir/Support/LLVM.h"
52#include "mlir/Transforms/FoldUtils.h"
53#include "mlir/Transforms/Passes.h"
54#include "llvm/ADT/STLExtras.h"
55#include <cassert>
56#include <cstddef>
57#include <memory>
58#include <optional>
59#include <vector>
60
61namespace mlir {
62#define GEN_PASS_DEF_REMOVEDEADVALUES
63#include "mlir/Transforms/Passes.h.inc"
64} // namespace mlir
65
66using namespace mlir;
67using namespace mlir::dataflow;
68
69//===----------------------------------------------------------------------===//
70// RemoveDeadValues Pass
71//===----------------------------------------------------------------------===//
72
73namespace {
74
75// Set of structures below to be filled with operations and arguments to erase.
76// This is done to separate analysis and tree modification phases,
77// otherwise analysis is operating on half-deleted tree which is incorrect.
78
79struct FunctionToCleanUp {
80 FunctionOpInterface funcOp;
81 BitVector nonLiveArgs;
82 BitVector nonLiveRets;
83};
84
85struct OperationToCleanup {
86 Operation *op;
87 BitVector nonLive;
88};
89
90struct BlockArgsToCleanup {
91 Block *b;
92 BitVector nonLiveArgs;
93};
94
95struct SuccessorOperandsToCleanup {
96 BranchOpInterface branch;
97 unsigned successorIndex;
98 BitVector nonLiveOperands;
99};
100
101struct RDVFinalCleanupList {
102 SmallVector<Operation *> operations;
103 SmallVector<Value> values;
104 SmallVector<FunctionToCleanUp> functions;
105 SmallVector<OperationToCleanup> operands;
106 SmallVector<OperationToCleanup> results;
107 SmallVector<BlockArgsToCleanup> blocks;
108 SmallVector<SuccessorOperandsToCleanup> successorOperands;
109};
110
111// Some helper functions...
112
113/// Return true iff at least one value in `values` is live, given the liveness
114/// information in `la`.
115static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
116 RunLivenessAnalysis &la) {
117 for (Value value : values) {
118 if (nonLiveSet.contains(value))
119 continue;
120
121 const Liveness *liveness = la.getLiveness(val: value);
122 if (!liveness || liveness->isLive)
123 return true;
124 }
125 return false;
126}
127
128/// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
129/// i-th value in `values` is live, given the liveness information in `la`.
130static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
131 RunLivenessAnalysis &la) {
132 BitVector lives(values.size(), true);
133
134 for (auto [index, value] : llvm::enumerate(values)) {
135 if (nonLiveSet.contains(value)) {
136 lives.reset(index);
137 continue;
138 }
139
140 const Liveness *liveness = la.getLiveness(value);
141 // It is important to note that when `liveness` is null, we can't tell if
142 // `value` is live or not. So, the safe option is to consider it live. Also,
143 // the execution of this pass might create new SSA values when erasing some
144 // of the results of an op and we know that these new values are live
145 // (because they weren't erased) and also their liveness is null because
146 // liveness analysis ran before their creation.
147 if (liveness && !liveness->isLive)
148 lives.reset(index);
149 }
150
151 return lives;
152}
153
154/// Collects values marked as "non-live" in the provided range and inserts them
155/// into the nonLiveSet. A value is considered "non-live" if the corresponding
156/// index in the `nonLive` bit vector is set.
157static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range,
158 const BitVector &nonLive) {
159 for (auto [index, result] : llvm::enumerate(range)) {
160 if (!nonLive[index])
161 continue;
162 nonLiveSet.insert(result);
163 }
164}
165
166/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
167/// is 1.
168static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
169 assert(op->getNumResults() == toErase.size() &&
170 "expected the number of results in `op` and the size of `toErase` to "
171 "be the same");
172
173 std::vector<Type> newResultTypes;
174 for (OpResult result : op->getResults())
175 if (!toErase[result.getResultNumber()])
176 newResultTypes.push_back(result.getType());
177 OpBuilder builder(op);
178 builder.setInsertionPointAfter(op);
179 OperationState state(op->getLoc(), op->getName().getStringRef(),
180 op->getOperands(), newResultTypes, op->getAttrs());
181 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i)
182 state.addRegion();
183 Operation *newOp = builder.create(state);
184 for (const auto &[index, region] : llvm::enumerate(op->getRegions())) {
185 Region &newRegion = newOp->getRegion(index);
186 // Move all blocks of `region` into `newRegion`.
187 Block *temp = new Block();
188 newRegion.push_back(temp);
189 while (!region.empty())
190 region.front().moveBefore(temp);
191 temp->erase();
192 }
193
194 unsigned indexOfNextNewCallOpResultToReplace = 0;
195 for (auto [index, result] : llvm::enumerate(op->getResults())) {
196 assert(result && "expected result to be non-null");
197 if (toErase[index]) {
198 result.dropAllUses();
199 } else {
200 result.replaceAllUsesWith(
201 newOp->getResult(indexOfNextNewCallOpResultToReplace++));
202 }
203 }
204 op->erase();
205}
206
207/// Convert a list of `Operand`s to a list of `OpOperand`s.
208static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
209 OpOperand *values = operands.getBase();
210 SmallVector<OpOperand *> opOperands;
211 for (unsigned i = 0, e = operands.size(); i < e; i++)
212 opOperands.push_back(&values[i]);
213 return opOperands;
214}
215
216/// Process a simple operation `op` using the liveness analysis `la`.
217/// If the operation has no memory effects and none of its results are live:
218/// 1. Add the operation to a list for future removal, and
219/// 2. Mark all its results as non-live values
220///
221/// The operation `op` is assumed to be simple. A simple operation is one that
222/// is NOT:
223/// - Function-like
224/// - Call-like
225/// - A region branch operation
226/// - A branch operation
227/// - A region branch terminator
228/// - Return-like
229static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
230 DenseSet<Value> &nonLiveSet,
231 RDVFinalCleanupList &cl) {
232 if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la))
233 return;
234
235 cl.operations.push_back(op);
236 collectNonLiveValues(nonLiveSet, op->getResults(),
237 BitVector(op->getNumResults(), true));
238}
239
240/// Process a function-like operation `funcOp` using the liveness analysis `la`
241/// and the IR in `module`. If it is not public or external:
242/// (1) Adding its non-live arguments to a list for future removal.
243/// (2) Marking their corresponding operands in its callers for removal.
244/// (3) Identifying and enqueueing unnecessary terminator operands
245/// (return values that are non-live across all callers) for removal.
246/// (4) Enqueueing the non-live arguments and return values for removal.
247/// (5) Collecting the uses of these return values in its callers for future
248/// removal.
249/// (6) Marking all its results as non-live values.
250static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
251 RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet,
252 RDVFinalCleanupList &cl) {
253 if (funcOp.isPublic() || funcOp.isExternal())
254 return;
255
256 // Get the list of unnecessary (non-live) arguments in `nonLiveArgs`.
257 SmallVector<Value> arguments(funcOp.getArguments());
258 BitVector nonLiveArgs = markLives(arguments, nonLiveSet, la);
259 nonLiveArgs = nonLiveArgs.flip();
260
261 // Do (1).
262 for (auto [index, arg] : llvm::enumerate(arguments))
263 if (arg && nonLiveArgs[index]) {
264 cl.values.push_back(arg);
265 nonLiveSet.insert(arg);
266 }
267
268 // Do (2).
269 SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
270 for (SymbolTable::SymbolUse use : uses) {
271 Operation *callOp = use.getUser();
272 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
273 // The number of operands in the call op may not match the number of
274 // arguments in the func op.
275 BitVector nonLiveCallOperands(callOp->getNumOperands(), false);
276 SmallVector<OpOperand *> callOpOperands =
277 operandsToOpOperands(cast<CallOpInterface>(callOp).getArgOperands());
278 for (int index : nonLiveArgs.set_bits())
279 nonLiveCallOperands.set(callOpOperands[index]->getOperandNumber());
280 cl.operands.push_back({callOp, nonLiveCallOperands});
281 }
282
283 // Do (3).
284 // Get the list of unnecessary terminator operands (return values that are
285 // non-live across all callers) in `nonLiveRets`. There is a very important
286 // subtlety here. Unnecessary terminator operands are NOT the operands of the
287 // terminator that are non-live. Instead, these are the return values of the
288 // callers such that a given return value is non-live across all callers. Such
289 // corresponding operands in the terminator could be live. An example to
290 // demonstrate this:
291 // func.func private @f(%arg0: memref<i32>) -> (i32, i32) {
292 // %c0_i32 = arith.constant 0 : i32
293 // %0 = arith.addi %c0_i32, %c0_i32 : i32
294 // memref.store %0, %arg0[] : memref<i32>
295 // return %c0_i32, %0 : i32, i32
296 // }
297 // func.func @main(%arg0: i32, %arg1: memref<i32>) -> (i32) {
298 // %1:2 = call @f(%arg1) : (memref<i32>) -> i32
299 // return %1#0 : i32
300 // }
301 // Here, we can see that %1#1 is never used. It is non-live. Thus, @f doesn't
302 // need to return %0. But, %0 is live. And, still, we want to stop it from
303 // being returned, in order to optimize our IR. So, this demonstrates how we
304 // can make our optimization strong by even removing a live return value (%0),
305 // since it forwards only to non-live value(s) (%1#1).
306 Operation *lastReturnOp = funcOp.back().getTerminator();
307 size_t numReturns = lastReturnOp->getNumOperands();
308 if (numReturns == 0)
309 return;
310 BitVector nonLiveRets(numReturns, true);
311 for (SymbolTable::SymbolUse use : uses) {
312 Operation *callOp = use.getUser();
313 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
314 BitVector liveCallRets = markLives(callOp->getResults(), nonLiveSet, la);
315 nonLiveRets &= liveCallRets.flip();
316 }
317
318 // Note that in the absence of control flow ops forcing the control to go from
319 // the entry (first) block to the other blocks, the control never reaches any
320 // block other than the entry block, because every block has a terminator.
321 for (Block &block : funcOp.getBlocks()) {
322 Operation *returnOp = block.getTerminator();
323 if (returnOp && returnOp->getNumOperands() == numReturns)
324 cl.operands.push_back({returnOp, nonLiveRets});
325 }
326
327 // Do (4).
328 cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets});
329
330 // Do (5) and (6).
331 for (SymbolTable::SymbolUse use : uses) {
332 Operation *callOp = use.getUser();
333 assert(isa<CallOpInterface>(callOp) && "expected a call-like user");
334 cl.results.push_back({callOp, nonLiveRets});
335 collectNonLiveValues(nonLiveSet, callOp->getResults(), nonLiveRets);
336 }
337}
338
339/// Process a region branch operation `regionBranchOp` using the liveness
340/// information in `la`. The processing involves two scenarios:
341///
342/// Scenario 1: If the operation has no memory effects and none of its results
343/// are live:
344/// (1') Enqueue all its uses for deletion.
345/// (2') Enqueue the branch itself for deletion.
346///
347/// Scenario 2: Otherwise:
348/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
349/// results or arguments).
350/// (2) Process each of its regions.
351/// (3) Collect the uses of its unnecessary results (results forwarded from
352/// unnecessary operands
353/// or terminator operands).
354/// (4) Add these results to the deletion list.
355///
356/// Processing a region includes:
357/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
358/// from unnecessary operands
359/// or terminator operands).
360/// (b) Collecting these unnecessary arguments.
361/// (c) Collecting its unnecessary terminator operands (terminator operands
362/// forwarded to unnecessary results
363/// or arguments).
364///
365/// Value Flow Note: In this operation, values flow as follows:
366/// - From operands and terminator operands (successor operands)
367/// - To arguments and results (successor inputs).
368static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
369 RunLivenessAnalysis &la,
370 DenseSet<Value> &nonLiveSet,
371 RDVFinalCleanupList &cl) {
372 // Mark live results of `regionBranchOp` in `liveResults`.
373 auto markLiveResults = [&](BitVector &liveResults) {
374 liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
375 };
376
377 // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
378 auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
379 for (Region &region : regionBranchOp->getRegions()) {
380 if (region.empty())
381 continue;
382 SmallVector<Value> arguments(region.front().getArguments());
383 BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
384 liveArgs[&region] = regionLiveArgs;
385 }
386 };
387
388 // Return the successors of `region` if the latter is not null. Else return
389 // the successors of `regionBranchOp`.
390 auto getSuccessors = [&](Region *region = nullptr) {
391 auto point = region ? region : RegionBranchPoint::parent();
392 SmallVector<RegionSuccessor> successors;
393 regionBranchOp.getSuccessorRegions(point, successors);
394 return successors;
395 };
396
397 // Return the operands of `terminator` that are forwarded to `successor` if
398 // the former is not null. Else return the operands of `regionBranchOp`
399 // forwarded to `successor`.
400 auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
401 Operation *terminator = nullptr) {
402 OperandRange operands =
403 terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
404 .getSuccessorOperands(successor)
405 : regionBranchOp.getEntrySuccessorOperands(successor);
406 SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
407 return opOperands;
408 };
409
410 // Mark the non-forwarded operands of `regionBranchOp` in
411 // `nonForwardedOperands`.
412 auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
413 nonForwardedOperands.resize(N: regionBranchOp->getNumOperands(), t: true);
414 for (const RegionSuccessor &successor : getSuccessors()) {
415 for (OpOperand *opOperand : getForwardedOpOperands(successor))
416 nonForwardedOperands.reset(opOperand->getOperandNumber());
417 }
418 };
419
420 // Mark the non-forwarded terminator operands of the various regions of
421 // `regionBranchOp` in `nonForwardedRets`.
422 auto markNonForwardedReturnValues =
423 [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
424 for (Region &region : regionBranchOp->getRegions()) {
425 if (region.empty())
426 continue;
427 Operation *terminator = region.front().getTerminator();
428 nonForwardedRets[terminator] =
429 BitVector(terminator->getNumOperands(), true);
430 for (const RegionSuccessor &successor : getSuccessors(&region)) {
431 for (OpOperand *opOperand :
432 getForwardedOpOperands(successor, terminator))
433 nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
434 }
435 }
436 };
437
438 // Update `valuesToKeep` (which is expected to correspond to operands or
439 // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
440 // `region`. When `valuesToKeep` correspond to operands, `region` is null.
441 // Else, `region` is the parent region of the terminator.
442 auto updateOperandsOrTerminatorOperandsToKeep =
443 [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
444 DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) {
445 Operation *terminator =
446 region ? region->front().getTerminator() : nullptr;
447
448 for (const RegionSuccessor &successor : getSuccessors(region)) {
449 Region *successorRegion = successor.getSuccessor();
450 for (auto [opOperand, input] :
451 llvm::zip(getForwardedOpOperands(successor, terminator),
452 successor.getSuccessorInputs())) {
453 size_t operandNum = opOperand->getOperandNumber();
454 bool updateBasedOn =
455 successorRegion
456 ? argsToKeep[successorRegion]
457 [cast<BlockArgument>(input).getArgNumber()]
458 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
459 valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn;
460 }
461 }
462 };
463
464 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
465 // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a
466 // value is modified, else, false.
467 auto recomputeResultsAndArgsToKeep =
468 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
469 BitVector &operandsToKeep,
470 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
471 bool &resultsOrArgsToKeepChanged) {
472 resultsOrArgsToKeepChanged = false;
473
474 // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`.
475 for (const RegionSuccessor &successor : getSuccessors()) {
476 Region *successorRegion = successor.getSuccessor();
477 for (auto [opOperand, input] :
478 llvm::zip(getForwardedOpOperands(successor),
479 successor.getSuccessorInputs())) {
480 bool recomputeBasedOn =
481 operandsToKeep[opOperand->getOperandNumber()];
482 bool toRecompute =
483 successorRegion
484 ? argsToKeep[successorRegion]
485 [cast<BlockArgument>(input).getArgNumber()]
486 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
487 if (!toRecompute && recomputeBasedOn)
488 resultsOrArgsToKeepChanged = true;
489 if (successorRegion) {
490 argsToKeep[successorRegion][cast<BlockArgument>(input)
491 .getArgNumber()] =
492 argsToKeep[successorRegion]
493 [cast<BlockArgument>(input).getArgNumber()] |
494 recomputeBasedOn;
495 } else {
496 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
497 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
498 recomputeBasedOn;
499 }
500 }
501 }
502
503 // Recompute `resultsToKeep` and `argsToKeep` based on
504 // `terminatorOperandsToKeep`.
505 for (Region &region : regionBranchOp->getRegions()) {
506 if (region.empty())
507 continue;
508 Operation *terminator = region.front().getTerminator();
509 for (const RegionSuccessor &successor : getSuccessors(&region)) {
510 Region *successorRegion = successor.getSuccessor();
511 for (auto [opOperand, input] :
512 llvm::zip(getForwardedOpOperands(successor, terminator),
513 successor.getSuccessorInputs())) {
514 bool recomputeBasedOn =
515 terminatorOperandsToKeep[region.back().getTerminator()]
516 [opOperand->getOperandNumber()];
517 bool toRecompute =
518 successorRegion
519 ? argsToKeep[successorRegion]
520 [cast<BlockArgument>(input).getArgNumber()]
521 : resultsToKeep[cast<OpResult>(input).getResultNumber()];
522 if (!toRecompute && recomputeBasedOn)
523 resultsOrArgsToKeepChanged = true;
524 if (successorRegion) {
525 argsToKeep[successorRegion][cast<BlockArgument>(input)
526 .getArgNumber()] =
527 argsToKeep[successorRegion]
528 [cast<BlockArgument>(input).getArgNumber()] |
529 recomputeBasedOn;
530 } else {
531 resultsToKeep[cast<OpResult>(input).getResultNumber()] =
532 resultsToKeep[cast<OpResult>(input).getResultNumber()] |
533 recomputeBasedOn;
534 }
535 }
536 }
537 }
538 };
539
540 // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
541 // `operandsToKeep`, and `terminatorOperandsToKeep`.
542 auto markValuesToKeep =
543 [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
544 BitVector &operandsToKeep,
545 DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
546 bool resultsOrArgsToKeepChanged = true;
547 // We keep updating and recomputing the values until we reach a point
548 // where they stop changing.
549 while (resultsOrArgsToKeepChanged) {
550 // Update the operands that need to be kept.
551 updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
552 resultsToKeep, argsToKeep);
553
554 // Update the terminator operands that need to be kept.
555 for (Region &region : regionBranchOp->getRegions()) {
556 if (region.empty())
557 continue;
558 updateOperandsOrTerminatorOperandsToKeep(
559 terminatorOperandsToKeep[region.back().getTerminator()],
560 resultsToKeep, argsToKeep, &region);
561 }
562
563 // Recompute the results and arguments that need to be kept.
564 recomputeResultsAndArgsToKeep(
565 resultsToKeep, argsToKeep, operandsToKeep,
566 terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
567 }
568 };
569
570 // Scenario 1. This is the only case where the entire `regionBranchOp`
571 // is removed. It will not happen in any other scenario. Note that in this
572 // case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
573 // It could never be live because of this op but its liveness could have been
574 // attributed to something else.
575 // Do (1') and (2').
576 if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
577 !hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
578 cl.operations.push_back(regionBranchOp.getOperation());
579 return;
580 }
581
582 // Scenario 2.
583 // At this point, we know that every non-forwarded operand of `regionBranchOp`
584 // is live.
585
586 // Stores the results of `regionBranchOp` that we want to keep.
587 BitVector resultsToKeep;
588 // Stores the mapping from regions of `regionBranchOp` to their arguments that
589 // we want to keep.
590 DenseMap<Region *, BitVector> argsToKeep;
591 // Stores the operands of `regionBranchOp` that we want to keep.
592 BitVector operandsToKeep;
593 // Stores the mapping from region terminators in `regionBranchOp` to their
594 // operands that we want to keep.
595 DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
596
597 // Initializing the above variables...
598
599 // The live results of `regionBranchOp` definitely need to be kept.
600 markLiveResults(resultsToKeep);
601 // Similarly, the live arguments of the regions in `regionBranchOp` definitely
602 // need to be kept.
603 markLiveArgs(argsToKeep);
604 // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
605 // A live forwarded operand can be removed but no non-forwarded operand can be
606 // removed since it "controls" the flow of data in this control flow op.
607 markNonForwardedOperands(operandsToKeep);
608 // Similarly, the non-forwarded terminator operands of the regions in
609 // `regionBranchOp` definitely need to be kept.
610 markNonForwardedReturnValues(terminatorOperandsToKeep);
611
612 // Mark the values (results, arguments, operands, and terminator operands)
613 // that we want to keep.
614 markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
615 terminatorOperandsToKeep);
616
617 // Do (1).
618 cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
619
620 // Do (2.a) and (2.b).
621 for (Region &region : regionBranchOp->getRegions()) {
622 if (region.empty())
623 continue;
624 BitVector argsToRemove = argsToKeep[&region].flip();
625 cl.blocks.push_back({&region.front(), argsToRemove});
626 collectNonLiveValues(nonLiveSet, region.front().getArguments(),
627 argsToRemove);
628 }
629
630 // Do (2.c).
631 for (Region &region : regionBranchOp->getRegions()) {
632 if (region.empty())
633 continue;
634 Operation *terminator = region.front().getTerminator();
635 cl.operands.push_back(
636 {terminator, terminatorOperandsToKeep[terminator].flip()});
637 }
638
639 // Do (3) and (4).
640 BitVector resultsToRemove = resultsToKeep.flip();
641 collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
642 resultsToRemove);
643 cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
644}
645
646/// Steps to process a `BranchOpInterface` operation:
647/// Iterate through each successor block of `branchOp`.
648/// (1) For each successor block, gather all operands from all successors.
649/// (2) Fetch their associated liveness analysis data and collect for future
650/// removal.
651/// (3) Identify and collect the dead operands from the successor block
652/// as well as their corresponding arguments.
653
654static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
655 DenseSet<Value> &nonLiveSet,
656 RDVFinalCleanupList &cl) {
657 unsigned numSuccessors = branchOp->getNumSuccessors();
658
659 for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
660 Block *successorBlock = branchOp->getSuccessor(succIdx);
661
662 // Do (1)
663 SuccessorOperands successorOperands =
664 branchOp.getSuccessorOperands(succIdx);
665 SmallVector<Value> operandValues;
666 for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
667 ++operandIdx) {
668 operandValues.push_back(successorOperands[operandIdx]);
669 }
670
671 // Do (2)
672 BitVector successorNonLive =
673 markLives(operandValues, nonLiveSet, la).flip();
674 collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
675 successorNonLive);
676
677 // Do (3)
678 cl.blocks.push_back({successorBlock, successorNonLive});
679 cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
680 }
681}
682
683/// Removes dead values collected in RDVFinalCleanupList.
684/// To be run once when all dead values have been collected.
685static void cleanUpDeadVals(RDVFinalCleanupList &list) {
686 // 1. Operations
687 for (auto &op : list.operations) {
688 op->dropAllUses();
689 op->erase();
690 }
691
692 // 2. Values
693 for (auto &v : list.values) {
694 v.dropAllUses();
695 }
696
697 // 3. Functions
698 for (auto &f : list.functions) {
699 // Some functions may not allow erasing arguments or results. These calls
700 // return failure in such cases without modifying the function, so it's okay
701 // to proceed.
702 (void)f.funcOp.eraseArguments(f.nonLiveArgs);
703 (void)f.funcOp.eraseResults(f.nonLiveRets);
704 }
705
706 // 4. Operands
707 for (auto &o : list.operands) {
708 o.op->eraseOperands(o.nonLive);
709 }
710
711 // 5. Results
712 for (auto &r : list.results) {
713 dropUsesAndEraseResults(r.op, r.nonLive);
714 }
715
716 // 6. Blocks
717 for (auto &b : list.blocks) {
718 // blocks that are accessed via multiple codepaths processed once
719 if (b.b->getNumArguments() != b.nonLiveArgs.size())
720 continue;
721 // it iterates backwards because erase invalidates all successor indexes
722 for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
723 if (!b.nonLiveArgs[i])
724 continue;
725 b.b->getArgument(i).dropAllUses();
726 b.b->eraseArgument(i);
727 }
728 }
729
730 // 7. Successor Operands
731 for (auto &op : list.successorOperands) {
732 SuccessorOperands successorOperands =
733 op.branch.getSuccessorOperands(op.successorIndex);
734 // blocks that are accessed via multiple codepaths processed once
735 if (successorOperands.size() != op.nonLiveOperands.size())
736 continue;
737 // it iterates backwards because erase invalidates all successor indexes
738 for (int i = successorOperands.size() - 1; i >= 0; --i) {
739 if (!op.nonLiveOperands[i])
740 continue;
741 successorOperands.erase(i);
742 }
743 }
744}
745
746struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
747 void runOnOperation() override;
748};
749} // namespace
750
751void RemoveDeadValues::runOnOperation() {
752 auto &la = getAnalysis<RunLivenessAnalysis>();
753 Operation *module = getOperation();
754
755 // Tracks values eligible for erasure - complements liveness analysis to
756 // identify "droppable" values.
757 DenseSet<Value> deadVals;
758
759 // Maintains a list of Ops, values, branches, etc., slated for cleanup at the
760 // end of this pass.
761 RDVFinalCleanupList finalCleanupList;
762
763 module->walk(callback: [&](Operation *op) {
764 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
765 processFuncOp(funcOp, module, la, deadVals, finalCleanupList);
766 } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
767 processRegionBranchOp(regionBranchOp, la, deadVals, finalCleanupList);
768 } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
769 processBranchOp(branchOp, la, deadVals, finalCleanupList);
770 } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) {
771 // Nothing to do here because this is a terminator op and it should be
772 // honored with respect to its parent
773 } else if (isa<CallOpInterface>(Val: op)) {
774 // Nothing to do because this op is associated with a function op and gets
775 // cleaned when the latter is cleaned.
776 } else {
777 processSimpleOp(op, la, deadVals, finalCleanupList);
778 }
779 });
780
781 cleanUpDeadVals(list&: finalCleanupList);
782}
783
784std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
785 return std::make_unique<RemoveDeadValues>();
786}
787

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Transforms/RemoveDeadValues.cpp