1 | //===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/AsmParser/Parser.h" |
10 | #include "llvm/IR/Module.h" |
11 | #include "llvm/Passes/PassBuilder.h" |
12 | #include "llvm/Support/SourceMgr.h" |
13 | #include "llvm/Testing/Support/Error.h" |
14 | #include "llvm/Transforms/Coroutines/CoroSplit.h" |
15 | #include "gtest/gtest.h" |
16 | |
17 | using namespace llvm; |
18 | |
19 | namespace { |
20 | |
21 | struct : public testing::Test { |
22 | LLVMContext ; |
23 | ModulePassManager ; |
24 | PassBuilder ; |
25 | LoopAnalysisManager ; |
26 | FunctionAnalysisManager ; |
27 | CGSCCAnalysisManager ; |
28 | ModuleAnalysisManager ; |
29 | LLVMContext ; |
30 | std::unique_ptr<Module> ; |
31 | |
32 | () { |
33 | PB.registerModuleAnalyses(MAM); |
34 | PB.registerCGSCCAnalyses(CGAM); |
35 | PB.registerFunctionAnalyses(FAM); |
36 | PB.registerLoopAnalyses(LAM); |
37 | PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); |
38 | } |
39 | |
40 | BasicBlock *(Function *F, StringRef Name) const { |
41 | for (BasicBlock &BB : *F) { |
42 | if (BB.getName() == Name) |
43 | return &BB; |
44 | } |
45 | return nullptr; |
46 | } |
47 | |
48 | CallInst *(BasicBlock *BB, StringRef Name) const { |
49 | for (Instruction &I : *BB) { |
50 | if (CallInst *CI = dyn_cast<CallInst>(Val: &I)) |
51 | if (CI->getCalledFunction()->getName() == Name) |
52 | return CI; |
53 | } |
54 | return nullptr; |
55 | } |
56 | |
57 | void (const StringRef IR) { |
58 | SMDiagnostic Error; |
59 | M = parseAssemblyString(AsmString: IR, Err&: Error, Context); |
60 | std::string errMsg; |
61 | raw_string_ostream os(errMsg); |
62 | Error.print(ProgName: "" , S&: os); |
63 | |
64 | // A failure here means that the test itself is buggy. |
65 | if (!M) |
66 | report_fatal_error(reason: os.str().c_str()); |
67 | } |
68 | }; |
69 | |
70 | StringRef Text = R"( |
71 | define ptr @f(i32 %n) presplitcoroutine { |
72 | entry: |
73 | %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) |
74 | %size = call i32 @llvm.coro.size.i32() |
75 | %alloc = call ptr @malloc(i32 %size) |
76 | %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc) |
77 | |
78 | %inc1 = add i32 %n, 1 |
79 | %val2 = call i32 @should.remat(i32 %inc1) |
80 | %sp1 = call i8 @llvm.coro.suspend(token none, i1 false) |
81 | switch i8 %sp1, label %suspend [i8 0, label %resume1 |
82 | i8 1, label %cleanup] |
83 | resume1: |
84 | %inc2 = add i32 %val2, 1 |
85 | %sp2 = call i8 @llvm.coro.suspend(token none, i1 false) |
86 | switch i8 %sp1, label %suspend [i8 0, label %resume2 |
87 | i8 1, label %cleanup] |
88 | |
89 | resume2: |
90 | call void @print(i32 %val2) |
91 | call void @print(i32 %inc2) |
92 | br label %cleanup |
93 | |
94 | cleanup: |
95 | %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) |
96 | call void @free(ptr %mem) |
97 | br label %suspend |
98 | suspend: |
99 | call i1 @llvm.coro.end(ptr %hdl, i1 0) |
100 | ret ptr %hdl |
101 | } |
102 | |
103 | declare ptr @llvm.coro.free(token, ptr) |
104 | declare i32 @llvm.coro.size.i32() |
105 | declare i8 @llvm.coro.suspend(token, i1) |
106 | declare void @llvm.coro.resume(ptr) |
107 | declare void @llvm.coro.destroy(ptr) |
108 | |
109 | declare token @llvm.coro.id(i32, ptr, ptr, ptr) |
110 | declare i1 @llvm.coro.alloc(token) |
111 | declare ptr @llvm.coro.begin(token, ptr) |
112 | declare i1 @llvm.coro.end(ptr, i1) |
113 | |
114 | declare i32 @should.remat(i32) |
115 | |
116 | declare noalias ptr @malloc(i32) |
117 | declare void @print(i32) |
118 | declare void @free(ptr) |
119 | )" ; |
120 | |
121 | // Materializable callback with extra rematerialization |
122 | bool (Instruction &I) { |
123 | if (isa<CastInst>(Val: &I) || isa<GetElementPtrInst>(Val: &I) || |
124 | isa<BinaryOperator>(Val: &I) || isa<CmpInst>(Val: &I) || isa<SelectInst>(Val: &I)) |
125 | return true; |
126 | |
127 | if (auto *CI = dyn_cast<CallInst>(Val: &I)) { |
128 | auto *CalledFunc = CI->getCalledFunction(); |
129 | if (CalledFunc && CalledFunc->getName().starts_with(Prefix: "should.remat" )) |
130 | return true; |
131 | } |
132 | |
133 | return false; |
134 | } |
135 | |
136 | TEST_F(ExtraRematTest, TestCoroRematDefault) { |
137 | ParseAssembly(IR: Text); |
138 | |
139 | ASSERT_TRUE(M); |
140 | |
141 | CGSCCPassManager CGPM; |
142 | CGPM.addPass(Pass: CoroSplitPass()); |
143 | MPM.addPass(Pass: createModuleToPostOrderCGSCCPassAdaptor(Pass: std::move(CGPM))); |
144 | MPM.run(IR&: *M, AM&: MAM); |
145 | |
146 | // Verify that extra rematerializable instruction has been rematerialized |
147 | Function *F = M->getFunction(Name: "f.resume" ); |
148 | ASSERT_TRUE(F) << "could not find split function f.resume" ; |
149 | |
150 | BasicBlock *Resume1 = getBasicBlockByName(F, Name: "resume1" ); |
151 | ASSERT_TRUE(Resume1) |
152 | << "could not find expected BB resume1 in split function" ; |
153 | |
154 | // With default materialization the intrinsic should not have been |
155 | // rematerialized |
156 | CallInst *CI = getCallByName(BB: Resume1, Name: "should.remat" ); |
157 | ASSERT_FALSE(CI); |
158 | } |
159 | |
160 | TEST_F(ExtraRematTest, TestCoroRematWithCallback) { |
161 | ParseAssembly(IR: Text); |
162 | |
163 | ASSERT_TRUE(M); |
164 | |
165 | CGSCCPassManager CGPM; |
166 | CGPM.addPass( |
167 | Pass: CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable))); |
168 | MPM.addPass(Pass: createModuleToPostOrderCGSCCPassAdaptor(Pass: std::move(CGPM))); |
169 | MPM.run(IR&: *M, AM&: MAM); |
170 | |
171 | // Verify that extra rematerializable instruction has been rematerialized |
172 | Function *F = M->getFunction(Name: "f.resume" ); |
173 | ASSERT_TRUE(F) << "could not find split function f.resume" ; |
174 | |
175 | BasicBlock *Resume1 = getBasicBlockByName(F, Name: "resume1" ); |
176 | ASSERT_TRUE(Resume1) |
177 | << "could not find expected BB resume1 in split function" ; |
178 | |
179 | // With callback the extra rematerialization of the function should have |
180 | // happened |
181 | CallInst *CI = getCallByName(BB: Resume1, Name: "should.remat" ); |
182 | ASSERT_TRUE(CI); |
183 | } |
184 | } // namespace |
185 | |