| 1 | use super::*; |
| 2 | |
| 3 | #[derive (Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] |
| 4 | pub struct Class { |
| 5 | pub def: TypeDef, |
| 6 | } |
| 7 | |
| 8 | impl Class { |
| 9 | pub fn type_name(&self) -> TypeName { |
| 10 | self.def.type_name() |
| 11 | } |
| 12 | |
| 13 | pub fn write(&self, writer: &Writer) -> TokenStream { |
| 14 | let mut required_interfaces = self.required_interfaces(); |
| 15 | required_interfaces.sort(); |
| 16 | let type_name = self.def.type_name(); |
| 17 | let name = to_ident(type_name.name()); |
| 18 | let mut dependencies = TypeMap::new(); |
| 19 | |
| 20 | if writer.config.package { |
| 21 | self.dependencies(&mut dependencies); |
| 22 | } |
| 23 | |
| 24 | let cfg = writer.write_cfg(self.def, type_name.namespace(), &dependencies, false); |
| 25 | let runtime_name = format!(" {type_name}" ); |
| 26 | |
| 27 | let runtime_name = quote! { |
| 28 | #cfg |
| 29 | impl windows_core::RuntimeName for #name { |
| 30 | const NAME: &'static str = #runtime_name; |
| 31 | } |
| 32 | }; |
| 33 | |
| 34 | let mut methods = quote! {}; |
| 35 | let mut method_names = MethodNames::new(); |
| 36 | |
| 37 | for interface in &required_interfaces { |
| 38 | let mut virtual_names = MethodNames::new(); |
| 39 | |
| 40 | for method in interface |
| 41 | .get_methods(writer) |
| 42 | .iter() |
| 43 | .filter_map(|method| match &method { |
| 44 | MethodOrName::Method(method) => Some(method), |
| 45 | _ => None, |
| 46 | }) |
| 47 | { |
| 48 | let mut difference = TypeMap::new(); |
| 49 | |
| 50 | if writer.config.package { |
| 51 | difference = method.dependencies.difference(&dependencies); |
| 52 | } |
| 53 | |
| 54 | let cfg = writer.write_cfg(self.def, type_name.namespace(), &difference, false); |
| 55 | |
| 56 | let method = method.write( |
| 57 | writer, |
| 58 | Some(interface), |
| 59 | interface.kind, |
| 60 | &mut method_names, |
| 61 | &mut virtual_names, |
| 62 | ); |
| 63 | |
| 64 | methods.combine(quote! { |
| 65 | #cfg |
| 66 | #method |
| 67 | }); |
| 68 | } |
| 69 | } |
| 70 | |
| 71 | let new = self.has_default_constructor().then(|| |
| 72 | quote! { |
| 73 | pub fn new() -> windows_core::Result<Self> { |
| 74 | Self::IActivationFactory(|f| f.ActivateInstance::<Self>()) |
| 75 | } |
| 76 | fn IActivationFactory<R, F: FnOnce(&windows_core::imp::IGenericFactory) -> windows_core::Result<R>>( |
| 77 | callback: F, |
| 78 | ) -> windows_core::Result<R> { |
| 79 | static SHARED: windows_core::imp::FactoryCache<#name, windows_core::imp::IGenericFactory> = |
| 80 | windows_core::imp::FactoryCache::new(); |
| 81 | SHARED.call(callback) |
| 82 | } |
| 83 | } |
| 84 | ); |
| 85 | |
| 86 | let factories = required_interfaces.iter().filter_map(|interface| match interface.kind { |
| 87 | InterfaceKind::Static | InterfaceKind::Composable => { |
| 88 | if interface.def.methods().next().is_none() { |
| 89 | None |
| 90 | } else { |
| 91 | let method_name = to_ident(interface.def.name()); |
| 92 | let interface_type = interface.write_name(writer); |
| 93 | let cfg = quote! {}; |
| 94 | |
| 95 | Some(quote! { |
| 96 | #cfg |
| 97 | fn #method_name<R, F: FnOnce(&#interface_type) -> windows_core::Result<R>>( |
| 98 | callback: F, |
| 99 | ) -> windows_core::Result<R> { |
| 100 | static SHARED: windows_core::imp::FactoryCache<#name, #interface_type> = |
| 101 | windows_core::imp::FactoryCache::new(); |
| 102 | SHARED.call(callback) |
| 103 | } |
| 104 | }) |
| 105 | } |
| 106 | } |
| 107 | _ => None, |
| 108 | }); |
| 109 | |
| 110 | if let Some(default_interface) = self.default_interface() { |
| 111 | if default_interface.is_async() { |
| 112 | let default_interface = default_interface.write_name(writer); |
| 113 | |
| 114 | return quote! { |
| 115 | #cfg |
| 116 | pub type #name = #default_interface; |
| 117 | }; |
| 118 | } |
| 119 | |
| 120 | let is_exclusive = default_interface.is_exclusive(); |
| 121 | let default_interface = default_interface.write_name(writer); |
| 122 | |
| 123 | let interface_hierarchy = if is_exclusive { |
| 124 | quote! { windows_core::imp::interface_hierarchy!(#name, windows_core::IUnknown, windows_core::IInspectable); } |
| 125 | } else { |
| 126 | quote! { windows_core::imp::interface_hierarchy!(#name, windows_core::IUnknown, windows_core::IInspectable, #default_interface); } |
| 127 | }; |
| 128 | |
| 129 | let required_hierarchy = { |
| 130 | let mut interfaces: Vec<_> = required_interfaces |
| 131 | .iter() |
| 132 | .filter(|ty| !ty.is_exclusive() && ty.kind != InterfaceKind::Default) |
| 133 | .map(|ty| ty.write_name(writer)) |
| 134 | .collect(); |
| 135 | |
| 136 | interfaces.extend(self.bases().iter().map(|ty| ty.write_name(writer))); |
| 137 | |
| 138 | if interfaces.is_empty() { |
| 139 | quote! {} |
| 140 | } else { |
| 141 | quote! { |
| 142 | #cfg |
| 143 | windows_core::imp::required_hierarchy!(#name, #(#interfaces),*); |
| 144 | } |
| 145 | } |
| 146 | }; |
| 147 | |
| 148 | let agile = self.def.is_agile().then(|| { |
| 149 | quote! { |
| 150 | #cfg |
| 151 | unsafe impl Send for #name {} |
| 152 | #cfg |
| 153 | unsafe impl Sync for #name {} |
| 154 | } |
| 155 | }); |
| 156 | |
| 157 | let into_iterator = required_interfaces |
| 158 | .iter() |
| 159 | .find(|interface| interface.def.type_name() == TypeName::IIterable) |
| 160 | .map(|interface| { |
| 161 | let ty = interface.generics[0].write_name(writer); |
| 162 | let namespace = writer.write_namespace(TypeName::IIterator); |
| 163 | |
| 164 | quote! { |
| 165 | #cfg |
| 166 | impl IntoIterator for #name { |
| 167 | type Item = #ty; |
| 168 | type IntoIter = #namespace IIterator<Self::Item>; |
| 169 | |
| 170 | fn into_iter(self) -> Self::IntoIter { |
| 171 | IntoIterator::into_iter(&self) |
| 172 | } |
| 173 | } |
| 174 | #cfg |
| 175 | impl IntoIterator for &#name { |
| 176 | type Item = #ty; |
| 177 | type IntoIter = #namespace IIterator<Self::Item>; |
| 178 | |
| 179 | fn into_iter(self) -> Self::IntoIter { |
| 180 | self.First().unwrap() |
| 181 | } |
| 182 | } |
| 183 | |
| 184 | } |
| 185 | }); |
| 186 | |
| 187 | quote! { |
| 188 | #cfg |
| 189 | #[repr(transparent)] |
| 190 | #[derive(Clone, Debug, Eq, PartialEq)] |
| 191 | pub struct #name(windows_core::IUnknown); |
| 192 | #cfg |
| 193 | #interface_hierarchy |
| 194 | #required_hierarchy |
| 195 | #cfg |
| 196 | impl #name { |
| 197 | #new |
| 198 | #methods |
| 199 | #(#factories)* |
| 200 | } |
| 201 | #cfg |
| 202 | impl windows_core::RuntimeType for #name { |
| 203 | const SIGNATURE: windows_core::imp::ConstBuffer = windows_core::imp::ConstBuffer::for_class::<Self, #default_interface>(); |
| 204 | } |
| 205 | #cfg |
| 206 | unsafe impl windows_core::Interface for #name { |
| 207 | type Vtable = <#default_interface as windows_core::Interface>::Vtable; |
| 208 | const IID: windows_core::GUID = <#default_interface as windows_core::Interface>::IID; |
| 209 | } |
| 210 | #runtime_name |
| 211 | #agile |
| 212 | #into_iterator |
| 213 | } |
| 214 | } else { |
| 215 | quote! { |
| 216 | #cfg |
| 217 | pub struct #name; |
| 218 | #cfg |
| 219 | impl #name { |
| 220 | #methods |
| 221 | #(#factories)* |
| 222 | } |
| 223 | #runtime_name |
| 224 | } |
| 225 | } |
| 226 | } |
| 227 | |
| 228 | pub fn write_name(&self, writer: &Writer) -> TokenStream { |
| 229 | self.type_name().write(writer, &[]) |
| 230 | } |
| 231 | |
| 232 | fn default_interface(&self) -> Option<Type> { |
| 233 | self.def |
| 234 | .interface_impls() |
| 235 | .find(|imp| imp.has_attribute("DefaultAttribute" )) |
| 236 | .map(|imp| imp.ty(&[])) |
| 237 | } |
| 238 | |
| 239 | pub fn runtime_signature(&self) -> String { |
| 240 | format!( |
| 241 | "rc( {}; {})" , |
| 242 | self.type_name(), |
| 243 | self.default_interface().unwrap().runtime_signature() |
| 244 | ) |
| 245 | } |
| 246 | |
| 247 | pub fn dependencies(&self, dependencies: &mut TypeMap) { |
| 248 | for interface in self.required_interfaces() { |
| 249 | Type::Interface(interface).dependencies(dependencies); |
| 250 | } |
| 251 | } |
| 252 | |
| 253 | fn bases(&self) -> Vec<Self> { |
| 254 | let mut bases = Vec::new(); |
| 255 | let mut def = self.def; |
| 256 | let reader = def.reader(); |
| 257 | |
| 258 | loop { |
| 259 | let extends = def.extends().unwrap(); |
| 260 | |
| 261 | if extends == TypeName::Object { |
| 262 | break; |
| 263 | } |
| 264 | |
| 265 | let Type::Class(base) = reader.unwrap_full_name(extends.namespace(), extends.name()) |
| 266 | else { |
| 267 | panic!("type not found: {extends}" ); |
| 268 | }; |
| 269 | |
| 270 | def = base.def; |
| 271 | bases.push(base); |
| 272 | } |
| 273 | |
| 274 | bases |
| 275 | } |
| 276 | |
| 277 | pub fn required_interfaces(&self) -> Vec<Interface> { |
| 278 | fn walk(def: TypeDef, generics: &[Type], is_base: bool, set: &mut Vec<Interface>) { |
| 279 | for imp in def.interface_impls() { |
| 280 | let Type::Interface(mut interface) = imp.ty(generics) else { |
| 281 | panic!(); |
| 282 | }; |
| 283 | |
| 284 | interface.kind = if !is_base && imp.has_attribute("DefaultAttribute" ) { |
| 285 | InterfaceKind::Default |
| 286 | } else if is_base { |
| 287 | InterfaceKind::Base |
| 288 | } else { |
| 289 | InterfaceKind::None |
| 290 | }; |
| 291 | |
| 292 | if let Some(pos) = set |
| 293 | .iter() |
| 294 | .position(|existing| existing.def == interface.def) |
| 295 | { |
| 296 | if interface.kind == InterfaceKind::Default { |
| 297 | set[pos].kind = interface.kind; |
| 298 | } |
| 299 | } else { |
| 300 | walk(interface.def, &interface.generics, is_base, set); |
| 301 | set.push(interface); |
| 302 | } |
| 303 | } |
| 304 | } |
| 305 | let mut set = vec![]; |
| 306 | walk(self.def, &[], false, &mut set); |
| 307 | |
| 308 | for base in self.bases() { |
| 309 | walk(base.def, &[], true, &mut set); |
| 310 | } |
| 311 | |
| 312 | for attribute in self.def.attributes() { |
| 313 | let kind = match attribute.name() { |
| 314 | "StaticAttribute" | "ActivatableAttribute" => InterfaceKind::Static, |
| 315 | "ComposableAttribute" => InterfaceKind::Composable, |
| 316 | _ => continue, |
| 317 | }; |
| 318 | |
| 319 | for (_, arg) in attribute.args() { |
| 320 | if let Value::TypeName(tn) = arg { |
| 321 | let Type::Interface(mut interface) = self |
| 322 | .def |
| 323 | .reader() |
| 324 | .unwrap_full_name(tn.namespace(), tn.name()) |
| 325 | else { |
| 326 | panic!("type not found: {tn}" ); |
| 327 | }; |
| 328 | |
| 329 | interface.kind = kind; |
| 330 | set.push(interface); |
| 331 | break; |
| 332 | } |
| 333 | } |
| 334 | } |
| 335 | |
| 336 | set |
| 337 | } |
| 338 | |
| 339 | fn has_default_constructor(&self) -> bool { |
| 340 | self.def |
| 341 | .attributes() |
| 342 | .filter(|attribute| attribute.name() == "ActivatableAttribute" ) |
| 343 | .any(|attribute| { |
| 344 | !attribute |
| 345 | .args() |
| 346 | .iter() |
| 347 | .any(|arg| matches!(arg.1, Value::TypeName(_))) |
| 348 | }) |
| 349 | } |
| 350 | } |
| 351 | |