1use super::*;
2use crate::ComInterface;
3use std::marker::PhantomData;
4use std::sync::atomic::{AtomicPtr, Ordering};
5
6#[doc(hidden)]
7pub struct FactoryCache<C, I> {
8 shared: AtomicPtr<std::ffi::c_void>,
9 _c: PhantomData<C>,
10 _i: PhantomData<I>,
11}
12
13impl<C, I> FactoryCache<C, I> {
14 pub const fn new() -> Self {
15 Self { shared: AtomicPtr::new(std::ptr::null_mut()), _c: PhantomData, _i: PhantomData }
16 }
17}
18
19impl<C: crate::RuntimeName, I: crate::ComInterface> FactoryCache<C, I> {
20 pub fn call<R, F: FnOnce(&I) -> crate::Result<R>>(&self, callback: F) -> crate::Result<R> {
21 loop {
22 // Attempt to load a previously cached factory pointer.
23 let ptr = self.shared.load(Ordering::Relaxed);
24
25 // If a pointer is found, the cache is primed and we're good to go.
26 if !ptr.is_null() {
27 return callback(unsafe { std::mem::transmute(&ptr) });
28 }
29
30 // Otherwise, we load the factory the usual way.
31 let factory = factory::<C, I>()?;
32
33 // If the factory is agile, we can safely cache it.
34 if factory.cast::<IAgileObject>().is_ok() {
35 if self.shared.compare_exchange_weak(std::ptr::null_mut(), factory.as_raw(), Ordering::Relaxed, Ordering::Relaxed).is_ok() {
36 std::mem::forget(factory);
37 }
38 } else {
39 // Otherwise, for non-agile factories we simply use the factory
40 // and discard after use as it is not safe to cache.
41 return callback(&factory);
42 }
43 }
44 }
45}
46
47// This is safe because `FactoryCache` only holds agile factory pointers, which are safe to cache and share between threads.
48unsafe impl<C, I> std::marker::Sync for FactoryCache<C, I> {}
49
50/// Attempts to load the factory object for the given WinRT class.
51/// This can be used to access COM interfaces implemented on a Windows Runtime class factory.
52pub fn factory<C: crate::RuntimeName, I: crate::ComInterface>() -> crate::Result<I> {
53 let mut factory: Option<I> = None;
54 let name = crate::HSTRING::from(C::NAME);
55
56 let code = if let Some(function) = unsafe { delay_load::<RoGetActivationFactory>(crate::s!("combase.dll"), crate::s!("RoGetActivationFactory")) } {
57 unsafe {
58 let mut code = function(std::mem::transmute_copy(&name), &I::IID, &mut factory as *mut _ as *mut _);
59
60 // If RoGetActivationFactory fails because combase hasn't been loaded yet then load combase
61 // automatically so that it "just works" for apartment-agnostic code.
62 if code == CO_E_NOTINITIALIZED {
63 if let Some(mta) = delay_load::<CoIncrementMTAUsage>(crate::s!("ole32.dll"), crate::s!("CoIncrementMTAUsage")) {
64 let mut cookie = std::ptr::null_mut();
65 let _ = mta(&mut cookie);
66 }
67
68 // Now try a second time to get the activation factory via the OS.
69 code = function(std::mem::transmute_copy(&name), &I::IID, &mut factory as *mut _ as *mut _);
70 }
71
72 code
73 }
74 } else {
75 CLASS_E_CLASSNOTAVAILABLE
76 };
77
78 // If this succeeded then return the resulting factory interface.
79 if code.is_ok() {
80 return code.and_some(factory);
81 }
82
83 // If not, first capture the error information from the failure above so that we
84 // can ultimately return this error information if all else fails.
85 let original: crate::Error = code.into();
86
87 // Now attempt to find the factory's implementation heuristically.
88 if let Some(i) = search_path(C::NAME, |library| unsafe { get_activation_factory(library, &name) }) {
89 i.cast()
90 } else {
91 Err(original)
92 }
93}
94
95// Remove the suffix until a match is found appending `.dll\0` at the end
96///
97/// For example, if the class name is
98/// "A.B.TypeName" then the attempted load order will be:
99/// 1. A.B.dll
100/// 2. A.dll
101fn search_path<F, R>(mut path: &str, mut callback: F) -> Option<R>
102where
103 F: FnMut(crate::PCSTR) -> crate::Result<R>,
104{
105 let suffix: &[u8; 5] = b".dll\0";
106 let mut library: Vec = vec![0; path.len() + suffix.len()];
107 while let Some(pos: usize) = path.rfind('.') {
108 path = &path[..pos];
109 library.truncate(len:path.len() + suffix.len());
110 library[..path.len()].copy_from_slice(src:path.as_bytes());
111 library[path.len()..].copy_from_slice(src:suffix);
112
113 if let Ok(r: R) = callback(crate::PCSTR::from_raw(library.as_ptr())) {
114 return Some(r);
115 }
116 }
117
118 None
119}
120
121unsafe fn get_activation_factory(library: crate::PCSTR, name: &crate::HSTRING) -> crate::Result<IGenericFactory> {
122 let function: fn(*mut c_void, *mut *mut …) -> … = delay_load::<DllGetActivationFactory>(library, crate::s!("DllGetActivationFactory")).ok_or_else(err:crate::Error::from_win32)?;
123 let mut abi: *mut c_void = std::ptr::null_mut();
124 function(std::mem::transmute_copy(src:name), &mut abi).from_abi(abi)
125}
126
127type CoIncrementMTAUsage = extern "system" fn(cookie: *mut *mut std::ffi::c_void) -> crate::HRESULT;
128type RoGetActivationFactory = extern "system" fn(hstring: *mut std::ffi::c_void, interface: &crate::GUID, result: *mut *mut std::ffi::c_void) -> crate::HRESULT;
129type DllGetActivationFactory = extern "system" fn(name: *mut std::ffi::c_void, factory: *mut *mut std::ffi::c_void) -> crate::HRESULT;
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn dll_search() {
137 let path = "A.B.TypeName";
138
139 // Test library successfully found
140 let mut results = Vec::new();
141 let end_result = search_path(path, |library| {
142 results.push(unsafe { library.to_string().unwrap() });
143 if unsafe { library.as_bytes() } == &b"A.dll"[..] {
144 Ok(42)
145 } else {
146 Err(crate::Error::OK)
147 }
148 });
149 assert!(matches!(end_result, Some(42)));
150 assert_eq!(results, vec!["A.B.dll", "A.dll"]);
151
152 // Test library never successfully found
153 let mut results = Vec::new();
154 let end_result = search_path(path, |library| {
155 results.push(unsafe { library.to_string().unwrap() });
156 crate::Result::<()>::Err(crate::Error::OK)
157 });
158 assert!(end_result.is_none());
159 assert_eq!(results, vec!["A.B.dll", "A.dll"]);
160 }
161}
162