1//===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/AsmParser/Parser.h"
10#include "llvm/IR/Module.h"
11#include "llvm/Passes/PassBuilder.h"
12#include "llvm/Support/SourceMgr.h"
13#include "llvm/Testing/Support/Error.h"
14#include "llvm/Transforms/Coroutines/CoroSplit.h"
15#include "gtest/gtest.h"
16
17using namespace llvm;
18
19namespace {
20
21struct ExtraRematTest : public testing::Test {
22 LLVMContext Ctx;
23 ModulePassManager MPM;
24 PassBuilder PB;
25 LoopAnalysisManager LAM;
26 FunctionAnalysisManager FAM;
27 CGSCCAnalysisManager CGAM;
28 ModuleAnalysisManager MAM;
29 LLVMContext Context;
30 std::unique_ptr<Module> M;
31
32 ExtraRematTest() {
33 PB.registerModuleAnalyses(MAM);
34 PB.registerCGSCCAnalyses(CGAM);
35 PB.registerFunctionAnalyses(FAM);
36 PB.registerLoopAnalyses(LAM);
37 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
38 }
39
40 BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const {
41 for (BasicBlock &BB : *F) {
42 if (BB.getName() == Name)
43 return &BB;
44 }
45 return nullptr;
46 }
47
48 CallInst *getCallByName(BasicBlock *BB, StringRef Name) const {
49 for (Instruction &I : *BB) {
50 if (CallInst *CI = dyn_cast<CallInst>(Val: &I))
51 if (CI->getCalledFunction()->getName() == Name)
52 return CI;
53 }
54 return nullptr;
55 }
56
57 void ParseAssembly(const StringRef IR) {
58 SMDiagnostic Error;
59 M = parseAssemblyString(AsmString: IR, Err&: Error, Context);
60 std::string errMsg;
61 raw_string_ostream os(errMsg);
62 Error.print(ProgName: "", S&: os);
63
64 // A failure here means that the test itself is buggy.
65 if (!M)
66 report_fatal_error(reason: os.str().c_str());
67 }
68};
69
70StringRef Text = R"(
71 define ptr @f(i32 %n) presplitcoroutine {
72 entry:
73 %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null)
74 %size = call i32 @llvm.coro.size.i32()
75 %alloc = call ptr @malloc(i32 %size)
76 %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc)
77
78 %inc1 = add i32 %n, 1
79 %val2 = call i32 @should.remat(i32 %inc1)
80 %sp1 = call i8 @llvm.coro.suspend(token none, i1 false)
81 switch i8 %sp1, label %suspend [i8 0, label %resume1
82 i8 1, label %cleanup]
83 resume1:
84 %inc2 = add i32 %val2, 1
85 %sp2 = call i8 @llvm.coro.suspend(token none, i1 false)
86 switch i8 %sp1, label %suspend [i8 0, label %resume2
87 i8 1, label %cleanup]
88
89 resume2:
90 call void @print(i32 %val2)
91 call void @print(i32 %inc2)
92 br label %cleanup
93
94 cleanup:
95 %mem = call ptr @llvm.coro.free(token %id, ptr %hdl)
96 call void @free(ptr %mem)
97 br label %suspend
98 suspend:
99 call i1 @llvm.coro.end(ptr %hdl, i1 0)
100 ret ptr %hdl
101 }
102
103 declare ptr @llvm.coro.free(token, ptr)
104 declare i32 @llvm.coro.size.i32()
105 declare i8 @llvm.coro.suspend(token, i1)
106 declare void @llvm.coro.resume(ptr)
107 declare void @llvm.coro.destroy(ptr)
108
109 declare token @llvm.coro.id(i32, ptr, ptr, ptr)
110 declare i1 @llvm.coro.alloc(token)
111 declare ptr @llvm.coro.begin(token, ptr)
112 declare i1 @llvm.coro.end(ptr, i1)
113
114 declare i32 @should.remat(i32)
115
116 declare noalias ptr @malloc(i32)
117 declare void @print(i32)
118 declare void @free(ptr)
119 )";
120
121// Materializable callback with extra rematerialization
122bool ExtraMaterializable(Instruction &I) {
123 if (isa<CastInst>(Val: &I) || isa<GetElementPtrInst>(Val: &I) ||
124 isa<BinaryOperator>(Val: &I) || isa<CmpInst>(Val: &I) || isa<SelectInst>(Val: &I))
125 return true;
126
127 if (auto *CI = dyn_cast<CallInst>(Val: &I)) {
128 auto *CalledFunc = CI->getCalledFunction();
129 if (CalledFunc && CalledFunc->getName().starts_with(Prefix: "should.remat"))
130 return true;
131 }
132
133 return false;
134}
135
136TEST_F(ExtraRematTest, TestCoroRematDefault) {
137 ParseAssembly(IR: Text);
138
139 ASSERT_TRUE(M);
140
141 CGSCCPassManager CGPM;
142 CGPM.addPass(Pass: CoroSplitPass());
143 MPM.addPass(Pass: createModuleToPostOrderCGSCCPassAdaptor(Pass: std::move(CGPM)));
144 MPM.run(IR&: *M, AM&: MAM);
145
146 // Verify that extra rematerializable instruction has been rematerialized
147 Function *F = M->getFunction(Name: "f.resume");
148 ASSERT_TRUE(F) << "could not find split function f.resume";
149
150 BasicBlock *Resume1 = getBasicBlockByName(F, Name: "resume1");
151 ASSERT_TRUE(Resume1)
152 << "could not find expected BB resume1 in split function";
153
154 // With default materialization the intrinsic should not have been
155 // rematerialized
156 CallInst *CI = getCallByName(BB: Resume1, Name: "should.remat");
157 ASSERT_FALSE(CI);
158}
159
160TEST_F(ExtraRematTest, TestCoroRematWithCallback) {
161 ParseAssembly(IR: Text);
162
163 ASSERT_TRUE(M);
164
165 CGSCCPassManager CGPM;
166 CGPM.addPass(
167 Pass: CoroSplitPass(std::function<bool(Instruction &)>(ExtraMaterializable)));
168 MPM.addPass(Pass: createModuleToPostOrderCGSCCPassAdaptor(Pass: std::move(CGPM)));
169 MPM.run(IR&: *M, AM&: MAM);
170
171 // Verify that extra rematerializable instruction has been rematerialized
172 Function *F = M->getFunction(Name: "f.resume");
173 ASSERT_TRUE(F) << "could not find split function f.resume";
174
175 BasicBlock *Resume1 = getBasicBlockByName(F, Name: "resume1");
176 ASSERT_TRUE(Resume1)
177 << "could not find expected BB resume1 in split function";
178
179 // With callback the extra rematerialization of the function should have
180 // happened
181 CallInst *CI = getCallByName(BB: Resume1, Name: "should.remat");
182 ASSERT_TRUE(CI);
183}
184} // namespace
185

source code of llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp