Skip to content

Commit 57676ea

Browse files
committed
Fix ctypes
1 parent 3367028 commit 57676ea

File tree

7 files changed

+309
-129
lines changed

7 files changed

+309
-129
lines changed

crates/vm/src/stdlib/ctypes.rs

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,8 @@ pub(crate) mod _ctypes {
385385
#[pyattr]
386386
const RTLD_GLOBAL: i32 = 0;
387387

388-
#[cfg(target_os = "windows")]
389-
#[pyattr]
390-
const SIZEOF_TIME_T: usize = 8;
391-
#[cfg(not(target_os = "windows"))]
392388
#[pyattr]
393-
const SIZEOF_TIME_T: usize = 4;
389+
const SIZEOF_TIME_T: usize = std::mem::size_of::<libc::time_t>();
394390

395391
#[pyattr]
396392
const CTYPES_MAX_ARGCOUNT: usize = 1024;
@@ -578,30 +574,42 @@ pub(crate) mod _ctypes {
578574
#[pyfunction(name = "dlopen")]
579575
fn load_library_unix(
580576
name: Option<crate::function::FsPath>,
581-
_load_flags: OptionalArg<i32>,
577+
load_flags: OptionalArg<i32>,
582578
vm: &VirtualMachine,
583579
) -> PyResult<usize> {
584-
// TODO: audit functions first
585-
// TODO: load_flags
580+
// Default mode: RTLD_NOW | RTLD_LOCAL, always force RTLD_NOW
581+
let mode = load_flags.unwrap_or(libc::RTLD_NOW | libc::RTLD_LOCAL) | libc::RTLD_NOW;
582+
586583
match name {
587584
Some(name) => {
588585
let cache = library::libcache();
589586
let mut cache_write = cache.write();
590587
let os_str = name.as_os_str(vm)?;
591-
let (id, _) = cache_write.get_or_insert_lib(&*os_str, vm).map_err(|e| {
592-
// Include filename in error message for better diagnostics
593-
let name_str = os_str.to_string_lossy();
594-
vm.new_os_error(format!("{}: {}", name_str, e))
595-
})?;
588+
let (id, _) = cache_write
589+
.get_or_insert_lib_with_mode(&*os_str, mode, vm)
590+
.map_err(|e| {
591+
let name_str = os_str.to_string_lossy();
592+
vm.new_os_error(format!("{}: {}", name_str, e))
593+
})?;
596594
Ok(id)
597595
}
598596
None => {
599-
// If None, call libc::dlopen(null, mode) to get the current process handle
600-
let handle = unsafe { libc::dlopen(std::ptr::null(), libc::RTLD_NOW) };
597+
// dlopen(NULL, mode) to get the current process handle (for pythonapi)
598+
let handle = unsafe { libc::dlopen(std::ptr::null(), mode) };
601599
if handle.is_null() {
602-
return Err(vm.new_os_error("dlopen() error"));
600+
let err = unsafe { libc::dlerror() };
601+
let msg = if err.is_null() {
602+
"dlopen() error".to_string()
603+
} else {
604+
unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() }
605+
};
606+
return Err(vm.new_os_error(msg));
603607
}
604-
Ok(handle as usize)
608+
// Add to library cache so symbol lookup works
609+
let cache = library::libcache();
610+
let mut cache_write = cache.write();
611+
let id = cache_write.insert_raw_handle(handle);
612+
Ok(id)
605613
}
606614
}
607615
}
@@ -614,6 +622,47 @@ pub(crate) mod _ctypes {
614622
Ok(())
615623
}
616624

625+
#[cfg(not(windows))]
626+
#[pyfunction]
627+
fn dlclose(handle: usize, vm: &VirtualMachine) -> PyResult<()> {
628+
let result = unsafe { libc::dlclose(handle as *mut libc::c_void) };
629+
if result != 0 {
630+
let err = unsafe { libc::dlerror() };
631+
let msg = if err.is_null() {
632+
"dlclose() error".to_string()
633+
} else {
634+
unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() }
635+
};
636+
return Err(vm.new_os_error(msg));
637+
}
638+
Ok(())
639+
}
640+
641+
#[cfg(not(windows))]
642+
#[pyfunction]
643+
fn dlsym(
644+
handle: usize,
645+
name: crate::builtins::PyStrRef,
646+
vm: &VirtualMachine,
647+
) -> PyResult<usize> {
648+
let symbol_name = std::ffi::CString::new(name.as_str())
649+
.map_err(|_| vm.new_value_error("symbol name contains null byte"))?;
650+
651+
// Clear previous error
652+
unsafe { libc::dlerror() };
653+
654+
let ptr = unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) };
655+
656+
// Check for error (ptr can be NULL for valid symbols, so check dlerror)
657+
let err = unsafe { libc::dlerror() };
658+
if !err.is_null() {
659+
let msg = unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() };
660+
return Err(vm.new_os_error(msg));
661+
}
662+
663+
Ok(ptr as usize)
664+
}
665+
617666
#[pyfunction(name = "POINTER")]
618667
fn create_pointer_type(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult {
619668
use crate::builtins::PyStr;

crates/vm/src/stdlib/ctypes/base.rs

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,31 +1168,29 @@ impl PyCData {
11681168
.ok_or_else(|| vm.new_value_error("Invalid library handle"))?
11691169
};
11701170

1171-
// Get symbol address using platform-specific API
1172-
let symbol_name = std::ffi::CString::new(name.as_str())
1173-
.map_err(|_| vm.new_value_error("Invalid symbol name"))?;
1174-
1175-
#[cfg(windows)]
1176-
let ptr: *const u8 = unsafe {
1177-
match windows_sys::Win32::System::LibraryLoader::GetProcAddress(
1178-
handle as windows_sys::Win32::Foundation::HMODULE,
1179-
symbol_name.as_ptr() as *const u8,
1180-
) {
1181-
Some(p) => p as *const u8,
1182-
None => std::ptr::null(),
1171+
// Look up the library in the cache and use lib.get() for symbol lookup
1172+
let library_cache = super::library::libcache().read();
1173+
let library = library_cache
1174+
.get_lib(handle)
1175+
.ok_or_else(|| vm.new_value_error("Library not found"))?;
1176+
let inner_lib = library.lib.lock();
1177+
1178+
let symbol_name_with_nul = format!("{}\0", name.as_str());
1179+
let ptr: *const u8 = if let Some(lib) = &*inner_lib {
1180+
unsafe {
1181+
lib.get::<*const u8>(symbol_name_with_nul.as_bytes())
1182+
.map(|sym| *sym)
1183+
.map_err(|_| {
1184+
vm.new_value_error(format!(
1185+
"symbol '{}' not found in library",
1186+
name.as_str()
1187+
))
1188+
})?
11831189
}
1190+
} else {
1191+
return Err(vm.new_value_error("Library closed"));
11841192
};
11851193

1186-
#[cfg(not(windows))]
1187-
let ptr: *const u8 =
1188-
unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) as *const u8 };
1189-
1190-
if ptr.is_null() {
1191-
return Err(
1192-
vm.new_value_error(format!("symbol '{}' not found in library", name.as_str()))
1193-
);
1194-
}
1195-
11961194
// PyCData_AtAddress
11971195
let cdata = unsafe { Self::at_address(ptr, size) };
11981196
cdata.into_ref_with_type(vm, cls).map(Into::into)

crates/vm/src/stdlib/ctypes/function.rs

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
// spell-checker:disable
22

33
use super::{
4-
_ctypes::CArgObject, PyCArray, PyCData, PyCPointer, PyCStructure, base::FfiArgValue,
5-
simple::PyCSimple, type_info,
4+
_ctypes::CArgObject,
5+
PyCArray, PyCData, PyCPointer, PyCStructure,
6+
base::{FfiArgValue, StgInfoFlags},
7+
simple::PyCSimple,
8+
type_info,
69
};
710
use crate::{
811
AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
@@ -18,7 +21,7 @@ use libffi::{
1821
middle::{Arg, Cif, Closure, CodePtr, Type},
1922
};
2023
use libloading::Symbol;
21-
use num_traits::ToPrimitive;
24+
use num_traits::{Signed, ToPrimitive};
2225
use rustpython_common::lock::PyRwLock;
2326
use std::ffi::{self, c_void};
2427
use std::fmt::Debug;
@@ -131,11 +134,19 @@ fn convert_to_pointer(value: &PyObject, vm: &VirtualMachine) -> PyResult<FfiArgV
131134
return Ok(FfiArgValue::Pointer(addr));
132135
}
133136

134-
// 7. Integer -> direct value
137+
// 7. Integer -> direct value (PyLong_AsVoidPtr behavior)
135138
if let Ok(int_val) = value.try_int(vm) {
136-
return Ok(FfiArgValue::Pointer(
137-
int_val.as_bigint().to_usize().unwrap_or(0),
138-
));
139+
let bigint = int_val.as_bigint();
140+
// Negative values: use signed conversion (allows -1 as 0xFFFF...)
141+
if bigint.is_negative() {
142+
if let Some(signed_val) = bigint.to_isize() {
143+
return Ok(FfiArgValue::Pointer(signed_val as usize));
144+
}
145+
} else if let Some(unsigned_val) = bigint.to_usize() {
146+
return Ok(FfiArgValue::Pointer(unsigned_val));
147+
}
148+
// Value out of range - raise OverflowError
149+
return Err(vm.new_overflow_error("int too large to convert to pointer".to_string()));
139150
}
140151

141152
// 8. Check _as_parameter_ attribute ( recursive ConvParam)
@@ -1060,13 +1071,16 @@ fn extract_call_info(zelf: &Py<PyCFuncPtr>, vm: &VirtualMachine) -> PyResult<Cal
10601071
.unwrap_or_else(Type::i32)
10611072
};
10621073

1063-
// Check if return type is a pointer type (P, z, Z) - need special handling on 64-bit
1074+
// Check if return type is a pointer type via TYPEFLAG_ISPOINTER
1075+
// This handles c_void_p, c_char_p, c_wchar_p, and POINTER(T) types
10641076
let is_pointer_return = restype_obj
10651077
.as_ref()
10661078
.and_then(|t| t.clone().downcast::<PyType>().ok())
1067-
.and_then(|t| t.as_object().get_attr(vm.ctx.intern_str("_type_"), vm).ok())
1068-
.and_then(|t| t.downcast_ref::<PyStr>().map(|s| s.to_string()))
1069-
.is_some_and(|tc| matches!(tc.as_str(), "P" | "z" | "Z"));
1079+
.and_then(|t| {
1080+
t.stg_info_opt()
1081+
.map(|info| info.flags.contains(StgInfoFlags::TYPEFLAG_ISPOINTER))
1082+
})
1083+
.unwrap_or(false);
10701084

10711085
Ok(CallInfo {
10721086
explicit_arg_types,
@@ -1477,11 +1491,23 @@ fn convert_raw_result(
14771491
}
14781492
}
14791493
let slice = unsafe { std::slice::from_raw_parts(wstr_ptr, len) };
1480-
let s: String = slice
1481-
.iter()
1482-
.filter_map(|&c| char::from_u32(c as u32))
1483-
.collect();
1484-
Some(vm.ctx.new_str(s).into())
1494+
// Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide
1495+
// Unix: wchar_t = i32 (UTF-32) -> convert via char::from_u32
1496+
#[cfg(windows)]
1497+
{
1498+
use rustpython_common::wtf8::Wtf8Buf;
1499+
let wide: Vec<u16> = slice.iter().map(|&c| c).collect();
1500+
let wtf8 = Wtf8Buf::from_wide(&wide);
1501+
Some(vm.ctx.new_str(wtf8).into())
1502+
}
1503+
#[cfg(not(windows))]
1504+
{
1505+
let s: String = slice
1506+
.iter()
1507+
.filter_map(|&c| char::from_u32(c as u32))
1508+
.collect();
1509+
Some(vm.ctx.new_str(s).into())
1510+
}
14851511
}
14861512
}
14871513
_ => {
@@ -1685,7 +1711,6 @@ impl PyCFuncPtr {
16851711
}
16861712

16871713
// Fallback to StgInfo for native types
1688-
use super::base::StgInfoFlags;
16891714
zelf.class()
16901715
.stg_info_opt()
16911716
.map(|stg| stg.flags.bits())
@@ -1758,11 +1783,23 @@ fn ffi_to_python(ty: &Py<PyType>, ptr: *const c_void, vm: &VirtualMachine) -> Py
17581783
len += 1;
17591784
}
17601785
let slice = std::slice::from_raw_parts(wstr_ptr, len);
1761-
let s: String = slice
1762-
.iter()
1763-
.filter_map(|&c| char::from_u32(c as u32))
1764-
.collect();
1765-
vm.ctx.new_str(s).into()
1786+
// Windows: wchar_t = u16 (UTF-16) -> use Wtf8Buf::from_wide
1787+
// Unix: wchar_t = i32 (UTF-32) -> convert via char::from_u32
1788+
#[cfg(windows)]
1789+
{
1790+
use rustpython_common::wtf8::Wtf8Buf;
1791+
let wide: Vec<u16> = slice.iter().map(|&c| c).collect();
1792+
let wtf8 = Wtf8Buf::from_wide(&wide);
1793+
vm.ctx.new_str(wtf8).into()
1794+
}
1795+
#[cfg(not(windows))]
1796+
{
1797+
let s: String = slice
1798+
.iter()
1799+
.filter_map(|&c| char::from_u32(c as u32))
1800+
.collect();
1801+
vm.ctx.new_str(s).into()
1802+
}
17661803
}
17671804
}
17681805
Some("P") => vm.ctx.new_int(*(ptr as *const usize)).into(),

0 commit comments

Comments
 (0)