1 | use serde::{Deserialize, Serialize}; |
2 | use static_assertions::assert_impl_all; |
3 | use std::convert::{TryFrom, TryInto}; |
4 | use zbus_names::{InterfaceName, MemberName}; |
5 | use zvariant::{ObjectPath, Type}; |
6 | |
7 | use crate::{Message, MessageField, MessageFieldCode, MessageHeader, Result}; |
8 | |
9 | // It's actually 10 (and even not that) but let's round it to next 8-byte alignment |
10 | const MAX_FIELDS_IN_MESSAGE: usize = 16; |
11 | |
12 | /// A collection of [`MessageField`] instances. |
13 | /// |
14 | /// [`MessageField`]: enum.MessageField.html |
15 | #[derive (Debug, Clone, Serialize, Deserialize, Type)] |
16 | pub struct MessageFields<'m>(#[serde(borrow)] Vec<MessageField<'m>>); |
17 | |
18 | assert_impl_all!(MessageFields<'_>: Send, Sync, Unpin); |
19 | |
20 | impl<'m> MessageFields<'m> { |
21 | /// Creates an empty collection of fields. |
22 | pub fn new() -> Self { |
23 | Self::default() |
24 | } |
25 | |
26 | /// Appends a [`MessageField`] to the collection of fields in the message. |
27 | /// |
28 | /// [`MessageField`]: enum.MessageField.html |
29 | pub fn add<'f: 'm>(&mut self, field: MessageField<'f>) { |
30 | self.0.push(field); |
31 | } |
32 | |
33 | /// Replaces a [`MessageField`] from the collection of fields with one with the same code, |
34 | /// returning the old value if present. |
35 | /// |
36 | /// [`MessageField`]: enum.MessageField.html |
37 | pub fn replace<'f: 'm>(&mut self, field: MessageField<'f>) -> Option<MessageField<'m>> { |
38 | let code = field.code(); |
39 | if let Some(found) = self.0.iter_mut().find(|f| f.code() == code) { |
40 | return Some(std::mem::replace(found, field)); |
41 | } |
42 | self.add(field); |
43 | None |
44 | } |
45 | |
46 | /// Returns a slice with all the [`MessageField`] in the message. |
47 | /// |
48 | /// [`MessageField`]: enum.MessageField.html |
49 | pub fn get(&self) -> &[MessageField<'m>] { |
50 | &self.0 |
51 | } |
52 | |
53 | /// Gets a reference to a specific [`MessageField`] by its code. |
54 | /// |
55 | /// Returns `None` if the message has no such field. |
56 | /// |
57 | /// [`MessageField`]: enum.MessageField.html |
58 | pub fn get_field(&self, code: MessageFieldCode) -> Option<&MessageField<'m>> { |
59 | self.0.iter().find(|f| f.code() == code) |
60 | } |
61 | |
62 | /// Consumes the `MessageFields` and returns a specific [`MessageField`] by its code. |
63 | /// |
64 | /// Returns `None` if the message has no such field. |
65 | /// |
66 | /// [`MessageField`]: enum.MessageField.html |
67 | pub fn into_field(self, code: MessageFieldCode) -> Option<MessageField<'m>> { |
68 | self.0.into_iter().find(|f| f.code() == code) |
69 | } |
70 | |
71 | /// Remove the field matching the `code`. |
72 | /// |
73 | /// Returns `true` if a field was found and removed, `false` otherwise. |
74 | pub(crate) fn remove(&mut self, code: MessageFieldCode) -> bool { |
75 | match self.0.iter().enumerate().find(|(_, f)| f.code() == code) { |
76 | Some((i, _)) => { |
77 | self.0.remove(i); |
78 | |
79 | true |
80 | } |
81 | None => false, |
82 | } |
83 | } |
84 | } |
85 | |
86 | /// A byte range of a field in a Message, used in [`QuickMessageFields`]. |
87 | /// |
88 | /// Some invalid encodings (end = 0) are used to indicate "not cached" and "not present". |
89 | #[derive (Debug, Default, Clone, Copy)] |
90 | pub(crate) struct FieldPos { |
91 | start: u32, |
92 | end: u32, |
93 | } |
94 | |
95 | impl FieldPos { |
96 | pub fn new_not_present() -> Self { |
97 | Self { start: 1, end: 0 } |
98 | } |
99 | |
100 | pub fn build(msg_buf: &[u8], field_buf: &str) -> Option<Self> { |
101 | let buf_start = msg_buf.as_ptr() as usize; |
102 | let field_start = field_buf.as_ptr() as usize; |
103 | let offset = field_start.checked_sub(buf_start)?; |
104 | if offset <= msg_buf.len() && offset + field_buf.len() <= msg_buf.len() { |
105 | Some(Self { |
106 | start: offset.try_into().ok()?, |
107 | end: (offset + field_buf.len()).try_into().ok()?, |
108 | }) |
109 | } else { |
110 | None |
111 | } |
112 | } |
113 | |
114 | pub fn new<T>(msg_buf: &[u8], field: Option<&T>) -> Self |
115 | where |
116 | T: std::ops::Deref<Target = str>, |
117 | { |
118 | field |
119 | .and_then(|f| Self::build(msg_buf, f.deref())) |
120 | .unwrap_or_else(Self::new_not_present) |
121 | } |
122 | |
123 | /// Reassemble a previously cached field. |
124 | /// |
125 | /// **NOTE**: The caller must ensure that the `msg_buff` is the same one `build` was called for. |
126 | /// Otherwise, you'll get a panic. |
127 | pub fn read<'m, T>(&self, msg_buf: &'m [u8]) -> Option<T> |
128 | where |
129 | T: TryFrom<&'m str>, |
130 | T::Error: std::fmt::Debug, |
131 | { |
132 | match self { |
133 | Self { |
134 | start: 0..=1, |
135 | end: 0, |
136 | } => None, |
137 | Self { start, end } => { |
138 | let s = std::str::from_utf8(&msg_buf[(*start as usize)..(*end as usize)]) |
139 | .expect("Invalid utf8 when reconstructing string" ); |
140 | // We already check the fields during the construction of `Self`. |
141 | T::try_from(s) |
142 | .map(Some) |
143 | .expect("Invalid field reconstruction" ) |
144 | } |
145 | } |
146 | } |
147 | } |
148 | |
149 | /// A cache of some commonly-used fields of the header of a Message. |
150 | #[derive (Debug, Default, Copy, Clone)] |
151 | pub(crate) struct QuickMessageFields { |
152 | path: FieldPos, |
153 | interface: FieldPos, |
154 | member: FieldPos, |
155 | reply_serial: Option<u32>, |
156 | } |
157 | |
158 | impl QuickMessageFields { |
159 | pub fn new(buf: &[u8], header: &MessageHeader<'_>) -> Result<Self> { |
160 | Ok(Self { |
161 | path: FieldPos::new(buf, header.path()?), |
162 | interface: FieldPos::new(buf, header.interface()?), |
163 | member: FieldPos::new(buf, header.member()?), |
164 | reply_serial: header.reply_serial()?, |
165 | }) |
166 | } |
167 | |
168 | pub fn path<'m>(&self, msg: &'m Message) -> Option<ObjectPath<'m>> { |
169 | self.path.read(msg.as_bytes()) |
170 | } |
171 | |
172 | pub fn interface<'m>(&self, msg: &'m Message) -> Option<InterfaceName<'m>> { |
173 | self.interface.read(msg.as_bytes()) |
174 | } |
175 | |
176 | pub fn member<'m>(&self, msg: &'m Message) -> Option<MemberName<'m>> { |
177 | self.member.read(msg.as_bytes()) |
178 | } |
179 | |
180 | pub fn reply_serial(&self) -> Option<u32> { |
181 | self.reply_serial |
182 | } |
183 | } |
184 | |
185 | impl<'m> Default for MessageFields<'m> { |
186 | fn default() -> Self { |
187 | Self(Vec::with_capacity(MAX_FIELDS_IN_MESSAGE)) |
188 | } |
189 | } |
190 | |
191 | impl<'m> std::ops::Deref for MessageFields<'m> { |
192 | type Target = [MessageField<'m>]; |
193 | |
194 | fn deref(&self) -> &Self::Target { |
195 | self.get() |
196 | } |
197 | } |
198 | |
199 | #[cfg (test)] |
200 | mod tests { |
201 | use super::{MessageField, MessageFields}; |
202 | |
203 | #[test ] |
204 | fn test() { |
205 | let mut mf = MessageFields::new(); |
206 | assert_eq!(mf.len(), 0); |
207 | mf.add(MessageField::ReplySerial(42)); |
208 | assert_eq!(mf.len(), 1); |
209 | mf.add(MessageField::ReplySerial(43)); |
210 | assert_eq!(mf.len(), 2); |
211 | |
212 | let mut mf = MessageFields::new(); |
213 | assert_eq!(mf.len(), 0); |
214 | mf.replace(MessageField::ReplySerial(42)); |
215 | assert_eq!(mf.len(), 1); |
216 | mf.replace(MessageField::ReplySerial(43)); |
217 | assert_eq!(mf.len(), 1); |
218 | } |
219 | } |
220 | |