1 | // SPDX-License-Identifier: GPL-2.0 |
2 | #include <kunit/test.h> |
3 | |
4 | #include "protocol.h" |
5 | |
6 | static struct mptcp_subflow_request_sock *build_req_sock(struct kunit *test) |
7 | { |
8 | struct mptcp_subflow_request_sock *req; |
9 | |
10 | req = kunit_kzalloc(test, size: sizeof(struct mptcp_subflow_request_sock), |
11 | GFP_USER); |
12 | KUNIT_EXPECT_NOT_ERR_OR_NULL(test, req); |
13 | mptcp_token_init_request(req: (struct request_sock *)req); |
14 | sock_net_set(sk: (struct sock *)req, net: &init_net); |
15 | return req; |
16 | } |
17 | |
18 | static void mptcp_token_test_req_basic(struct kunit *test) |
19 | { |
20 | struct mptcp_subflow_request_sock *req = build_req_sock(test); |
21 | struct mptcp_sock *null_msk = NULL; |
22 | |
23 | KUNIT_ASSERT_EQ(test, 0, |
24 | mptcp_token_new_request((struct request_sock *)req)); |
25 | KUNIT_EXPECT_NE(test, 0, (int)req->token); |
26 | KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, req->token)); |
27 | |
28 | /* cleanup */ |
29 | mptcp_token_destroy_request(req: (struct request_sock *)req); |
30 | } |
31 | |
32 | static struct inet_connection_sock *build_icsk(struct kunit *test) |
33 | { |
34 | struct inet_connection_sock *icsk; |
35 | |
36 | icsk = kunit_kzalloc(test, size: sizeof(struct inet_connection_sock), |
37 | GFP_USER); |
38 | KUNIT_EXPECT_NOT_ERR_OR_NULL(test, icsk); |
39 | return icsk; |
40 | } |
41 | |
42 | static struct mptcp_subflow_context *build_ctx(struct kunit *test) |
43 | { |
44 | struct mptcp_subflow_context *ctx; |
45 | |
46 | ctx = kunit_kzalloc(test, size: sizeof(struct mptcp_subflow_context), |
47 | GFP_USER); |
48 | KUNIT_EXPECT_NOT_ERR_OR_NULL(test, ctx); |
49 | return ctx; |
50 | } |
51 | |
52 | static struct mptcp_sock *build_msk(struct kunit *test) |
53 | { |
54 | struct mptcp_sock *msk; |
55 | |
56 | msk = kunit_kzalloc(test, size: sizeof(struct mptcp_sock), GFP_USER); |
57 | KUNIT_EXPECT_NOT_ERR_OR_NULL(test, msk); |
58 | refcount_set(r: &((struct sock *)msk)->sk_refcnt, n: 1); |
59 | sock_net_set(sk: (struct sock *)msk, net: &init_net); |
60 | |
61 | /* be sure the token helpers can dereference sk->sk_prot */ |
62 | ((struct sock *)msk)->sk_prot = &tcp_prot; |
63 | return msk; |
64 | } |
65 | |
66 | static void mptcp_token_test_msk_basic(struct kunit *test) |
67 | { |
68 | struct inet_connection_sock *icsk = build_icsk(test); |
69 | struct mptcp_subflow_context *ctx = build_ctx(test); |
70 | struct mptcp_sock *msk = build_msk(test); |
71 | struct mptcp_sock *null_msk = NULL; |
72 | struct sock *sk; |
73 | |
74 | rcu_assign_pointer(icsk->icsk_ulp_data, ctx); |
75 | ctx->conn = (struct sock *)msk; |
76 | sk = (struct sock *)msk; |
77 | |
78 | KUNIT_ASSERT_EQ(test, 0, |
79 | mptcp_token_new_connect((struct sock *)icsk)); |
80 | KUNIT_EXPECT_NE(test, 0, (int)ctx->token); |
81 | KUNIT_EXPECT_EQ(test, ctx->token, msk->token); |
82 | KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, ctx->token)); |
83 | KUNIT_EXPECT_EQ(test, 2, (int)refcount_read(&sk->sk_refcnt)); |
84 | |
85 | mptcp_token_destroy(msk); |
86 | KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, ctx->token)); |
87 | } |
88 | |
89 | static void mptcp_token_test_accept(struct kunit *test) |
90 | { |
91 | struct mptcp_subflow_request_sock *req = build_req_sock(test); |
92 | struct mptcp_sock *msk = build_msk(test); |
93 | |
94 | KUNIT_ASSERT_EQ(test, 0, |
95 | mptcp_token_new_request((struct request_sock *)req)); |
96 | msk->token = req->token; |
97 | mptcp_token_accept(r: req, msk); |
98 | KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token)); |
99 | |
100 | /* this is now a no-op */ |
101 | mptcp_token_destroy_request(req: (struct request_sock *)req); |
102 | KUNIT_EXPECT_PTR_EQ(test, msk, mptcp_token_get_sock(&init_net, msk->token)); |
103 | |
104 | /* cleanup */ |
105 | mptcp_token_destroy(msk); |
106 | } |
107 | |
108 | static void mptcp_token_test_destroyed(struct kunit *test) |
109 | { |
110 | struct mptcp_subflow_request_sock *req = build_req_sock(test); |
111 | struct mptcp_sock *msk = build_msk(test); |
112 | struct mptcp_sock *null_msk = NULL; |
113 | struct sock *sk; |
114 | |
115 | sk = (struct sock *)msk; |
116 | |
117 | KUNIT_ASSERT_EQ(test, 0, |
118 | mptcp_token_new_request((struct request_sock *)req)); |
119 | msk->token = req->token; |
120 | mptcp_token_accept(r: req, msk); |
121 | |
122 | /* simulate race on removal */ |
123 | refcount_set(r: &sk->sk_refcnt, n: 0); |
124 | KUNIT_EXPECT_PTR_EQ(test, null_msk, mptcp_token_get_sock(&init_net, msk->token)); |
125 | |
126 | /* cleanup */ |
127 | mptcp_token_destroy(msk); |
128 | } |
129 | |
130 | static struct kunit_case mptcp_token_test_cases[] = { |
131 | KUNIT_CASE(mptcp_token_test_req_basic), |
132 | KUNIT_CASE(mptcp_token_test_msk_basic), |
133 | KUNIT_CASE(mptcp_token_test_accept), |
134 | KUNIT_CASE(mptcp_token_test_destroyed), |
135 | {} |
136 | }; |
137 | |
138 | static struct kunit_suite mptcp_token_suite = { |
139 | .name = "mptcp-token" , |
140 | .test_cases = mptcp_token_test_cases, |
141 | }; |
142 | |
143 | kunit_test_suite(mptcp_token_suite); |
144 | |
145 | MODULE_LICENSE("GPL" ); |
146 | |