1mod builtins;
2mod collections;
3mod pyo3;
4
5#[cfg(feature = "numpy")]
6mod numpy;
7
8use maplit::hashset;
9use std::{collections::HashSet, fmt, ops};
10
11#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
12pub enum ModuleRef {
13 Named(String),
14
15 /// Default module that PyO3 creates.
16 ///
17 /// - For pure Rust project, the default module name is the crate name specified in `Cargo.toml`
18 /// or `project.name` specified in `pyproject.toml`
19 /// - For mixed Rust/Python project, the default module name is `tool.maturin.module-name` specified in `pyproject.toml`
20 ///
21 /// Because the default module name cannot be known at compile time, it will be resolved at the time of the stub file generation.
22 /// This is a placeholder for the default module name.
23 #[default]
24 Default,
25}
26
27impl ModuleRef {
28 pub fn get(&self) -> Option<&str> {
29 match self {
30 Self::Named(name: &String) => Some(name),
31 Self::Default => None,
32 }
33 }
34}
35
36impl From<&str> for ModuleRef {
37 fn from(s: &str) -> Self {
38 Self::Named(s.to_string())
39 }
40}
41
42/// Type information for creating Python stub files annotated by [PyStubType] trait.
43#[derive(Debug, Clone, PartialEq, Eq)]
44pub struct TypeInfo {
45 /// The Python type name.
46 pub name: String,
47
48 /// Python modules must be imported in the stub file.
49 ///
50 /// For example, when `name` is `typing.Sequence[int]`, `import` should contain `typing`.
51 /// This makes it possible to use user-defined types in the stub file.
52 pub import: HashSet<ModuleRef>,
53}
54
55impl fmt::Display for TypeInfo {
56 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
57 write!(f, "{}", self.name)
58 }
59}
60
61impl TypeInfo {
62 /// A `None` type annotation.
63 pub fn none() -> Self {
64 // NOTE: since 3.10, NoneType is provided from types module,
65 // but there is no corresponding definitions prior to 3.10.
66 Self {
67 name: "None".to_string(),
68 import: HashSet::new(),
69 }
70 }
71
72 /// A `typing.Any` type annotation.
73 pub fn any() -> Self {
74 Self {
75 name: "typing.Any".to_string(),
76 import: hashset! { "typing".into() },
77 }
78 }
79
80 /// A `list[Type]` type annotation.
81 pub fn list_of<T: PyStubType>() -> Self {
82 let TypeInfo { name, mut import } = T::type_output();
83 import.insert("builtins".into());
84 TypeInfo {
85 name: format!("builtins.list[{}]", name),
86 import,
87 }
88 }
89
90 /// A `set[Type]` type annotation.
91 pub fn set_of<T: PyStubType>() -> Self {
92 let TypeInfo { name, mut import } = T::type_output();
93 import.insert("builtins".into());
94 TypeInfo {
95 name: format!("builtins.set[{}]", name),
96 import,
97 }
98 }
99
100 /// A `dict[Type]` type annotation.
101 pub fn dict_of<K: PyStubType, V: PyStubType>() -> Self {
102 let TypeInfo {
103 name: name_k,
104 mut import,
105 } = K::type_output();
106 let TypeInfo {
107 name: name_v,
108 import: import_v,
109 } = V::type_output();
110 import.extend(import_v);
111 import.insert("builtins".into());
112 TypeInfo {
113 name: format!("builtins.set[{}, {}]", name_k, name_v),
114 import,
115 }
116 }
117
118 /// A type annotation of a built-in type provided from `builtins` module, such as `int`, `str`, or `float`. Generic builtin types are also possible, such as `dict[str, str]`.
119 pub fn builtin(name: &str) -> Self {
120 Self {
121 name: format!("builtins.{name}"),
122 import: hashset! { "builtins".into() },
123 }
124 }
125
126 /// Unqualified type.
127 pub fn unqualified(name: &str) -> Self {
128 Self {
129 name: name.to_string(),
130 import: hashset! {},
131 }
132 }
133
134 /// A type annotation of a type that must be imported. The type name must be qualified with the module name:
135 ///
136 /// ```
137 /// pyo3_stub_gen::TypeInfo::with_module("pathlib.Path", "pathlib".into());
138 /// ```
139 pub fn with_module(name: &str, module: ModuleRef) -> Self {
140 let mut import = HashSet::new();
141 import.insert(module);
142 Self {
143 name: name.to_string(),
144 import,
145 }
146 }
147}
148
149impl ops::BitOr for TypeInfo {
150 type Output = Self;
151
152 fn bitor(mut self, rhs: Self) -> Self {
153 self.import.extend(iter:rhs.import);
154 Self {
155 name: format!("{} | {}", self.name, rhs.name),
156 import: self.import,
157 }
158 }
159}
160
161/// Implement [PyStubType]
162///
163/// ```rust
164/// use pyo3::*;
165/// use pyo3_stub_gen::{impl_stub_type, derive::*};
166///
167/// #[gen_stub_pyclass]
168/// #[pyclass]
169/// struct A;
170///
171/// #[gen_stub_pyclass]
172/// #[pyclass]
173/// struct B;
174///
175/// enum E {
176/// A(A),
177/// B(B),
178/// }
179/// impl_stub_type!(E = A | B);
180///
181/// struct X(A);
182/// impl_stub_type!(X = A);
183///
184/// struct Y {
185/// a: A,
186/// b: B,
187/// }
188/// impl_stub_type!(Y = (A, B));
189/// ```
190#[macro_export]
191macro_rules! impl_stub_type {
192 ($ty: ty = $($base:ty)|+) => {
193 impl ::pyo3_stub_gen::PyStubType for $ty {
194 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
195 $(<$base>::type_output()) | *
196 }
197 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
198 $(<$base>::type_input()) | *
199 }
200 }
201 };
202 ($ty:ty = $base:ty) => {
203 impl ::pyo3_stub_gen::PyStubType for $ty {
204 fn type_output() -> ::pyo3_stub_gen::TypeInfo {
205 <$base>::type_output()
206 }
207 fn type_input() -> ::pyo3_stub_gen::TypeInfo {
208 <$base>::type_input()
209 }
210 }
211 };
212}
213
214/// Annotate Rust types with Python type information.
215pub trait PyStubType {
216 /// The type to be used in the output signature, i.e. return type of the Python function or methods.
217 fn type_output() -> TypeInfo;
218
219 /// The type to be used in the input signature, i.e. the arguments of the Python function or methods.
220 ///
221 /// This defaults to the output type, but can be overridden for types that are not valid input types.
222 /// For example, `Vec::<T>::type_output` returns `list[T]` while `Vec::<T>::type_input` returns `typing.Sequence[T]`.
223 fn type_input() -> TypeInfo {
224 Self::type_output()
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use super::*;
231 use maplit::hashset;
232 use std::collections::HashMap;
233 use test_case::test_case;
234
235 #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")]
236 #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")]
237 #[test_case(Vec::<u32>::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")]
238 #[test_case(Vec::<u32>::type_output(), "builtins.list[builtins.int]", hashset! { "builtins".into() } ; "Vec_u32_output")]
239 #[test_case(HashMap::<u32, String>::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")]
240 #[test_case(HashMap::<u32, String>::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")]
241 #[test_case(HashMap::<u32, Vec<u32>>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")]
242 #[test_case(HashMap::<u32, Vec<u32>>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")]
243 #[test_case(HashSet::<u32>::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")]
244 fn test(tinfo: TypeInfo, name: &str, import: HashSet<ModuleRef>) {
245 assert_eq!(tinfo.name, name);
246 if import.is_empty() {
247 assert!(tinfo.import.is_empty());
248 } else {
249 assert_eq!(tinfo.import, import);
250 }
251 }
252}
253