1 | use super::*; |
2 | use crate::{IUnknown, IUnknown_Vtbl, Interface, GUID, HRESULT}; |
3 | use core::ffi::c_void; |
4 | use core::mem::{transmute, transmute_copy}; |
5 | use core::ptr::null_mut; |
6 | use core::sync::atomic::{AtomicIsize, Ordering}; |
7 | |
8 | #[repr (transparent)] |
9 | #[derive (Default)] |
10 | pub struct WeakRefCount(AtomicIsize); |
11 | |
12 | impl WeakRefCount { |
13 | pub const fn new() -> Self { |
14 | Self(AtomicIsize::new(1)) |
15 | } |
16 | |
17 | pub fn add_ref(&self) -> u32 { |
18 | self.0 |
19 | .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |count_or_pointer| { |
20 | bool::then_some(!is_weak_ref(count_or_pointer), count_or_pointer + 1) |
21 | }) |
22 | .map(|u| u as u32 + 1) |
23 | .unwrap_or_else(|pointer| unsafe { TearOff::decode(pointer).strong_count.add_ref() }) |
24 | } |
25 | |
26 | #[inline (always)] |
27 | pub fn is_one(&self) -> bool { |
28 | self.0.load(Ordering::Acquire) == 1 |
29 | } |
30 | |
31 | pub fn release(&self) -> u32 { |
32 | self.0 |
33 | .fetch_update(Ordering::Release, Ordering::Relaxed, |count_or_pointer| { |
34 | bool::then_some(!is_weak_ref(count_or_pointer), count_or_pointer - 1) |
35 | }) |
36 | .map(|u| u as u32 - 1) |
37 | .unwrap_or_else(|pointer| unsafe { |
38 | let tear_off = TearOff::decode(pointer); |
39 | let remaining = tear_off.strong_count.release(); |
40 | |
41 | // If this is the last strong reference, we can release the weak reference implied by the strong reference. |
42 | // There may still be weak references, so the WeakRelease is called to handle such possibilities. |
43 | if remaining == 0 { |
44 | TearOff::WeakRelease(&mut tear_off.weak_vtable as *mut _ as _); |
45 | } |
46 | |
47 | remaining |
48 | }) |
49 | } |
50 | |
51 | /// # Safety |
52 | pub unsafe fn query(&self, iid: &GUID, object: *mut c_void) -> *mut c_void { |
53 | unsafe { |
54 | if iid != &IWeakReferenceSource::IID { |
55 | return null_mut(); |
56 | } |
57 | |
58 | let mut count_or_pointer = self.0.load(Ordering::Relaxed); |
59 | |
60 | if is_weak_ref(count_or_pointer) { |
61 | return TearOff::from_encoding(count_or_pointer); |
62 | } |
63 | |
64 | let tear_off = TearOff::new(object, count_or_pointer as u32); |
65 | let tear_off_ptr: *mut c_void = transmute_copy(&tear_off); |
66 | let encoding: usize = ((tear_off_ptr as usize) >> 1) | (1 << (usize::BITS - 1)); |
67 | |
68 | loop { |
69 | match self.0.compare_exchange_weak( |
70 | count_or_pointer, |
71 | encoding as isize, |
72 | Ordering::AcqRel, |
73 | Ordering::Relaxed, |
74 | ) { |
75 | Ok(_) => { |
76 | let result: *mut c_void = transmute(tear_off); |
77 | TearOff::from_strong_ptr(result).strong_count.add_ref(); |
78 | return result; |
79 | } |
80 | Err(pointer) => count_or_pointer = pointer, |
81 | } |
82 | |
83 | if is_weak_ref(count_or_pointer) { |
84 | return TearOff::from_encoding(count_or_pointer); |
85 | } |
86 | |
87 | TearOff::from_strong_ptr(tear_off_ptr) |
88 | .strong_count |
89 | .0 |
90 | .store(count_or_pointer as i32, Ordering::SeqCst); |
91 | } |
92 | } |
93 | } |
94 | } |
95 | |
96 | fn is_weak_ref(value: isize) -> bool { |
97 | value < 0 |
98 | } |
99 | |
100 | #[repr (C)] |
101 | struct TearOff { |
102 | strong_vtable: *const IWeakReferenceSource_Vtbl, |
103 | weak_vtable: *const IWeakReference_Vtbl, |
104 | object: *mut c_void, |
105 | strong_count: RefCount, |
106 | weak_count: RefCount, |
107 | } |
108 | |
109 | impl TearOff { |
110 | #[allow (clippy::new_ret_no_self)] |
111 | unsafe fn new(object: *mut c_void, strong_count: u32) -> IWeakReferenceSource { |
112 | unsafe { |
113 | transmute(Box::new(TearOff { |
114 | strong_vtable: &Self::STRONG_VTABLE, |
115 | weak_vtable: &Self::WEAK_VTABLE, |
116 | object, |
117 | strong_count: RefCount::new(strong_count), |
118 | weak_count: RefCount::new(1), |
119 | })) |
120 | } |
121 | } |
122 | |
123 | unsafe fn from_encoding(encoding: isize) -> *mut c_void { |
124 | unsafe { |
125 | let tear_off = TearOff::decode(encoding); |
126 | tear_off.strong_count.add_ref(); |
127 | tear_off as *mut _ as *mut _ |
128 | } |
129 | } |
130 | |
131 | const STRONG_VTABLE: IWeakReferenceSource_Vtbl = IWeakReferenceSource_Vtbl { |
132 | base__: IUnknown_Vtbl { |
133 | QueryInterface: Self::StrongQueryInterface, |
134 | AddRef: Self::StrongAddRef, |
135 | Release: Self::StrongRelease, |
136 | }, |
137 | GetWeakReference: Self::StrongDowngrade, |
138 | }; |
139 | |
140 | const WEAK_VTABLE: IWeakReference_Vtbl = IWeakReference_Vtbl { |
141 | base__: IUnknown_Vtbl { |
142 | QueryInterface: Self::WeakQueryInterface, |
143 | AddRef: Self::WeakAddRef, |
144 | Release: Self::WeakRelease, |
145 | }, |
146 | Resolve: Self::WeakUpgrade, |
147 | }; |
148 | |
149 | unsafe fn from_strong_ptr<'a>(this: *mut c_void) -> &'a mut Self { |
150 | unsafe { &mut *(this as *mut *mut c_void as *mut Self) } |
151 | } |
152 | |
153 | unsafe fn from_weak_ptr<'a>(this: *mut c_void) -> &'a mut Self { |
154 | unsafe { &mut *((this as *mut *mut c_void).sub(1) as *mut Self) } |
155 | } |
156 | |
157 | unsafe fn decode<'a>(value: isize) -> &'a mut Self { |
158 | unsafe { transmute(value << 1) } |
159 | } |
160 | |
161 | unsafe fn query_interface(&self, iid: *const GUID, interface: *mut *mut c_void) -> HRESULT { |
162 | unsafe { |
163 | ((*(*(self.object as *mut *mut IUnknown_Vtbl))).QueryInterface)( |
164 | self.object, |
165 | iid, |
166 | interface, |
167 | ) |
168 | } |
169 | } |
170 | |
171 | unsafe extern "system" fn StrongQueryInterface( |
172 | ptr: *mut c_void, |
173 | iid: *const GUID, |
174 | interface: *mut *mut c_void, |
175 | ) -> HRESULT { |
176 | unsafe { |
177 | let this = Self::from_strong_ptr(ptr); |
178 | |
179 | if iid.is_null() || interface.is_null() { |
180 | return E_POINTER; |
181 | } |
182 | |
183 | // Only directly respond to queries for the the tear-off's strong interface. This is |
184 | // effectively a self-query. |
185 | if *iid == IWeakReferenceSource::IID { |
186 | *interface = ptr; |
187 | this.strong_count.add_ref(); |
188 | return HRESULT(0); |
189 | } |
190 | |
191 | // As the tear-off is sharing the identity of the object, simply delegate any remaining |
192 | // queries to the object. |
193 | this.query_interface(iid, interface) |
194 | } |
195 | } |
196 | |
197 | unsafe extern "system" fn WeakQueryInterface( |
198 | ptr: *mut c_void, |
199 | iid: *const GUID, |
200 | interface: *mut *mut c_void, |
201 | ) -> HRESULT { |
202 | unsafe { |
203 | let this = Self::from_weak_ptr(ptr); |
204 | |
205 | if iid.is_null() || interface.is_null() { |
206 | return E_POINTER; |
207 | } |
208 | |
209 | // While the weak vtable is packed into the same allocation as the strong vtable and |
210 | // tear-off, it represents a distinct COM identity and thus does not share or delegate to |
211 | // the object. |
212 | |
213 | *interface = if *iid == IWeakReference::IID |
214 | || *iid == IUnknown::IID |
215 | || *iid == IAgileObject::IID |
216 | { |
217 | ptr |
218 | } else { |
219 | #[cfg (windows)] |
220 | if *iid == IMarshal::IID { |
221 | this.weak_count.add_ref(); |
222 | return marshaler(transmute::<*mut c_void, IUnknown>(ptr), interface); |
223 | } |
224 | |
225 | null_mut() |
226 | }; |
227 | |
228 | if (*interface).is_null() { |
229 | E_NOINTERFACE |
230 | } else { |
231 | this.weak_count.add_ref(); |
232 | HRESULT(0) |
233 | } |
234 | } |
235 | } |
236 | |
237 | unsafe extern "system" fn StrongAddRef(ptr: *mut c_void) -> u32 { |
238 | unsafe { |
239 | let this = Self::from_strong_ptr(ptr); |
240 | |
241 | // Implement `AddRef` directly as we own the strong reference. |
242 | this.strong_count.add_ref() |
243 | } |
244 | } |
245 | |
246 | unsafe extern "system" fn WeakAddRef(ptr: *mut c_void) -> u32 { |
247 | unsafe { |
248 | let this = Self::from_weak_ptr(ptr); |
249 | |
250 | // Implement `AddRef` directly as we own the weak reference. |
251 | this.weak_count.add_ref() |
252 | } |
253 | } |
254 | |
255 | unsafe extern "system" fn StrongRelease(ptr: *mut c_void) -> u32 { |
256 | unsafe { |
257 | let this = Self::from_strong_ptr(ptr); |
258 | |
259 | // Forward strong `Release` to the object so that it can destroy itself. It will then |
260 | // decrement its weak reference and allow the tear-off to be released as needed. |
261 | ((*(*(this.object as *mut *mut IUnknown_Vtbl))).Release)(this.object) |
262 | } |
263 | } |
264 | |
265 | unsafe extern "system" fn WeakRelease(ptr: *mut c_void) -> u32 { |
266 | unsafe { |
267 | let this = Self::from_weak_ptr(ptr); |
268 | |
269 | // Implement `Release` directly as we own the weak reference. |
270 | let remaining = this.weak_count.release(); |
271 | |
272 | // If there are no remaining references, it means that the object has already been |
273 | // destroyed. Go ahead and destroy the tear-off. |
274 | if remaining == 0 { |
275 | let _ = Box::from_raw(this); |
276 | } |
277 | |
278 | remaining |
279 | } |
280 | } |
281 | |
282 | unsafe extern "system" fn StrongDowngrade( |
283 | ptr: *mut c_void, |
284 | interface: *mut *mut c_void, |
285 | ) -> HRESULT { |
286 | unsafe { |
287 | let this = Self::from_strong_ptr(ptr); |
288 | |
289 | // The strong vtable hands out a reference to the weak vtable. This is always safe and |
290 | // straightforward since a strong reference guarantees there is at least one weak |
291 | // reference. |
292 | *interface = &mut this.weak_vtable as *mut _ as _; |
293 | this.weak_count.add_ref(); |
294 | HRESULT(0) |
295 | } |
296 | } |
297 | |
298 | unsafe extern "system" fn WeakUpgrade( |
299 | ptr: *mut c_void, |
300 | iid: *const GUID, |
301 | interface: *mut *mut c_void, |
302 | ) -> HRESULT { |
303 | unsafe { |
304 | let this = Self::from_weak_ptr(ptr); |
305 | |
306 | this.strong_count |
307 | .0 |
308 | .fetch_update(Ordering::Acquire, Ordering::Relaxed, |count| { |
309 | // Attempt to acquire a strong reference count to stabilize the object for the duration |
310 | // of the `QueryInterface` call. |
311 | bool::then_some(count != 0, count + 1) |
312 | }) |
313 | .map(|_| { |
314 | // Let the object respond to the upgrade query. |
315 | let result = this.query_interface(iid, interface); |
316 | // Decrement the temporary reference account used to stabilize the object. |
317 | this.strong_count.0.fetch_sub(1, Ordering::Relaxed); |
318 | // Return the result of the query. |
319 | result |
320 | }) |
321 | .unwrap_or_else(|_| { |
322 | *interface = null_mut(); |
323 | HRESULT(0) |
324 | }) |
325 | } |
326 | } |
327 | } |
328 | |