From 4fbc371442b6ffaa50da903ece0f9269352e6d68 Mon Sep 17 00:00:00 2001 From: Thomas Pellissier-Tanon Date: Fri, 27 Feb 2026 09:22:28 +0100 Subject: [PATCH] Stub: allow custom imports These imports are written in a small DSL like: ```rust #[pymodule(stubs = { from datetime import datetime as dt, time; from uuid import UUID; })] ``` Then parsed, sent as an AST inside the introspection data (following the same AST format as the type hints) and serialized by the introspection crate that merges these imports with the auto generated ones The `#[pymodule]` parameter is named `stub` because we might include some other features in the future like protocols --- newsfragments/5877.added.md | 1 + pyo3-introspection/src/introspection.rs | 59 +++++- pyo3-introspection/src/model.rs | 25 +++ pyo3-introspection/src/stubs.rs | 222 +++++++++++++++++++--- pyo3-macros-backend/src/attributes.rs | 6 + pyo3-macros-backend/src/introspection.rs | 31 +-- pyo3-macros-backend/src/json.rs | 75 ++++++++ pyo3-macros-backend/src/lib.rs | 4 + pyo3-macros-backend/src/module.rs | 18 +- pyo3-macros-backend/src/py_stubs.rs | 229 +++++++++++++++++++++++ pytests/src/annotations.rs | 30 +++ pytests/src/lib.rs | 6 + pytests/src/pyfunctions.rs | 14 -- pytests/stubs/annotations.pyi | 9 + pytests/stubs/pyfunctions.pyi | 3 - tests/test_compile_error.rs | 1 + 16 files changed, 665 insertions(+), 68 deletions(-) create mode 100644 newsfragments/5877.added.md create mode 100644 pyo3-macros-backend/src/json.rs create mode 100644 pyo3-macros-backend/src/py_stubs.rs create mode 100644 pytests/src/annotations.rs create mode 100644 pytests/stubs/annotations.pyi diff --git a/newsfragments/5877.added.md b/newsfragments/5877.added.md new file mode 100644 index 00000000000..5666e5d5c67 --- /dev/null +++ b/newsfragments/5877.added.md @@ -0,0 +1 @@ +Introspection: allow to set custom stub imports in `#[pymodule]` macro \ No newline at end of file diff --git a/pyo3-introspection/src/introspection.rs b/pyo3-introspection/src/introspection.rs index 88931bb38bd..5f949e8f623 100644 --- a/pyo3-introspection/src/introspection.rs +++ b/pyo3-introspection/src/introspection.rs @@ -1,6 +1,6 @@ use crate::model::{ - Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator, - VariableLengthArgument, + Argument, Arguments, Attribute, Class, Constant, Expr, Function, ImportAlias, Module, Operator, + Statement, VariableLengthArgument, }; use anyhow::{anyhow, bail, ensure, Context, Result}; use goblin::elf::section_header::SHN_XINDEX; @@ -52,6 +52,7 @@ fn parse_chunks(chunks: &[Chunk], main_module_name: &str) -> Result { members, doc, incomplete, + stubs, } = chunk { if name == main_module_name { @@ -65,6 +66,7 @@ fn parse_chunks(chunks: &[Chunk], main_module_name: &str) -> Result { name, members, *incomplete, + stubs, doc.as_deref(), &chunks_by_id, &chunks_by_parent, @@ -82,6 +84,7 @@ fn convert_module( name: &str, members: &[String], mut incomplete: bool, + stubs: &[ChunkStatement], docstring: Option<&str>, chunks_by_id: &HashMap<&str, &Chunk>, chunks_by_parent: &HashMap<&str, Vec<&Chunk>>, @@ -114,6 +117,35 @@ fn convert_module( functions, attributes, incomplete, + stubs: stubs + .iter() + .map(|statement| match statement { + ChunkStatement::ImportFrom { + module, + names, + level, + } => Statement::ImportFrom { + module: module.clone(), + names: names + .iter() + .map(|alias| ImportAlias { + name: alias.name.clone(), + asname: alias.asname.clone(), + }) + .collect(), + level: *level, + }, + ChunkStatement::Import { names } => Statement::Import { + names: names + .iter() + .map(|alias| ImportAlias { + name: alias.name.clone(), + asname: alias.asname.clone(), + }) + .collect(), + }, + }) + .collect(), docstring: docstring.map(Into::into), }) } @@ -139,12 +171,14 @@ fn convert_members<'a>( members, incomplete, doc, + stubs, } => { modules.push(convert_module( id, name, members, *incomplete, + stubs, doc.as_deref(), chunks_by_id, chunks_by_parent, @@ -672,6 +706,8 @@ enum Chunk { #[serde(default)] doc: Option, incomplete: bool, + #[serde(default)] + stubs: Vec, }, Class { id: String, @@ -739,6 +775,25 @@ struct ChunkArgument { annotation: Option, } +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum ChunkStatement { + ImportFrom { + module: String, + names: Vec, + level: usize, + }, + Import { + names: Vec, + }, +} + +#[derive(Deserialize)] +struct ChunkAlias { + name: String, + asname: Option, +} + #[derive(Deserialize)] #[serde(tag = "type", rename_all = "lowercase")] enum ChunkExpr { diff --git a/pyo3-introspection/src/model.rs b/pyo3-introspection/src/model.rs index b6ca8d28a8d..f30e93c8aa8 100644 --- a/pyo3-introspection/src/model.rs +++ b/pyo3-introspection/src/model.rs @@ -6,6 +6,7 @@ pub struct Module { pub functions: Vec, pub attributes: Vec, pub incomplete: bool, + pub stubs: Vec, pub docstring: Option, } @@ -74,6 +75,30 @@ pub struct VariableLengthArgument { pub annotation: Option, } +/// A python statement +/// +/// This is the `stmt` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub enum Statement { + /// `from {module} import {names}` + ImportFrom { + module: String, + names: Vec, + level: usize, + }, + /// `import {names}` + Import { names: Vec }, +} + +/// A python import alias `{name} as {asname}` +/// +/// This is the `alias` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub struct ImportAlias { + pub name: String, + pub asname: Option, +} + /// A python expression /// /// This is the `expr` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) diff --git a/pyo3-introspection/src/stubs.rs b/pyo3-introspection/src/stubs.rs index 9877652cc40..9f285e47214 100644 --- a/pyo3-introspection/src/stubs.rs +++ b/pyo3-introspection/src/stubs.rs @@ -1,6 +1,6 @@ use crate::model::{ - Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator, - VariableLengthArgument, + Argument, Arguments, Attribute, Class, Constant, Expr, Function, ImportAlias, Module, Operator, + Statement, VariableLengthArgument, }; use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::fmt::Write; @@ -95,7 +95,27 @@ fn module_stubs(module: &Module, parents: &[&str]) -> String { if let Some(docstring) = &module.docstring { final_elements.push(format!("\"\"\"\n{docstring}\n\"\"\"")); } - final_elements.extend(imports.imports); + for (module, names) in &imports.imports { + let mut line = String::new(); + if let Some(module) = module { + line.push_str("from "); + line.push_str(module); + line.push(' '); + } + line.push_str("import"); + for (i, name) in names.iter().enumerate() { + if i > 0 { + line.push(','); + } + line.push(' '); + line.push_str(&name.name); + if let Some(asname) = &name.asname { + line.push_str(" as "); + line.push_str(asname); + } + } + final_elements.push(line); + } final_elements.extend(elements); let mut output = String::new(); @@ -294,7 +314,7 @@ fn variable_length_argument_stub(argument: &VariableLengthArgument, imports: &Im #[derive(Default)] struct Imports { /// Import lines ready to use - imports: Vec, + imports: BTreeMap, Vec>, /// Renaming map: from module name and member name return the name to use in type hints renaming: BTreeMap<(String, String), String>, } @@ -311,7 +331,7 @@ impl Imports { let mut elements_used_in_annotations = ElementsUsedInAnnotations::new(); elements_used_in_annotations.walk_module(module); - let mut imports = Vec::new(); + let mut imports = BTreeMap::, Vec>::new(); let mut renaming = BTreeMap::new(); let mut local_name_to_module_and_attribute = BTreeMap::new(); @@ -334,10 +354,38 @@ impl Imports { local_name_to_module_and_attribute .insert(name.clone(), (current_module_name.clone(), name.clone())); } - // We don't process the current module elements, no need to care about them - local_name_to_module_and_attribute.remove(¤t_module_name); - // We process then imports, normalizing local imports + // Also, we insert the imports from the user-provided stub + for statement in &module.stubs { + let (module, names) = match statement { + Statement::Import { names } => (None, names), + Statement::ImportFrom { + module, + names, + level, + } => { + // We build the python module relative path + let mut module = module.clone(); + for _ in 0..*level { + module = format!(".{module}") + } + (Some(module), names) + } + }; + for name in names { + let module_and_name = (module.clone().unwrap_or_default(), name.name.clone()); + let local_name = name.asname.as_ref().unwrap_or(&name.name).clone(); + local_name_to_module_and_attribute + .insert(local_name.clone(), module_and_name.clone()); + renaming.insert(module_and_name, local_name); + } + imports + .entry(module) + .or_default() + .extend(names.iter().cloned()); + } + + // Finally, We process imports from built-in annotations (always absolute) for (module, attrs) in &elements_used_in_annotations.module_to_name { let mut import_for_module = Vec::new(); for attr in attrs { @@ -345,6 +393,18 @@ impl Imports { let (root_attr, attr_path) = attr .split_once('.') .map_or((attr.as_str(), None), |(root, path)| (root, Some(path))); + + if let Some(local_name) = renaming.get(&(module.clone(), root_attr.to_owned())) { + // it's already imported, we make sure to get a renaming for the nested class if relevant + if let Some(attr_path) = &attr_path { + renaming.insert( + (module.clone(), attr.clone()), + format!("{local_name}.{attr_path}"), + ); + } + continue; + } + let mut local_name = root_attr.to_owned(); let mut already_imported = false; while let Some((possible_conflict_module, possible_conflict_attr)) = @@ -383,21 +443,31 @@ impl Imports { let is_not_aliased_builtin = module == "builtins" && local_name == root_attr; if !is_not_aliased_builtin { import_for_module.push(if local_name == root_attr { - local_name + ImportAlias { + name: local_name, + asname: None, + } } else { - format!("{root_attr} as {local_name}") + ImportAlias { + name: root_attr.into(), + asname: Some(local_name), + } }); } } } if !import_for_module.is_empty() { - imports.push(format!( - "from {module} import {}", - import_for_module.join(", ") - )); + imports + .entry(Some(module.clone())) + .or_default() + .extend(import_for_module); } } - imports.sort(); // We make sure they are sorted + + // We sort imports + for names in imports.values_mut() { + names.sort_by(|l, r| (&l.name, &l.asname).cmp(&(&r.name, &r.asname))); + } Self { imports, renaming } } @@ -628,7 +698,7 @@ impl ElementsUsedInAnnotations { #[cfg(test)] mod tests { use super::*; - use crate::model::Arguments; + use crate::model::{Arguments, ImportAlias, Statement}; #[test] fn function_stubs_with_variable_length() { @@ -769,6 +839,14 @@ mod tests { value: Box::new(Expr::Name { id: "foo".into() }), attr: "B".into(), }, + Expr::Attribute { + value: Box::new(Expr::Name { id: "foo".into() }), + attr: "C".into(), + }, + Expr::Attribute { + value: Box::new(Expr::Name { id: "foo".into() }), + attr: "D".into(), + }, Expr::Attribute { value: Box::new(Expr::Name { id: "bat".into() }), attr: "A".into(), @@ -831,22 +909,116 @@ mod tests { }], attributes: Vec::new(), incomplete: true, + stubs: vec![ + Statement::ImportFrom { + module: "foo".into(), + names: vec![ + ImportAlias { + name: "A".into(), + asname: Some("AAlt".into()), + }, + ImportAlias { + name: "B".into(), + asname: Some("B2".into()), + }, + ImportAlias { + name: "C".into(), + asname: None, + }, + ], + level: 0, + }, + Statement::Import { + names: vec![ImportAlias { + name: "bat".into(), + asname: None, + }], + }, + Statement::ImportFrom { + module: "bat".into(), + names: vec![ImportAlias { + name: "D".into(), + asname: None, + }], + level: 0, + }, + ], docstring: None, }, &["foo"], ); assert_eq!( - &imports.imports, - &[ - "from _typeshed import Incomplete", - "from bat import A as A2", - "from builtins import int as int2", - "from foo import A as A3, B", - "from typing import final" - ] + imports.imports, + BTreeMap::from([ + ( + None, + vec![ImportAlias { + name: "bat".into(), + asname: None + }] + ), + ( + Some("_typeshed".to_string()), + vec![ImportAlias { + name: "Incomplete".into(), + asname: None + }] + ), + ( + Some("bat".into()), + vec![ + ImportAlias { + name: "A".into(), + asname: Some("A2".into()) + }, + ImportAlias { + name: "D".into(), + asname: None + } + ] + ), + ( + Some("builtins".into()), + vec![ImportAlias { + name: "int".into(), + asname: Some("int2".into()) + }] + ), + ( + Some("foo".into()), + vec![ + ImportAlias { + name: "A".into(), + asname: Some("AAlt".into()) + }, + ImportAlias { + name: "B".into(), + asname: Some("B2".into()) + }, + ImportAlias { + name: "C".into(), + asname: None + }, + ImportAlias { + name: "D".into(), + asname: Some("D2".into()) + } + ] + ), + ( + Some("typing".into()), + vec![ImportAlias { + name: "final".into(), + asname: None + }] + ), + ]) ); let mut output = String::new(); imports.serialize_expr(&big_type, &mut output); - assert_eq!(output, "dict[A, (A3.C, A3.D, B, A2, int, int2, float)]"); + assert_eq!( + output, + "dict[A, (AAlt.C, AAlt.D, B2, C, D2, A2, int, int2, float)]" + ); } } diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index 9894c463628..221a2a43529 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -11,6 +11,8 @@ use syn::{ }; use crate::combine_errors::CombineErrors; +#[cfg(feature = "experimental-inspect")] +use crate::py_stubs::PyStubs; pub mod kw { syn::custom_keyword!(annotation); @@ -56,6 +58,8 @@ pub mod kw { syn::custom_keyword!(category); syn::custom_keyword!(from_py_object); syn::custom_keyword!(skip_from_py_object); + #[cfg(feature = "experimental-inspect")] + syn::custom_keyword!(stubs); } fn take_int(read: &mut &str, tracker: &mut usize) -> String { @@ -349,6 +353,8 @@ pub type TextSignatureAttribute = KeywordAttribute; pub type SubmoduleAttribute = kw::submodule; pub type GILUsedAttribute = KeywordAttribute; +#[cfg(feature = "experimental-inspect")] +pub type StubsAttribute = KeywordAttribute; impl Parse for KeywordAttribute { fn parse(input: ParseStream<'_>) -> Result { diff --git a/pyo3-macros-backend/src/introspection.rs b/pyo3-macros-backend/src/introspection.rs index 6448e30878d..e80ea9ca686 100644 --- a/pyo3-macros-backend/src/introspection.rs +++ b/pyo3-macros-backend/src/introspection.rs @@ -8,8 +8,10 @@ //! The JSON blobs format must be synchronized with the `pyo3_introspection::introspection.rs::Chunk` //! type that is used to parse them. +use crate::json::escape_json_string; use crate::method::{FnArg, RegularArg}; use crate::py_expr::PyExpr; +use crate::py_stubs::PyStubs; use crate::pyfunction::FunctionSignature; use crate::utils::{PyO3CratePath, PythonDoc, StrOrExpr}; use proc_macro2::{Span, TokenStream}; @@ -17,7 +19,6 @@ use quote::{format_ident, quote, ToTokens}; use std::borrow::Cow; use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; -use std::fmt::Write; use std::hash::{Hash, Hasher}; use std::mem::take; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -32,6 +33,7 @@ pub fn module_introspection_code<'a>( members_cfg_attrs: impl IntoIterator>, doc: Option<&PythonDoc>, incomplete: bool, + extra_stubs: Option<&PyStubs>, ) -> TokenStream { let mut desc = HashMap::from([ ("type", IntrospectionNode::String("module".into())), @@ -55,6 +57,9 @@ pub fn module_introspection_code<'a>( if let Some(doc) = doc { desc.insert("doc", IntrospectionNode::Doc(doc)); } + if let Some(stubs) = extra_stubs { + desc.insert("stubs", IntrospectionNode::Stubs(stubs)); + } IntrospectionNode::Map(desc).emit(pyo3_crate_path) } @@ -332,6 +337,7 @@ enum IntrospectionNode<'a> { IntrospectionId(Option>), TypeHint(Cow<'a, PyExpr>), Doc(&'a PythonDoc), + Stubs(&'a PyStubs), Map(HashMap<&'static str, IntrospectionNode<'a>>), List(Vec>), } @@ -390,6 +396,9 @@ impl IntrospectionNode<'_> { } content.push_str("\""); } + Self::Stubs(stubs) => { + content.push_str(&stubs.as_json().to_string()); + } Self::Map(map) => { content.push_str("{"); for (i, (key, value)) in map.into_iter().enumerate() { @@ -591,23 +600,3 @@ fn ident_to_type(ident: &Ident) -> Cow<'static, Type> { .into(), ) } - -fn escape_json_string(value: &str) -> String { - let mut output = String::with_capacity(value.len()); - for c in value.chars() { - match c { - '\\' => output.push_str("\\\\"), - '"' => output.push_str("\\\""), - '\x08' => output.push_str("\\b"), - '\x0C' => output.push_str("\\f"), - '\n' => output.push_str("\\n"), - '\r' => output.push_str("\\r"), - '\t' => output.push_str("\\t"), - c @ '\0'..='\x1F' => { - write!(output, "\\u{:0>4x}", u32::from(c)).unwrap(); - } - c => output.push(c), - } - } - output -} diff --git a/pyo3-macros-backend/src/json.rs b/pyo3-macros-backend/src/json.rs new file mode 100644 index 00000000000..dbb058fff65 --- /dev/null +++ b/pyo3-macros-backend/src/json.rs @@ -0,0 +1,75 @@ +//! JSON-related utilities + +use std::borrow::Cow; +use std::collections::HashMap; +use std::fmt; +use std::fmt::Write as _; + +pub enum JsonValue { + String(Cow<'static, str>), + Number(i16), + Array(Vec), + Object(HashMap<&'static str, JsonValue>), +} + +impl fmt::Display for JsonValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::String(value) => { + f.write_char('"')?; + write_escaped_json_string(value, f)?; + f.write_char('"') + } + Self::Number(value) => value.fmt(f), + JsonValue::Array(values) => { + f.write_char('[')?; + for (i, value) in values.iter().enumerate() { + if i > 0 { + f.write_char(',')?; + } + value.fmt(f)?; + } + f.write_char(']') + } + JsonValue::Object(key_values) => { + f.write_char('{')?; + for (i, (key, value)) in key_values.iter().enumerate() { + if i > 0 { + f.write_char(',')?; + } + f.write_char('"')?; + write_escaped_json_string(key, f)?; + f.write_char('"')?; + f.write_char(':')?; + value.fmt(f)?; + } + f.write_char('}') + } + } + } +} + +pub fn escape_json_string(value: &str) -> String { + let mut output = String::with_capacity(value.len()); + write_escaped_json_string(value, &mut output).unwrap(); + output +} + +fn write_escaped_json_string(value: &str, output: &mut impl fmt::Write) -> fmt::Result { + for c in value.chars() { + match c { + '\\' => output.write_str("\\\\"), + '"' => output.write_str("\\\""), + '\x08' => output.write_str("\\b"), + '\x0C' => output.write_str("\\f"), + '\n' => output.write_str("\\n"), + '\r' => output.write_str("\\r"), + '\t' => output.write_str("\\t"), + c @ '\0'..='\x1F' => { + write!(output, "\\u{:0>4x}", u32::from(c)) + } + c => output.write_char(c), + }?; + } + Ok(()) +} diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index a90fa73678e..87479c956d9 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -15,12 +15,16 @@ mod frompyobject; mod intopyobject; #[cfg(feature = "experimental-inspect")] mod introspection; +#[cfg(feature = "experimental-inspect")] +mod json; mod konst; mod method; mod module; mod params; #[cfg(feature = "experimental-inspect")] mod py_expr; +#[cfg(feature = "experimental-inspect")] +mod py_stubs; mod pyclass; mod pyfunction; mod pyimpl; diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 001c36d1eed..63f19f46579 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -1,5 +1,7 @@ //! Code generation for the function that initializes a python module and adds classes and function. +#[cfg(feature = "experimental-inspect")] +use crate::attributes::StubsAttribute; #[cfg(feature = "experimental-inspect")] use crate::introspection::{ attribute_introspection_code, introspection_id_const, module_introspection_code, @@ -38,6 +40,8 @@ pub struct PyModuleOptions { module: Option, submodule: Option, gil_used: Option, + #[cfg(feature = "experimental-inspect")] + stubs: Option, } impl Parse for PyModuleOptions { @@ -83,9 +87,9 @@ impl PyModuleOptions { submodule, " (it is implicitly always specified for nested modules)" ), - PyModulePyO3Option::GILUsed(gil_used) => { - set_option!(gil_used) - } + PyModulePyO3Option::GILUsed(gil_used) => set_option!(gil_used), + #[cfg(feature = "experimental-inspect")] + PyModulePyO3Option::Stubs(stubs) => set_option!(stubs), } Ok(()) @@ -383,6 +387,7 @@ pub fn pymodule_module_impl( &module_items_cfg_attrs, doc.as_ref(), pymodule_init.is_some(), + options.stubs.as_ref().map(|a| &a.value), ); #[cfg(not(feature = "experimental-inspect"))] let introspection = quote! {}; @@ -471,6 +476,7 @@ pub fn pymodule_function_impl( &[], doc.as_ref(), true, + options.stubs.as_ref().map(|a| &a.value), ); #[cfg(not(feature = "experimental-inspect"))] let introspection = quote! {}; @@ -712,6 +718,8 @@ enum PyModulePyO3Option { Name(NameAttribute), Module(ModuleAttribute), GILUsed(GILUsedAttribute), + #[cfg(feature = "experimental-inspect")] + Stubs(StubsAttribute), } impl Parse for PyModulePyO3Option { @@ -728,6 +736,10 @@ impl Parse for PyModulePyO3Option { } else if lookahead.peek(attributes::kw::gil_used) { input.parse().map(PyModulePyO3Option::GILUsed) } else { + #[cfg(feature = "experimental-inspect")] + if lookahead.peek(attributes::kw::stubs) { + return input.parse().map(PyModulePyO3Option::Stubs); + } Err(lookahead.error()) } } diff --git a/pyo3-macros-backend/src/py_stubs.rs b/pyo3-macros-backend/src/py_stubs.rs new file mode 100644 index 00000000000..f58cf405dbf --- /dev/null +++ b/pyo3-macros-backend/src/py_stubs.rs @@ -0,0 +1,229 @@ +//! Parsing and serialization code for custom type stubs + +use crate::json::JsonValue; +use proc_macro2::{Ident, TokenStream}; +use quote::ToTokens; +use std::collections::HashMap; +use syn::ext::IdentExt; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::{braced, Token}; + +mod kw { + syn::custom_keyword!(import); + syn::custom_keyword!(from); +} + +/// Custom provided stubs in #[pymodule] +pub struct PyStubs { + bracket_token: Brace, + imports: Punctuated, +} + +impl PyStubs { + /// Returns a JSON object following the https://docs.python.org/fr/3/library/ast.html syntax tree + pub fn as_json(&self) -> JsonValue { + JsonValue::Array(self.imports.iter().map(|i| i.as_json()).collect()) + } +} + +impl Parse for PyStubs { + fn parse(input: ParseStream<'_>) -> syn::Result { + let content; + Ok(Self { + bracket_token: braced!(content in input), + imports: content.parse_terminated(PyStatement::parse, Token![;])?, + }) + } +} + +impl ToTokens for PyStubs { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.bracket_token + .surround(tokens, |tokens| self.imports.to_tokens(tokens)) + } +} + +/// A Python statement +enum PyStatement { + ImportFrom(PyImportFrom), + Import(PyImport), +} + +impl PyStatement { + pub fn as_json(&self) -> JsonValue { + match self { + Self::ImportFrom(s) => s.as_json(), + Self::Import(s) => s.as_json(), + } + } +} + +impl Parse for PyStatement { + fn parse(input: ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::from) { + input.parse().map(Self::ImportFrom) + } else if lookahead.peek(kw::import) { + input.parse().map(Self::Import) + } else { + Err(lookahead.error()) + } + } +} + +impl ToTokens for PyStatement { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::ImportFrom(s) => s.to_tokens(tokens), + Self::Import(s) => s.to_tokens(tokens), + } + } +} + +/// `from {module} import {names}` +struct PyImportFrom { + pub from_token: kw::from, + pub module: Ident, + pub import_token: kw::import, + pub names: Punctuated, +} + +impl PyImportFrom { + pub fn as_json(&self) -> JsonValue { + JsonValue::Object(HashMap::from([ + ("type", JsonValue::String("importfrom".into())), + ( + "module", + JsonValue::String(self.module.unraw().to_string().into()), + ), + ( + "names", + JsonValue::Array(self.names.iter().map(|i| i.as_json()).collect()), + ), + ("level", JsonValue::Number(0)), + ])) + } +} + +impl Parse for PyImportFrom { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + from_token: input.parse()?, + module: input.parse()?, + import_token: input.parse()?, + names: Punctuated::parse_separated_nonempty(input)?, + }) + } +} + +impl ToTokens for PyImportFrom { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.from_token.to_tokens(tokens); + self.module.to_tokens(tokens); + self.import_token.to_tokens(tokens); + self.names.to_tokens(tokens); + } +} + +/// `import {names}` +struct PyImport { + pub import_token: kw::import, + pub names: Punctuated, +} + +impl PyImport { + pub fn as_json(&self) -> JsonValue { + JsonValue::Object(HashMap::from([ + ("type", JsonValue::String("import".into())), + ( + "names", + JsonValue::Array(self.names.iter().map(|i| i.as_json()).collect()), + ), + ])) + } +} + +impl Parse for PyImport { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + import_token: input.parse()?, + names: Punctuated::parse_separated_nonempty(input)?, + }) + } +} + +impl ToTokens for PyImport { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.import_token.to_tokens(tokens); + self.names.to_tokens(tokens); + } +} + +/// `{name} [as {as_name}]` +struct PyAlias { + pub name: Ident, + pub as_name: Option, +} + +impl PyAlias { + pub fn as_json(&self) -> JsonValue { + let mut args = HashMap::from([ + ("type", JsonValue::String("alias".into())), + ( + "name", + JsonValue::String(self.name.unraw().to_string().into()), + ), + ]); + if let Some(as_name) = &self.as_name { + args.insert( + "asname", + JsonValue::String(as_name.name.unraw().to_string().into()), + ); + } + JsonValue::Object(args) + } +} + +impl Parse for PyAlias { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + name: input.parse()?, + as_name: if input.lookahead1().peek(Token![as]) { + Some(input.parse()?) + } else { + None + }, + }) + } +} + +impl ToTokens for PyAlias { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.name.to_tokens(tokens); + self.as_name.to_tokens(tokens); + } +} + +/// `as {name}` +struct PyAliasAsName { + pub as_token: Token![as], + pub name: Ident, +} + +impl Parse for PyAliasAsName { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + as_token: input.parse()?, + name: input.parse()?, + }) + } +} + +impl ToTokens for PyAliasAsName { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.as_token.to_tokens(tokens); + self.name.to_tokens(tokens); + } +} diff --git a/pytests/src/annotations.rs b/pytests/src/annotations.rs new file mode 100644 index 00000000000..1abe4fc77d1 --- /dev/null +++ b/pytests/src/annotations.rs @@ -0,0 +1,30 @@ +//! Example of custom annotations. + +use pyo3::prelude::*; + +#[pymodule(stubs = { + from datetime import datetime as dt, time; + from uuid import UUID; +})] +pub mod annotations { + use pyo3::prelude::*; + use pyo3::types::{PyDate, PyDateTime, PyDict, PyTime, PyTuple}; + + #[pyfunction(signature = (a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool") -> "int")] + fn with_custom_type_annotations<'py>( + a: Bound<'py, PyAny>, + _args: Bound<'py, PyTuple>, + _b: Option>, + _kwargs: Option>, + ) -> Bound<'py, PyAny> { + a + } + + #[pyfunction] + fn with_built_in_type_annotations( + _date_time: Bound<'_, PyDateTime>, + _time: Bound<'_, PyTime>, + _date: Bound<'_, PyDate>, + ) { + } +} diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index f6f4b151e6e..00b8d4f8e21 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -1,6 +1,8 @@ use pyo3::prelude::*; use pyo3::types::PyDict; +#[cfg(feature = "experimental-inspect")] +mod annotations; mod awaitable; mod buf_and_str; mod comparisons; @@ -32,6 +34,10 @@ mod pyo3_pytests { #[pymodule_export] use datetime::datetime; + #[cfg(feature = "experimental-inspect")] + #[pymodule_export] + use annotations::annotations; + #[pymodule_export] use { awaitable::awaitable, comparisons::comparisons, consts::consts, dict_iter::dict_iter, diff --git a/pytests/src/pyfunctions.rs b/pytests/src/pyfunctions.rs index 6e1015e7627..e1ffb444cac 100644 --- a/pytests/src/pyfunctions.rs +++ b/pytests/src/pyfunctions.rs @@ -77,17 +77,6 @@ fn with_typed_args(a: bool, b: u64, c: f64, d: &str) -> (bool, u64, f64, &str) { (a, b, c, d) } -#[cfg(feature = "experimental-inspect")] -#[pyfunction(signature = (a: "int", *_args: "str", _b: "int | None" = None, **_kwargs: "bool") -> "int")] -fn with_custom_type_annotations<'py>( - a: Any<'py>, - _args: Tuple<'py>, - _b: Option>, - _kwargs: Option>, -) -> Any<'py> { - a -} - #[cfg(feature = "experimental-async")] #[pyfunction] async fn with_async() {} @@ -143,9 +132,6 @@ pub mod pyfunctions { #[cfg(feature = "experimental-async")] #[pymodule_export] use super::with_async; - #[cfg(feature = "experimental-inspect")] - #[pymodule_export] - use super::with_custom_type_annotations; #[pymodule_export] use super::{ args_kwargs, many_keyword_arguments, none, positional_only, simple, simple_args, diff --git a/pytests/stubs/annotations.pyi b/pytests/stubs/annotations.pyi new file mode 100644 index 00000000000..e7e3fb32c97 --- /dev/null +++ b/pytests/stubs/annotations.pyi @@ -0,0 +1,9 @@ +from datetime import date, datetime as dt, time +from uuid import UUID + +def with_built_in_type_annotations( + _date_time: dt, _time: time, _date: date +) -> None: ... +def with_custom_type_annotations( + a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" +) -> "int": ... diff --git a/pytests/stubs/pyfunctions.pyi b/pytests/stubs/pyfunctions.pyi index 1d5cca9a33d..322f2642339 100644 --- a/pytests/stubs/pyfunctions.pyi +++ b/pytests/stubs/pyfunctions.pyi @@ -35,9 +35,6 @@ def simple_kwargs( a: Any, b: Any | None = None, c: Any | None = None, **kwargs ) -> tuple[Any, Any | None, Any | None, dict | None]: ... async def with_async() -> None: ... -def with_custom_type_annotations( - a: "int", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" -) -> "int": ... def with_typed_args( a: bool = False, b: int = 0, c: float = 0.0, d: str = "" ) -> tuple[bool, int, float, str]: ... diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 423f0dd5cea..4bceeb5dbf4 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -31,6 +31,7 @@ fn test_compile_errors() { t.compile_fail("tests/ui/invalid_pymethods_duplicates.rs"); t.compile_fail("tests/ui/invalid_pymethod_enum.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); + #[cfg(not(feature = "experimental-inspect"))] t.compile_fail("tests/ui/invalid_pymodule_args.rs"); t.compile_fail("tests/ui/invalid_pycallargs.rs"); t.compile_fail("tests/ui/reject_generics.rs");