1 | //===- llvm-extract.cpp - LLVM function extraction utility ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This utility changes the input module to only contain a single function, |
10 | // which is primarily used for debugging transformations. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/ADT/SetVector.h" |
15 | #include "llvm/ADT/SmallPtrSet.h" |
16 | #include "llvm/Bitcode/BitcodeWriterPass.h" |
17 | #include "llvm/IR/DataLayout.h" |
18 | #include "llvm/IR/IRPrintingPasses.h" |
19 | #include "llvm/IR/Instructions.h" |
20 | #include "llvm/IR/LLVMContext.h" |
21 | #include "llvm/IR/Module.h" |
22 | #include "llvm/IRPrinter/IRPrintingPasses.h" |
23 | #include "llvm/IRReader/IRReader.h" |
24 | #include "llvm/Passes/PassBuilder.h" |
25 | #include "llvm/Support/CommandLine.h" |
26 | #include "llvm/Support/Error.h" |
27 | #include "llvm/Support/FileSystem.h" |
28 | #include "llvm/Support/InitLLVM.h" |
29 | #include "llvm/Support/Regex.h" |
30 | #include "llvm/Support/SourceMgr.h" |
31 | #include "llvm/Support/SystemUtils.h" |
32 | #include "llvm/Support/ToolOutputFile.h" |
33 | #include "llvm/Transforms/IPO.h" |
34 | #include "llvm/Transforms/IPO/BlockExtractor.h" |
35 | #include "llvm/Transforms/IPO/ExtractGV.h" |
36 | #include "llvm/Transforms/IPO/GlobalDCE.h" |
37 | #include "llvm/Transforms/IPO/StripDeadPrototypes.h" |
38 | #include "llvm/Transforms/IPO/StripSymbols.h" |
39 | #include <memory> |
40 | #include <utility> |
41 | |
42 | using namespace llvm; |
43 | |
44 | static cl::OptionCategory ("llvm-extract Options" ); |
45 | |
46 | // InputFilename - The filename to read from. |
47 | static cl::opt<std::string> InputFilename(cl::Positional, |
48 | cl::desc("<input bitcode file>" ), |
49 | cl::init(Val: "-" ), |
50 | cl::value_desc("filename" )); |
51 | |
52 | static cl::opt<std::string> OutputFilename("o" , |
53 | cl::desc("Specify output filename" ), |
54 | cl::value_desc("filename" ), |
55 | cl::init(Val: "-" ), cl::cat(ExtractCat)); |
56 | |
57 | static cl::opt<bool> Force("f" , cl::desc("Enable binary output on terminals" ), |
58 | cl::cat(ExtractCat)); |
59 | |
60 | static cl::opt<bool> DeleteFn("delete" , |
61 | cl::desc("Delete specified Globals from Module" ), |
62 | cl::cat(ExtractCat)); |
63 | |
64 | static cl::opt<bool> KeepConstInit("keep-const-init" , |
65 | cl::desc("Keep initializers of constants" ), |
66 | cl::cat(ExtractCat)); |
67 | |
68 | static cl::opt<bool> |
69 | Recursive("recursive" , cl::desc("Recursively extract all called functions" ), |
70 | cl::cat(ExtractCat)); |
71 | |
72 | // ExtractFuncs - The functions to extract from the module. |
73 | static cl::list<std::string> |
74 | ("func" , cl::desc("Specify function to extract" ), |
75 | cl::value_desc("function" ), cl::cat(ExtractCat)); |
76 | |
77 | // ExtractRegExpFuncs - The functions, matched via regular expression, to |
78 | // extract from the module. |
79 | static cl::list<std::string> |
80 | ("rfunc" , |
81 | cl::desc("Specify function(s) to extract using a " |
82 | "regular expression" ), |
83 | cl::value_desc("rfunction" ), cl::cat(ExtractCat)); |
84 | |
85 | // ExtractBlocks - The blocks to extract from the module. |
86 | static cl::list<std::string> ( |
87 | "bb" , |
88 | cl::desc( |
89 | "Specify <function, basic block1[;basic block2...]> pairs to extract.\n" |
90 | "Each pair will create a function.\n" |
91 | "If multiple basic blocks are specified in one pair,\n" |
92 | "the first block in the sequence should dominate the rest.\n" |
93 | "If an unnamed basic block is to be extracted,\n" |
94 | "'%' should be added before the basic block variable names.\n" |
95 | "eg:\n" |
96 | " --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n" |
97 | " --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one " |
98 | "with bb2.\n" |
99 | " --bb=f:%1 will extract one function with basic block 1;" ), |
100 | cl::value_desc("function:bb1[;bb2...]" ), cl::cat(ExtractCat)); |
101 | |
102 | // ExtractAlias - The alias to extract from the module. |
103 | static cl::list<std::string> |
104 | ("alias" , cl::desc("Specify alias to extract" ), |
105 | cl::value_desc("alias" ), cl::cat(ExtractCat)); |
106 | |
107 | // ExtractRegExpAliases - The aliases, matched via regular expression, to |
108 | // extract from the module. |
109 | static cl::list<std::string> |
110 | ("ralias" , |
111 | cl::desc("Specify alias(es) to extract using a " |
112 | "regular expression" ), |
113 | cl::value_desc("ralias" ), cl::cat(ExtractCat)); |
114 | |
115 | // ExtractGlobals - The globals to extract from the module. |
116 | static cl::list<std::string> |
117 | ("glob" , cl::desc("Specify global to extract" ), |
118 | cl::value_desc("global" ), cl::cat(ExtractCat)); |
119 | |
120 | // ExtractRegExpGlobals - The globals, matched via regular expression, to |
121 | // extract from the module... |
122 | static cl::list<std::string> |
123 | ("rglob" , |
124 | cl::desc("Specify global(s) to extract using a " |
125 | "regular expression" ), |
126 | cl::value_desc("rglobal" ), cl::cat(ExtractCat)); |
127 | |
128 | static cl::opt<bool> OutputAssembly("S" , |
129 | cl::desc("Write output as LLVM assembly" ), |
130 | cl::Hidden, cl::cat(ExtractCat)); |
131 | |
132 | static cl::opt<bool> PreserveBitcodeUseListOrder( |
133 | "preserve-bc-uselistorder" , |
134 | cl::desc("Preserve use-list order when writing LLVM bitcode." ), |
135 | cl::init(Val: true), cl::Hidden, cl::cat(ExtractCat)); |
136 | |
137 | static cl::opt<bool> PreserveAssemblyUseListOrder( |
138 | "preserve-ll-uselistorder" , |
139 | cl::desc("Preserve use-list order when writing LLVM assembly." ), |
140 | cl::init(Val: false), cl::Hidden, cl::cat(ExtractCat)); |
141 | |
142 | int main(int argc, char **argv) { |
143 | InitLLVM X(argc, argv); |
144 | |
145 | LLVMContext Context; |
146 | cl::HideUnrelatedOptions(Category&: ExtractCat); |
147 | cl::ParseCommandLineOptions(argc, argv, Overview: "llvm extractor\n" ); |
148 | |
149 | // Use lazy loading, since we only care about selected global values. |
150 | SMDiagnostic Err; |
151 | std::unique_ptr<Module> M = getLazyIRFileModule(Filename: InputFilename, Err, Context); |
152 | |
153 | if (!M) { |
154 | Err.print(ProgName: argv[0], S&: errs()); |
155 | return 1; |
156 | } |
157 | |
158 | // Use SetVector to avoid duplicates. |
159 | SetVector<GlobalValue *> GVs; |
160 | |
161 | // Figure out which aliases we should extract. |
162 | for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) { |
163 | GlobalAlias *GA = M->getNamedAlias(Name: ExtractAliases[i]); |
164 | if (!GA) { |
165 | errs() << argv[0] << ": program doesn't contain alias named '" |
166 | << ExtractAliases[i] << "'!\n" ; |
167 | return 1; |
168 | } |
169 | GVs.insert(X: GA); |
170 | } |
171 | |
172 | // Extract aliases via regular expression matching. |
173 | for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) { |
174 | std::string Error; |
175 | Regex RegEx(ExtractRegExpAliases[i]); |
176 | if (!RegEx.isValid(Error)) { |
177 | errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' " |
178 | "invalid regex: " << Error; |
179 | } |
180 | bool match = false; |
181 | for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end(); |
182 | GA != E; GA++) { |
183 | if (RegEx.match(String: GA->getName())) { |
184 | GVs.insert(X: &*GA); |
185 | match = true; |
186 | } |
187 | } |
188 | if (!match) { |
189 | errs() << argv[0] << ": program doesn't contain global named '" |
190 | << ExtractRegExpAliases[i] << "'!\n" ; |
191 | return 1; |
192 | } |
193 | } |
194 | |
195 | // Figure out which globals we should extract. |
196 | for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) { |
197 | GlobalValue *GV = M->getNamedGlobal(Name: ExtractGlobals[i]); |
198 | if (!GV) { |
199 | errs() << argv[0] << ": program doesn't contain global named '" |
200 | << ExtractGlobals[i] << "'!\n" ; |
201 | return 1; |
202 | } |
203 | GVs.insert(X: GV); |
204 | } |
205 | |
206 | // Extract globals via regular expression matching. |
207 | for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) { |
208 | std::string Error; |
209 | Regex RegEx(ExtractRegExpGlobals[i]); |
210 | if (!RegEx.isValid(Error)) { |
211 | errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' " |
212 | "invalid regex: " << Error; |
213 | } |
214 | bool match = false; |
215 | for (auto &GV : M->globals()) { |
216 | if (RegEx.match(String: GV.getName())) { |
217 | GVs.insert(X: &GV); |
218 | match = true; |
219 | } |
220 | } |
221 | if (!match) { |
222 | errs() << argv[0] << ": program doesn't contain global named '" |
223 | << ExtractRegExpGlobals[i] << "'!\n" ; |
224 | return 1; |
225 | } |
226 | } |
227 | |
228 | // Figure out which functions we should extract. |
229 | for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) { |
230 | GlobalValue *GV = M->getFunction(Name: ExtractFuncs[i]); |
231 | if (!GV) { |
232 | errs() << argv[0] << ": program doesn't contain function named '" |
233 | << ExtractFuncs[i] << "'!\n" ; |
234 | return 1; |
235 | } |
236 | GVs.insert(X: GV); |
237 | } |
238 | // Extract functions via regular expression matching. |
239 | for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) { |
240 | std::string Error; |
241 | StringRef RegExStr = ExtractRegExpFuncs[i]; |
242 | Regex RegEx(RegExStr); |
243 | if (!RegEx.isValid(Error)) { |
244 | errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' " |
245 | "invalid regex: " << Error; |
246 | } |
247 | bool match = false; |
248 | for (Module::iterator F = M->begin(), E = M->end(); F != E; |
249 | F++) { |
250 | if (RegEx.match(String: F->getName())) { |
251 | GVs.insert(X: &*F); |
252 | match = true; |
253 | } |
254 | } |
255 | if (!match) { |
256 | errs() << argv[0] << ": program doesn't contain global named '" |
257 | << ExtractRegExpFuncs[i] << "'!\n" ; |
258 | return 1; |
259 | } |
260 | } |
261 | |
262 | // Figure out which BasicBlocks we should extract. |
263 | SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap; |
264 | for (StringRef StrPair : ExtractBlocks) { |
265 | SmallVector<StringRef, 16> BBNames; |
266 | auto BBInfo = StrPair.split(Separator: ':'); |
267 | // Get the function. |
268 | Function *F = M->getFunction(Name: BBInfo.first); |
269 | if (!F) { |
270 | errs() << argv[0] << ": program doesn't contain a function named '" |
271 | << BBInfo.first << "'!\n" ; |
272 | return 1; |
273 | } |
274 | // Add the function to the materialize list, and store the basic block names |
275 | // to check after materialization. |
276 | GVs.insert(X: F); |
277 | BBInfo.second.split(A&: BBNames, Separator: ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false); |
278 | BBMap.push_back(Elt: {F, std::move(BBNames)}); |
279 | } |
280 | |
281 | // Use *argv instead of argv[0] to work around a wrong GCC warning. |
282 | ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: " ); |
283 | |
284 | if (Recursive) { |
285 | std::vector<llvm::Function *> Workqueue; |
286 | for (GlobalValue *GV : GVs) { |
287 | if (auto *F = dyn_cast<Function>(Val: GV)) { |
288 | Workqueue.push_back(x: F); |
289 | } |
290 | } |
291 | while (!Workqueue.empty()) { |
292 | Function *F = &*Workqueue.back(); |
293 | Workqueue.pop_back(); |
294 | ExitOnErr(F->materialize()); |
295 | for (auto &BB : *F) { |
296 | for (auto &I : BB) { |
297 | CallBase *CB = dyn_cast<CallBase>(Val: &I); |
298 | if (!CB) |
299 | continue; |
300 | Function *CF = CB->getCalledFunction(); |
301 | if (!CF) |
302 | continue; |
303 | if (CF->isDeclaration() || !GVs.insert(X: CF)) |
304 | continue; |
305 | Workqueue.push_back(x: CF); |
306 | } |
307 | } |
308 | } |
309 | } |
310 | |
311 | auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); }; |
312 | |
313 | // Materialize requisite global values. |
314 | if (!DeleteFn) { |
315 | for (size_t i = 0, e = GVs.size(); i != e; ++i) |
316 | Materialize(*GVs[i]); |
317 | } else { |
318 | // Deleting. Materialize every GV that's *not* in GVs. |
319 | SmallPtrSet<GlobalValue *, 8> GVSet(llvm::from_range, GVs); |
320 | for (auto &F : *M) { |
321 | if (!GVSet.count(Ptr: &F)) |
322 | Materialize(F); |
323 | } |
324 | } |
325 | |
326 | { |
327 | std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end()); |
328 | LoopAnalysisManager LAM; |
329 | FunctionAnalysisManager FAM; |
330 | CGSCCAnalysisManager CGAM; |
331 | ModuleAnalysisManager MAM; |
332 | |
333 | PassBuilder PB; |
334 | |
335 | PB.registerModuleAnalyses(MAM); |
336 | PB.registerCGSCCAnalyses(CGAM); |
337 | PB.registerFunctionAnalyses(FAM); |
338 | PB.registerLoopAnalyses(LAM); |
339 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
340 | |
341 | ModulePassManager PM; |
342 | PM.addPass(Pass: ExtractGVPass(Gvs, DeleteFn, KeepConstInit)); |
343 | PM.run(IR&: *M, AM&: MAM); |
344 | |
345 | // Now that we have all the GVs we want, mark the module as fully |
346 | // materialized. |
347 | // FIXME: should the GVExtractionPass handle this? |
348 | ExitOnErr(M->materializeAll()); |
349 | } |
350 | |
351 | // Extract the specified basic blocks from the module and erase the existing |
352 | // functions. |
353 | if (!ExtractBlocks.empty()) { |
354 | // Figure out which BasicBlocks we should extract. |
355 | std::vector<std::vector<BasicBlock *>> GroupOfBBs; |
356 | for (auto &P : BBMap) { |
357 | std::vector<BasicBlock *> BBs; |
358 | for (StringRef BBName : P.second) { |
359 | // The function has been materialized, so add its matching basic blocks |
360 | // to the block extractor list, or fail if a name is not found. |
361 | auto Res = llvm::find_if(Range&: *P.first, P: [&](const BasicBlock &BB) { |
362 | return BB.getNameOrAsOperand() == BBName; |
363 | }); |
364 | if (Res == P.first->end()) { |
365 | errs() << argv[0] << ": function " << P.first->getName() |
366 | << " doesn't contain a basic block named '" << BBName |
367 | << "'!\n" ; |
368 | return 1; |
369 | } |
370 | BBs.push_back(x: &*Res); |
371 | } |
372 | GroupOfBBs.push_back(x: BBs); |
373 | } |
374 | |
375 | LoopAnalysisManager LAM; |
376 | FunctionAnalysisManager FAM; |
377 | CGSCCAnalysisManager CGAM; |
378 | ModuleAnalysisManager MAM; |
379 | |
380 | PassBuilder PB; |
381 | |
382 | PB.registerModuleAnalyses(MAM); |
383 | PB.registerCGSCCAnalyses(CGAM); |
384 | PB.registerFunctionAnalyses(FAM); |
385 | PB.registerLoopAnalyses(LAM); |
386 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
387 | |
388 | ModulePassManager PM; |
389 | PM.addPass(Pass: BlockExtractorPass(std::move(GroupOfBBs), true)); |
390 | PM.run(IR&: *M, AM&: MAM); |
391 | } |
392 | |
393 | // In addition to deleting all other functions, we also want to spiff it |
394 | // up a little bit. Do this now. |
395 | |
396 | LoopAnalysisManager LAM; |
397 | FunctionAnalysisManager FAM; |
398 | CGSCCAnalysisManager CGAM; |
399 | ModuleAnalysisManager MAM; |
400 | |
401 | PassBuilder PB; |
402 | |
403 | PB.registerModuleAnalyses(MAM); |
404 | PB.registerCGSCCAnalyses(CGAM); |
405 | PB.registerFunctionAnalyses(FAM); |
406 | PB.registerLoopAnalyses(LAM); |
407 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
408 | |
409 | ModulePassManager PM; |
410 | if (!DeleteFn) |
411 | PM.addPass(Pass: GlobalDCEPass()); |
412 | PM.addPass(Pass: StripDeadDebugInfoPass()); |
413 | PM.addPass(Pass: StripDeadPrototypesPass()); |
414 | PM.addPass(Pass: StripDeadCGProfilePass()); |
415 | |
416 | std::error_code EC; |
417 | ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None); |
418 | if (EC) { |
419 | errs() << EC.message() << '\n'; |
420 | return 1; |
421 | } |
422 | |
423 | if (OutputAssembly) |
424 | PM.addPass(Pass: PrintModulePass(Out.os(), "" , PreserveAssemblyUseListOrder)); |
425 | else if (Force || !CheckBitcodeOutputToConsole(stream_to_check&: Out.os())) |
426 | PM.addPass(Pass: BitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder)); |
427 | |
428 | PM.run(IR&: *M, AM&: MAM); |
429 | |
430 | // Declare success. |
431 | Out.keep(); |
432 | |
433 | return 0; |
434 | } |
435 | |