1 | use crate::{generate::*, pyproject::PyProject, type_info::*}; |
2 | use anyhow::{Context, Result}; |
3 | use std::{ |
4 | collections::{BTreeMap, BTreeSet}, |
5 | fs, |
6 | io::Write, |
7 | path::*, |
8 | }; |
9 | |
10 | #[derive (Debug, Clone, PartialEq)] |
11 | pub struct StubInfo { |
12 | pub modules: BTreeMap<String, Module>, |
13 | pub pyproject: PyProject, |
14 | } |
15 | |
16 | impl StubInfo { |
17 | pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> { |
18 | let pyproject = PyProject::parse_toml(path)?; |
19 | Ok(StubInfoBuilder::new(pyproject).build()) |
20 | } |
21 | |
22 | pub fn generate(&self) -> Result<()> { |
23 | let python_root = self |
24 | .pyproject |
25 | .python_source() |
26 | .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR" ).unwrap())); |
27 | |
28 | for (name, module) in self.modules.iter() { |
29 | let path = name.replace("." , "/" ); |
30 | let dest = if module.submodules.is_empty() { |
31 | python_root.join(format!(" {path}.pyi" )) |
32 | } else { |
33 | python_root.join(path).join("__init__.pyi" ) |
34 | }; |
35 | |
36 | let dir = dest.parent().context("Cannot get parent directory" )?; |
37 | if !dir.exists() { |
38 | fs::create_dir_all(dir)?; |
39 | } |
40 | |
41 | let mut f = fs::File::create(&dest)?; |
42 | write!(f, " {}" , module)?; |
43 | log::info!( |
44 | "Generate stub file of a module ` {name}` at {dest}" , |
45 | dest = dest.display() |
46 | ); |
47 | } |
48 | Ok(()) |
49 | } |
50 | } |
51 | |
52 | struct StubInfoBuilder { |
53 | modules: BTreeMap<String, Module>, |
54 | default_module_name: String, |
55 | pyproject: PyProject, |
56 | } |
57 | |
58 | impl StubInfoBuilder { |
59 | fn new(pyproject: PyProject) -> Self { |
60 | Self { |
61 | modules: BTreeMap::new(), |
62 | default_module_name: pyproject.module_name().to_string(), |
63 | pyproject, |
64 | } |
65 | } |
66 | |
67 | fn get_module(&mut self, name: Option<&str>) -> &mut Module { |
68 | let name = name.unwrap_or(&self.default_module_name).to_string(); |
69 | let module = self.modules.entry(name.clone()).or_default(); |
70 | module.name = name; |
71 | module.default_module_name = self.default_module_name.clone(); |
72 | module |
73 | } |
74 | |
75 | fn register_submodules(&mut self) { |
76 | let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new(); |
77 | for module in self.modules.keys() { |
78 | let path = module.split('.' ).collect::<Vec<_>>(); |
79 | let n = path.len(); |
80 | if n <= 1 { |
81 | continue; |
82 | } |
83 | map.entry(path[..n - 1].join("." )) |
84 | .or_default() |
85 | .insert(path[n - 1].to_string()); |
86 | } |
87 | for (parent, children) in map { |
88 | if let Some(module) = self.modules.get_mut(&parent) { |
89 | module.submodules.extend(children); |
90 | } |
91 | } |
92 | } |
93 | |
94 | fn add_class(&mut self, info: &PyClassInfo) { |
95 | self.get_module(info.module) |
96 | .class |
97 | .insert((info.struct_id)(), ClassDef::from(info)); |
98 | } |
99 | |
100 | fn add_enum(&mut self, info: &PyEnumInfo) { |
101 | self.get_module(info.module) |
102 | .enum_ |
103 | .insert((info.enum_id)(), EnumDef::from(info)); |
104 | } |
105 | |
106 | fn add_function(&mut self, info: &PyFunctionInfo) { |
107 | self.get_module(info.module) |
108 | .function |
109 | .insert(info.name, FunctionDef::from(info)); |
110 | } |
111 | |
112 | fn add_error(&mut self, info: &PyErrorInfo) { |
113 | self.get_module(Some(info.module)) |
114 | .error |
115 | .insert(info.name, ErrorDef::from(info)); |
116 | } |
117 | |
118 | fn add_variable(&mut self, info: &PyVariableInfo) { |
119 | self.get_module(Some(info.module)) |
120 | .variables |
121 | .insert(info.name, VariableDef::from(info)); |
122 | } |
123 | |
124 | fn add_methods(&mut self, info: &PyMethodsInfo) { |
125 | let struct_id = (info.struct_id)(); |
126 | for module in self.modules.values_mut() { |
127 | if let Some(entry) = module.class.get_mut(&struct_id) { |
128 | for getter in info.getters { |
129 | entry.members.push(MemberDef { |
130 | name: getter.name, |
131 | r#type: (getter.r#type)(), |
132 | }); |
133 | } |
134 | for method in info.methods { |
135 | entry.methods.push(MethodDef::from(method)) |
136 | } |
137 | if let Some(new) = &info.new { |
138 | entry.new = Some(NewDef::from(new)); |
139 | } |
140 | return; |
141 | } |
142 | } |
143 | unreachable!("Missing struct_id = {:?}" , struct_id); |
144 | } |
145 | |
146 | fn build(mut self) -> StubInfo { |
147 | for info in inventory::iter::<PyClassInfo> { |
148 | self.add_class(info); |
149 | } |
150 | for info in inventory::iter::<PyEnumInfo> { |
151 | self.add_enum(info); |
152 | } |
153 | for info in inventory::iter::<PyFunctionInfo> { |
154 | self.add_function(info); |
155 | } |
156 | for info in inventory::iter::<PyErrorInfo> { |
157 | self.add_error(info); |
158 | } |
159 | for info in inventory::iter::<PyVariableInfo> { |
160 | self.add_variable(info); |
161 | } |
162 | for info in inventory::iter::<PyMethodsInfo> { |
163 | self.add_methods(info); |
164 | } |
165 | self.register_submodules(); |
166 | StubInfo { |
167 | modules: self.modules, |
168 | pyproject: self.pyproject, |
169 | } |
170 | } |
171 | } |
172 | |