1 | use crate::crypto::hash; |
2 | use crate::msgs::codec::Codec; |
3 | use crate::msgs::enums::HashAlgorithm; |
4 | use crate::msgs::handshake::HandshakeMessagePayload; |
5 | use crate::msgs::message::{Message, MessagePayload}; |
6 | |
7 | use alloc::boxed::Box; |
8 | use alloc::vec::Vec; |
9 | use core::mem; |
10 | |
11 | /// Early stage buffering of handshake payloads. |
12 | /// |
13 | /// Before we know the hash algorithm to use to verify the handshake, we just buffer the messages. |
14 | /// During the handshake, we may restart the transcript due to a HelloRetryRequest, reverting |
15 | /// from the `HandshakeHash` to a `HandshakeHashBuffer` again. |
16 | pub(crate) struct HandshakeHashBuffer { |
17 | buffer: Vec<u8>, |
18 | client_auth_enabled: bool, |
19 | } |
20 | |
21 | impl HandshakeHashBuffer { |
22 | pub(crate) fn new() -> Self { |
23 | Self { |
24 | buffer: Vec::new(), |
25 | client_auth_enabled: false, |
26 | } |
27 | } |
28 | |
29 | /// We might be doing client auth, so need to keep a full |
30 | /// log of the handshake. |
31 | pub(crate) fn set_client_auth_enabled(&mut self) { |
32 | self.client_auth_enabled = true; |
33 | } |
34 | |
35 | /// Hash/buffer a handshake message. |
36 | pub(crate) fn add_message(&mut self, m: &Message) { |
37 | if let MessagePayload::Handshake { encoded, .. } = &m.payload { |
38 | self.buffer |
39 | .extend_from_slice(&encoded.0); |
40 | } |
41 | } |
42 | |
43 | /// Hash or buffer a byte slice. |
44 | #[cfg (all(test, any(feature = "ring" , feature = "aws_lc_rs" )))] |
45 | fn update_raw(&mut self, buf: &[u8]) { |
46 | self.buffer.extend_from_slice(buf); |
47 | } |
48 | |
49 | /// Get the hash value if we were to hash `extra` too. |
50 | pub(crate) fn get_hash_given( |
51 | &self, |
52 | provider: &'static dyn hash::Hash, |
53 | extra: &[u8], |
54 | ) -> hash::Output { |
55 | let mut ctx = provider.start(); |
56 | ctx.update(&self.buffer); |
57 | ctx.update(extra); |
58 | ctx.finish() |
59 | } |
60 | |
61 | /// We now know what hash function the verify_data will use. |
62 | pub(crate) fn start_hash(self, provider: &'static dyn hash::Hash) -> HandshakeHash { |
63 | let mut ctx = provider.start(); |
64 | ctx.update(&self.buffer); |
65 | HandshakeHash { |
66 | provider, |
67 | ctx, |
68 | client_auth: match self.client_auth_enabled { |
69 | true => Some(self.buffer), |
70 | false => None, |
71 | }, |
72 | } |
73 | } |
74 | } |
75 | |
76 | /// This deals with keeping a running hash of the handshake |
77 | /// payloads. This is computed by buffering initially. Once |
78 | /// we know what hash function we need to use we switch to |
79 | /// incremental hashing. |
80 | /// |
81 | /// For client auth, we also need to buffer all the messages. |
82 | /// This is disabled in cases where client auth is not possible. |
83 | pub(crate) struct HandshakeHash { |
84 | provider: &'static dyn hash::Hash, |
85 | ctx: Box<dyn hash::Context>, |
86 | |
87 | /// buffer for client-auth. |
88 | client_auth: Option<Vec<u8>>, |
89 | } |
90 | |
91 | impl HandshakeHash { |
92 | /// We decided not to do client auth after all, so discard |
93 | /// the transcript. |
94 | pub(crate) fn abandon_client_auth(&mut self) { |
95 | self.client_auth = None; |
96 | } |
97 | |
98 | /// Hash/buffer a handshake message. |
99 | pub(crate) fn add_message(&mut self, m: &Message) -> &mut Self { |
100 | if let MessagePayload::Handshake { encoded, .. } = &m.payload { |
101 | self.update_raw(&encoded.0); |
102 | } |
103 | self |
104 | } |
105 | |
106 | /// Hash or buffer a byte slice. |
107 | fn update_raw(&mut self, buf: &[u8]) -> &mut Self { |
108 | self.ctx.update(buf); |
109 | |
110 | if let Some(buffer) = &mut self.client_auth { |
111 | buffer.extend_from_slice(buf); |
112 | } |
113 | |
114 | self |
115 | } |
116 | |
117 | /// Get the hash value if we were to hash `extra` too, |
118 | /// using hash function `hash`. |
119 | pub(crate) fn get_hash_given(&self, extra: &[u8]) -> hash::Output { |
120 | let mut ctx = self.ctx.fork(); |
121 | ctx.update(extra); |
122 | ctx.finish() |
123 | } |
124 | |
125 | pub(crate) fn into_hrr_buffer(self) -> HandshakeHashBuffer { |
126 | let old_hash = self.ctx.finish(); |
127 | let old_handshake_hash_msg = |
128 | HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref()); |
129 | |
130 | HandshakeHashBuffer { |
131 | client_auth_enabled: self.client_auth.is_some(), |
132 | buffer: old_handshake_hash_msg.get_encoding(), |
133 | } |
134 | } |
135 | |
136 | /// Take the current hash value, and encapsulate it in a |
137 | /// 'handshake_hash' handshake message. Start this hash |
138 | /// again, with that message at the front. |
139 | pub(crate) fn rollup_for_hrr(&mut self) { |
140 | let ctx = &mut self.ctx; |
141 | |
142 | let old_ctx = mem::replace(ctx, self.provider.start()); |
143 | let old_hash = old_ctx.finish(); |
144 | let old_handshake_hash_msg = |
145 | HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref()); |
146 | |
147 | self.update_raw(&old_handshake_hash_msg.get_encoding()); |
148 | } |
149 | |
150 | /// Get the current hash value. |
151 | pub(crate) fn get_current_hash(&self) -> hash::Output { |
152 | self.ctx.fork_finish() |
153 | } |
154 | |
155 | /// Takes this object's buffer containing all handshake messages |
156 | /// so far. This method only works once; it resets the buffer |
157 | /// to empty. |
158 | #[cfg (feature = "tls12" )] |
159 | pub(crate) fn take_handshake_buf(&mut self) -> Option<Vec<u8>> { |
160 | self.client_auth.take() |
161 | } |
162 | |
163 | /// The hashing algorithm |
164 | pub(crate) fn algorithm(&self) -> HashAlgorithm { |
165 | self.provider.algorithm() |
166 | } |
167 | } |
168 | |
169 | #[cfg (all(test, any(feature = "ring" , feature = "aws_lc_rs" )))] |
170 | mod tests { |
171 | use super::HandshakeHashBuffer; |
172 | use crate::test_provider::hash::SHA256; |
173 | |
174 | #[test ] |
175 | fn hashes_correctly() { |
176 | let mut hhb = HandshakeHashBuffer::new(); |
177 | hhb.update_raw(b"hello" ); |
178 | assert_eq!(hhb.buffer.len(), 5); |
179 | let mut hh = hhb.start_hash(&SHA256); |
180 | assert!(hh.client_auth.is_none()); |
181 | hh.update_raw(b"world" ); |
182 | let h = hh.get_current_hash(); |
183 | let h = h.as_ref(); |
184 | assert_eq!(h[0], 0x93); |
185 | assert_eq!(h[1], 0x6a); |
186 | assert_eq!(h[2], 0x18); |
187 | assert_eq!(h[3], 0x5c); |
188 | } |
189 | |
190 | #[cfg (feature = "tls12" )] |
191 | #[test ] |
192 | fn buffers_correctly() { |
193 | let mut hhb = HandshakeHashBuffer::new(); |
194 | hhb.set_client_auth_enabled(); |
195 | hhb.update_raw(b"hello" ); |
196 | assert_eq!(hhb.buffer.len(), 5); |
197 | let mut hh = hhb.start_hash(&SHA256); |
198 | assert_eq!( |
199 | hh.client_auth |
200 | .as_ref() |
201 | .map(|buf| buf.len()), |
202 | Some(5) |
203 | ); |
204 | hh.update_raw(b"world" ); |
205 | assert_eq!( |
206 | hh.client_auth |
207 | .as_ref() |
208 | .map(|buf| buf.len()), |
209 | Some(10) |
210 | ); |
211 | let h = hh.get_current_hash(); |
212 | let h = h.as_ref(); |
213 | assert_eq!(h[0], 0x93); |
214 | assert_eq!(h[1], 0x6a); |
215 | assert_eq!(h[2], 0x18); |
216 | assert_eq!(h[3], 0x5c); |
217 | let buf = hh.take_handshake_buf(); |
218 | assert_eq!(Some(b"helloworld" .to_vec()), buf); |
219 | } |
220 | |
221 | #[test ] |
222 | fn abandon() { |
223 | let mut hhb = HandshakeHashBuffer::new(); |
224 | hhb.set_client_auth_enabled(); |
225 | hhb.update_raw(b"hello" ); |
226 | assert_eq!(hhb.buffer.len(), 5); |
227 | let mut hh = hhb.start_hash(&SHA256); |
228 | assert_eq!( |
229 | hh.client_auth |
230 | .as_ref() |
231 | .map(|buf| buf.len()), |
232 | Some(5) |
233 | ); |
234 | hh.abandon_client_auth(); |
235 | assert_eq!(hh.client_auth, None); |
236 | hh.update_raw(b"world" ); |
237 | assert_eq!(hh.client_auth, None); |
238 | let h = hh.get_current_hash(); |
239 | let h = h.as_ref(); |
240 | assert_eq!(h[0], 0x93); |
241 | assert_eq!(h[1], 0x6a); |
242 | assert_eq!(h[2], 0x18); |
243 | assert_eq!(h[3], 0x5c); |
244 | } |
245 | } |
246 | |