1 | use super::*; |
2 | use crate::ComInterface; |
3 | use std::marker::PhantomData; |
4 | use std::sync::atomic::{AtomicPtr, Ordering}; |
5 | |
6 | #[doc (hidden)] |
7 | pub struct FactoryCache<C, I> { |
8 | shared: AtomicPtr<std::ffi::c_void>, |
9 | _c: PhantomData<C>, |
10 | _i: PhantomData<I>, |
11 | } |
12 | |
13 | impl<C, I> FactoryCache<C, I> { |
14 | pub const fn new() -> Self { |
15 | Self { shared: AtomicPtr::new(std::ptr::null_mut()), _c: PhantomData, _i: PhantomData } |
16 | } |
17 | } |
18 | |
19 | impl<C: crate::RuntimeName, I: crate::ComInterface> FactoryCache<C, I> { |
20 | pub fn call<R, F: FnOnce(&I) -> crate::Result<R>>(&self, callback: F) -> crate::Result<R> { |
21 | loop { |
22 | // Attempt to load a previously cached factory pointer. |
23 | let ptr = self.shared.load(Ordering::Relaxed); |
24 | |
25 | // If a pointer is found, the cache is primed and we're good to go. |
26 | if !ptr.is_null() { |
27 | return callback(unsafe { std::mem::transmute(&ptr) }); |
28 | } |
29 | |
30 | // Otherwise, we load the factory the usual way. |
31 | let factory = factory::<C, I>()?; |
32 | |
33 | // If the factory is agile, we can safely cache it. |
34 | if factory.cast::<IAgileObject>().is_ok() { |
35 | if self.shared.compare_exchange_weak(std::ptr::null_mut(), factory.as_raw(), Ordering::Relaxed, Ordering::Relaxed).is_ok() { |
36 | std::mem::forget(factory); |
37 | } |
38 | } else { |
39 | // Otherwise, for non-agile factories we simply use the factory |
40 | // and discard after use as it is not safe to cache. |
41 | return callback(&factory); |
42 | } |
43 | } |
44 | } |
45 | } |
46 | |
47 | // This is safe because `FactoryCache` only holds agile factory pointers, which are safe to cache and share between threads. |
48 | unsafe impl<C, I> std::marker::Sync for FactoryCache<C, I> {} |
49 | |
50 | /// Attempts to load the factory object for the given WinRT class. |
51 | /// This can be used to access COM interfaces implemented on a Windows Runtime class factory. |
52 | pub fn factory<C: crate::RuntimeName, I: crate::ComInterface>() -> crate::Result<I> { |
53 | let mut factory: Option<I> = None; |
54 | let name = crate::HSTRING::from(C::NAME); |
55 | |
56 | let code = if let Some(function) = unsafe { delay_load::<RoGetActivationFactory>(crate::s!("combase.dll" ), crate::s!("RoGetActivationFactory" )) } { |
57 | unsafe { |
58 | let mut code = function(std::mem::transmute_copy(&name), &I::IID, &mut factory as *mut _ as *mut _); |
59 | |
60 | // If RoGetActivationFactory fails because combase hasn't been loaded yet then load combase |
61 | // automatically so that it "just works" for apartment-agnostic code. |
62 | if code == CO_E_NOTINITIALIZED { |
63 | if let Some(mta) = delay_load::<CoIncrementMTAUsage>(crate::s!("ole32.dll" ), crate::s!("CoIncrementMTAUsage" )) { |
64 | let mut cookie = std::ptr::null_mut(); |
65 | let _ = mta(&mut cookie); |
66 | } |
67 | |
68 | // Now try a second time to get the activation factory via the OS. |
69 | code = function(std::mem::transmute_copy(&name), &I::IID, &mut factory as *mut _ as *mut _); |
70 | } |
71 | |
72 | code |
73 | } |
74 | } else { |
75 | CLASS_E_CLASSNOTAVAILABLE |
76 | }; |
77 | |
78 | // If this succeeded then return the resulting factory interface. |
79 | if code.is_ok() { |
80 | return code.and_some(factory); |
81 | } |
82 | |
83 | // If not, first capture the error information from the failure above so that we |
84 | // can ultimately return this error information if all else fails. |
85 | let original: crate::Error = code.into(); |
86 | |
87 | // Now attempt to find the factory's implementation heuristically. |
88 | if let Some(i) = search_path(C::NAME, |library| unsafe { get_activation_factory(library, &name) }) { |
89 | i.cast() |
90 | } else { |
91 | Err(original) |
92 | } |
93 | } |
94 | |
95 | // Remove the suffix until a match is found appending `.dll\0` at the end |
96 | /// |
97 | /// For example, if the class name is |
98 | /// "A.B.TypeName" then the attempted load order will be: |
99 | /// 1. A.B.dll |
100 | /// 2. A.dll |
101 | fn search_path<F, R>(mut path: &str, mut callback: F) -> Option<R> |
102 | where |
103 | F: FnMut(crate::PCSTR) -> crate::Result<R>, |
104 | { |
105 | let suffix: &[u8; 5] = b".dll \0" ; |
106 | let mut library: Vec = vec![0; path.len() + suffix.len()]; |
107 | while let Some(pos: usize) = path.rfind('.' ) { |
108 | path = &path[..pos]; |
109 | library.truncate(len:path.len() + suffix.len()); |
110 | library[..path.len()].copy_from_slice(src:path.as_bytes()); |
111 | library[path.len()..].copy_from_slice(src:suffix); |
112 | |
113 | if let Ok(r: R) = callback(crate::PCSTR::from_raw(library.as_ptr())) { |
114 | return Some(r); |
115 | } |
116 | } |
117 | |
118 | None |
119 | } |
120 | |
121 | unsafe fn get_activation_factory(library: crate::PCSTR, name: &crate::HSTRING) -> crate::Result<IGenericFactory> { |
122 | let function: fn(*mut c_void, *mut *mut …) -> … = delay_load::<DllGetActivationFactory>(library, crate::s!("DllGetActivationFactory" )).ok_or_else(err:crate::Error::from_win32)?; |
123 | let mut abi: *mut c_void = std::ptr::null_mut(); |
124 | function(std::mem::transmute_copy(src:name), &mut abi).from_abi(abi) |
125 | } |
126 | |
127 | type CoIncrementMTAUsage = extern "system" fn(cookie: *mut *mut std::ffi::c_void) -> crate::HRESULT; |
128 | type RoGetActivationFactory = extern "system" fn(hstring: *mut std::ffi::c_void, interface: &crate::GUID, result: *mut *mut std::ffi::c_void) -> crate::HRESULT; |
129 | type DllGetActivationFactory = extern "system" fn(name: *mut std::ffi::c_void, factory: *mut *mut std::ffi::c_void) -> crate::HRESULT; |
130 | |
131 | #[cfg (test)] |
132 | mod tests { |
133 | use super::*; |
134 | |
135 | #[test ] |
136 | fn dll_search() { |
137 | let path = "A.B.TypeName" ; |
138 | |
139 | // Test library successfully found |
140 | let mut results = Vec::new(); |
141 | let end_result = search_path(path, |library| { |
142 | results.push(unsafe { library.to_string().unwrap() }); |
143 | if unsafe { library.as_bytes() } == &b"A.dll" [..] { |
144 | Ok(42) |
145 | } else { |
146 | Err(crate::Error::OK) |
147 | } |
148 | }); |
149 | assert!(matches!(end_result, Some(42))); |
150 | assert_eq!(results, vec!["A.B.dll" , "A.dll" ]); |
151 | |
152 | // Test library never successfully found |
153 | let mut results = Vec::new(); |
154 | let end_result = search_path(path, |library| { |
155 | results.push(unsafe { library.to_string().unwrap() }); |
156 | crate::Result::<()>::Err(crate::Error::OK) |
157 | }); |
158 | assert!(end_result.is_none()); |
159 | assert_eq!(results, vec!["A.B.dll" , "A.dll" ]); |
160 | } |
161 | } |
162 | |