1use super::*;
2use crate::ComInterface;
3use std::sync::atomic::{AtomicIsize, Ordering};
4
5#[doc(hidden)]
6#[repr(transparent)]
7#[derive(Default)]
8pub struct WeakRefCount(AtomicIsize);
9
10impl WeakRefCount {
11 pub fn new() -> Self {
12 Self(AtomicIsize::new(1))
13 }
14
15 pub fn add_ref(&self) -> u32 {
16 self.0.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |count_or_pointer| then_some(!is_weak_ref(count_or_pointer), count_or_pointer + 1)).map(|u| u as u32 + 1).unwrap_or_else(|pointer| unsafe { TearOff::decode(pointer).strong_count.add_ref() })
17 }
18
19 pub fn release(&self) -> u32 {
20 self.0.fetch_update(Ordering::Release, Ordering::Relaxed, |count_or_pointer| then_some(!is_weak_ref(count_or_pointer), count_or_pointer - 1)).map(|u| u as u32 - 1).unwrap_or_else(|pointer| unsafe {
21 let tear_off = TearOff::decode(pointer);
22 let remaining = tear_off.strong_count.release();
23
24 // If this is the last strong reference, we can release the weak reference implied by the strong reference.
25 // There may still be weak references, so the WeakRelease is called to handle such possibilities.
26 if remaining == 0 {
27 TearOff::WeakRelease(&mut tear_off.weak_vtable as *mut _ as _);
28 }
29
30 remaining
31 })
32 }
33
34 /// # Safety
35 pub unsafe fn query(&self, iid: *const crate::GUID, object: *mut std::ffi::c_void) -> *mut std::ffi::c_void {
36 if *iid != IWeakReferenceSource::IID {
37 return std::ptr::null_mut();
38 }
39
40 let mut count_or_pointer = self.0.load(Ordering::Relaxed);
41
42 if is_weak_ref(count_or_pointer) {
43 return TearOff::from_encoding(count_or_pointer);
44 }
45
46 let tear_off = TearOff::new(object, count_or_pointer as u32);
47 let tear_off_ptr: *mut std::ffi::c_void = std::mem::transmute_copy(&tear_off);
48 let encoding: usize = ((tear_off_ptr as usize) >> 1) | (1 << (std::mem::size_of::<usize>() * 8 - 1));
49
50 loop {
51 match self.0.compare_exchange_weak(count_or_pointer, encoding as isize, Ordering::AcqRel, Ordering::Relaxed) {
52 Ok(_) => {
53 let result: *mut std::ffi::c_void = std::mem::transmute(tear_off);
54 TearOff::from_strong_ptr(result).strong_count.add_ref();
55 return result;
56 }
57 Err(pointer) => count_or_pointer = pointer,
58 }
59
60 if is_weak_ref(count_or_pointer) {
61 return TearOff::from_encoding(count_or_pointer);
62 }
63
64 TearOff::from_strong_ptr(tear_off_ptr).strong_count.0.store(count_or_pointer as i32, Ordering::SeqCst);
65 }
66 }
67}
68
69fn is_weak_ref(value: isize) -> bool {
70 value < 0
71}
72
73#[repr(C)]
74struct TearOff {
75 strong_vtable: *const IWeakReferenceSource_Vtbl,
76 weak_vtable: *const IWeakReference_Vtbl,
77 object: *mut std::ffi::c_void,
78 strong_count: RefCount,
79 weak_count: RefCount,
80}
81
82impl TearOff {
83 #[allow(clippy::new_ret_no_self)]
84 unsafe fn new(object: *mut std::ffi::c_void, strong_count: u32) -> IWeakReferenceSource {
85 std::mem::transmute(std::boxed::Box::new(TearOff {
86 strong_vtable: &Self::STRONG_VTABLE,
87 weak_vtable: &Self::WEAK_VTABLE,
88 object,
89 strong_count: RefCount::new(strong_count),
90 weak_count: RefCount::new(1),
91 }))
92 }
93
94 unsafe fn from_encoding(encoding: isize) -> *mut std::ffi::c_void {
95 let tear_off = TearOff::decode(encoding);
96 tear_off.strong_count.add_ref();
97 tear_off as *mut _ as *mut _
98 }
99
100 const STRONG_VTABLE: IWeakReferenceSource_Vtbl = IWeakReferenceSource_Vtbl {
101 base__: crate::IUnknown_Vtbl { QueryInterface: Self::StrongQueryInterface, AddRef: Self::StrongAddRef, Release: Self::StrongRelease },
102 GetWeakReference: Self::StrongDowngrade,
103 };
104
105 const WEAK_VTABLE: IWeakReference_Vtbl = IWeakReference_Vtbl {
106 base__: crate::IUnknown_Vtbl { QueryInterface: Self::WeakQueryInterface, AddRef: Self::WeakAddRef, Release: Self::WeakRelease },
107 Resolve: Self::WeakUpgrade,
108 };
109
110 unsafe fn from_strong_ptr<'a>(this: *mut std::ffi::c_void) -> &'a mut Self {
111 &mut *(this as *mut *mut std::ffi::c_void as *mut Self)
112 }
113
114 unsafe fn from_weak_ptr<'a>(this: *mut std::ffi::c_void) -> &'a mut Self {
115 &mut *((this as *mut *mut std::ffi::c_void).sub(1) as *mut Self)
116 }
117
118 unsafe fn decode<'a>(value: isize) -> &'a mut Self {
119 std::mem::transmute(value << 1)
120 }
121
122 unsafe fn query_interface(&self, iid: *const crate::GUID, interface: *mut *mut std::ffi::c_void) -> crate::HRESULT {
123 ((*(*(self.object as *mut *mut crate::IUnknown_Vtbl))).QueryInterface)(self.object, iid, interface)
124 }
125
126 unsafe extern "system" fn StrongQueryInterface(ptr: *mut std::ffi::c_void, iid: *const crate::GUID, interface: *mut *mut std::ffi::c_void) -> crate::HRESULT {
127 let this = Self::from_strong_ptr(ptr);
128
129 if iid.is_null() || interface.is_null() {
130 return ::windows_core::HRESULT(-2147467261); // E_POINTER
131 }
132
133 // Only directly respond to queries for the the tear-off's strong interface. This is
134 // effectively a self-query.
135 if *iid == IWeakReferenceSource::IID {
136 *interface = ptr;
137 this.strong_count.add_ref();
138 return crate::HRESULT(0);
139 }
140
141 // As the tear-off is sharing the identity of the object, simply delegate any remaining
142 // queries to the object.
143 this.query_interface(iid, interface)
144 }
145
146 unsafe extern "system" fn WeakQueryInterface(ptr: *mut std::ffi::c_void, iid: *const crate::GUID, interface: *mut *mut std::ffi::c_void) -> crate::HRESULT {
147 let this = Self::from_weak_ptr(ptr);
148
149 if iid.is_null() || interface.is_null() {
150 return ::windows_core::HRESULT(-2147467261); // E_POINTER
151 }
152
153 // While the weak vtable is packed into the same allocation as the strong vtable and
154 // tear-off, it represents a distinct COM identity and thus does not share or delegate to
155 // the object.
156
157 *interface = if *iid == IWeakReference::IID || *iid == crate::IUnknown::IID || *iid == IAgileObject::IID { ptr } else { std::ptr::null_mut() };
158
159 // TODO: implement IMarshal
160
161 if (*interface).is_null() {
162 E_NOINTERFACE
163 } else {
164 this.weak_count.add_ref();
165 crate::HRESULT(0)
166 }
167 }
168
169 unsafe extern "system" fn StrongAddRef(ptr: *mut std::ffi::c_void) -> u32 {
170 let this = Self::from_strong_ptr(ptr);
171
172 // Implement `AddRef` directly as we own the strong reference.
173 this.strong_count.add_ref()
174 }
175
176 unsafe extern "system" fn WeakAddRef(ptr: *mut std::ffi::c_void) -> u32 {
177 let this = Self::from_weak_ptr(ptr);
178
179 // Implement `AddRef` directly as we own the weak reference.
180 this.weak_count.add_ref()
181 }
182
183 unsafe extern "system" fn StrongRelease(ptr: *mut std::ffi::c_void) -> u32 {
184 let this = Self::from_strong_ptr(ptr);
185
186 // Forward strong `Release` to the object so that it can destroy itself. It will then
187 // decrement its weak reference and allow the tear-off to be released as needed.
188 ((*(*(this.object as *mut *mut crate::IUnknown_Vtbl))).Release)(this.object)
189 }
190
191 unsafe extern "system" fn WeakRelease(ptr: *mut std::ffi::c_void) -> u32 {
192 let this = Self::from_weak_ptr(ptr);
193
194 // Implement `Release` directly as we own the weak reference.
195 let remaining = this.weak_count.release();
196
197 // If there are no remaining references, it means that the object has already been
198 // destroyed. Go ahead and destroy the tear-off.
199 if remaining == 0 {
200 let _ = std::boxed::Box::from_raw(this);
201 }
202
203 remaining
204 }
205
206 unsafe extern "system" fn StrongDowngrade(ptr: *mut std::ffi::c_void, interface: *mut *mut std::ffi::c_void) -> crate::HRESULT {
207 let this = Self::from_strong_ptr(ptr);
208
209 // The strong vtable hands out a reference to the weak vtable. This is always safe and
210 // straightforward since a strong reference guarantees there is at least one weak
211 // reference.
212 *interface = &mut this.weak_vtable as *mut _ as _;
213 this.weak_count.add_ref();
214 crate::HRESULT(0)
215 }
216
217 unsafe extern "system" fn WeakUpgrade(ptr: *mut std::ffi::c_void, iid: *const crate::GUID, interface: *mut *mut std::ffi::c_void) -> crate::HRESULT {
218 let this = Self::from_weak_ptr(ptr);
219
220 this.strong_count
221 .0
222 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |count| {
223 // Attempt to acquire a strong reference count to stabilize the object for the duration
224 // of the `QueryInterface` call.
225 then_some(count != 0, count + 1)
226 })
227 .map(|_| {
228 // Let the object respond to the upgrade query.
229 let result = this.query_interface(iid, interface);
230 // Decrement the temporary reference account used to stabilize the object.
231 this.strong_count.0.fetch_sub(1, Ordering::Relaxed);
232 // Return the result of the query.
233 result
234 })
235 .unwrap_or_else(|_| {
236 *interface = std::ptr::null_mut();
237 crate::HRESULT(0)
238 })
239 }
240}
241