1use super::*;
2use crate::{IUnknown, IUnknown_Vtbl, Interface, GUID, HRESULT};
3use core::ffi::c_void;
4use core::mem::{transmute, transmute_copy};
5use core::ptr::null_mut;
6use core::sync::atomic::{AtomicIsize, Ordering};
7
8#[repr(transparent)]
9#[derive(Default)]
10pub struct WeakRefCount(AtomicIsize);
11
12impl WeakRefCount {
13 pub const fn new() -> Self {
14 Self(AtomicIsize::new(1))
15 }
16
17 pub fn add_ref(&self) -> u32 {
18 self.0
19 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |count_or_pointer| {
20 bool::then_some(!is_weak_ref(count_or_pointer), count_or_pointer + 1)
21 })
22 .map(|u| u as u32 + 1)
23 .unwrap_or_else(|pointer| unsafe { TearOff::decode(pointer).strong_count.add_ref() })
24 }
25
26 #[inline(always)]
27 pub fn is_one(&self) -> bool {
28 self.0.load(Ordering::Acquire) == 1
29 }
30
31 pub fn release(&self) -> u32 {
32 self.0
33 .fetch_update(Ordering::Release, Ordering::Relaxed, |count_or_pointer| {
34 bool::then_some(!is_weak_ref(count_or_pointer), count_or_pointer - 1)
35 })
36 .map(|u| u as u32 - 1)
37 .unwrap_or_else(|pointer| unsafe {
38 let tear_off = TearOff::decode(pointer);
39 let remaining = tear_off.strong_count.release();
40
41 // If this is the last strong reference, we can release the weak reference implied by the strong reference.
42 // There may still be weak references, so the WeakRelease is called to handle such possibilities.
43 if remaining == 0 {
44 TearOff::WeakRelease(&mut tear_off.weak_vtable as *mut _ as _);
45 }
46
47 remaining
48 })
49 }
50
51 /// # Safety
52 pub unsafe fn query(&self, iid: &GUID, object: *mut c_void) -> *mut c_void {
53 unsafe {
54 if iid != &IWeakReferenceSource::IID {
55 return null_mut();
56 }
57
58 let mut count_or_pointer = self.0.load(Ordering::Relaxed);
59
60 if is_weak_ref(count_or_pointer) {
61 return TearOff::from_encoding(count_or_pointer);
62 }
63
64 let tear_off = TearOff::new(object, count_or_pointer as u32);
65 let tear_off_ptr: *mut c_void = transmute_copy(&tear_off);
66 let encoding: usize = ((tear_off_ptr as usize) >> 1) | (1 << (usize::BITS - 1));
67
68 loop {
69 match self.0.compare_exchange_weak(
70 count_or_pointer,
71 encoding as isize,
72 Ordering::AcqRel,
73 Ordering::Relaxed,
74 ) {
75 Ok(_) => {
76 let result: *mut c_void = transmute(tear_off);
77 TearOff::from_strong_ptr(result).strong_count.add_ref();
78 return result;
79 }
80 Err(pointer) => count_or_pointer = pointer,
81 }
82
83 if is_weak_ref(count_or_pointer) {
84 return TearOff::from_encoding(count_or_pointer);
85 }
86
87 TearOff::from_strong_ptr(tear_off_ptr)
88 .strong_count
89 .0
90 .store(count_or_pointer as i32, Ordering::SeqCst);
91 }
92 }
93 }
94}
95
96fn is_weak_ref(value: isize) -> bool {
97 value < 0
98}
99
100#[repr(C)]
101struct TearOff {
102 strong_vtable: *const IWeakReferenceSource_Vtbl,
103 weak_vtable: *const IWeakReference_Vtbl,
104 object: *mut c_void,
105 strong_count: RefCount,
106 weak_count: RefCount,
107}
108
109impl TearOff {
110 #[allow(clippy::new_ret_no_self)]
111 unsafe fn new(object: *mut c_void, strong_count: u32) -> IWeakReferenceSource {
112 unsafe {
113 transmute(Box::new(TearOff {
114 strong_vtable: &Self::STRONG_VTABLE,
115 weak_vtable: &Self::WEAK_VTABLE,
116 object,
117 strong_count: RefCount::new(strong_count),
118 weak_count: RefCount::new(1),
119 }))
120 }
121 }
122
123 unsafe fn from_encoding(encoding: isize) -> *mut c_void {
124 unsafe {
125 let tear_off = TearOff::decode(encoding);
126 tear_off.strong_count.add_ref();
127 tear_off as *mut _ as *mut _
128 }
129 }
130
131 const STRONG_VTABLE: IWeakReferenceSource_Vtbl = IWeakReferenceSource_Vtbl {
132 base__: IUnknown_Vtbl {
133 QueryInterface: Self::StrongQueryInterface,
134 AddRef: Self::StrongAddRef,
135 Release: Self::StrongRelease,
136 },
137 GetWeakReference: Self::StrongDowngrade,
138 };
139
140 const WEAK_VTABLE: IWeakReference_Vtbl = IWeakReference_Vtbl {
141 base__: IUnknown_Vtbl {
142 QueryInterface: Self::WeakQueryInterface,
143 AddRef: Self::WeakAddRef,
144 Release: Self::WeakRelease,
145 },
146 Resolve: Self::WeakUpgrade,
147 };
148
149 unsafe fn from_strong_ptr<'a>(this: *mut c_void) -> &'a mut Self {
150 unsafe { &mut *(this as *mut *mut c_void as *mut Self) }
151 }
152
153 unsafe fn from_weak_ptr<'a>(this: *mut c_void) -> &'a mut Self {
154 unsafe { &mut *((this as *mut *mut c_void).sub(1) as *mut Self) }
155 }
156
157 unsafe fn decode<'a>(value: isize) -> &'a mut Self {
158 unsafe { transmute(value << 1) }
159 }
160
161 unsafe fn query_interface(&self, iid: *const GUID, interface: *mut *mut c_void) -> HRESULT {
162 unsafe {
163 ((*(*(self.object as *mut *mut IUnknown_Vtbl))).QueryInterface)(
164 self.object,
165 iid,
166 interface,
167 )
168 }
169 }
170
171 unsafe extern "system" fn StrongQueryInterface(
172 ptr: *mut c_void,
173 iid: *const GUID,
174 interface: *mut *mut c_void,
175 ) -> HRESULT {
176 unsafe {
177 let this = Self::from_strong_ptr(ptr);
178
179 if iid.is_null() || interface.is_null() {
180 return E_POINTER;
181 }
182
183 // Only directly respond to queries for the the tear-off's strong interface. This is
184 // effectively a self-query.
185 if *iid == IWeakReferenceSource::IID {
186 *interface = ptr;
187 this.strong_count.add_ref();
188 return HRESULT(0);
189 }
190
191 // As the tear-off is sharing the identity of the object, simply delegate any remaining
192 // queries to the object.
193 this.query_interface(iid, interface)
194 }
195 }
196
197 unsafe extern "system" fn WeakQueryInterface(
198 ptr: *mut c_void,
199 iid: *const GUID,
200 interface: *mut *mut c_void,
201 ) -> HRESULT {
202 unsafe {
203 let this = Self::from_weak_ptr(ptr);
204
205 if iid.is_null() || interface.is_null() {
206 return E_POINTER;
207 }
208
209 // While the weak vtable is packed into the same allocation as the strong vtable and
210 // tear-off, it represents a distinct COM identity and thus does not share or delegate to
211 // the object.
212
213 *interface = if *iid == IWeakReference::IID
214 || *iid == IUnknown::IID
215 || *iid == IAgileObject::IID
216 {
217 ptr
218 } else {
219 #[cfg(windows)]
220 if *iid == IMarshal::IID {
221 this.weak_count.add_ref();
222 return marshaler(transmute::<*mut c_void, IUnknown>(ptr), interface);
223 }
224
225 null_mut()
226 };
227
228 if (*interface).is_null() {
229 E_NOINTERFACE
230 } else {
231 this.weak_count.add_ref();
232 HRESULT(0)
233 }
234 }
235 }
236
237 unsafe extern "system" fn StrongAddRef(ptr: *mut c_void) -> u32 {
238 unsafe {
239 let this = Self::from_strong_ptr(ptr);
240
241 // Implement `AddRef` directly as we own the strong reference.
242 this.strong_count.add_ref()
243 }
244 }
245
246 unsafe extern "system" fn WeakAddRef(ptr: *mut c_void) -> u32 {
247 unsafe {
248 let this = Self::from_weak_ptr(ptr);
249
250 // Implement `AddRef` directly as we own the weak reference.
251 this.weak_count.add_ref()
252 }
253 }
254
255 unsafe extern "system" fn StrongRelease(ptr: *mut c_void) -> u32 {
256 unsafe {
257 let this = Self::from_strong_ptr(ptr);
258
259 // Forward strong `Release` to the object so that it can destroy itself. It will then
260 // decrement its weak reference and allow the tear-off to be released as needed.
261 ((*(*(this.object as *mut *mut IUnknown_Vtbl))).Release)(this.object)
262 }
263 }
264
265 unsafe extern "system" fn WeakRelease(ptr: *mut c_void) -> u32 {
266 unsafe {
267 let this = Self::from_weak_ptr(ptr);
268
269 // Implement `Release` directly as we own the weak reference.
270 let remaining = this.weak_count.release();
271
272 // If there are no remaining references, it means that the object has already been
273 // destroyed. Go ahead and destroy the tear-off.
274 if remaining == 0 {
275 let _ = Box::from_raw(this);
276 }
277
278 remaining
279 }
280 }
281
282 unsafe extern "system" fn StrongDowngrade(
283 ptr: *mut c_void,
284 interface: *mut *mut c_void,
285 ) -> HRESULT {
286 unsafe {
287 let this = Self::from_strong_ptr(ptr);
288
289 // The strong vtable hands out a reference to the weak vtable. This is always safe and
290 // straightforward since a strong reference guarantees there is at least one weak
291 // reference.
292 *interface = &mut this.weak_vtable as *mut _ as _;
293 this.weak_count.add_ref();
294 HRESULT(0)
295 }
296 }
297
298 unsafe extern "system" fn WeakUpgrade(
299 ptr: *mut c_void,
300 iid: *const GUID,
301 interface: *mut *mut c_void,
302 ) -> HRESULT {
303 unsafe {
304 let this = Self::from_weak_ptr(ptr);
305
306 this.strong_count
307 .0
308 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |count| {
309 // Attempt to acquire a strong reference count to stabilize the object for the duration
310 // of the `QueryInterface` call.
311 bool::then_some(count != 0, count + 1)
312 })
313 .map(|_| {
314 // Let the object respond to the upgrade query.
315 let result = this.query_interface(iid, interface);
316 // Decrement the temporary reference account used to stabilize the object.
317 this.strong_count.0.fetch_sub(1, Ordering::Relaxed);
318 // Return the result of the query.
319 result
320 })
321 .unwrap_or_else(|_| {
322 *interface = null_mut();
323 HRESULT(0)
324 })
325 }
326 }
327}
328