diff --git a/newsfragments/5877.added.md b/newsfragments/5877.added.md new file mode 100644 index 00000000000..5666e5d5c67 --- /dev/null +++ b/newsfragments/5877.added.md @@ -0,0 +1 @@ +Introspection: allow to set custom stub imports in `#[pymodule]` macro \ No newline at end of file diff --git a/pyo3-introspection/src/introspection.rs b/pyo3-introspection/src/introspection.rs index 88931bb38bd..5f949e8f623 100644 --- a/pyo3-introspection/src/introspection.rs +++ b/pyo3-introspection/src/introspection.rs @@ -1,6 +1,6 @@ use crate::model::{ - Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator, - VariableLengthArgument, + Argument, Arguments, Attribute, Class, Constant, Expr, Function, ImportAlias, Module, Operator, + Statement, VariableLengthArgument, }; use anyhow::{anyhow, bail, ensure, Context, Result}; use goblin::elf::section_header::SHN_XINDEX; @@ -52,6 +52,7 @@ fn parse_chunks(chunks: &[Chunk], main_module_name: &str) -> Result { members, doc, incomplete, + stubs, } = chunk { if name == main_module_name { @@ -65,6 +66,7 @@ fn parse_chunks(chunks: &[Chunk], main_module_name: &str) -> Result { name, members, *incomplete, + stubs, doc.as_deref(), &chunks_by_id, &chunks_by_parent, @@ -82,6 +84,7 @@ fn convert_module( name: &str, members: &[String], mut incomplete: bool, + stubs: &[ChunkStatement], docstring: Option<&str>, chunks_by_id: &HashMap<&str, &Chunk>, chunks_by_parent: &HashMap<&str, Vec<&Chunk>>, @@ -114,6 +117,35 @@ fn convert_module( functions, attributes, incomplete, + stubs: stubs + .iter() + .map(|statement| match statement { + ChunkStatement::ImportFrom { + module, + names, + level, + } => Statement::ImportFrom { + module: module.clone(), + names: names + .iter() + .map(|alias| ImportAlias { + name: alias.name.clone(), + asname: alias.asname.clone(), + }) + .collect(), + level: *level, + }, + ChunkStatement::Import { names } => Statement::Import { + names: names + .iter() + .map(|alias| ImportAlias { + name: alias.name.clone(), + asname: alias.asname.clone(), + }) + .collect(), + }, + }) + .collect(), docstring: docstring.map(Into::into), }) } @@ -139,12 +171,14 @@ fn convert_members<'a>( members, incomplete, doc, + stubs, } => { modules.push(convert_module( id, name, members, *incomplete, + stubs, doc.as_deref(), chunks_by_id, chunks_by_parent, @@ -672,6 +706,8 @@ enum Chunk { #[serde(default)] doc: Option, incomplete: bool, + #[serde(default)] + stubs: Vec, }, Class { id: String, @@ -739,6 +775,25 @@ struct ChunkArgument { annotation: Option, } +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +enum ChunkStatement { + ImportFrom { + module: String, + names: Vec, + level: usize, + }, + Import { + names: Vec, + }, +} + +#[derive(Deserialize)] +struct ChunkAlias { + name: String, + asname: Option, +} + #[derive(Deserialize)] #[serde(tag = "type", rename_all = "lowercase")] enum ChunkExpr { diff --git a/pyo3-introspection/src/model.rs b/pyo3-introspection/src/model.rs index b6ca8d28a8d..f30e93c8aa8 100644 --- a/pyo3-introspection/src/model.rs +++ b/pyo3-introspection/src/model.rs @@ -6,6 +6,7 @@ pub struct Module { pub functions: Vec, pub attributes: Vec, pub incomplete: bool, + pub stubs: Vec, pub docstring: Option, } @@ -74,6 +75,30 @@ pub struct VariableLengthArgument { pub annotation: Option, } +/// A python statement +/// +/// This is the `stmt` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub enum Statement { + /// `from {module} import {names}` + ImportFrom { + module: String, + names: Vec, + level: usize, + }, + /// `import {names}` + Import { names: Vec }, +} + +/// A python import alias `{name} as {asname}` +/// +/// This is the `alias` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) +#[derive(Debug, Eq, PartialEq, Clone, Hash)] +pub struct ImportAlias { + pub name: String, + pub asname: Option, +} + /// A python expression /// /// This is the `expr` production of the [Python `ast` module grammar](https://docs.python.org/3/library/ast.html#abstract-grammar) diff --git a/pyo3-introspection/src/stubs.rs b/pyo3-introspection/src/stubs.rs index 9877652cc40..9f285e47214 100644 --- a/pyo3-introspection/src/stubs.rs +++ b/pyo3-introspection/src/stubs.rs @@ -1,6 +1,6 @@ use crate::model::{ - Argument, Arguments, Attribute, Class, Constant, Expr, Function, Module, Operator, - VariableLengthArgument, + Argument, Arguments, Attribute, Class, Constant, Expr, Function, ImportAlias, Module, Operator, + Statement, VariableLengthArgument, }; use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::fmt::Write; @@ -95,7 +95,27 @@ fn module_stubs(module: &Module, parents: &[&str]) -> String { if let Some(docstring) = &module.docstring { final_elements.push(format!("\"\"\"\n{docstring}\n\"\"\"")); } - final_elements.extend(imports.imports); + for (module, names) in &imports.imports { + let mut line = String::new(); + if let Some(module) = module { + line.push_str("from "); + line.push_str(module); + line.push(' '); + } + line.push_str("import"); + for (i, name) in names.iter().enumerate() { + if i > 0 { + line.push(','); + } + line.push(' '); + line.push_str(&name.name); + if let Some(asname) = &name.asname { + line.push_str(" as "); + line.push_str(asname); + } + } + final_elements.push(line); + } final_elements.extend(elements); let mut output = String::new(); @@ -294,7 +314,7 @@ fn variable_length_argument_stub(argument: &VariableLengthArgument, imports: &Im #[derive(Default)] struct Imports { /// Import lines ready to use - imports: Vec, + imports: BTreeMap, Vec>, /// Renaming map: from module name and member name return the name to use in type hints renaming: BTreeMap<(String, String), String>, } @@ -311,7 +331,7 @@ impl Imports { let mut elements_used_in_annotations = ElementsUsedInAnnotations::new(); elements_used_in_annotations.walk_module(module); - let mut imports = Vec::new(); + let mut imports = BTreeMap::, Vec>::new(); let mut renaming = BTreeMap::new(); let mut local_name_to_module_and_attribute = BTreeMap::new(); @@ -334,10 +354,38 @@ impl Imports { local_name_to_module_and_attribute .insert(name.clone(), (current_module_name.clone(), name.clone())); } - // We don't process the current module elements, no need to care about them - local_name_to_module_and_attribute.remove(¤t_module_name); - // We process then imports, normalizing local imports + // Also, we insert the imports from the user-provided stub + for statement in &module.stubs { + let (module, names) = match statement { + Statement::Import { names } => (None, names), + Statement::ImportFrom { + module, + names, + level, + } => { + // We build the python module relative path + let mut module = module.clone(); + for _ in 0..*level { + module = format!(".{module}") + } + (Some(module), names) + } + }; + for name in names { + let module_and_name = (module.clone().unwrap_or_default(), name.name.clone()); + let local_name = name.asname.as_ref().unwrap_or(&name.name).clone(); + local_name_to_module_and_attribute + .insert(local_name.clone(), module_and_name.clone()); + renaming.insert(module_and_name, local_name); + } + imports + .entry(module) + .or_default() + .extend(names.iter().cloned()); + } + + // Finally, We process imports from built-in annotations (always absolute) for (module, attrs) in &elements_used_in_annotations.module_to_name { let mut import_for_module = Vec::new(); for attr in attrs { @@ -345,6 +393,18 @@ impl Imports { let (root_attr, attr_path) = attr .split_once('.') .map_or((attr.as_str(), None), |(root, path)| (root, Some(path))); + + if let Some(local_name) = renaming.get(&(module.clone(), root_attr.to_owned())) { + // it's already imported, we make sure to get a renaming for the nested class if relevant + if let Some(attr_path) = &attr_path { + renaming.insert( + (module.clone(), attr.clone()), + format!("{local_name}.{attr_path}"), + ); + } + continue; + } + let mut local_name = root_attr.to_owned(); let mut already_imported = false; while let Some((possible_conflict_module, possible_conflict_attr)) = @@ -383,21 +443,31 @@ impl Imports { let is_not_aliased_builtin = module == "builtins" && local_name == root_attr; if !is_not_aliased_builtin { import_for_module.push(if local_name == root_attr { - local_name + ImportAlias { + name: local_name, + asname: None, + } } else { - format!("{root_attr} as {local_name}") + ImportAlias { + name: root_attr.into(), + asname: Some(local_name), + } }); } } } if !import_for_module.is_empty() { - imports.push(format!( - "from {module} import {}", - import_for_module.join(", ") - )); + imports + .entry(Some(module.clone())) + .or_default() + .extend(import_for_module); } } - imports.sort(); // We make sure they are sorted + + // We sort imports + for names in imports.values_mut() { + names.sort_by(|l, r| (&l.name, &l.asname).cmp(&(&r.name, &r.asname))); + } Self { imports, renaming } } @@ -628,7 +698,7 @@ impl ElementsUsedInAnnotations { #[cfg(test)] mod tests { use super::*; - use crate::model::Arguments; + use crate::model::{Arguments, ImportAlias, Statement}; #[test] fn function_stubs_with_variable_length() { @@ -769,6 +839,14 @@ mod tests { value: Box::new(Expr::Name { id: "foo".into() }), attr: "B".into(), }, + Expr::Attribute { + value: Box::new(Expr::Name { id: "foo".into() }), + attr: "C".into(), + }, + Expr::Attribute { + value: Box::new(Expr::Name { id: "foo".into() }), + attr: "D".into(), + }, Expr::Attribute { value: Box::new(Expr::Name { id: "bat".into() }), attr: "A".into(), @@ -831,22 +909,116 @@ mod tests { }], attributes: Vec::new(), incomplete: true, + stubs: vec![ + Statement::ImportFrom { + module: "foo".into(), + names: vec![ + ImportAlias { + name: "A".into(), + asname: Some("AAlt".into()), + }, + ImportAlias { + name: "B".into(), + asname: Some("B2".into()), + }, + ImportAlias { + name: "C".into(), + asname: None, + }, + ], + level: 0, + }, + Statement::Import { + names: vec![ImportAlias { + name: "bat".into(), + asname: None, + }], + }, + Statement::ImportFrom { + module: "bat".into(), + names: vec![ImportAlias { + name: "D".into(), + asname: None, + }], + level: 0, + }, + ], docstring: None, }, &["foo"], ); assert_eq!( - &imports.imports, - &[ - "from _typeshed import Incomplete", - "from bat import A as A2", - "from builtins import int as int2", - "from foo import A as A3, B", - "from typing import final" - ] + imports.imports, + BTreeMap::from([ + ( + None, + vec![ImportAlias { + name: "bat".into(), + asname: None + }] + ), + ( + Some("_typeshed".to_string()), + vec![ImportAlias { + name: "Incomplete".into(), + asname: None + }] + ), + ( + Some("bat".into()), + vec![ + ImportAlias { + name: "A".into(), + asname: Some("A2".into()) + }, + ImportAlias { + name: "D".into(), + asname: None + } + ] + ), + ( + Some("builtins".into()), + vec![ImportAlias { + name: "int".into(), + asname: Some("int2".into()) + }] + ), + ( + Some("foo".into()), + vec![ + ImportAlias { + name: "A".into(), + asname: Some("AAlt".into()) + }, + ImportAlias { + name: "B".into(), + asname: Some("B2".into()) + }, + ImportAlias { + name: "C".into(), + asname: None + }, + ImportAlias { + name: "D".into(), + asname: Some("D2".into()) + } + ] + ), + ( + Some("typing".into()), + vec![ImportAlias { + name: "final".into(), + asname: None + }] + ), + ]) ); let mut output = String::new(); imports.serialize_expr(&big_type, &mut output); - assert_eq!(output, "dict[A, (A3.C, A3.D, B, A2, int, int2, float)]"); + assert_eq!( + output, + "dict[A, (AAlt.C, AAlt.D, B2, C, D2, A2, int, int2, float)]" + ); } } diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index 9894c463628..221a2a43529 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -11,6 +11,8 @@ use syn::{ }; use crate::combine_errors::CombineErrors; +#[cfg(feature = "experimental-inspect")] +use crate::py_stubs::PyStubs; pub mod kw { syn::custom_keyword!(annotation); @@ -56,6 +58,8 @@ pub mod kw { syn::custom_keyword!(category); syn::custom_keyword!(from_py_object); syn::custom_keyword!(skip_from_py_object); + #[cfg(feature = "experimental-inspect")] + syn::custom_keyword!(stubs); } fn take_int(read: &mut &str, tracker: &mut usize) -> String { @@ -349,6 +353,8 @@ pub type TextSignatureAttribute = KeywordAttribute; pub type SubmoduleAttribute = kw::submodule; pub type GILUsedAttribute = KeywordAttribute; +#[cfg(feature = "experimental-inspect")] +pub type StubsAttribute = KeywordAttribute; impl Parse for KeywordAttribute { fn parse(input: ParseStream<'_>) -> Result { diff --git a/pyo3-macros-backend/src/introspection.rs b/pyo3-macros-backend/src/introspection.rs index 6448e30878d..e80ea9ca686 100644 --- a/pyo3-macros-backend/src/introspection.rs +++ b/pyo3-macros-backend/src/introspection.rs @@ -8,8 +8,10 @@ //! The JSON blobs format must be synchronized with the `pyo3_introspection::introspection.rs::Chunk` //! type that is used to parse them. +use crate::json::escape_json_string; use crate::method::{FnArg, RegularArg}; use crate::py_expr::PyExpr; +use crate::py_stubs::PyStubs; use crate::pyfunction::FunctionSignature; use crate::utils::{PyO3CratePath, PythonDoc, StrOrExpr}; use proc_macro2::{Span, TokenStream}; @@ -17,7 +19,6 @@ use quote::{format_ident, quote, ToTokens}; use std::borrow::Cow; use std::collections::hash_map::DefaultHasher; use std::collections::HashMap; -use std::fmt::Write; use std::hash::{Hash, Hasher}; use std::mem::take; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -32,6 +33,7 @@ pub fn module_introspection_code<'a>( members_cfg_attrs: impl IntoIterator>, doc: Option<&PythonDoc>, incomplete: bool, + extra_stubs: Option<&PyStubs>, ) -> TokenStream { let mut desc = HashMap::from([ ("type", IntrospectionNode::String("module".into())), @@ -55,6 +57,9 @@ pub fn module_introspection_code<'a>( if let Some(doc) = doc { desc.insert("doc", IntrospectionNode::Doc(doc)); } + if let Some(stubs) = extra_stubs { + desc.insert("stubs", IntrospectionNode::Stubs(stubs)); + } IntrospectionNode::Map(desc).emit(pyo3_crate_path) } @@ -332,6 +337,7 @@ enum IntrospectionNode<'a> { IntrospectionId(Option>), TypeHint(Cow<'a, PyExpr>), Doc(&'a PythonDoc), + Stubs(&'a PyStubs), Map(HashMap<&'static str, IntrospectionNode<'a>>), List(Vec>), } @@ -390,6 +396,9 @@ impl IntrospectionNode<'_> { } content.push_str("\""); } + Self::Stubs(stubs) => { + content.push_str(&stubs.as_json().to_string()); + } Self::Map(map) => { content.push_str("{"); for (i, (key, value)) in map.into_iter().enumerate() { @@ -591,23 +600,3 @@ fn ident_to_type(ident: &Ident) -> Cow<'static, Type> { .into(), ) } - -fn escape_json_string(value: &str) -> String { - let mut output = String::with_capacity(value.len()); - for c in value.chars() { - match c { - '\\' => output.push_str("\\\\"), - '"' => output.push_str("\\\""), - '\x08' => output.push_str("\\b"), - '\x0C' => output.push_str("\\f"), - '\n' => output.push_str("\\n"), - '\r' => output.push_str("\\r"), - '\t' => output.push_str("\\t"), - c @ '\0'..='\x1F' => { - write!(output, "\\u{:0>4x}", u32::from(c)).unwrap(); - } - c => output.push(c), - } - } - output -} diff --git a/pyo3-macros-backend/src/json.rs b/pyo3-macros-backend/src/json.rs new file mode 100644 index 00000000000..dbb058fff65 --- /dev/null +++ b/pyo3-macros-backend/src/json.rs @@ -0,0 +1,75 @@ +//! JSON-related utilities + +use std::borrow::Cow; +use std::collections::HashMap; +use std::fmt; +use std::fmt::Write as _; + +pub enum JsonValue { + String(Cow<'static, str>), + Number(i16), + Array(Vec), + Object(HashMap<&'static str, JsonValue>), +} + +impl fmt::Display for JsonValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::String(value) => { + f.write_char('"')?; + write_escaped_json_string(value, f)?; + f.write_char('"') + } + Self::Number(value) => value.fmt(f), + JsonValue::Array(values) => { + f.write_char('[')?; + for (i, value) in values.iter().enumerate() { + if i > 0 { + f.write_char(',')?; + } + value.fmt(f)?; + } + f.write_char(']') + } + JsonValue::Object(key_values) => { + f.write_char('{')?; + for (i, (key, value)) in key_values.iter().enumerate() { + if i > 0 { + f.write_char(',')?; + } + f.write_char('"')?; + write_escaped_json_string(key, f)?; + f.write_char('"')?; + f.write_char(':')?; + value.fmt(f)?; + } + f.write_char('}') + } + } + } +} + +pub fn escape_json_string(value: &str) -> String { + let mut output = String::with_capacity(value.len()); + write_escaped_json_string(value, &mut output).unwrap(); + output +} + +fn write_escaped_json_string(value: &str, output: &mut impl fmt::Write) -> fmt::Result { + for c in value.chars() { + match c { + '\\' => output.write_str("\\\\"), + '"' => output.write_str("\\\""), + '\x08' => output.write_str("\\b"), + '\x0C' => output.write_str("\\f"), + '\n' => output.write_str("\\n"), + '\r' => output.write_str("\\r"), + '\t' => output.write_str("\\t"), + c @ '\0'..='\x1F' => { + write!(output, "\\u{:0>4x}", u32::from(c)) + } + c => output.write_char(c), + }?; + } + Ok(()) +} diff --git a/pyo3-macros-backend/src/lib.rs b/pyo3-macros-backend/src/lib.rs index a90fa73678e..87479c956d9 100644 --- a/pyo3-macros-backend/src/lib.rs +++ b/pyo3-macros-backend/src/lib.rs @@ -15,12 +15,16 @@ mod frompyobject; mod intopyobject; #[cfg(feature = "experimental-inspect")] mod introspection; +#[cfg(feature = "experimental-inspect")] +mod json; mod konst; mod method; mod module; mod params; #[cfg(feature = "experimental-inspect")] mod py_expr; +#[cfg(feature = "experimental-inspect")] +mod py_stubs; mod pyclass; mod pyfunction; mod pyimpl; diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 001c36d1eed..63f19f46579 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -1,5 +1,7 @@ //! Code generation for the function that initializes a python module and adds classes and function. +#[cfg(feature = "experimental-inspect")] +use crate::attributes::StubsAttribute; #[cfg(feature = "experimental-inspect")] use crate::introspection::{ attribute_introspection_code, introspection_id_const, module_introspection_code, @@ -38,6 +40,8 @@ pub struct PyModuleOptions { module: Option, submodule: Option, gil_used: Option, + #[cfg(feature = "experimental-inspect")] + stubs: Option, } impl Parse for PyModuleOptions { @@ -83,9 +87,9 @@ impl PyModuleOptions { submodule, " (it is implicitly always specified for nested modules)" ), - PyModulePyO3Option::GILUsed(gil_used) => { - set_option!(gil_used) - } + PyModulePyO3Option::GILUsed(gil_used) => set_option!(gil_used), + #[cfg(feature = "experimental-inspect")] + PyModulePyO3Option::Stubs(stubs) => set_option!(stubs), } Ok(()) @@ -383,6 +387,7 @@ pub fn pymodule_module_impl( &module_items_cfg_attrs, doc.as_ref(), pymodule_init.is_some(), + options.stubs.as_ref().map(|a| &a.value), ); #[cfg(not(feature = "experimental-inspect"))] let introspection = quote! {}; @@ -471,6 +476,7 @@ pub fn pymodule_function_impl( &[], doc.as_ref(), true, + options.stubs.as_ref().map(|a| &a.value), ); #[cfg(not(feature = "experimental-inspect"))] let introspection = quote! {}; @@ -712,6 +718,8 @@ enum PyModulePyO3Option { Name(NameAttribute), Module(ModuleAttribute), GILUsed(GILUsedAttribute), + #[cfg(feature = "experimental-inspect")] + Stubs(StubsAttribute), } impl Parse for PyModulePyO3Option { @@ -728,6 +736,10 @@ impl Parse for PyModulePyO3Option { } else if lookahead.peek(attributes::kw::gil_used) { input.parse().map(PyModulePyO3Option::GILUsed) } else { + #[cfg(feature = "experimental-inspect")] + if lookahead.peek(attributes::kw::stubs) { + return input.parse().map(PyModulePyO3Option::Stubs); + } Err(lookahead.error()) } } diff --git a/pyo3-macros-backend/src/py_stubs.rs b/pyo3-macros-backend/src/py_stubs.rs new file mode 100644 index 00000000000..f58cf405dbf --- /dev/null +++ b/pyo3-macros-backend/src/py_stubs.rs @@ -0,0 +1,229 @@ +//! Parsing and serialization code for custom type stubs + +use crate::json::JsonValue; +use proc_macro2::{Ident, TokenStream}; +use quote::ToTokens; +use std::collections::HashMap; +use syn::ext::IdentExt; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::token::Brace; +use syn::{braced, Token}; + +mod kw { + syn::custom_keyword!(import); + syn::custom_keyword!(from); +} + +/// Custom provided stubs in #[pymodule] +pub struct PyStubs { + bracket_token: Brace, + imports: Punctuated, +} + +impl PyStubs { + /// Returns a JSON object following the https://docs.python.org/fr/3/library/ast.html syntax tree + pub fn as_json(&self) -> JsonValue { + JsonValue::Array(self.imports.iter().map(|i| i.as_json()).collect()) + } +} + +impl Parse for PyStubs { + fn parse(input: ParseStream<'_>) -> syn::Result { + let content; + Ok(Self { + bracket_token: braced!(content in input), + imports: content.parse_terminated(PyStatement::parse, Token![;])?, + }) + } +} + +impl ToTokens for PyStubs { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.bracket_token + .surround(tokens, |tokens| self.imports.to_tokens(tokens)) + } +} + +/// A Python statement +enum PyStatement { + ImportFrom(PyImportFrom), + Import(PyImport), +} + +impl PyStatement { + pub fn as_json(&self) -> JsonValue { + match self { + Self::ImportFrom(s) => s.as_json(), + Self::Import(s) => s.as_json(), + } + } +} + +impl Parse for PyStatement { + fn parse(input: ParseStream<'_>) -> syn::Result { + let lookahead = input.lookahead1(); + if lookahead.peek(kw::from) { + input.parse().map(Self::ImportFrom) + } else if lookahead.peek(kw::import) { + input.parse().map(Self::Import) + } else { + Err(lookahead.error()) + } + } +} + +impl ToTokens for PyStatement { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + Self::ImportFrom(s) => s.to_tokens(tokens), + Self::Import(s) => s.to_tokens(tokens), + } + } +} + +/// `from {module} import {names}` +struct PyImportFrom { + pub from_token: kw::from, + pub module: Ident, + pub import_token: kw::import, + pub names: Punctuated, +} + +impl PyImportFrom { + pub fn as_json(&self) -> JsonValue { + JsonValue::Object(HashMap::from([ + ("type", JsonValue::String("importfrom".into())), + ( + "module", + JsonValue::String(self.module.unraw().to_string().into()), + ), + ( + "names", + JsonValue::Array(self.names.iter().map(|i| i.as_json()).collect()), + ), + ("level", JsonValue::Number(0)), + ])) + } +} + +impl Parse for PyImportFrom { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + from_token: input.parse()?, + module: input.parse()?, + import_token: input.parse()?, + names: Punctuated::parse_separated_nonempty(input)?, + }) + } +} + +impl ToTokens for PyImportFrom { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.from_token.to_tokens(tokens); + self.module.to_tokens(tokens); + self.import_token.to_tokens(tokens); + self.names.to_tokens(tokens); + } +} + +/// `import {names}` +struct PyImport { + pub import_token: kw::import, + pub names: Punctuated, +} + +impl PyImport { + pub fn as_json(&self) -> JsonValue { + JsonValue::Object(HashMap::from([ + ("type", JsonValue::String("import".into())), + ( + "names", + JsonValue::Array(self.names.iter().map(|i| i.as_json()).collect()), + ), + ])) + } +} + +impl Parse for PyImport { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + import_token: input.parse()?, + names: Punctuated::parse_separated_nonempty(input)?, + }) + } +} + +impl ToTokens for PyImport { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.import_token.to_tokens(tokens); + self.names.to_tokens(tokens); + } +} + +/// `{name} [as {as_name}]` +struct PyAlias { + pub name: Ident, + pub as_name: Option, +} + +impl PyAlias { + pub fn as_json(&self) -> JsonValue { + let mut args = HashMap::from([ + ("type", JsonValue::String("alias".into())), + ( + "name", + JsonValue::String(self.name.unraw().to_string().into()), + ), + ]); + if let Some(as_name) = &self.as_name { + args.insert( + "asname", + JsonValue::String(as_name.name.unraw().to_string().into()), + ); + } + JsonValue::Object(args) + } +} + +impl Parse for PyAlias { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + name: input.parse()?, + as_name: if input.lookahead1().peek(Token![as]) { + Some(input.parse()?) + } else { + None + }, + }) + } +} + +impl ToTokens for PyAlias { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.name.to_tokens(tokens); + self.as_name.to_tokens(tokens); + } +} + +/// `as {name}` +struct PyAliasAsName { + pub as_token: Token![as], + pub name: Ident, +} + +impl Parse for PyAliasAsName { + fn parse(input: ParseStream<'_>) -> syn::Result { + Ok(Self { + as_token: input.parse()?, + name: input.parse()?, + }) + } +} + +impl ToTokens for PyAliasAsName { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.as_token.to_tokens(tokens); + self.name.to_tokens(tokens); + } +} diff --git a/pytests/src/annotations.rs b/pytests/src/annotations.rs new file mode 100644 index 00000000000..1abe4fc77d1 --- /dev/null +++ b/pytests/src/annotations.rs @@ -0,0 +1,30 @@ +//! Example of custom annotations. + +use pyo3::prelude::*; + +#[pymodule(stubs = { + from datetime import datetime as dt, time; + from uuid import UUID; +})] +pub mod annotations { + use pyo3::prelude::*; + use pyo3::types::{PyDate, PyDateTime, PyDict, PyTime, PyTuple}; + + #[pyfunction(signature = (a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool") -> "int")] + fn with_custom_type_annotations<'py>( + a: Bound<'py, PyAny>, + _args: Bound<'py, PyTuple>, + _b: Option>, + _kwargs: Option>, + ) -> Bound<'py, PyAny> { + a + } + + #[pyfunction] + fn with_built_in_type_annotations( + _date_time: Bound<'_, PyDateTime>, + _time: Bound<'_, PyTime>, + _date: Bound<'_, PyDate>, + ) { + } +} diff --git a/pytests/src/lib.rs b/pytests/src/lib.rs index f6f4b151e6e..00b8d4f8e21 100644 --- a/pytests/src/lib.rs +++ b/pytests/src/lib.rs @@ -1,6 +1,8 @@ use pyo3::prelude::*; use pyo3::types::PyDict; +#[cfg(feature = "experimental-inspect")] +mod annotations; mod awaitable; mod buf_and_str; mod comparisons; @@ -32,6 +34,10 @@ mod pyo3_pytests { #[pymodule_export] use datetime::datetime; + #[cfg(feature = "experimental-inspect")] + #[pymodule_export] + use annotations::annotations; + #[pymodule_export] use { awaitable::awaitable, comparisons::comparisons, consts::consts, dict_iter::dict_iter, diff --git a/pytests/src/pyfunctions.rs b/pytests/src/pyfunctions.rs index 6e1015e7627..e1ffb444cac 100644 --- a/pytests/src/pyfunctions.rs +++ b/pytests/src/pyfunctions.rs @@ -77,17 +77,6 @@ fn with_typed_args(a: bool, b: u64, c: f64, d: &str) -> (bool, u64, f64, &str) { (a, b, c, d) } -#[cfg(feature = "experimental-inspect")] -#[pyfunction(signature = (a: "int", *_args: "str", _b: "int | None" = None, **_kwargs: "bool") -> "int")] -fn with_custom_type_annotations<'py>( - a: Any<'py>, - _args: Tuple<'py>, - _b: Option>, - _kwargs: Option>, -) -> Any<'py> { - a -} - #[cfg(feature = "experimental-async")] #[pyfunction] async fn with_async() {} @@ -143,9 +132,6 @@ pub mod pyfunctions { #[cfg(feature = "experimental-async")] #[pymodule_export] use super::with_async; - #[cfg(feature = "experimental-inspect")] - #[pymodule_export] - use super::with_custom_type_annotations; #[pymodule_export] use super::{ args_kwargs, many_keyword_arguments, none, positional_only, simple, simple_args, diff --git a/pytests/stubs/annotations.pyi b/pytests/stubs/annotations.pyi new file mode 100644 index 00000000000..e7e3fb32c97 --- /dev/null +++ b/pytests/stubs/annotations.pyi @@ -0,0 +1,9 @@ +from datetime import date, datetime as dt, time +from uuid import UUID + +def with_built_in_type_annotations( + _date_time: dt, _time: time, _date: date +) -> None: ... +def with_custom_type_annotations( + a: "dt | time | UUID", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" +) -> "int": ... diff --git a/pytests/stubs/pyfunctions.pyi b/pytests/stubs/pyfunctions.pyi index 1d5cca9a33d..322f2642339 100644 --- a/pytests/stubs/pyfunctions.pyi +++ b/pytests/stubs/pyfunctions.pyi @@ -35,9 +35,6 @@ def simple_kwargs( a: Any, b: Any | None = None, c: Any | None = None, **kwargs ) -> tuple[Any, Any | None, Any | None, dict | None]: ... async def with_async() -> None: ... -def with_custom_type_annotations( - a: "int", *_args: "str", _b: "int | None" = None, **_kwargs: "bool" -) -> "int": ... def with_typed_args( a: bool = False, b: int = 0, c: float = 0.0, d: str = "" ) -> tuple[bool, int, float, str]: ... diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 423f0dd5cea..4bceeb5dbf4 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -31,6 +31,7 @@ fn test_compile_errors() { t.compile_fail("tests/ui/invalid_pymethods_duplicates.rs"); t.compile_fail("tests/ui/invalid_pymethod_enum.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); + #[cfg(not(feature = "experimental-inspect"))] t.compile_fail("tests/ui/invalid_pymodule_args.rs"); t.compile_fail("tests/ui/invalid_pycallargs.rs"); t.compile_fail("tests/ui/reject_generics.rs");