1use std::{
2 collections::HashMap,
3 sync::{self, Arc},
4};
5
6use futures_util::future::poll_fn;
7use tracing::{debug, instrument, trace};
8
9use crate::{
10 async_lock::Mutex, raw::Connection as RawConnection, Executor, MsgBroadcaster, OwnedMatchRule,
11 Socket, Task,
12};
13
14#[derive(Debug)]
15pub(crate) struct SocketReader {
16 raw_conn: Arc<sync::Mutex<RawConnection<Box<dyn Socket>>>>,
17 senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
18}
19
20impl SocketReader {
21 pub fn new(
22 raw_conn: Arc<sync::Mutex<RawConnection<Box<dyn Socket>>>>,
23 senders: Arc<Mutex<HashMap<Option<OwnedMatchRule>, MsgBroadcaster>>>,
24 ) -> Self {
25 Self { raw_conn, senders }
26 }
27
28 pub fn spawn(self, executor: &Executor<'_>) -> Task<()> {
29 executor.spawn(self.receive_msg(), "socket reader")
30 }
31
32 // Keep receiving messages and put them on the queue.
33 #[instrument(name = "socket reader", skip(self))]
34 async fn receive_msg(self) {
35 loop {
36 trace!("Waiting for message on the socket..");
37 let msg = {
38 poll_fn(|cx| {
39 let mut raw_conn = self.raw_conn.lock().expect("poisoned lock");
40 raw_conn.try_receive_message(cx)
41 })
42 .await
43 .map(Arc::new)
44 };
45 match &msg {
46 Ok(msg) => trace!("Message received on the socket: {:?}", msg),
47 Err(e) => trace!("Error reading from the socket: {:?}", e),
48 };
49
50 let mut senders = self.senders.lock().await;
51 for (rule, sender) in &*senders {
52 if let Ok(msg) = &msg {
53 if let Some(rule) = rule.as_ref() {
54 match rule.matches(msg) {
55 Ok(true) => (),
56 Ok(false) => continue,
57 Err(e) => {
58 debug!("Error matching message against rule: {:?}", e);
59
60 continue;
61 }
62 }
63 }
64 }
65
66 if let Err(e) = sender.broadcast(msg.clone()).await {
67 // An error would be due to either of these:
68 //
69 // 1. the channel is closed.
70 // 2. No active receivers.
71 //
72 // In either case, just log it.
73 trace!(
74 "Error broadcasting message to stream for `{:?}`: {:?}",
75 rule,
76 e
77 );
78 }
79 }
80 trace!("Broadcasted to all streams: {:?}", msg);
81
82 if msg.is_err() {
83 senders.clear();
84 trace!("Socket reading task stopped");
85
86 return;
87 }
88 }
89 }
90}
91