1// SPDX-License-Identifier: GPL-2.0
2/*
3 * Copyright (C) 2015-2019 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4 */
5
6#include "noise.h"
7#include "device.h"
8#include "peer.h"
9#include "messages.h"
10#include "queueing.h"
11#include "peerlookup.h"
12
13#include <linux/rcupdate.h>
14#include <linux/slab.h>
15#include <linux/bitmap.h>
16#include <linux/scatterlist.h>
17#include <linux/highmem.h>
18#include <crypto/utils.h>
19
20/* This implements Noise_IKpsk2:
21 *
22 * <- s
23 * ******
24 * -> e, es, s, ss, {t}
25 * <- e, ee, se, psk, {}
26 */
27
28static const u8 handshake_name[37] = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s";
29static const u8 identifier_name[34] = "WireGuard v1 zx2c4 Jason@zx2c4.com";
30static u8 handshake_init_hash[NOISE_HASH_LEN] __ro_after_init;
31static u8 handshake_init_chaining_key[NOISE_HASH_LEN] __ro_after_init;
32static atomic64_t keypair_counter = ATOMIC64_INIT(0);
33
34void __init wg_noise_init(void)
35{
36 struct blake2s_state blake;
37
38 blake2s(out: handshake_init_chaining_key, in: handshake_name, NULL,
39 outlen: NOISE_HASH_LEN, inlen: sizeof(handshake_name), keylen: 0);
40 blake2s_init(state: &blake, outlen: NOISE_HASH_LEN);
41 blake2s_update(state: &blake, in: handshake_init_chaining_key, inlen: NOISE_HASH_LEN);
42 blake2s_update(state: &blake, in: identifier_name, inlen: sizeof(identifier_name));
43 blake2s_final(state: &blake, out: handshake_init_hash);
44}
45
46/* Must hold peer->handshake.static_identity->lock */
47void wg_noise_precompute_static_static(struct wg_peer *peer)
48{
49 down_write(sem: &peer->handshake.lock);
50 if (!peer->handshake.static_identity->has_identity ||
51 !curve25519(mypublic: peer->handshake.precomputed_static_static,
52 secret: peer->handshake.static_identity->static_private,
53 basepoint: peer->handshake.remote_static))
54 memset(peer->handshake.precomputed_static_static, 0,
55 NOISE_PUBLIC_KEY_LEN);
56 up_write(sem: &peer->handshake.lock);
57}
58
59void wg_noise_handshake_init(struct noise_handshake *handshake,
60 struct noise_static_identity *static_identity,
61 const u8 peer_public_key[NOISE_PUBLIC_KEY_LEN],
62 const u8 peer_preshared_key[NOISE_SYMMETRIC_KEY_LEN],
63 struct wg_peer *peer)
64{
65 memset(handshake, 0, sizeof(*handshake));
66 init_rwsem(&handshake->lock);
67 handshake->entry.type = INDEX_HASHTABLE_HANDSHAKE;
68 handshake->entry.peer = peer;
69 memcpy(handshake->remote_static, peer_public_key, NOISE_PUBLIC_KEY_LEN);
70 if (peer_preshared_key)
71 memcpy(handshake->preshared_key, peer_preshared_key,
72 NOISE_SYMMETRIC_KEY_LEN);
73 handshake->static_identity = static_identity;
74 handshake->state = HANDSHAKE_ZEROED;
75 wg_noise_precompute_static_static(peer);
76}
77
78static void handshake_zero(struct noise_handshake *handshake)
79{
80 memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
81 memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
82 memset(&handshake->hash, 0, NOISE_HASH_LEN);
83 memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
84 handshake->remote_index = 0;
85 handshake->state = HANDSHAKE_ZEROED;
86}
87
88void wg_noise_handshake_clear(struct noise_handshake *handshake)
89{
90 down_write(sem: &handshake->lock);
91 wg_index_hashtable_remove(
92 table: handshake->entry.peer->device->index_hashtable,
93 entry: &handshake->entry);
94 handshake_zero(handshake);
95 up_write(sem: &handshake->lock);
96}
97
98static struct noise_keypair *keypair_create(struct wg_peer *peer)
99{
100 struct noise_keypair *keypair = kzalloc(size: sizeof(*keypair), GFP_KERNEL);
101
102 if (unlikely(!keypair))
103 return NULL;
104 spin_lock_init(&keypair->receiving_counter.lock);
105 keypair->internal_id = atomic64_inc_return(v: &keypair_counter);
106 keypair->entry.type = INDEX_HASHTABLE_KEYPAIR;
107 keypair->entry.peer = peer;
108 kref_init(kref: &keypair->refcount);
109 return keypair;
110}
111
112static void keypair_free_rcu(struct rcu_head *rcu)
113{
114 kfree_sensitive(container_of(rcu, struct noise_keypair, rcu));
115}
116
117static void keypair_free_kref(struct kref *kref)
118{
119 struct noise_keypair *keypair =
120 container_of(kref, struct noise_keypair, refcount);
121
122 net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n",
123 keypair->entry.peer->device->dev->name,
124 keypair->internal_id,
125 keypair->entry.peer->internal_id);
126 wg_index_hashtable_remove(table: keypair->entry.peer->device->index_hashtable,
127 entry: &keypair->entry);
128 call_rcu(head: &keypair->rcu, func: keypair_free_rcu);
129}
130
131void wg_noise_keypair_put(struct noise_keypair *keypair, bool unreference_now)
132{
133 if (unlikely(!keypair))
134 return;
135 if (unlikely(unreference_now))
136 wg_index_hashtable_remove(
137 table: keypair->entry.peer->device->index_hashtable,
138 entry: &keypair->entry);
139 kref_put(kref: &keypair->refcount, release: keypair_free_kref);
140}
141
142struct noise_keypair *wg_noise_keypair_get(struct noise_keypair *keypair)
143{
144 RCU_LOCKDEP_WARN(!rcu_read_lock_bh_held(),
145 "Taking noise keypair reference without holding the RCU BH read lock");
146 if (unlikely(!keypair || !kref_get_unless_zero(&keypair->refcount)))
147 return NULL;
148 return keypair;
149}
150
151void wg_noise_keypairs_clear(struct noise_keypairs *keypairs)
152{
153 struct noise_keypair *old;
154
155 spin_lock_bh(lock: &keypairs->keypair_update_lock);
156
157 /* We zero the next_keypair before zeroing the others, so that
158 * wg_noise_received_with_keypair returns early before subsequent ones
159 * are zeroed.
160 */
161 old = rcu_dereference_protected(keypairs->next_keypair,
162 lockdep_is_held(&keypairs->keypair_update_lock));
163 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
164 wg_noise_keypair_put(keypair: old, unreference_now: true);
165
166 old = rcu_dereference_protected(keypairs->previous_keypair,
167 lockdep_is_held(&keypairs->keypair_update_lock));
168 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
169 wg_noise_keypair_put(keypair: old, unreference_now: true);
170
171 old = rcu_dereference_protected(keypairs->current_keypair,
172 lockdep_is_held(&keypairs->keypair_update_lock));
173 RCU_INIT_POINTER(keypairs->current_keypair, NULL);
174 wg_noise_keypair_put(keypair: old, unreference_now: true);
175
176 spin_unlock_bh(lock: &keypairs->keypair_update_lock);
177}
178
179void wg_noise_expire_current_peer_keypairs(struct wg_peer *peer)
180{
181 struct noise_keypair *keypair;
182
183 wg_noise_handshake_clear(handshake: &peer->handshake);
184 wg_noise_reset_last_sent_handshake(handshake_ns: &peer->last_sent_handshake);
185
186 spin_lock_bh(lock: &peer->keypairs.keypair_update_lock);
187 keypair = rcu_dereference_protected(peer->keypairs.next_keypair,
188 lockdep_is_held(&peer->keypairs.keypair_update_lock));
189 if (keypair)
190 keypair->sending.is_valid = false;
191 keypair = rcu_dereference_protected(peer->keypairs.current_keypair,
192 lockdep_is_held(&peer->keypairs.keypair_update_lock));
193 if (keypair)
194 keypair->sending.is_valid = false;
195 spin_unlock_bh(lock: &peer->keypairs.keypair_update_lock);
196}
197
198static void add_new_keypair(struct noise_keypairs *keypairs,
199 struct noise_keypair *new_keypair)
200{
201 struct noise_keypair *previous_keypair, *next_keypair, *current_keypair;
202
203 spin_lock_bh(lock: &keypairs->keypair_update_lock);
204 previous_keypair = rcu_dereference_protected(keypairs->previous_keypair,
205 lockdep_is_held(&keypairs->keypair_update_lock));
206 next_keypair = rcu_dereference_protected(keypairs->next_keypair,
207 lockdep_is_held(&keypairs->keypair_update_lock));
208 current_keypair = rcu_dereference_protected(keypairs->current_keypair,
209 lockdep_is_held(&keypairs->keypair_update_lock));
210 if (new_keypair->i_am_the_initiator) {
211 /* If we're the initiator, it means we've sent a handshake, and
212 * received a confirmation response, which means this new
213 * keypair can now be used.
214 */
215 if (next_keypair) {
216 /* If there already was a next keypair pending, we
217 * demote it to be the previous keypair, and free the
218 * existing current. Note that this means KCI can result
219 * in this transition. It would perhaps be more sound to
220 * always just get rid of the unused next keypair
221 * instead of putting it in the previous slot, but this
222 * might be a bit less robust. Something to think about
223 * for the future.
224 */
225 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
226 rcu_assign_pointer(keypairs->previous_keypair,
227 next_keypair);
228 wg_noise_keypair_put(keypair: current_keypair, unreference_now: true);
229 } else /* If there wasn't an existing next keypair, we replace
230 * the previous with the current one.
231 */
232 rcu_assign_pointer(keypairs->previous_keypair,
233 current_keypair);
234 /* At this point we can get rid of the old previous keypair, and
235 * set up the new keypair.
236 */
237 wg_noise_keypair_put(keypair: previous_keypair, unreference_now: true);
238 rcu_assign_pointer(keypairs->current_keypair, new_keypair);
239 } else {
240 /* If we're the responder, it means we can't use the new keypair
241 * until we receive confirmation via the first data packet, so
242 * we get rid of the existing previous one, the possibly
243 * existing next one, and slide in the new next one.
244 */
245 rcu_assign_pointer(keypairs->next_keypair, new_keypair);
246 wg_noise_keypair_put(keypair: next_keypair, unreference_now: true);
247 RCU_INIT_POINTER(keypairs->previous_keypair, NULL);
248 wg_noise_keypair_put(keypair: previous_keypair, unreference_now: true);
249 }
250 spin_unlock_bh(lock: &keypairs->keypair_update_lock);
251}
252
253bool wg_noise_received_with_keypair(struct noise_keypairs *keypairs,
254 struct noise_keypair *received_keypair)
255{
256 struct noise_keypair *old_keypair;
257 bool key_is_new;
258
259 /* We first check without taking the spinlock. */
260 key_is_new = received_keypair ==
261 rcu_access_pointer(keypairs->next_keypair);
262 if (likely(!key_is_new))
263 return false;
264
265 spin_lock_bh(lock: &keypairs->keypair_update_lock);
266 /* After locking, we double check that things didn't change from
267 * beneath us.
268 */
269 if (unlikely(received_keypair !=
270 rcu_dereference_protected(keypairs->next_keypair,
271 lockdep_is_held(&keypairs->keypair_update_lock)))) {
272 spin_unlock_bh(lock: &keypairs->keypair_update_lock);
273 return false;
274 }
275
276 /* When we've finally received the confirmation, we slide the next
277 * into the current, the current into the previous, and get rid of
278 * the old previous.
279 */
280 old_keypair = rcu_dereference_protected(keypairs->previous_keypair,
281 lockdep_is_held(&keypairs->keypair_update_lock));
282 rcu_assign_pointer(keypairs->previous_keypair,
283 rcu_dereference_protected(keypairs->current_keypair,
284 lockdep_is_held(&keypairs->keypair_update_lock)));
285 wg_noise_keypair_put(keypair: old_keypair, unreference_now: true);
286 rcu_assign_pointer(keypairs->current_keypair, received_keypair);
287 RCU_INIT_POINTER(keypairs->next_keypair, NULL);
288
289 spin_unlock_bh(lock: &keypairs->keypair_update_lock);
290 return true;
291}
292
293/* Must hold static_identity->lock */
294void wg_noise_set_static_identity_private_key(
295 struct noise_static_identity *static_identity,
296 const u8 private_key[NOISE_PUBLIC_KEY_LEN])
297{
298 memcpy(static_identity->static_private, private_key,
299 NOISE_PUBLIC_KEY_LEN);
300 curve25519_clamp_secret(secret: static_identity->static_private);
301 static_identity->has_identity = curve25519_generate_public(
302 pub: static_identity->static_public, secret: private_key);
303}
304
305static void hmac(u8 *out, const u8 *in, const u8 *key, const size_t inlen, const size_t keylen)
306{
307 struct blake2s_state state;
308 u8 x_key[BLAKE2S_BLOCK_SIZE] __aligned(__alignof__(u32)) = { 0 };
309 u8 i_hash[BLAKE2S_HASH_SIZE] __aligned(__alignof__(u32));
310 int i;
311
312 if (keylen > BLAKE2S_BLOCK_SIZE) {
313 blake2s_init(state: &state, outlen: BLAKE2S_HASH_SIZE);
314 blake2s_update(state: &state, in: key, inlen: keylen);
315 blake2s_final(state: &state, out: x_key);
316 } else
317 memcpy(x_key, key, keylen);
318
319 for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
320 x_key[i] ^= 0x36;
321
322 blake2s_init(state: &state, outlen: BLAKE2S_HASH_SIZE);
323 blake2s_update(state: &state, in: x_key, inlen: BLAKE2S_BLOCK_SIZE);
324 blake2s_update(state: &state, in, inlen);
325 blake2s_final(state: &state, out: i_hash);
326
327 for (i = 0; i < BLAKE2S_BLOCK_SIZE; ++i)
328 x_key[i] ^= 0x5c ^ 0x36;
329
330 blake2s_init(state: &state, outlen: BLAKE2S_HASH_SIZE);
331 blake2s_update(state: &state, in: x_key, inlen: BLAKE2S_BLOCK_SIZE);
332 blake2s_update(state: &state, in: i_hash, inlen: BLAKE2S_HASH_SIZE);
333 blake2s_final(state: &state, out: i_hash);
334
335 memcpy(out, i_hash, BLAKE2S_HASH_SIZE);
336 memzero_explicit(s: x_key, count: BLAKE2S_BLOCK_SIZE);
337 memzero_explicit(s: i_hash, count: BLAKE2S_HASH_SIZE);
338}
339
340/* This is Hugo Krawczyk's HKDF:
341 * - https://eprint.iacr.org/2010/264.pdf
342 * - https://tools.ietf.org/html/rfc5869
343 */
344static void kdf(u8 *first_dst, u8 *second_dst, u8 *third_dst, const u8 *data,
345 size_t first_len, size_t second_len, size_t third_len,
346 size_t data_len, const u8 chaining_key[NOISE_HASH_LEN])
347{
348 u8 output[BLAKE2S_HASH_SIZE + 1];
349 u8 secret[BLAKE2S_HASH_SIZE];
350
351 WARN_ON(IS_ENABLED(DEBUG) &&
352 (first_len > BLAKE2S_HASH_SIZE ||
353 second_len > BLAKE2S_HASH_SIZE ||
354 third_len > BLAKE2S_HASH_SIZE ||
355 ((second_len || second_dst || third_len || third_dst) &&
356 (!first_len || !first_dst)) ||
357 ((third_len || third_dst) && (!second_len || !second_dst))));
358
359 /* Extract entropy from data into secret */
360 hmac(out: secret, in: data, key: chaining_key, inlen: data_len, keylen: NOISE_HASH_LEN);
361
362 if (!first_dst || !first_len)
363 goto out;
364
365 /* Expand first key: key = secret, data = 0x1 */
366 output[0] = 1;
367 hmac(out: output, in: output, key: secret, inlen: 1, keylen: BLAKE2S_HASH_SIZE);
368 memcpy(first_dst, output, first_len);
369
370 if (!second_dst || !second_len)
371 goto out;
372
373 /* Expand second key: key = secret, data = first-key || 0x2 */
374 output[BLAKE2S_HASH_SIZE] = 2;
375 hmac(out: output, in: output, key: secret, inlen: BLAKE2S_HASH_SIZE + 1, keylen: BLAKE2S_HASH_SIZE);
376 memcpy(second_dst, output, second_len);
377
378 if (!third_dst || !third_len)
379 goto out;
380
381 /* Expand third key: key = secret, data = second-key || 0x3 */
382 output[BLAKE2S_HASH_SIZE] = 3;
383 hmac(out: output, in: output, key: secret, inlen: BLAKE2S_HASH_SIZE + 1, keylen: BLAKE2S_HASH_SIZE);
384 memcpy(third_dst, output, third_len);
385
386out:
387 /* Clear sensitive data from stack */
388 memzero_explicit(s: secret, count: BLAKE2S_HASH_SIZE);
389 memzero_explicit(s: output, count: BLAKE2S_HASH_SIZE + 1);
390}
391
392static void derive_keys(struct noise_symmetric_key *first_dst,
393 struct noise_symmetric_key *second_dst,
394 const u8 chaining_key[NOISE_HASH_LEN])
395{
396 u64 birthdate = ktime_get_coarse_boottime_ns();
397 kdf(first_dst: first_dst->key, second_dst: second_dst->key, NULL, NULL,
398 first_len: NOISE_SYMMETRIC_KEY_LEN, second_len: NOISE_SYMMETRIC_KEY_LEN, third_len: 0, data_len: 0,
399 chaining_key);
400 first_dst->birthdate = second_dst->birthdate = birthdate;
401 first_dst->is_valid = second_dst->is_valid = true;
402}
403
404static bool __must_check mix_dh(u8 chaining_key[NOISE_HASH_LEN],
405 u8 key[NOISE_SYMMETRIC_KEY_LEN],
406 const u8 private[NOISE_PUBLIC_KEY_LEN],
407 const u8 public[NOISE_PUBLIC_KEY_LEN])
408{
409 u8 dh_calculation[NOISE_PUBLIC_KEY_LEN];
410
411 if (unlikely(!curve25519(dh_calculation, private, public)))
412 return false;
413 kdf(first_dst: chaining_key, second_dst: key, NULL, data: dh_calculation, first_len: NOISE_HASH_LEN,
414 second_len: NOISE_SYMMETRIC_KEY_LEN, third_len: 0, data_len: NOISE_PUBLIC_KEY_LEN, chaining_key);
415 memzero_explicit(s: dh_calculation, count: NOISE_PUBLIC_KEY_LEN);
416 return true;
417}
418
419static bool __must_check mix_precomputed_dh(u8 chaining_key[NOISE_HASH_LEN],
420 u8 key[NOISE_SYMMETRIC_KEY_LEN],
421 const u8 precomputed[NOISE_PUBLIC_KEY_LEN])
422{
423 static u8 zero_point[NOISE_PUBLIC_KEY_LEN];
424 if (unlikely(!crypto_memneq(precomputed, zero_point, NOISE_PUBLIC_KEY_LEN)))
425 return false;
426 kdf(first_dst: chaining_key, second_dst: key, NULL, data: precomputed, first_len: NOISE_HASH_LEN,
427 second_len: NOISE_SYMMETRIC_KEY_LEN, third_len: 0, data_len: NOISE_PUBLIC_KEY_LEN,
428 chaining_key);
429 return true;
430}
431
432static void mix_hash(u8 hash[NOISE_HASH_LEN], const u8 *src, size_t src_len)
433{
434 struct blake2s_state blake;
435
436 blake2s_init(state: &blake, outlen: NOISE_HASH_LEN);
437 blake2s_update(state: &blake, in: hash, inlen: NOISE_HASH_LEN);
438 blake2s_update(state: &blake, in: src, inlen: src_len);
439 blake2s_final(state: &blake, out: hash);
440}
441
442static void mix_psk(u8 chaining_key[NOISE_HASH_LEN], u8 hash[NOISE_HASH_LEN],
443 u8 key[NOISE_SYMMETRIC_KEY_LEN],
444 const u8 psk[NOISE_SYMMETRIC_KEY_LEN])
445{
446 u8 temp_hash[NOISE_HASH_LEN];
447
448 kdf(first_dst: chaining_key, second_dst: temp_hash, third_dst: key, data: psk, first_len: NOISE_HASH_LEN, second_len: NOISE_HASH_LEN,
449 third_len: NOISE_SYMMETRIC_KEY_LEN, data_len: NOISE_SYMMETRIC_KEY_LEN, chaining_key);
450 mix_hash(hash, src: temp_hash, src_len: NOISE_HASH_LEN);
451 memzero_explicit(s: temp_hash, count: NOISE_HASH_LEN);
452}
453
454static void handshake_init(u8 chaining_key[NOISE_HASH_LEN],
455 u8 hash[NOISE_HASH_LEN],
456 const u8 remote_static[NOISE_PUBLIC_KEY_LEN])
457{
458 memcpy(hash, handshake_init_hash, NOISE_HASH_LEN);
459 memcpy(chaining_key, handshake_init_chaining_key, NOISE_HASH_LEN);
460 mix_hash(hash, src: remote_static, src_len: NOISE_PUBLIC_KEY_LEN);
461}
462
463static void message_encrypt(u8 *dst_ciphertext, const u8 *src_plaintext,
464 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
465 u8 hash[NOISE_HASH_LEN])
466{
467 chacha20poly1305_encrypt(dst: dst_ciphertext, src: src_plaintext, src_len, ad: hash,
468 ad_len: NOISE_HASH_LEN,
469 nonce: 0 /* Always zero for Noise_IK */, key);
470 mix_hash(hash, src: dst_ciphertext, noise_encrypted_len(src_len));
471}
472
473static bool message_decrypt(u8 *dst_plaintext, const u8 *src_ciphertext,
474 size_t src_len, u8 key[NOISE_SYMMETRIC_KEY_LEN],
475 u8 hash[NOISE_HASH_LEN])
476{
477 if (!chacha20poly1305_decrypt(dst: dst_plaintext, src: src_ciphertext, src_len,
478 ad: hash, ad_len: NOISE_HASH_LEN,
479 nonce: 0 /* Always zero for Noise_IK */, key))
480 return false;
481 mix_hash(hash, src: src_ciphertext, src_len);
482 return true;
483}
484
485static void message_ephemeral(u8 ephemeral_dst[NOISE_PUBLIC_KEY_LEN],
486 const u8 ephemeral_src[NOISE_PUBLIC_KEY_LEN],
487 u8 chaining_key[NOISE_HASH_LEN],
488 u8 hash[NOISE_HASH_LEN])
489{
490 if (ephemeral_dst != ephemeral_src)
491 memcpy(ephemeral_dst, ephemeral_src, NOISE_PUBLIC_KEY_LEN);
492 mix_hash(hash, src: ephemeral_src, src_len: NOISE_PUBLIC_KEY_LEN);
493 kdf(first_dst: chaining_key, NULL, NULL, data: ephemeral_src, first_len: NOISE_HASH_LEN, second_len: 0, third_len: 0,
494 data_len: NOISE_PUBLIC_KEY_LEN, chaining_key);
495}
496
497static void tai64n_now(u8 output[NOISE_TIMESTAMP_LEN])
498{
499 struct timespec64 now;
500
501 ktime_get_real_ts64(tv: &now);
502
503 /* In order to prevent some sort of infoleak from precise timers, we
504 * round down the nanoseconds part to the closest rounded-down power of
505 * two to the maximum initiations per second allowed anyway by the
506 * implementation.
507 */
508 now.tv_nsec = ALIGN_DOWN(now.tv_nsec,
509 rounddown_pow_of_two(NSEC_PER_SEC / INITIATIONS_PER_SECOND));
510
511 /* https://cr.yp.to/libtai/tai64.html */
512 *(__be64 *)output = cpu_to_be64(0x400000000000000aULL + now.tv_sec);
513 *(__be32 *)(output + sizeof(__be64)) = cpu_to_be32(now.tv_nsec);
514}
515
516bool
517wg_noise_handshake_create_initiation(struct message_handshake_initiation *dst,
518 struct noise_handshake *handshake)
519{
520 u8 timestamp[NOISE_TIMESTAMP_LEN];
521 u8 key[NOISE_SYMMETRIC_KEY_LEN];
522 bool ret = false;
523
524 /* We need to wait for crng _before_ taking any locks, since
525 * curve25519_generate_secret uses get_random_bytes_wait.
526 */
527 wait_for_random_bytes();
528
529 down_read(sem: &handshake->static_identity->lock);
530 down_write(sem: &handshake->lock);
531
532 if (unlikely(!handshake->static_identity->has_identity))
533 goto out;
534
535 dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_INITIATION);
536
537 handshake_init(chaining_key: handshake->chaining_key, hash: handshake->hash,
538 remote_static: handshake->remote_static);
539
540 /* e */
541 curve25519_generate_secret(secret: handshake->ephemeral_private);
542 if (!curve25519_generate_public(pub: dst->unencrypted_ephemeral,
543 secret: handshake->ephemeral_private))
544 goto out;
545 message_ephemeral(ephemeral_dst: dst->unencrypted_ephemeral,
546 ephemeral_src: dst->unencrypted_ephemeral, chaining_key: handshake->chaining_key,
547 hash: handshake->hash);
548
549 /* es */
550 if (!mix_dh(chaining_key: handshake->chaining_key, key, private: handshake->ephemeral_private,
551 public: handshake->remote_static))
552 goto out;
553
554 /* s */
555 message_encrypt(dst_ciphertext: dst->encrypted_static,
556 src_plaintext: handshake->static_identity->static_public,
557 src_len: NOISE_PUBLIC_KEY_LEN, key, hash: handshake->hash);
558
559 /* ss */
560 if (!mix_precomputed_dh(chaining_key: handshake->chaining_key, key,
561 precomputed: handshake->precomputed_static_static))
562 goto out;
563
564 /* {t} */
565 tai64n_now(output: timestamp);
566 message_encrypt(dst_ciphertext: dst->encrypted_timestamp, src_plaintext: timestamp,
567 src_len: NOISE_TIMESTAMP_LEN, key, hash: handshake->hash);
568
569 dst->sender_index = wg_index_hashtable_insert(
570 table: handshake->entry.peer->device->index_hashtable,
571 entry: &handshake->entry);
572
573 handshake->state = HANDSHAKE_CREATED_INITIATION;
574 ret = true;
575
576out:
577 up_write(sem: &handshake->lock);
578 up_read(sem: &handshake->static_identity->lock);
579 memzero_explicit(s: key, count: NOISE_SYMMETRIC_KEY_LEN);
580 return ret;
581}
582
583struct wg_peer *
584wg_noise_handshake_consume_initiation(struct message_handshake_initiation *src,
585 struct wg_device *wg)
586{
587 struct wg_peer *peer = NULL, *ret_peer = NULL;
588 struct noise_handshake *handshake;
589 bool replay_attack, flood_attack;
590 u8 key[NOISE_SYMMETRIC_KEY_LEN];
591 u8 chaining_key[NOISE_HASH_LEN];
592 u8 hash[NOISE_HASH_LEN];
593 u8 s[NOISE_PUBLIC_KEY_LEN];
594 u8 e[NOISE_PUBLIC_KEY_LEN];
595 u8 t[NOISE_TIMESTAMP_LEN];
596 u64 initiation_consumption;
597
598 down_read(sem: &wg->static_identity.lock);
599 if (unlikely(!wg->static_identity.has_identity))
600 goto out;
601
602 handshake_init(chaining_key, hash, remote_static: wg->static_identity.static_public);
603
604 /* e */
605 message_ephemeral(ephemeral_dst: e, ephemeral_src: src->unencrypted_ephemeral, chaining_key, hash);
606
607 /* es */
608 if (!mix_dh(chaining_key, key, private: wg->static_identity.static_private, public: e))
609 goto out;
610
611 /* s */
612 if (!message_decrypt(dst_plaintext: s, src_ciphertext: src->encrypted_static,
613 src_len: sizeof(src->encrypted_static), key, hash))
614 goto out;
615
616 /* Lookup which peer we're actually talking to */
617 peer = wg_pubkey_hashtable_lookup(table: wg->peer_hashtable, pubkey: s);
618 if (!peer)
619 goto out;
620 handshake = &peer->handshake;
621
622 /* ss */
623 if (!mix_precomputed_dh(chaining_key, key,
624 precomputed: handshake->precomputed_static_static))
625 goto out;
626
627 /* {t} */
628 if (!message_decrypt(dst_plaintext: t, src_ciphertext: src->encrypted_timestamp,
629 src_len: sizeof(src->encrypted_timestamp), key, hash))
630 goto out;
631
632 down_read(sem: &handshake->lock);
633 replay_attack = memcmp(p: t, q: handshake->latest_timestamp,
634 size: NOISE_TIMESTAMP_LEN) <= 0;
635 flood_attack = (s64)handshake->last_initiation_consumption +
636 NSEC_PER_SEC / INITIATIONS_PER_SECOND >
637 (s64)ktime_get_coarse_boottime_ns();
638 up_read(sem: &handshake->lock);
639 if (replay_attack || flood_attack)
640 goto out;
641
642 /* Success! Copy everything to peer */
643 down_write(sem: &handshake->lock);
644 memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
645 if (memcmp(p: t, q: handshake->latest_timestamp, size: NOISE_TIMESTAMP_LEN) > 0)
646 memcpy(handshake->latest_timestamp, t, NOISE_TIMESTAMP_LEN);
647 memcpy(handshake->hash, hash, NOISE_HASH_LEN);
648 memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
649 handshake->remote_index = src->sender_index;
650 initiation_consumption = ktime_get_coarse_boottime_ns();
651 if ((s64)(handshake->last_initiation_consumption - initiation_consumption) < 0)
652 handshake->last_initiation_consumption = initiation_consumption;
653 handshake->state = HANDSHAKE_CONSUMED_INITIATION;
654 up_write(sem: &handshake->lock);
655 ret_peer = peer;
656
657out:
658 memzero_explicit(s: key, count: NOISE_SYMMETRIC_KEY_LEN);
659 memzero_explicit(s: hash, count: NOISE_HASH_LEN);
660 memzero_explicit(s: chaining_key, count: NOISE_HASH_LEN);
661 up_read(sem: &wg->static_identity.lock);
662 if (!ret_peer)
663 wg_peer_put(peer);
664 return ret_peer;
665}
666
667bool wg_noise_handshake_create_response(struct message_handshake_response *dst,
668 struct noise_handshake *handshake)
669{
670 u8 key[NOISE_SYMMETRIC_KEY_LEN];
671 bool ret = false;
672
673 /* We need to wait for crng _before_ taking any locks, since
674 * curve25519_generate_secret uses get_random_bytes_wait.
675 */
676 wait_for_random_bytes();
677
678 down_read(sem: &handshake->static_identity->lock);
679 down_write(sem: &handshake->lock);
680
681 if (handshake->state != HANDSHAKE_CONSUMED_INITIATION)
682 goto out;
683
684 dst->header.type = cpu_to_le32(MESSAGE_HANDSHAKE_RESPONSE);
685 dst->receiver_index = handshake->remote_index;
686
687 /* e */
688 curve25519_generate_secret(secret: handshake->ephemeral_private);
689 if (!curve25519_generate_public(pub: dst->unencrypted_ephemeral,
690 secret: handshake->ephemeral_private))
691 goto out;
692 message_ephemeral(ephemeral_dst: dst->unencrypted_ephemeral,
693 ephemeral_src: dst->unencrypted_ephemeral, chaining_key: handshake->chaining_key,
694 hash: handshake->hash);
695
696 /* ee */
697 if (!mix_dh(chaining_key: handshake->chaining_key, NULL, private: handshake->ephemeral_private,
698 public: handshake->remote_ephemeral))
699 goto out;
700
701 /* se */
702 if (!mix_dh(chaining_key: handshake->chaining_key, NULL, private: handshake->ephemeral_private,
703 public: handshake->remote_static))
704 goto out;
705
706 /* psk */
707 mix_psk(chaining_key: handshake->chaining_key, hash: handshake->hash, key,
708 psk: handshake->preshared_key);
709
710 /* {} */
711 message_encrypt(dst_ciphertext: dst->encrypted_nothing, NULL, src_len: 0, key, hash: handshake->hash);
712
713 dst->sender_index = wg_index_hashtable_insert(
714 table: handshake->entry.peer->device->index_hashtable,
715 entry: &handshake->entry);
716
717 handshake->state = HANDSHAKE_CREATED_RESPONSE;
718 ret = true;
719
720out:
721 up_write(sem: &handshake->lock);
722 up_read(sem: &handshake->static_identity->lock);
723 memzero_explicit(s: key, count: NOISE_SYMMETRIC_KEY_LEN);
724 return ret;
725}
726
727struct wg_peer *
728wg_noise_handshake_consume_response(struct message_handshake_response *src,
729 struct wg_device *wg)
730{
731 enum noise_handshake_state state = HANDSHAKE_ZEROED;
732 struct wg_peer *peer = NULL, *ret_peer = NULL;
733 struct noise_handshake *handshake;
734 u8 key[NOISE_SYMMETRIC_KEY_LEN];
735 u8 hash[NOISE_HASH_LEN];
736 u8 chaining_key[NOISE_HASH_LEN];
737 u8 e[NOISE_PUBLIC_KEY_LEN];
738 u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
739 u8 static_private[NOISE_PUBLIC_KEY_LEN];
740 u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN];
741
742 down_read(sem: &wg->static_identity.lock);
743
744 if (unlikely(!wg->static_identity.has_identity))
745 goto out;
746
747 handshake = (struct noise_handshake *)wg_index_hashtable_lookup(
748 table: wg->index_hashtable, type_mask: INDEX_HASHTABLE_HANDSHAKE,
749 index: src->receiver_index, peer: &peer);
750 if (unlikely(!handshake))
751 goto out;
752
753 down_read(sem: &handshake->lock);
754 state = handshake->state;
755 memcpy(hash, handshake->hash, NOISE_HASH_LEN);
756 memcpy(chaining_key, handshake->chaining_key, NOISE_HASH_LEN);
757 memcpy(ephemeral_private, handshake->ephemeral_private,
758 NOISE_PUBLIC_KEY_LEN);
759 memcpy(preshared_key, handshake->preshared_key,
760 NOISE_SYMMETRIC_KEY_LEN);
761 up_read(sem: &handshake->lock);
762
763 if (state != HANDSHAKE_CREATED_INITIATION)
764 goto fail;
765
766 /* e */
767 message_ephemeral(ephemeral_dst: e, ephemeral_src: src->unencrypted_ephemeral, chaining_key, hash);
768
769 /* ee */
770 if (!mix_dh(chaining_key, NULL, private: ephemeral_private, public: e))
771 goto fail;
772
773 /* se */
774 if (!mix_dh(chaining_key, NULL, private: wg->static_identity.static_private, public: e))
775 goto fail;
776
777 /* psk */
778 mix_psk(chaining_key, hash, key, psk: preshared_key);
779
780 /* {} */
781 if (!message_decrypt(NULL, src_ciphertext: src->encrypted_nothing,
782 src_len: sizeof(src->encrypted_nothing), key, hash))
783 goto fail;
784
785 /* Success! Copy everything to peer */
786 down_write(sem: &handshake->lock);
787 /* It's important to check that the state is still the same, while we
788 * have an exclusive lock.
789 */
790 if (handshake->state != state) {
791 up_write(sem: &handshake->lock);
792 goto fail;
793 }
794 memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
795 memcpy(handshake->hash, hash, NOISE_HASH_LEN);
796 memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
797 handshake->remote_index = src->sender_index;
798 handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
799 up_write(sem: &handshake->lock);
800 ret_peer = peer;
801 goto out;
802
803fail:
804 wg_peer_put(peer);
805out:
806 memzero_explicit(s: key, count: NOISE_SYMMETRIC_KEY_LEN);
807 memzero_explicit(s: hash, count: NOISE_HASH_LEN);
808 memzero_explicit(s: chaining_key, count: NOISE_HASH_LEN);
809 memzero_explicit(s: ephemeral_private, count: NOISE_PUBLIC_KEY_LEN);
810 memzero_explicit(s: static_private, count: NOISE_PUBLIC_KEY_LEN);
811 memzero_explicit(s: preshared_key, count: NOISE_SYMMETRIC_KEY_LEN);
812 up_read(sem: &wg->static_identity.lock);
813 return ret_peer;
814}
815
816bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
817 struct noise_keypairs *keypairs)
818{
819 struct noise_keypair *new_keypair;
820 bool ret = false;
821
822 down_write(sem: &handshake->lock);
823 if (handshake->state != HANDSHAKE_CREATED_RESPONSE &&
824 handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
825 goto out;
826
827 new_keypair = keypair_create(peer: handshake->entry.peer);
828 if (!new_keypair)
829 goto out;
830 new_keypair->i_am_the_initiator = handshake->state ==
831 HANDSHAKE_CONSUMED_RESPONSE;
832 new_keypair->remote_index = handshake->remote_index;
833
834 if (new_keypair->i_am_the_initiator)
835 derive_keys(first_dst: &new_keypair->sending, second_dst: &new_keypair->receiving,
836 chaining_key: handshake->chaining_key);
837 else
838 derive_keys(first_dst: &new_keypair->receiving, second_dst: &new_keypair->sending,
839 chaining_key: handshake->chaining_key);
840
841 handshake_zero(handshake);
842 rcu_read_lock_bh();
843 if (likely(!READ_ONCE(container_of(handshake, struct wg_peer,
844 handshake)->is_dead))) {
845 add_new_keypair(keypairs, new_keypair);
846 net_dbg_ratelimited("%s: Keypair %llu created for peer %llu\n",
847 handshake->entry.peer->device->dev->name,
848 new_keypair->internal_id,
849 handshake->entry.peer->internal_id);
850 ret = wg_index_hashtable_replace(
851 table: handshake->entry.peer->device->index_hashtable,
852 old: &handshake->entry, new: &new_keypair->entry);
853 } else {
854 kfree_sensitive(objp: new_keypair);
855 }
856 rcu_read_unlock_bh();
857
858out:
859 up_write(sem: &handshake->lock);
860 return ret;
861}
862

source code of linux/drivers/net/wireguard/noise.c