1use crate::crypto::hash;
2use crate::msgs::codec::Codec;
3use crate::msgs::enums::HashAlgorithm;
4use crate::msgs::handshake::HandshakeMessagePayload;
5use crate::msgs::message::{Message, MessagePayload};
6
7use alloc::boxed::Box;
8use alloc::vec::Vec;
9use core::mem;
10
11/// Early stage buffering of handshake payloads.
12///
13/// Before we know the hash algorithm to use to verify the handshake, we just buffer the messages.
14/// During the handshake, we may restart the transcript due to a HelloRetryRequest, reverting
15/// from the `HandshakeHash` to a `HandshakeHashBuffer` again.
16pub(crate) struct HandshakeHashBuffer {
17 buffer: Vec<u8>,
18 client_auth_enabled: bool,
19}
20
21impl HandshakeHashBuffer {
22 pub(crate) fn new() -> Self {
23 Self {
24 buffer: Vec::new(),
25 client_auth_enabled: false,
26 }
27 }
28
29 /// We might be doing client auth, so need to keep a full
30 /// log of the handshake.
31 pub(crate) fn set_client_auth_enabled(&mut self) {
32 self.client_auth_enabled = true;
33 }
34
35 /// Hash/buffer a handshake message.
36 pub(crate) fn add_message(&mut self, m: &Message) {
37 if let MessagePayload::Handshake { encoded, .. } = &m.payload {
38 self.buffer
39 .extend_from_slice(&encoded.0);
40 }
41 }
42
43 /// Hash or buffer a byte slice.
44 #[cfg(all(test, any(feature = "ring", feature = "aws_lc_rs")))]
45 fn update_raw(&mut self, buf: &[u8]) {
46 self.buffer.extend_from_slice(buf);
47 }
48
49 /// Get the hash value if we were to hash `extra` too.
50 pub(crate) fn get_hash_given(
51 &self,
52 provider: &'static dyn hash::Hash,
53 extra: &[u8],
54 ) -> hash::Output {
55 let mut ctx = provider.start();
56 ctx.update(&self.buffer);
57 ctx.update(extra);
58 ctx.finish()
59 }
60
61 /// We now know what hash function the verify_data will use.
62 pub(crate) fn start_hash(self, provider: &'static dyn hash::Hash) -> HandshakeHash {
63 let mut ctx = provider.start();
64 ctx.update(&self.buffer);
65 HandshakeHash {
66 provider,
67 ctx,
68 client_auth: match self.client_auth_enabled {
69 true => Some(self.buffer),
70 false => None,
71 },
72 }
73 }
74}
75
76/// This deals with keeping a running hash of the handshake
77/// payloads. This is computed by buffering initially. Once
78/// we know what hash function we need to use we switch to
79/// incremental hashing.
80///
81/// For client auth, we also need to buffer all the messages.
82/// This is disabled in cases where client auth is not possible.
83pub(crate) struct HandshakeHash {
84 provider: &'static dyn hash::Hash,
85 ctx: Box<dyn hash::Context>,
86
87 /// buffer for client-auth.
88 client_auth: Option<Vec<u8>>,
89}
90
91impl HandshakeHash {
92 /// We decided not to do client auth after all, so discard
93 /// the transcript.
94 pub(crate) fn abandon_client_auth(&mut self) {
95 self.client_auth = None;
96 }
97
98 /// Hash/buffer a handshake message.
99 pub(crate) fn add_message(&mut self, m: &Message) -> &mut Self {
100 if let MessagePayload::Handshake { encoded, .. } = &m.payload {
101 self.update_raw(&encoded.0);
102 }
103 self
104 }
105
106 /// Hash or buffer a byte slice.
107 fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
108 self.ctx.update(buf);
109
110 if let Some(buffer) = &mut self.client_auth {
111 buffer.extend_from_slice(buf);
112 }
113
114 self
115 }
116
117 /// Get the hash value if we were to hash `extra` too,
118 /// using hash function `hash`.
119 pub(crate) fn get_hash_given(&self, extra: &[u8]) -> hash::Output {
120 let mut ctx = self.ctx.fork();
121 ctx.update(extra);
122 ctx.finish()
123 }
124
125 pub(crate) fn into_hrr_buffer(self) -> HandshakeHashBuffer {
126 let old_hash = self.ctx.finish();
127 let old_handshake_hash_msg =
128 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
129
130 HandshakeHashBuffer {
131 client_auth_enabled: self.client_auth.is_some(),
132 buffer: old_handshake_hash_msg.get_encoding(),
133 }
134 }
135
136 /// Take the current hash value, and encapsulate it in a
137 /// 'handshake_hash' handshake message. Start this hash
138 /// again, with that message at the front.
139 pub(crate) fn rollup_for_hrr(&mut self) {
140 let ctx = &mut self.ctx;
141
142 let old_ctx = mem::replace(ctx, self.provider.start());
143 let old_hash = old_ctx.finish();
144 let old_handshake_hash_msg =
145 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
146
147 self.update_raw(&old_handshake_hash_msg.get_encoding());
148 }
149
150 /// Get the current hash value.
151 pub(crate) fn get_current_hash(&self) -> hash::Output {
152 self.ctx.fork_finish()
153 }
154
155 /// Takes this object's buffer containing all handshake messages
156 /// so far. This method only works once; it resets the buffer
157 /// to empty.
158 #[cfg(feature = "tls12")]
159 pub(crate) fn take_handshake_buf(&mut self) -> Option<Vec<u8>> {
160 self.client_auth.take()
161 }
162
163 /// The hashing algorithm
164 pub(crate) fn algorithm(&self) -> HashAlgorithm {
165 self.provider.algorithm()
166 }
167}
168
169#[cfg(all(test, any(feature = "ring", feature = "aws_lc_rs")))]
170mod tests {
171 use super::HandshakeHashBuffer;
172 use crate::test_provider::hash::SHA256;
173
174 #[test]
175 fn hashes_correctly() {
176 let mut hhb = HandshakeHashBuffer::new();
177 hhb.update_raw(b"hello");
178 assert_eq!(hhb.buffer.len(), 5);
179 let mut hh = hhb.start_hash(&SHA256);
180 assert!(hh.client_auth.is_none());
181 hh.update_raw(b"world");
182 let h = hh.get_current_hash();
183 let h = h.as_ref();
184 assert_eq!(h[0], 0x93);
185 assert_eq!(h[1], 0x6a);
186 assert_eq!(h[2], 0x18);
187 assert_eq!(h[3], 0x5c);
188 }
189
190 #[cfg(feature = "tls12")]
191 #[test]
192 fn buffers_correctly() {
193 let mut hhb = HandshakeHashBuffer::new();
194 hhb.set_client_auth_enabled();
195 hhb.update_raw(b"hello");
196 assert_eq!(hhb.buffer.len(), 5);
197 let mut hh = hhb.start_hash(&SHA256);
198 assert_eq!(
199 hh.client_auth
200 .as_ref()
201 .map(|buf| buf.len()),
202 Some(5)
203 );
204 hh.update_raw(b"world");
205 assert_eq!(
206 hh.client_auth
207 .as_ref()
208 .map(|buf| buf.len()),
209 Some(10)
210 );
211 let h = hh.get_current_hash();
212 let h = h.as_ref();
213 assert_eq!(h[0], 0x93);
214 assert_eq!(h[1], 0x6a);
215 assert_eq!(h[2], 0x18);
216 assert_eq!(h[3], 0x5c);
217 let buf = hh.take_handshake_buf();
218 assert_eq!(Some(b"helloworld".to_vec()), buf);
219 }
220
221 #[test]
222 fn abandon() {
223 let mut hhb = HandshakeHashBuffer::new();
224 hhb.set_client_auth_enabled();
225 hhb.update_raw(b"hello");
226 assert_eq!(hhb.buffer.len(), 5);
227 let mut hh = hhb.start_hash(&SHA256);
228 assert_eq!(
229 hh.client_auth
230 .as_ref()
231 .map(|buf| buf.len()),
232 Some(5)
233 );
234 hh.abandon_client_auth();
235 assert_eq!(hh.client_auth, None);
236 hh.update_raw(b"world");
237 assert_eq!(hh.client_auth, None);
238 let h = hh.get_current_hash();
239 let h = h.as_ref();
240 assert_eq!(h[0], 0x93);
241 assert_eq!(h[1], 0x6a);
242 assert_eq!(h[2], 0x18);
243 assert_eq!(h[3], 0x5c);
244 }
245}
246