1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4 collections::{BTreeMap, BTreeSet},
5 fs,
6 io::Write,
7 path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12 pub modules: BTreeMap<String, Module>,
13 pub pyproject: PyProject,
14}
15
16impl StubInfo {
17 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
18 let pyproject = PyProject::parse_toml(path)?;
19 Ok(StubInfoBuilder::new(pyproject).build())
20 }
21
22 pub fn generate(&self) -> Result<()> {
23 let python_root = self
24 .pyproject
25 .python_source()
26 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()));
27
28 for (name, module) in self.modules.iter() {
29 let path = name.replace(".", "/");
30 let dest = if module.submodules.is_empty() {
31 python_root.join(format!("{path}.pyi"))
32 } else {
33 python_root.join(path).join("__init__.pyi")
34 };
35
36 let dir = dest.parent().context("Cannot get parent directory")?;
37 if !dir.exists() {
38 fs::create_dir_all(dir)?;
39 }
40
41 let mut f = fs::File::create(&dest)?;
42 write!(f, "{}", module)?;
43 log::info!(
44 "Generate stub file of a module `{name}` at {dest}",
45 dest = dest.display()
46 );
47 }
48 Ok(())
49 }
50}
51
52struct StubInfoBuilder {
53 modules: BTreeMap<String, Module>,
54 default_module_name: String,
55 pyproject: PyProject,
56}
57
58impl StubInfoBuilder {
59 fn new(pyproject: PyProject) -> Self {
60 Self {
61 modules: BTreeMap::new(),
62 default_module_name: pyproject.module_name().to_string(),
63 pyproject,
64 }
65 }
66
67 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
68 let name = name.unwrap_or(&self.default_module_name).to_string();
69 let module = self.modules.entry(name.clone()).or_default();
70 module.name = name;
71 module.default_module_name = self.default_module_name.clone();
72 module
73 }
74
75 fn register_submodules(&mut self) {
76 let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
77 for module in self.modules.keys() {
78 let path = module.split('.').collect::<Vec<_>>();
79 let n = path.len();
80 if n <= 1 {
81 continue;
82 }
83 map.entry(path[..n - 1].join("."))
84 .or_default()
85 .insert(path[n - 1].to_string());
86 }
87 for (parent, children) in map {
88 if let Some(module) = self.modules.get_mut(&parent) {
89 module.submodules.extend(children);
90 }
91 }
92 }
93
94 fn add_class(&mut self, info: &PyClassInfo) {
95 self.get_module(info.module)
96 .class
97 .insert((info.struct_id)(), ClassDef::from(info));
98 }
99
100 fn add_enum(&mut self, info: &PyEnumInfo) {
101 self.get_module(info.module)
102 .enum_
103 .insert((info.enum_id)(), EnumDef::from(info));
104 }
105
106 fn add_function(&mut self, info: &PyFunctionInfo) {
107 self.get_module(info.module)
108 .function
109 .insert(info.name, FunctionDef::from(info));
110 }
111
112 fn add_error(&mut self, info: &PyErrorInfo) {
113 self.get_module(Some(info.module))
114 .error
115 .insert(info.name, ErrorDef::from(info));
116 }
117
118 fn add_variable(&mut self, info: &PyVariableInfo) {
119 self.get_module(Some(info.module))
120 .variables
121 .insert(info.name, VariableDef::from(info));
122 }
123
124 fn add_methods(&mut self, info: &PyMethodsInfo) {
125 let struct_id = (info.struct_id)();
126 for module in self.modules.values_mut() {
127 if let Some(entry) = module.class.get_mut(&struct_id) {
128 for getter in info.getters {
129 entry.members.push(MemberDef {
130 name: getter.name,
131 r#type: (getter.r#type)(),
132 });
133 }
134 for method in info.methods {
135 entry.methods.push(MethodDef::from(method))
136 }
137 if let Some(new) = &info.new {
138 entry.new = Some(NewDef::from(new));
139 }
140 return;
141 }
142 }
143 unreachable!("Missing struct_id = {:?}", struct_id);
144 }
145
146 fn build(mut self) -> StubInfo {
147 for info in inventory::iter::<PyClassInfo> {
148 self.add_class(info);
149 }
150 for info in inventory::iter::<PyEnumInfo> {
151 self.add_enum(info);
152 }
153 for info in inventory::iter::<PyFunctionInfo> {
154 self.add_function(info);
155 }
156 for info in inventory::iter::<PyErrorInfo> {
157 self.add_error(info);
158 }
159 for info in inventory::iter::<PyVariableInfo> {
160 self.add_variable(info);
161 }
162 for info in inventory::iter::<PyMethodsInfo> {
163 self.add_methods(info);
164 }
165 self.register_submodules();
166 StubInfo {
167 modules: self.modules,
168 pyproject: self.pyproject,
169 }
170 }
171}
172