| 1 | use crate::{generate::*, pyproject::PyProject, type_info::*}; |
| 2 | use anyhow::{Context, Result}; |
| 3 | use std::{ |
| 4 | collections::{BTreeMap, BTreeSet}, |
| 5 | fs, |
| 6 | io::Write, |
| 7 | path::*, |
| 8 | }; |
| 9 | |
| 10 | #[derive (Debug, Clone, PartialEq)] |
| 11 | pub struct StubInfo { |
| 12 | pub modules: BTreeMap<String, Module>, |
| 13 | pub pyproject: PyProject, |
| 14 | } |
| 15 | |
| 16 | impl StubInfo { |
| 17 | pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> { |
| 18 | let pyproject = PyProject::parse_toml(path)?; |
| 19 | Ok(StubInfoBuilder::new(pyproject).build()) |
| 20 | } |
| 21 | |
| 22 | pub fn generate(&self) -> Result<()> { |
| 23 | let python_root = self |
| 24 | .pyproject |
| 25 | .python_source() |
| 26 | .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR" ).unwrap())); |
| 27 | |
| 28 | for (name, module) in self.modules.iter() { |
| 29 | let path = name.replace("." , "/" ); |
| 30 | let dest = if module.submodules.is_empty() { |
| 31 | python_root.join(format!(" {path}.pyi" )) |
| 32 | } else { |
| 33 | python_root.join(path).join("__init__.pyi" ) |
| 34 | }; |
| 35 | |
| 36 | let dir = dest.parent().context("Cannot get parent directory" )?; |
| 37 | if !dir.exists() { |
| 38 | fs::create_dir_all(dir)?; |
| 39 | } |
| 40 | |
| 41 | let mut f = fs::File::create(&dest)?; |
| 42 | write!(f, " {}" , module)?; |
| 43 | log::info!( |
| 44 | "Generate stub file of a module ` {name}` at {dest}" , |
| 45 | dest = dest.display() |
| 46 | ); |
| 47 | } |
| 48 | Ok(()) |
| 49 | } |
| 50 | } |
| 51 | |
| 52 | struct StubInfoBuilder { |
| 53 | modules: BTreeMap<String, Module>, |
| 54 | default_module_name: String, |
| 55 | pyproject: PyProject, |
| 56 | } |
| 57 | |
| 58 | impl StubInfoBuilder { |
| 59 | fn new(pyproject: PyProject) -> Self { |
| 60 | Self { |
| 61 | modules: BTreeMap::new(), |
| 62 | default_module_name: pyproject.module_name().to_string(), |
| 63 | pyproject, |
| 64 | } |
| 65 | } |
| 66 | |
| 67 | fn get_module(&mut self, name: Option<&str>) -> &mut Module { |
| 68 | let name = name.unwrap_or(&self.default_module_name).to_string(); |
| 69 | let module = self.modules.entry(name.clone()).or_default(); |
| 70 | module.name = name; |
| 71 | module.default_module_name = self.default_module_name.clone(); |
| 72 | module |
| 73 | } |
| 74 | |
| 75 | fn register_submodules(&mut self) { |
| 76 | let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new(); |
| 77 | for module in self.modules.keys() { |
| 78 | let path = module.split('.' ).collect::<Vec<_>>(); |
| 79 | let n = path.len(); |
| 80 | if n <= 1 { |
| 81 | continue; |
| 82 | } |
| 83 | map.entry(path[..n - 1].join("." )) |
| 84 | .or_default() |
| 85 | .insert(path[n - 1].to_string()); |
| 86 | } |
| 87 | for (parent, children) in map { |
| 88 | if let Some(module) = self.modules.get_mut(&parent) { |
| 89 | module.submodules.extend(children); |
| 90 | } |
| 91 | } |
| 92 | } |
| 93 | |
| 94 | fn add_class(&mut self, info: &PyClassInfo) { |
| 95 | self.get_module(info.module) |
| 96 | .class |
| 97 | .insert((info.struct_id)(), ClassDef::from(info)); |
| 98 | } |
| 99 | |
| 100 | fn add_enum(&mut self, info: &PyEnumInfo) { |
| 101 | self.get_module(info.module) |
| 102 | .enum_ |
| 103 | .insert((info.enum_id)(), EnumDef::from(info)); |
| 104 | } |
| 105 | |
| 106 | fn add_function(&mut self, info: &PyFunctionInfo) { |
| 107 | self.get_module(info.module) |
| 108 | .function |
| 109 | .insert(info.name, FunctionDef::from(info)); |
| 110 | } |
| 111 | |
| 112 | fn add_error(&mut self, info: &PyErrorInfo) { |
| 113 | self.get_module(Some(info.module)) |
| 114 | .error |
| 115 | .insert(info.name, ErrorDef::from(info)); |
| 116 | } |
| 117 | |
| 118 | fn add_variable(&mut self, info: &PyVariableInfo) { |
| 119 | self.get_module(Some(info.module)) |
| 120 | .variables |
| 121 | .insert(info.name, VariableDef::from(info)); |
| 122 | } |
| 123 | |
| 124 | fn add_methods(&mut self, info: &PyMethodsInfo) { |
| 125 | let struct_id = (info.struct_id)(); |
| 126 | for module in self.modules.values_mut() { |
| 127 | if let Some(entry) = module.class.get_mut(&struct_id) { |
| 128 | for getter in info.getters { |
| 129 | entry.members.push(MemberDef { |
| 130 | name: getter.name, |
| 131 | r#type: (getter.r#type)(), |
| 132 | }); |
| 133 | } |
| 134 | for method in info.methods { |
| 135 | entry.methods.push(MethodDef::from(method)) |
| 136 | } |
| 137 | if let Some(new) = &info.new { |
| 138 | entry.new = Some(NewDef::from(new)); |
| 139 | } |
| 140 | return; |
| 141 | } |
| 142 | } |
| 143 | unreachable!("Missing struct_id = {:?}" , struct_id); |
| 144 | } |
| 145 | |
| 146 | fn build(mut self) -> StubInfo { |
| 147 | for info in inventory::iter::<PyClassInfo> { |
| 148 | self.add_class(info); |
| 149 | } |
| 150 | for info in inventory::iter::<PyEnumInfo> { |
| 151 | self.add_enum(info); |
| 152 | } |
| 153 | for info in inventory::iter::<PyFunctionInfo> { |
| 154 | self.add_function(info); |
| 155 | } |
| 156 | for info in inventory::iter::<PyErrorInfo> { |
| 157 | self.add_error(info); |
| 158 | } |
| 159 | for info in inventory::iter::<PyVariableInfo> { |
| 160 | self.add_variable(info); |
| 161 | } |
| 162 | for info in inventory::iter::<PyMethodsInfo> { |
| 163 | self.add_methods(info); |
| 164 | } |
| 165 | self.register_submodules(); |
| 166 | StubInfo { |
| 167 | modules: self.modules, |
| 168 | pyproject: self.pyproject, |
| 169 | } |
| 170 | } |
| 171 | } |
| 172 | |