1//===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/SROA.h"
10#include "mlir/Analysis/DataLayoutAnalysis.h"
11#include "mlir/Analysis/SliceAnalysis.h"
12#include "mlir/Interfaces/MemorySlotInterfaces.h"
13#include "mlir/Transforms/Passes.h"
14
15namespace mlir {
16#define GEN_PASS_DEF_SROA
17#include "mlir/Transforms/Passes.h.inc"
18} // namespace mlir
19
20#define DEBUG_TYPE "sroa"
21
22using namespace mlir;
23
24namespace {
25
26/// Information computed by destructurable memory slot analysis used to perform
27/// actual destructuring of the slot. This struct is only constructed if
28/// destructuring is possible, and contains the necessary data to perform it.
29struct MemorySlotDestructuringInfo {
30 /// Set of the indices that are actually used when accessing the subelements.
31 SmallPtrSet<Attribute, 8> usedIndices;
32 /// Blocking uses of a given user of the memory slot that must be eliminated.
33 DenseMap<Operation *, SmallPtrSet<OpOperand *, 4>> userToBlockingUses;
34 /// List of potentially indirect accessors of the memory slot that need
35 /// rewiring.
36 SmallVector<DestructurableAccessorOpInterface> accessors;
37};
38
39} // namespace
40
41/// Computes information for slot destructuring. This will compute whether this
42/// slot can be destructured and data to perform the destructuring. Returns
43/// nothing if the slot cannot be destructured or if there is no useful work to
44/// be done.
45static std::optional<MemorySlotDestructuringInfo>
46computeDestructuringInfo(DestructurableMemorySlot &slot,
47 const DataLayout &dataLayout) {
48 assert(isa<DestructurableTypeInterface>(slot.elemType));
49
50 if (slot.ptr.use_empty())
51 return {};
52
53 MemorySlotDestructuringInfo info;
54
55 SmallVector<MemorySlot> usedSafelyWorklist;
56
57 auto scheduleAsBlockingUse = [&](OpOperand &use) {
58 SmallPtrSetImpl<OpOperand *> &blockingUses =
59 info.userToBlockingUses.getOrInsertDefault(use.getOwner());
60 blockingUses.insert(Ptr: &use);
61 };
62
63 // Initialize the analysis with the immediate users of the slot.
64 for (OpOperand &use : slot.ptr.getUses()) {
65 if (auto accessor =
66 dyn_cast<DestructurableAccessorOpInterface>(use.getOwner())) {
67 if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist,
68 dataLayout)) {
69 info.accessors.push_back(accessor);
70 continue;
71 }
72 }
73
74 // If it cannot be shown that the operation uses the slot safely, maybe it
75 // can be promoted out of using the slot?
76 scheduleAsBlockingUse(use);
77 }
78
79 SmallPtrSet<OpOperand *, 16> visited;
80 while (!usedSafelyWorklist.empty()) {
81 MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val();
82 for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) {
83 if (!visited.insert(Ptr: &subslotUse).second)
84 continue;
85 Operation *subslotUser = subslotUse.getOwner();
86
87 if (auto memOp = dyn_cast<SafeMemorySlotAccessOpInterface>(subslotUser))
88 if (succeeded(memOp.ensureOnlySafeAccesses(
89 mustBeUsedSafely, usedSafelyWorklist, dataLayout)))
90 continue;
91
92 // If it cannot be shown that the operation uses the slot safely, maybe it
93 // can be promoted out of using the slot?
94 scheduleAsBlockingUse(subslotUse);
95 }
96 }
97
98 SetVector<Operation *> forwardSlice;
99 mlir::getForwardSlice(root: slot.ptr, forwardSlice: &forwardSlice);
100 for (Operation *user : forwardSlice) {
101 // If the next operation has no blocking uses, everything is fine.
102 if (!info.userToBlockingUses.contains(user))
103 continue;
104
105 SmallPtrSet<OpOperand *, 4> &blockingUses = info.userToBlockingUses[user];
106 auto promotable = dyn_cast<PromotableOpInterface>(user);
107
108 // An operation that has blocking uses must be promoted. If it is not
109 // promotable, destructuring must fail.
110 if (!promotable)
111 return {};
112
113 SmallVector<OpOperand *> newBlockingUses;
114 // If the operation decides it cannot deal with removing the blocking uses,
115 // destructuring must fail.
116 if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses, dataLayout))
117 return {};
118
119 // Then, register any new blocking uses for coming operations.
120 for (OpOperand *blockingUse : newBlockingUses) {
121 assert(llvm::is_contained(user->getResults(), blockingUse->get()));
122
123 SmallPtrSetImpl<OpOperand *> &newUserBlockingUseSet =
124 info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner());
125 newUserBlockingUseSet.insert(Ptr: blockingUse);
126 }
127 }
128
129 return info;
130}
131
132/// Performs the destructuring of a destructible slot given associated
133/// destructuring information. The provided slot will be destructured in
134/// subslots as specified by its allocator.
135static void destructureSlot(DestructurableMemorySlot &slot,
136 DestructurableAllocationOpInterface allocator,
137 RewriterBase &rewriter,
138 const DataLayout &dataLayout,
139 MemorySlotDestructuringInfo &info,
140 const SROAStatistics &statistics) {
141 RewriterBase::InsertionGuard guard(rewriter);
142
143 rewriter.setInsertionPointToStart(slot.ptr.getParentBlock());
144 DenseMap<Attribute, MemorySlot> subslots =
145 allocator.destructure(slot, info.usedIndices, rewriter);
146
147 if (statistics.slotsWithMemoryBenefit &&
148 slot.elementPtrs.size() != info.usedIndices.size())
149 (*statistics.slotsWithMemoryBenefit)++;
150
151 if (statistics.maxSubelementAmount)
152 statistics.maxSubelementAmount->updateMax(V: slot.elementPtrs.size());
153
154 SetVector<Operation *> usersToRewire;
155 for (Operation *user : llvm::make_first_range(info.userToBlockingUses))
156 usersToRewire.insert(user);
157 for (DestructurableAccessorOpInterface accessor : info.accessors)
158 usersToRewire.insert(accessor);
159 usersToRewire = mlir::topologicalSort(toSort: usersToRewire);
160
161 llvm::SmallVector<Operation *> toErase;
162 for (Operation *toRewire : llvm::reverse(C&: usersToRewire)) {
163 rewriter.setInsertionPointAfter(toRewire);
164 if (auto accessor = dyn_cast<DestructurableAccessorOpInterface>(toRewire)) {
165 if (accessor.rewire(slot, subslots, rewriter, dataLayout) ==
166 DeletionKind::Delete)
167 toErase.push_back(Elt: accessor);
168 continue;
169 }
170
171 auto promotable = cast<PromotableOpInterface>(toRewire);
172 if (promotable.removeBlockingUses(info.userToBlockingUses[promotable],
173 rewriter) == DeletionKind::Delete)
174 toErase.push_back(Elt: promotable);
175 }
176
177 for (Operation *toEraseOp : toErase)
178 rewriter.eraseOp(op: toEraseOp);
179
180 assert(slot.ptr.use_empty() && "after destructuring, the original slot "
181 "pointer should no longer be used");
182
183 LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr
184 << "\n");
185
186 if (statistics.destructuredAmount)
187 (*statistics.destructuredAmount)++;
188
189 allocator.handleDestructuringComplete(slot, rewriter);
190}
191
192LogicalResult mlir::tryToDestructureMemorySlots(
193 ArrayRef<DestructurableAllocationOpInterface> allocators,
194 RewriterBase &rewriter, const DataLayout &dataLayout,
195 SROAStatistics statistics) {
196 bool destructuredAny = false;
197
198 for (DestructurableAllocationOpInterface allocator : allocators) {
199 for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) {
200 std::optional<MemorySlotDestructuringInfo> info =
201 computeDestructuringInfo(slot, dataLayout);
202 if (!info)
203 continue;
204
205 destructureSlot(slot, allocator, rewriter, dataLayout, *info, statistics);
206 destructuredAny = true;
207 }
208 }
209
210 return success(isSuccess: destructuredAny);
211}
212
213namespace {
214
215struct SROA : public impl::SROABase<SROA> {
216 using impl::SROABase<SROA>::SROABase;
217
218 void runOnOperation() override {
219 Operation *scopeOp = getOperation();
220
221 SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
222 &maxSubelementAmount};
223
224 auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
225 const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(scopeOp);
226 bool changed = false;
227
228 for (Region &region : scopeOp->getRegions()) {
229 if (region.getBlocks().empty())
230 continue;
231
232 OpBuilder builder(&region.front(), region.front().begin());
233 IRRewriter rewriter(builder);
234
235 // Destructuring a slot can allow for further destructuring of other
236 // slots, destructuring is tried until no destructuring succeeds.
237 while (true) {
238 SmallVector<DestructurableAllocationOpInterface> allocators;
239 // Build a list of allocators to attempt to destructure the slots of.
240 // TODO: Update list on the fly to avoid repeated visiting of the same
241 // allocators.
242 region.walk([&](DestructurableAllocationOpInterface allocator) {
243 allocators.emplace_back(allocator);
244 });
245
246 if (failed(tryToDestructureMemorySlots(allocators, rewriter, dataLayout,
247 statistics)))
248 break;
249
250 changed = true;
251 }
252 }
253 if (!changed)
254 markAllAnalysesPreserved();
255 }
256};
257
258} // namespace
259

source code of mlir/lib/Transforms/SROA.cpp