1 | use serde::{Deserialize, Serialize}; |
2 | use static_assertions::assert_impl_all; |
3 | use std::num::NonZeroU32; |
4 | use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName}; |
5 | use zvariant::{ObjectPath, Signature, Type}; |
6 | |
7 | use crate::{ |
8 | message::{Field, FieldCode, Header, Message}, |
9 | Result, |
10 | }; |
11 | |
12 | // It's actually 10 (and even not that) but let's round it to next 8-byte alignment |
13 | const MAX_FIELDS_IN_MESSAGE: usize = 16; |
14 | |
15 | /// A collection of [`Field`] instances. |
16 | /// |
17 | /// [`Field`]: enum.Field.html |
18 | #[derive (Debug, Clone, Serialize, Deserialize, Type)] |
19 | pub(crate) struct Fields<'m>(#[serde(borrow)] Vec<Field<'m>>); |
20 | |
21 | assert_impl_all!(Fields<'_>: Send, Sync, Unpin); |
22 | |
23 | impl<'m> Fields<'m> { |
24 | /// Creates an empty collection of fields. |
25 | pub fn new() -> Self { |
26 | Self::default() |
27 | } |
28 | |
29 | /// Appends a [`Field`] to the collection of fields in the message. |
30 | /// |
31 | /// [`Field`]: enum.Field.html |
32 | pub fn add<'f: 'm>(&mut self, field: Field<'f>) { |
33 | self.0.push(field); |
34 | } |
35 | |
36 | /// Replaces a [`Field`] from the collection of fields with one with the same code, |
37 | /// returning the old value if present. |
38 | /// |
39 | /// [`Field`]: enum.Field.html |
40 | pub fn replace<'f: 'm>(&mut self, field: Field<'f>) -> Option<Field<'m>> { |
41 | let code = field.code(); |
42 | if let Some(found) = self.0.iter_mut().find(|f| f.code() == code) { |
43 | return Some(std::mem::replace(found, field)); |
44 | } |
45 | self.add(field); |
46 | None |
47 | } |
48 | |
49 | /// Returns a slice with all the [`Field`] in the message. |
50 | /// |
51 | /// [`Field`]: enum.Field.html |
52 | pub fn get(&self) -> &[Field<'m>] { |
53 | &self.0 |
54 | } |
55 | |
56 | /// Gets a reference to a specific [`Field`] by its code. |
57 | /// |
58 | /// Returns `None` if the message has no such field. |
59 | /// |
60 | /// [`Field`]: enum.Field.html |
61 | pub fn get_field(&self, code: FieldCode) -> Option<&Field<'m>> { |
62 | self.0.iter().find(|f| f.code() == code) |
63 | } |
64 | |
65 | /// Remove the field matching the `code`. |
66 | /// |
67 | /// Returns `true` if a field was found and removed, `false` otherwise. |
68 | pub(crate) fn remove(&mut self, code: FieldCode) -> bool { |
69 | match self.0.iter().enumerate().find(|(_, f)| f.code() == code) { |
70 | Some((i, _)) => { |
71 | self.0.remove(i); |
72 | |
73 | true |
74 | } |
75 | None => false, |
76 | } |
77 | } |
78 | } |
79 | |
80 | /// A byte range of a field in a Message, used in [`QuickFields`]. |
81 | /// |
82 | /// Some invalid encodings (end = 0) are used to indicate "not cached" and "not present". |
83 | #[derive (Debug, Default, Clone, Copy)] |
84 | pub(crate) struct FieldPos { |
85 | start: u32, |
86 | end: u32, |
87 | } |
88 | |
89 | impl FieldPos { |
90 | pub fn new_not_present() -> Self { |
91 | Self { start: 1, end: 0 } |
92 | } |
93 | |
94 | pub fn build(msg_buf: &[u8], field_buf: &str) -> Option<Self> { |
95 | let buf_start = msg_buf.as_ptr() as usize; |
96 | let field_start = field_buf.as_ptr() as usize; |
97 | let offset = field_start.checked_sub(buf_start)?; |
98 | if offset <= msg_buf.len() && offset + field_buf.len() <= msg_buf.len() { |
99 | Some(Self { |
100 | start: offset.try_into().ok()?, |
101 | end: (offset + field_buf.len()).try_into().ok()?, |
102 | }) |
103 | } else { |
104 | None |
105 | } |
106 | } |
107 | |
108 | pub fn new<T>(msg_buf: &[u8], field: Option<&T>) -> Self |
109 | where |
110 | T: std::ops::Deref<Target = str>, |
111 | { |
112 | field |
113 | .and_then(|f| Self::build(msg_buf, f.deref())) |
114 | .unwrap_or_else(Self::new_not_present) |
115 | } |
116 | |
117 | /// Reassemble a previously cached field. |
118 | /// |
119 | /// **NOTE**: The caller must ensure that the `msg_buff` is the same one `build` was called for. |
120 | /// Otherwise, you'll get a panic. |
121 | pub fn read<'m, T>(&self, msg_buf: &'m [u8]) -> Option<T> |
122 | where |
123 | T: TryFrom<&'m str>, |
124 | T::Error: std::fmt::Debug, |
125 | { |
126 | match self { |
127 | Self { |
128 | start: 0..=1, |
129 | end: 0, |
130 | } => None, |
131 | Self { start, end } => { |
132 | let s = std::str::from_utf8(&msg_buf[(*start as usize)..(*end as usize)]) |
133 | .expect("Invalid utf8 when reconstructing string" ); |
134 | // We already check the fields during the construction of `Self`. |
135 | T::try_from(s) |
136 | .map(Some) |
137 | .expect("Invalid field reconstruction" ) |
138 | } |
139 | } |
140 | } |
141 | } |
142 | |
143 | /// A cache of the Message header fields. |
144 | #[derive (Debug, Default, Copy, Clone)] |
145 | pub(crate) struct QuickFields { |
146 | path: FieldPos, |
147 | interface: FieldPos, |
148 | member: FieldPos, |
149 | error_name: FieldPos, |
150 | reply_serial: Option<NonZeroU32>, |
151 | destination: FieldPos, |
152 | sender: FieldPos, |
153 | signature: FieldPos, |
154 | unix_fds: Option<u32>, |
155 | } |
156 | |
157 | impl QuickFields { |
158 | pub fn new(buf: &[u8], header: &Header<'_>) -> Result<Self> { |
159 | Ok(Self { |
160 | path: FieldPos::new(buf, header.path()), |
161 | interface: FieldPos::new(buf, header.interface()), |
162 | member: FieldPos::new(buf, header.member()), |
163 | error_name: FieldPos::new(buf, header.error_name()), |
164 | reply_serial: header.reply_serial(), |
165 | destination: FieldPos::new(buf, header.destination()), |
166 | sender: FieldPos::new(buf, header.sender()), |
167 | signature: FieldPos::new(buf, header.signature()), |
168 | unix_fds: header.unix_fds(), |
169 | }) |
170 | } |
171 | |
172 | pub fn path<'m>(&self, msg: &'m Message) -> Option<ObjectPath<'m>> { |
173 | self.path.read(msg.data()) |
174 | } |
175 | |
176 | pub fn interface<'m>(&self, msg: &'m Message) -> Option<InterfaceName<'m>> { |
177 | self.interface.read(msg.data()) |
178 | } |
179 | |
180 | pub fn member<'m>(&self, msg: &'m Message) -> Option<MemberName<'m>> { |
181 | self.member.read(msg.data()) |
182 | } |
183 | |
184 | pub fn error_name<'m>(&self, msg: &'m Message) -> Option<ErrorName<'m>> { |
185 | self.error_name.read(msg.data()) |
186 | } |
187 | |
188 | pub fn reply_serial(&self) -> Option<NonZeroU32> { |
189 | self.reply_serial |
190 | } |
191 | |
192 | pub fn destination<'m>(&self, msg: &'m Message) -> Option<BusName<'m>> { |
193 | self.destination.read(msg.data()) |
194 | } |
195 | |
196 | pub fn sender<'m>(&self, msg: &'m Message) -> Option<UniqueName<'m>> { |
197 | self.sender.read(msg.data()) |
198 | } |
199 | |
200 | pub fn signature<'m>(&self, msg: &'m Message) -> Option<Signature<'m>> { |
201 | self.signature.read(msg.data()) |
202 | } |
203 | |
204 | pub fn unix_fds(&self) -> Option<u32> { |
205 | self.unix_fds |
206 | } |
207 | } |
208 | |
209 | impl<'m> Default for Fields<'m> { |
210 | fn default() -> Self { |
211 | Self(Vec::with_capacity(MAX_FIELDS_IN_MESSAGE)) |
212 | } |
213 | } |
214 | |
215 | impl<'m> std::ops::Deref for Fields<'m> { |
216 | type Target = [Field<'m>]; |
217 | |
218 | fn deref(&self) -> &Self::Target { |
219 | self.get() |
220 | } |
221 | } |
222 | |
223 | #[cfg (test)] |
224 | mod tests { |
225 | use super::{Field, Fields}; |
226 | |
227 | #[test ] |
228 | fn test() { |
229 | let mut mf = Fields::new(); |
230 | assert_eq!(mf.len(), 0); |
231 | mf.add(Field::ReplySerial(42.try_into().unwrap())); |
232 | assert_eq!(mf.len(), 1); |
233 | mf.add(Field::ReplySerial(43.try_into().unwrap())); |
234 | assert_eq!(mf.len(), 2); |
235 | |
236 | let mut mf = Fields::new(); |
237 | assert_eq!(mf.len(), 0); |
238 | mf.replace(Field::ReplySerial(42.try_into().unwrap())); |
239 | assert_eq!(mf.len(), 1); |
240 | mf.replace(Field::ReplySerial(43.try_into().unwrap())); |
241 | assert_eq!(mf.len(), 1); |
242 | } |
243 | } |
244 | |