1//! Helper for implementing `RequestConnection::extension_information()`.
2
3use std::collections::{hash_map::Entry as HashMapEntry, HashMap};
4
5use crate::connection::RequestConnection;
6use crate::cookie::Cookie;
7use crate::errors::{ConnectionError, ReplyError};
8use crate::protocol::xproto::{ConnectionExt, QueryExtensionReply};
9use crate::x11_utils::{ExtInfoProvider, ExtensionInformation};
10
11use x11rb_protocol::SequenceNumber;
12
13/// Helper for implementing `RequestConnection::extension_information()`.
14///
15/// This helps with implementing `RequestConnection`. Most likely, you do not need this in your own
16/// code, unless you really want to implement your own X11 connection.
17#[derive(Debug, Default)]
18pub struct ExtensionManager(HashMap<&'static str, CheckState>);
19
20#[derive(Debug)]
21enum CheckState {
22 Prefetched(SequenceNumber),
23 Present(ExtensionInformation),
24 Missing,
25 Error,
26}
27
28impl ExtensionManager {
29 /// If the extension has not prefetched yet, sends a `QueryExtension`
30 /// requests, adds a field to the hash map and returns a reference to it.
31 fn prefetch_extension_information_aux<C: RequestConnection>(
32 &mut self,
33 conn: &C,
34 extension_name: &'static str,
35 ) -> Result<&mut CheckState, ConnectionError> {
36 match self.0.entry(extension_name) {
37 // Extension already checked, return the cached value
38 HashMapEntry::Occupied(entry) => Ok(entry.into_mut()),
39 HashMapEntry::Vacant(entry) => {
40 crate::debug!(
41 "Prefetching information about '{}' extension",
42 extension_name
43 );
44 let cookie = conn.query_extension(extension_name.as_bytes())?;
45 Ok(entry.insert(CheckState::Prefetched(cookie.into_sequence_number())))
46 }
47 }
48 }
49
50 /// Prefetchs an extension sending a `QueryExtension` without waiting for
51 /// the reply.
52 pub fn prefetch_extension_information<C: RequestConnection>(
53 &mut self,
54 conn: &C,
55 extension_name: &'static str,
56 ) -> Result<(), ConnectionError> {
57 // We are not interested on the reference to the entry.
58 let _ = self.prefetch_extension_information_aux(conn, extension_name)?;
59 Ok(())
60 }
61
62 /// Insert an extension if you already have the information.
63 pub fn insert_extension_information(
64 &mut self,
65 extension_name: &'static str,
66 info: Option<ExtensionInformation>,
67 ) {
68 crate::debug!(
69 "Inserting '{}' extension information directly: {:?}",
70 extension_name,
71 info
72 );
73 let state = match info {
74 Some(info) => CheckState::Present(info),
75 None => CheckState::Missing,
76 };
77
78 let _ = self.0.insert(extension_name, state);
79 }
80
81 /// An implementation of `RequestConnection::extension_information()`.
82 ///
83 /// The given connection is used for sending a `QueryExtension` request if needed.
84 pub fn extension_information<C: RequestConnection>(
85 &mut self,
86 conn: &C,
87 extension_name: &'static str,
88 ) -> Result<Option<ExtensionInformation>, ConnectionError> {
89 let _guard = crate::debug_span!("extension_information", extension_name).entered();
90 let entry = self.prefetch_extension_information_aux(conn, extension_name)?;
91 match entry {
92 CheckState::Prefetched(sequence_number) => {
93 crate::debug!(
94 "Waiting for QueryInfo reply for '{}' extension",
95 extension_name
96 );
97 match Cookie::<C, QueryExtensionReply>::new(conn, *sequence_number).reply() {
98 Err(err) => {
99 crate::warning!(
100 "Got error {:?} for QueryInfo reply for '{}' extension",
101 err,
102 extension_name
103 );
104 *entry = CheckState::Error;
105 match err {
106 ReplyError::ConnectionError(e) => Err(e),
107 // The X11 protocol specification does not specify any error
108 // for the QueryExtension request, so this should not happen.
109 ReplyError::X11Error(_) => Err(ConnectionError::UnknownError),
110 }
111 }
112 Ok(info) => {
113 if info.present {
114 let info = ExtensionInformation {
115 major_opcode: info.major_opcode,
116 first_event: info.first_event,
117 first_error: info.first_error,
118 };
119 crate::debug!("Extension '{}' is present: {:?}", extension_name, info);
120 *entry = CheckState::Present(info);
121 Ok(Some(info))
122 } else {
123 crate::debug!("Extension '{}' is not present", extension_name);
124 *entry = CheckState::Missing;
125 Ok(None)
126 }
127 }
128 }
129 }
130 CheckState::Present(info) => Ok(Some(*info)),
131 CheckState::Missing => Ok(None),
132 CheckState::Error => Err(ConnectionError::UnknownError),
133 }
134 }
135}
136
137impl ExtInfoProvider for ExtensionManager {
138 fn get_from_major_opcode(&self, major_opcode: u8) -> Option<(&str, ExtensionInformation)> {
139 self.0
140 .iter()
141 .filter_map(|(name, state)| {
142 if let CheckState::Present(info) = state {
143 Some((*name, *info))
144 } else {
145 None
146 }
147 })
148 .find(|(_, info)| info.major_opcode == major_opcode)
149 }
150
151 fn get_from_event_code(&self, event_code: u8) -> Option<(&str, ExtensionInformation)> {
152 self.0
153 .iter()
154 .filter_map(|(name, state)| {
155 if let CheckState::Present(info) = state {
156 if info.first_event <= event_code {
157 Some((*name, *info))
158 } else {
159 None
160 }
161 } else {
162 None
163 }
164 })
165 .max_by_key(|(_, info)| info.first_event)
166 }
167
168 fn get_from_error_code(&self, error_code: u8) -> Option<(&str, ExtensionInformation)> {
169 self.0
170 .iter()
171 .filter_map(|(name, state)| {
172 if let CheckState::Present(info) = state {
173 if info.first_error <= error_code {
174 Some((*name, *info))
175 } else {
176 None
177 }
178 } else {
179 None
180 }
181 })
182 .max_by_key(|(_, info)| info.first_error)
183 }
184}
185
186#[cfg(test)]
187mod test {
188 use std::cell::RefCell;
189 use std::io::IoSlice;
190
191 use crate::connection::{BufWithFds, ReplyOrError, RequestConnection, RequestKind};
192 use crate::cookie::{Cookie, CookieWithFds, VoidCookie};
193 use crate::errors::{ConnectionError, ParseError};
194 use crate::utils::RawFdContainer;
195 use crate::x11_utils::{ExtInfoProvider, ExtensionInformation, TryParse, TryParseFd};
196 use x11rb_protocol::{DiscardMode, SequenceNumber};
197
198 use super::{CheckState, ExtensionManager};
199
200 struct FakeConnection(RefCell<SequenceNumber>);
201
202 impl RequestConnection for FakeConnection {
203 type Buf = Vec<u8>;
204
205 fn send_request_with_reply<R>(
206 &self,
207 _bufs: &[IoSlice<'_>],
208 _fds: Vec<RawFdContainer>,
209 ) -> Result<Cookie<'_, Self, R>, ConnectionError>
210 where
211 R: TryParse,
212 {
213 Ok(Cookie::new(self, 1))
214 }
215
216 fn send_request_with_reply_with_fds<R>(
217 &self,
218 _bufs: &[IoSlice<'_>],
219 _fds: Vec<RawFdContainer>,
220 ) -> Result<CookieWithFds<'_, Self, R>, ConnectionError>
221 where
222 R: TryParseFd,
223 {
224 unimplemented!()
225 }
226
227 fn send_request_without_reply(
228 &self,
229 _bufs: &[IoSlice<'_>],
230 _fds: Vec<RawFdContainer>,
231 ) -> Result<VoidCookie<'_, Self>, ConnectionError> {
232 unimplemented!()
233 }
234
235 fn discard_reply(&self, _sequence: SequenceNumber, _kind: RequestKind, _mode: DiscardMode) {
236 unimplemented!()
237 }
238
239 fn prefetch_extension_information(
240 &self,
241 _extension_name: &'static str,
242 ) -> Result<(), ConnectionError> {
243 unimplemented!();
244 }
245
246 fn extension_information(
247 &self,
248 _extension_name: &'static str,
249 ) -> Result<Option<ExtensionInformation>, ConnectionError> {
250 unimplemented!()
251 }
252
253 fn wait_for_reply_or_raw_error(
254 &self,
255 sequence: SequenceNumber,
256 ) -> Result<ReplyOrError<Vec<u8>>, ConnectionError> {
257 // Code should only ask once for the reply to a request. Check that this is the case
258 // (by requiring monotonically increasing sequence numbers here).
259 let mut last = self.0.borrow_mut();
260 assert!(
261 *last < sequence,
262 "Last sequence number that was awaited was {}, but now {}",
263 *last,
264 sequence
265 );
266 *last = sequence;
267 // Then return an error, because that's what the #[test] below needs.
268 Err(ConnectionError::UnknownError)
269 }
270
271 fn wait_for_reply(
272 &self,
273 _sequence: SequenceNumber,
274 ) -> Result<Option<Vec<u8>>, ConnectionError> {
275 unimplemented!()
276 }
277
278 fn wait_for_reply_with_fds_raw(
279 &self,
280 _sequence: SequenceNumber,
281 ) -> Result<ReplyOrError<BufWithFds<Vec<u8>>, Vec<u8>>, ConnectionError> {
282 unimplemented!()
283 }
284
285 fn check_for_raw_error(
286 &self,
287 _sequence: SequenceNumber,
288 ) -> Result<Option<Vec<u8>>, ConnectionError> {
289 unimplemented!()
290 }
291
292 fn maximum_request_bytes(&self) -> usize {
293 0
294 }
295
296 fn prefetch_maximum_request_bytes(&self) {
297 unimplemented!()
298 }
299
300 fn parse_error(&self, _error: &[u8]) -> Result<crate::x11_utils::X11Error, ParseError> {
301 unimplemented!()
302 }
303
304 fn parse_event(&self, _event: &[u8]) -> Result<crate::protocol::Event, ParseError> {
305 unimplemented!()
306 }
307 }
308
309 #[test]
310 fn test_double_await() {
311 let conn = FakeConnection(RefCell::new(0));
312 let mut ext_info = ExtensionManager::default();
313
314 // Ask for an extension info. FakeConnection will return an error.
315 match ext_info.extension_information(&conn, "whatever") {
316 Err(ConnectionError::UnknownError) => {}
317 r => panic!("Unexpected result: {:?}", r),
318 }
319
320 // Ask again for the extension information. ExtensionInformation should not try to get the
321 // reply again, because that would just hang. Once upon a time, this caused a hang.
322 match ext_info.extension_information(&conn, "whatever") {
323 Err(ConnectionError::UnknownError) => {}
324 r => panic!("Unexpected result: {:?}", r),
325 }
326 }
327
328 #[test]
329 fn test_info_provider() {
330 let info = ExtensionInformation {
331 major_opcode: 4,
332 first_event: 5,
333 first_error: 6,
334 };
335
336 let mut ext_info = ExtensionManager::default();
337 let _ = ext_info.0.insert("prefetched", CheckState::Prefetched(42));
338 let _ = ext_info.0.insert("present", CheckState::Present(info));
339 let _ = ext_info.0.insert("missing", CheckState::Missing);
340 let _ = ext_info.0.insert("error", CheckState::Error);
341
342 assert_eq!(ext_info.get_from_major_opcode(4), Some(("present", info)));
343 assert_eq!(ext_info.get_from_event_code(5), Some(("present", info)));
344 assert_eq!(ext_info.get_from_error_code(6), Some(("present", info)));
345 }
346}
347