1/*
2 * Copyright 2021 WebAssembly Community Group participants
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//
18// Apply more specific subtypes to signature/function types where possible.
19//
20// This differs from DeadArgumentElimination's refineArgumentTypes() etc. in
21// that DAE will modify the type of a function. It can only do that if the
22// function's type is not observable, which means it is not taken by reference.
23// On the other hand, this pass will modify the signature types themselves,
24// which means it can optimize functions whose reference is taken, and it does
25// so while considering all users of the type (across all functions sharing that
26// type, and all call_refs using it).
27//
28
29#include "ir/export-utils.h"
30#include "ir/find_all.h"
31#include "ir/lubs.h"
32#include "ir/module-utils.h"
33#include "ir/subtypes.h"
34#include "ir/type-updating.h"
35#include "ir/utils.h"
36#include "pass.h"
37#include "wasm-type.h"
38#include "wasm.h"
39
40namespace wasm {
41
42namespace {
43
44struct SignatureRefining : public Pass {
45 // Only changes heap types and parameter types (but not locals).
46 bool requiresNonNullableLocalFixups() override { return false; }
47
48 // Maps each heap type to the possible refinement of the types in their
49 // signatures. We will fill this during analysis and then use it while doing
50 // an update of the types. If a type has no improvement that we can find, it
51 // will not appear in this map.
52 std::unordered_map<HeapType, Signature> newSignatures;
53
54 void run(Module* module) override {
55 if (!module->features.hasGC()) {
56 return;
57 }
58
59 if (!module->tables.empty()) {
60 // When there are tables we must also take their types into account, which
61 // would require us to take call_indirect, element segments, etc. into
62 // account. For now, do nothing if there are tables.
63 // TODO
64 return;
65 }
66
67 // First, find all the information we need. Start by collecting inside each
68 // function in parallel.
69
70 struct Info {
71 // The calls and call_refs.
72 std::vector<Call*> calls;
73 std::vector<CallRef*> callRefs;
74
75 // A possibly improved LUB for the results.
76 LUBFinder resultsLUB;
77
78 // Normally we can optimize, but some cases prevent a particular signature
79 // type from being changed at all, see below.
80 bool canModify = true;
81 };
82
83 // This analysis also modifies the wasm as it goes, as the getResultsLUB()
84 // operation has side effects (see comment on header declaration).
85 ModuleUtils::ParallelFunctionAnalysis<Info, Mutable> analysis(
86 *module, [&](Function* func, Info& info) {
87 if (func->imported()) {
88 // Avoid changing the types of imported functions. Spec and VM support
89 // for that is not yet stable.
90 // TODO: optimize this when possible in the future
91 info.canModify = false;
92 return;
93 }
94 info.calls = std::move(FindAll<Call>(func->body).list);
95 info.callRefs = std::move(FindAll<CallRef>(func->body).list);
96 info.resultsLUB = LUB::getResultsLUB(func, *module);
97 });
98
99 // A map of types to all the information combined over all the functions
100 // with that type.
101 std::unordered_map<HeapType, Info> allInfo;
102
103 // Combine all the information we gathered into that map.
104 for (auto& [func, info] : analysis.map) {
105 // For direct calls, add each call to the type of the function being
106 // called.
107 for (auto* call : info.calls) {
108 allInfo[module->getFunction(call->target)->type].calls.push_back(call);
109 }
110
111 // For indirect calls, add each call_ref to the type the call_ref uses.
112 for (auto* callRef : info.callRefs) {
113 auto calledType = callRef->target->type;
114 if (calledType != Type::unreachable) {
115 allInfo[calledType.getHeapType()].callRefs.push_back(callRef);
116 }
117 }
118
119 // Add the function's return LUB to the one for the heap type of that
120 // function.
121 allInfo[func->type].resultsLUB.combine(info.resultsLUB);
122
123 // If one function cannot be modified, that entire type cannot be.
124 if (!info.canModify) {
125 allInfo[func->type].canModify = false;
126 }
127 }
128
129 // We cannot alter the signature of an exported function, as the outside may
130 // notice us doing so. For example, if we turn a parameter from nullable
131 // into non-nullable then callers sending a null will break. Put another
132 // way, we need to see all callers to refine types, and for exports we
133 // cannot do so.
134 // TODO If a function type is passed we should also mark the types used
135 // there, etc., recursively. For now this code just handles the top-
136 // level type, which is enough to keep the fuzzer from erroring. More
137 // generally, we need to decide about adding a "closed-world" flag of
138 // some kind.
139 for (auto* exportedFunc : ExportUtils::getExportedFunctions(*module)) {
140 allInfo[exportedFunc->type].canModify = false;
141 }
142
143 // For now, do not optimize types that have subtypes. When we modify such a
144 // type we need to modify subtypes as well, similar to the analysis in
145 // TypeRefining, and perhaps we can unify this pass with that. TODO
146 SubTypes subTypes(*module);
147 for (auto& [type, info] : allInfo) {
148 if (!subTypes.getImmediateSubTypes(type).empty()) {
149 info.canModify = false;
150 } else if (type.getSuperType()) {
151 // Also avoid modifying types with supertypes, as we do not handle
152 // contravariance here. That is, when we refine parameters we look for
153 // a more refined type, but the type must be *less* refined than the
154 // param type for the parent (or equal) TODO
155 info.canModify = false;
156 }
157 }
158
159 // Compute optimal LUBs.
160 std::unordered_set<HeapType> seen;
161 for (auto& func : module->functions) {
162 auto type = func->type;
163 if (!seen.insert(type).second) {
164 continue;
165 }
166
167 auto& info = allInfo[type];
168 if (!info.canModify) {
169 continue;
170 }
171
172 auto sig = type.getSignature();
173
174 auto numParams = sig.params.size();
175 std::vector<LUBFinder> paramLUBs(numParams);
176
177 auto updateLUBs = [&](const ExpressionList& operands) {
178 for (Index i = 0; i < numParams; i++) {
179 paramLUBs[i].note(operands[i]->type);
180 }
181 };
182
183 for (auto* call : info.calls) {
184 updateLUBs(call->operands);
185 }
186 for (auto* callRef : info.callRefs) {
187 updateLUBs(callRef->operands);
188 }
189
190 // Find the final LUBs, and see if we found an improvement.
191 std::vector<Type> newParamsTypes;
192 for (auto& lub : paramLUBs) {
193 if (!lub.noted()) {
194 break;
195 }
196 newParamsTypes.push_back(lub.getLUB());
197 }
198 Type newParams;
199 if (newParamsTypes.size() < numParams) {
200 // We did not have type information to calculate a LUB (no calls, or
201 // some param is always unreachable), so there is nothing we can improve
202 // here. Other passes might remove the type entirely.
203 newParams = func->getParams();
204 } else {
205 newParams = Type(newParamsTypes);
206 }
207
208 auto& resultsLUB = info.resultsLUB;
209 Type newResults;
210 if (!resultsLUB.noted()) {
211 // We did not have type information to calculate a LUB (no returned
212 // value, or it can return a value but traps instead etc.).
213 newResults = func->getResults();
214 } else {
215 newResults = resultsLUB.getLUB();
216 }
217
218 if (newParams == func->getParams() && newResults == func->getResults()) {
219 continue;
220 }
221
222 // We found an improvement!
223 newSignatures[type] = Signature(newParams, newResults);
224
225 if (newResults != func->getResults()) {
226 // Update the types of calls using the signature.
227 for (auto* call : info.calls) {
228 if (call->type != Type::unreachable) {
229 call->type = newResults;
230 }
231 }
232 for (auto* callRef : info.callRefs) {
233 if (callRef->type != Type::unreachable) {
234 callRef->type = newResults;
235 }
236 }
237 }
238 }
239
240 if (newSignatures.empty()) {
241 // We found nothing to optimize.
242 return;
243 }
244
245 // Update function contents for their new parameter types.
246 struct CodeUpdater : public WalkerPass<PostWalker<CodeUpdater>> {
247 bool isFunctionParallel() override { return true; }
248
249 // Updating parameter types cannot affect validation (only updating var
250 // types types might).
251 bool requiresNonNullableLocalFixups() override { return false; }
252
253 SignatureRefining& parent;
254 Module& wasm;
255
256 CodeUpdater(SignatureRefining& parent, Module& wasm)
257 : parent(parent), wasm(wasm) {}
258
259 std::unique_ptr<Pass> create() override {
260 return std::make_unique<CodeUpdater>(parent, wasm);
261 }
262
263 void doWalkFunction(Function* func) {
264 auto iter = parent.newSignatures.find(func->type);
265 if (iter != parent.newSignatures.end()) {
266 std::vector<Type> newParamsTypes;
267 for (auto param : iter->second.params) {
268 newParamsTypes.push_back(param);
269 }
270 // Do not update local.get/local.tee here, as we will do so in
271 // GlobalTypeRewriter::updateSignatures, below. (Doing an update here
272 // would leave the IR in an inconsistent state of a partial update;
273 // instead, do the full update at the end.)
274 TypeUpdating::updateParamTypes(
275 func,
276 newParamsTypes,
277 wasm,
278 TypeUpdating::LocalUpdatingMode::DoNotUpdate);
279 }
280 }
281 };
282 CodeUpdater(*this, *module).run(getPassRunner(), module);
283
284 // Rewrite the types.
285 GlobalTypeRewriter::updateSignatures(newSignatures, *module);
286
287 // TODO: we could do this only in relevant functions perhaps
288 ReFinalize().run(getPassRunner(), module);
289 }
290};
291
292} // anonymous namespace
293
294Pass* createSignatureRefiningPass() { return new SignatureRefining(); }
295
296} // namespace wasm
297

source code of dart_sdk/third_party/binaryen/src/src/passes/SignatureRefining.cpp