1use std::{
2 convert::TryInto,
3 io::{Cursor, Write},
4};
5
6#[cfg(unix)]
7use crate::Fds;
8#[cfg(unix)]
9use std::{
10 os::unix::io::RawFd,
11 sync::{Arc, RwLock},
12};
13
14use enumflags2::BitFlags;
15use zbus_names::{BusName, ErrorName, InterfaceName, MemberName, UniqueName};
16
17use crate::{
18 utils::padding_for_8_bytes,
19 zvariant::{DynamicType, EncodingContext, ObjectPath, Signature},
20 Error, Message, MessageField, MessageFieldCode, MessageFields, MessageFlags, MessageHeader,
21 MessagePrimaryHeader, MessageSequence, MessageType, QuickMessageFields, Result,
22 MAX_MESSAGE_SIZE,
23};
24
25#[cfg(unix)]
26type BuildGenericResult = Vec<RawFd>;
27
28#[cfg(not(unix))]
29type BuildGenericResult = ();
30
31macro_rules! dbus_context {
32 ($n_bytes_before: expr) => {
33 EncodingContext::<byteorder::NativeEndian>::new_dbus($n_bytes_before)
34 };
35}
36
37/// A builder for [`Message`]
38#[derive(Debug, Clone)]
39pub struct MessageBuilder<'a> {
40 header: MessageHeader<'a>,
41}
42
43impl<'a> MessageBuilder<'a> {
44 fn new(msg_type: MessageType) -> Self {
45 let primary = MessagePrimaryHeader::new(msg_type, 0);
46 let fields = MessageFields::new();
47 let header = MessageHeader::new(primary, fields);
48 Self { header }
49 }
50
51 /// Create a message of type [`MessageType::MethodCall`].
52 pub fn method_call<'p: 'a, 'm: 'a, P, M>(path: P, method_name: M) -> Result<Self>
53 where
54 P: TryInto<ObjectPath<'p>>,
55 M: TryInto<MemberName<'m>>,
56 P::Error: Into<Error>,
57 M::Error: Into<Error>,
58 {
59 Self::new(MessageType::MethodCall)
60 .path(path)?
61 .member(method_name)
62 }
63
64 /// Create a message of type [`MessageType::Signal`].
65 pub fn signal<'p: 'a, 'i: 'a, 'm: 'a, P, I, M>(path: P, interface: I, name: M) -> Result<Self>
66 where
67 P: TryInto<ObjectPath<'p>>,
68 I: TryInto<InterfaceName<'i>>,
69 M: TryInto<MemberName<'m>>,
70 P::Error: Into<Error>,
71 I::Error: Into<Error>,
72 M::Error: Into<Error>,
73 {
74 Self::new(MessageType::Signal)
75 .path(path)?
76 .interface(interface)?
77 .member(name)
78 }
79
80 /// Create a message of type [`MessageType::MethodReturn`].
81 pub fn method_return(reply_to: &MessageHeader<'_>) -> Result<Self> {
82 Self::new(MessageType::MethodReturn).reply_to(reply_to)
83 }
84
85 /// Create a message of type [`MessageType::Error`].
86 pub fn error<'e: 'a, E>(reply_to: &MessageHeader<'_>, name: E) -> Result<Self>
87 where
88 E: TryInto<ErrorName<'e>>,
89 E::Error: Into<Error>,
90 {
91 Self::new(MessageType::Error)
92 .error_name(name)?
93 .reply_to(reply_to)
94 }
95
96 /// Add flags to the message.
97 ///
98 /// See [`MessageFlags`] documentation for the meaning of the flags.
99 ///
100 /// The function will return an error if invalid flags are given for the message type.
101 pub fn with_flags(mut self, flag: MessageFlags) -> Result<Self> {
102 if self.header.message_type()? != MessageType::MethodCall
103 && BitFlags::from_flag(flag).contains(MessageFlags::NoReplyExpected)
104 {
105 return Err(Error::InvalidField);
106 }
107 let flags = self.header.primary().flags() | flag;
108 self.header.primary_mut().set_flags(flags);
109 Ok(self)
110 }
111
112 /// Set the unique name of the sending connection.
113 pub fn sender<'s: 'a, S>(mut self, sender: S) -> Result<Self>
114 where
115 S: TryInto<UniqueName<'s>>,
116 S::Error: Into<Error>,
117 {
118 self.header
119 .fields_mut()
120 .replace(MessageField::Sender(sender.try_into().map_err(Into::into)?));
121 Ok(self)
122 }
123
124 /// Set the object to send a call to, or the object a signal is emitted from.
125 pub fn path<'p: 'a, P>(mut self, path: P) -> Result<Self>
126 where
127 P: TryInto<ObjectPath<'p>>,
128 P::Error: Into<Error>,
129 {
130 self.header
131 .fields_mut()
132 .replace(MessageField::Path(path.try_into().map_err(Into::into)?));
133 Ok(self)
134 }
135
136 /// Set the interface to invoke a method call on, or that a signal is emitted from.
137 pub fn interface<'i: 'a, I>(mut self, interface: I) -> Result<Self>
138 where
139 I: TryInto<InterfaceName<'i>>,
140 I::Error: Into<Error>,
141 {
142 self.header.fields_mut().replace(MessageField::Interface(
143 interface.try_into().map_err(Into::into)?,
144 ));
145 Ok(self)
146 }
147
148 /// Set the member, either the method name or signal name.
149 pub fn member<'m: 'a, M>(mut self, member: M) -> Result<Self>
150 where
151 M: TryInto<MemberName<'m>>,
152 M::Error: Into<Error>,
153 {
154 self.header
155 .fields_mut()
156 .replace(MessageField::Member(member.try_into().map_err(Into::into)?));
157 Ok(self)
158 }
159
160 fn error_name<'e: 'a, E>(mut self, error: E) -> Result<Self>
161 where
162 E: TryInto<ErrorName<'e>>,
163 E::Error: Into<Error>,
164 {
165 self.header.fields_mut().replace(MessageField::ErrorName(
166 error.try_into().map_err(Into::into)?,
167 ));
168 Ok(self)
169 }
170
171 /// Set the name of the connection this message is intended for.
172 pub fn destination<'d: 'a, D>(mut self, destination: D) -> Result<Self>
173 where
174 D: TryInto<BusName<'d>>,
175 D::Error: Into<Error>,
176 {
177 self.header.fields_mut().replace(MessageField::Destination(
178 destination.try_into().map_err(Into::into)?,
179 ));
180 Ok(self)
181 }
182
183 fn reply_to(mut self, reply_to: &MessageHeader<'_>) -> Result<Self> {
184 let serial = reply_to.primary().serial_num().ok_or(Error::MissingField)?;
185 self.header
186 .fields_mut()
187 .replace(MessageField::ReplySerial(*serial));
188
189 if let Some(sender) = reply_to.sender()? {
190 self.destination(sender.to_owned())
191 } else {
192 Ok(self)
193 }
194 }
195
196 /// Build the [`Message`] with the given body.
197 ///
198 /// You may pass `()` as the body if the message has no body.
199 ///
200 /// The caller is currently required to ensure that the resulting message contains the headers
201 /// as compliant with the [specification]. Additional checks may be added to this builder over
202 /// time as needed.
203 ///
204 /// [specification]:
205 /// https://dbus.freedesktop.org/doc/dbus-specification.html#message-protocol-header-fields
206 pub fn build<B>(self, body: &B) -> Result<Message>
207 where
208 B: serde::ser::Serialize + DynamicType,
209 {
210 let ctxt = dbus_context!(0);
211
212 // Note: this iterates the body twice, but we prefer efficient handling of large messages
213 // to efficient handling of ones that are complex to serialize.
214 #[cfg(unix)]
215 let (body_len, fds_len) = zvariant::serialized_size_fds(ctxt, body)?;
216 #[cfg(not(unix))]
217 let body_len = zvariant::serialized_size(ctxt, body)?;
218
219 let signature = body.dynamic_signature();
220
221 self.build_generic(
222 signature,
223 body_len,
224 move |cursor| {
225 #[cfg(unix)]
226 {
227 let (_, fds) = zvariant::to_writer_fds(cursor, ctxt, body)?;
228 Ok::<Vec<RawFd>, Error>(fds)
229 }
230 #[cfg(not(unix))]
231 {
232 zvariant::to_writer(cursor, ctxt, body)?;
233 Ok::<(), Error>(())
234 }
235 },
236 #[cfg(unix)]
237 fds_len,
238 )
239 }
240
241 /// Create a new message from a raw slice of bytes to populate the body with, rather than by
242 /// serializing a value. The message body will be the exact bytes.
243 ///
244 /// # Safety
245 ///
246 /// This method is unsafe because it can be used to build an invalid message.
247 pub unsafe fn build_raw_body<'b, S>(
248 self,
249 body_bytes: &[u8],
250 signature: S,
251 #[cfg(unix)] fds: Vec<RawFd>,
252 ) -> Result<Message>
253 where
254 S: TryInto<Signature<'b>>,
255 S::Error: Into<Error>,
256 {
257 let signature: Signature<'b> = signature.try_into().map_err(Into::into)?;
258 #[cfg(unix)]
259 let fds_len = fds.len();
260
261 self.build_generic(
262 signature,
263 body_bytes.len(),
264 move |cursor: &mut Cursor<&mut Vec<u8>>| {
265 cursor.write_all(body_bytes)?;
266
267 #[cfg(unix)]
268 return Ok::<Vec<RawFd>, Error>(fds);
269
270 #[cfg(not(unix))]
271 return Ok::<(), Error>(());
272 },
273 #[cfg(unix)]
274 fds_len,
275 )
276 }
277
278 fn build_generic<WriteFunc>(
279 self,
280 mut signature: Signature<'_>,
281 body_len: usize,
282 write_body: WriteFunc,
283 #[cfg(unix)] fds_len: usize,
284 ) -> Result<Message>
285 where
286 WriteFunc: FnOnce(&mut Cursor<&mut Vec<u8>>) -> Result<BuildGenericResult>,
287 {
288 let ctxt = dbus_context!(0);
289 let mut header = self.header;
290
291 if !signature.is_empty() {
292 if signature.starts_with(zvariant::STRUCT_SIG_START_STR) {
293 // Remove leading and trailing STRUCT delimiters
294 signature = signature.slice(1..signature.len() - 1);
295 }
296 header.fields_mut().add(MessageField::Signature(signature));
297 }
298
299 let body_len_u32 = body_len.try_into().map_err(|_| Error::ExcessData)?;
300 header.primary_mut().set_body_len(body_len_u32);
301
302 #[cfg(unix)]
303 {
304 let fds_len_u32 = fds_len.try_into().map_err(|_| Error::ExcessData)?;
305 if fds_len != 0 {
306 header.fields_mut().add(MessageField::UnixFDs(fds_len_u32));
307 }
308 }
309
310 let hdr_len = zvariant::serialized_size(ctxt, &header)?;
311 // We need to align the body to 8-byte boundary.
312 let body_padding = padding_for_8_bytes(hdr_len);
313 let body_offset = hdr_len + body_padding;
314 let total_len = body_offset + body_len;
315 if total_len > MAX_MESSAGE_SIZE {
316 return Err(Error::ExcessData);
317 }
318 let mut bytes: Vec<u8> = Vec::with_capacity(total_len);
319 let mut cursor = Cursor::new(&mut bytes);
320
321 zvariant::to_writer(&mut cursor, ctxt, &header)?;
322 for _ in 0..body_padding {
323 cursor.write_all(&[0u8])?;
324 }
325 #[cfg(unix)]
326 let fds = write_body(&mut cursor)?;
327 #[cfg(not(unix))]
328 write_body(&mut cursor)?;
329
330 let primary_header = header.into_primary();
331 let header: MessageHeader<'_> = zvariant::from_slice(&bytes, ctxt)?;
332 let quick_fields = QuickMessageFields::new(&bytes, &header)?;
333
334 Ok(Message {
335 primary_header,
336 quick_fields,
337 bytes,
338 body_offset,
339 #[cfg(unix)]
340 fds: Arc::new(RwLock::new(Fds::Raw(fds))),
341 recv_seq: MessageSequence::default(),
342 })
343 }
344}
345
346impl<'m> From<MessageHeader<'m>> for MessageBuilder<'m> {
347 fn from(mut header: MessageHeader<'m>) -> Self {
348 // Signature and Fds are added by body* methods.
349 let fields: &mut MessageFields<'_> = header.fields_mut();
350 fields.remove(code:MessageFieldCode::Signature);
351 fields.remove(code:MessageFieldCode::UnixFDs);
352
353 Self { header }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::MessageBuilder;
360 use crate::Error;
361 use test_log::test;
362
363 #[test]
364 fn test_raw() -> Result<(), Error> {
365 let raw_body: &[u8] = &[16, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0];
366 let message_builder = MessageBuilder::signal("/", "test.test", "test")?;
367 let message = unsafe {
368 message_builder.build_raw_body(
369 raw_body,
370 "ai",
371 #[cfg(unix)]
372 vec![],
373 )?
374 };
375
376 let output: Vec<i32> = message.body()?;
377 assert_eq!(output, vec![1, 2, 3, 4]);
378
379 Ok(())
380 }
381}
382