From 65b9d1201a2ee1ad337e4228085600367ec86f9f Mon Sep 17 00:00:00 2001 From: Dev-iL <6509619+Dev-iL@users.noreply.github.com> Date: Thu, 14 May 2026 16:26:17 +0300 Subject: [PATCH] Add copy_records_to_table for COPY FROM STDIN bulk-load Closes #166. The existing binary_copy_to_table required callers to pre-encode PostgreSQL's binary COPY wire format, leaving no ergonomic bulk-load path comparable to asyncpg's copy_records_to_table or psycopg3's cursor.copy(...). The new method on Connection and Transaction accepts an iterable of records, introspects column types from the target table, and streams rows via tokio-postgres' BinaryCopyInWriter using the same PythonDTO conversions used by execute(). --- python/psqlpy/_internal/__init__.pyi | 52 ++++++++ python/tests/test_copy_records.py | 174 +++++++++++++++++++++++++++ src/driver/common.rs | 142 +++++++++++++++++++++- 3 files changed, 365 insertions(+), 3 deletions(-) create mode 100644 python/tests/test_copy_records.py diff --git a/python/psqlpy/_internal/__init__.pyi b/python/psqlpy/_internal/__init__.pyi index 336bc311..e0f4f794 100644 --- a/python/psqlpy/_internal/__init__.pyi +++ b/python/psqlpy/_internal/__init__.pyi @@ -820,6 +820,32 @@ class Transaction: number of inserted rows; """ + async def copy_records_to_table( + self: Self, + table_name: str, + records: typing.Iterable[Sequence[Any]], + columns: Sequence[str] | None = None, + schema_name: str | None = None, + ) -> int: + """Copy records into a table using the binary COPY FROM STDIN protocol. + + Column types are introspected from the target table, so each record + may contain raw Python values (the same conversions used by + `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. + + ### Parameters: + - `table_name`: name of the table. + - `records`: iterable of records (each a sequence of column values + matching the order of `columns`, or of the table's columns when + `columns` is `None`). + - `columns`: sequence of column names to load into. When `None`, + all columns of the table are used in their declared order. + - `schema_name`: optional schema for `table_name`. + + ### Returns: + number of inserted rows; + """ + async def connect( dsn: str | None = None, username: str | None = None, @@ -1146,6 +1172,32 @@ class Connection: number of inserted rows; """ + async def copy_records_to_table( + self: Self, + table_name: str, + records: typing.Iterable[Sequence[Any]], + columns: Sequence[str] | None = None, + schema_name: str | None = None, + ) -> int: + """Copy records into a table using the binary COPY FROM STDIN protocol. + + Column types are introspected from the target table, so each record + may contain raw Python values (the same conversions used by + `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. + + ### Parameters: + - `table_name`: name of the table. + - `records`: iterable of records (each a sequence of column values + matching the order of `columns`, or of the table's columns when + `columns` is `None`). + - `columns`: sequence of column names to load into. When `None`, + all columns of the table are used in their declared order. + - `schema_name`: optional schema for `table_name`. + + ### Returns: + number of inserted rows; + """ + class ConnectionPoolStatus: max_size: int size: int diff --git a/python/tests/test_copy_records.py b/python/tests/test_copy_records.py new file mode 100644 index 00000000..28aba34a --- /dev/null +++ b/python/tests/test_copy_records.py @@ -0,0 +1,174 @@ +import typing +from datetime import datetime, timezone + +import pytest +from psqlpy import ConnectionPool +from psqlpy.exceptions import PyToRustValueMappingError + +pytestmark = pytest.mark.anyio + + +async def _setup_target_table(psql_pool: ConnectionPool, name: str) -> None: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {name}") + await connection.execute( + f""" + CREATE TABLE {name} ( + id INTEGER, + label TEXT, + weight DOUBLE PRECISION, + created_at TIMESTAMPTZ + ) + """, + ) + + +async def _drop_target_table(psql_pool: ConnectionPool, name: str) -> None: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP TABLE IF EXISTS {name}") + + +async def test_copy_records_to_table_on_connection( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_conn" + await _setup_target_table(psql_pool, target) + try: + records = [ + (1, "alpha", 1.5, datetime(2026, 1, 1, tzinfo=timezone.utc)), + (2, "beta", 2.25, datetime(2026, 1, 2, tzinfo=timezone.utc)), + (3, "gamma", None, datetime(2026, 1, 3, tzinfo=timezone.utc)), + ] + + async with psql_pool.acquire() as connection: + inserted = await connection.copy_records_to_table( + table_name=target, + records=records, + ) + + assert inserted == len(records) + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT id, label, weight FROM {target} ORDER BY id", + ) + rows = result.result() + assert [(r["id"], r["label"], r["weight"]) for r in rows] == [ + (1, "alpha", 1.5), + (2, "beta", 2.25), + (3, "gamma", None), + ] + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_with_columns_subset( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_subset" + await _setup_target_table(psql_pool, target) + try: + records = [(10, "only-label"), (11, "another")] + + async with psql_pool.acquire() as connection: + inserted = await connection.copy_records_to_table( + table_name=target, + records=records, + columns=["id", "label"], + ) + + assert inserted == len(records) + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT id, label, weight, created_at FROM {target} ORDER BY id", + ) + rows = result.result() + assert [(r["id"], r["label"]) for r in rows] == [ + (10, "only-label"), + (11, "another"), + ] + # Untouched columns must remain NULL + assert all(r["weight"] is None and r["created_at"] is None for r in rows) + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_in_transaction( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_tx" + await _setup_target_table(psql_pool, target) + try: + records = [(100, "tx-row", 0.0, datetime(2026, 5, 1, tzinfo=timezone.utc))] + + async with ( + psql_pool.acquire() as connection, + connection.transaction() as tx, + ): + inserted = await tx.copy_records_to_table( + table_name=target, + records=records, + ) + + assert inserted == 1 + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT COUNT(*) AS c FROM {target}", + ) + assert result.result()[0]["c"] == 1 + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_rejects_record_arity_mismatch( + psql_pool: ConnectionPool, +) -> None: + target: typing.Final = "copy_records_mismatch" + await _setup_target_table(psql_pool, target) + try: + records = [(1, "missing-rest")] # table has 4 columns + + async with psql_pool.acquire() as connection: + with pytest.raises(PyToRustValueMappingError): + await connection.copy_records_to_table( + table_name=target, + records=records, + ) + finally: + await _drop_target_table(psql_pool, target) + + +async def test_copy_records_to_table_uses_schema_qualifier( + psql_pool: ConnectionPool, +) -> None: + schema: typing.Final = "copy_records_schema" + target: typing.Final = "tbl" + + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE") + await connection.execute(f"CREATE SCHEMA {schema}") + await connection.execute( + f"CREATE TABLE {schema}.{target} (id INTEGER, label TEXT)", + ) + + try: + records = [(1, "schema-a"), (2, "schema-b")] + async with psql_pool.acquire() as connection: + inserted = await connection.copy_records_to_table( + table_name=target, + records=records, + schema_name=schema, + ) + + assert inserted == len(records) + + async with psql_pool.acquire() as connection: + result = await connection.execute( + f"SELECT id, label FROM {schema}.{target} ORDER BY id", + ) + assert [(r["id"], r["label"]) for r in result.result()] == records + finally: + async with psql_pool.acquire() as connection: + await connection.execute(f"DROP SCHEMA IF EXISTS {schema} CASCADE") diff --git a/src/driver/common.rs b/src/driver/common.rs index b2ff6a52..2794cba2 100644 --- a/src/driver/common.rs +++ b/src/driver/common.rs @@ -10,14 +10,15 @@ use super::{ use pyo3::{pymethods, Py, PyAny}; use crate::{ - connection::traits::CloseTransaction, + connection::traits::{CloseTransaction, Connection as _}, exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + value_converter::{dto::enums::PythonDTO, from_python::from_python_typed}, }; use bytes::BytesMut; use futures_util::pin_mut; -use pyo3::{buffer::PyBuffer, Python}; -use tokio_postgres::binary_copy::BinaryCopyInWriter; +use pyo3::{buffer::PyBuffer, types::PyAnyMethods, Python}; +use tokio_postgres::{binary_copy::BinaryCopyInWriter, types::ToSql}; use crate::format_helpers::quote_ident; @@ -320,3 +321,138 @@ macro_rules! impl_binary_copy_method { impl_binary_copy_method!(Connection); impl_binary_copy_method!(Transaction); + +macro_rules! impl_copy_records_method { + ($name:ident) => { + #[pymethods] + impl $name { + /// Copy a list of records into a table using the COPY FROM STDIN + /// binary protocol. + /// + /// Column types are introspected from the target table, so callers + /// pass Python values directly (the same conversions used by + /// `execute`). Mirrors `asyncpg.Connection.copy_records_to_table`. + /// + /// # Errors + /// May return error if there is some problem with DB communication, + /// the table cannot be introspected, or a value cannot be converted. + #[pyo3(signature = (table_name, records, columns=None, schema_name=None))] + pub async fn copy_records_to_table( + self_: pyo3::Py, + table_name: String, + records: Py, + columns: Option>, + schema_name: Option, + ) -> PSQLPyResult { + let (db_client, raw_records) = Python::with_gil( + |gil| -> PSQLPyResult<(Option<_>, Vec>>)> { + let db_client = self_.borrow(gil).conn.clone(); + + let Some(db_client) = db_client else { + return Ok((None, Vec::new())); + }; + + let bound = records.bind(gil); + let mut rows: Vec>> = Vec::new(); + for item in bound.try_iter()? { + let row = item?; + let mut row_vec: Vec> = Vec::new(); + for cell in row.try_iter()? { + row_vec.push(cell?.unbind()); + } + rows.push(row_vec); + } + + Ok((Some(db_client), rows)) + }, + )?; + + let Some(db_client) = db_client else { + return Ok(0); + }; + + let full_table_name = match schema_name { + Some(ref schema) => { + format!("{}.{}", quote_ident(schema), quote_ident(&table_name)) + } + None => quote_ident(&table_name), + }; + + let columns_sql = match columns { + Some(ref cols) if !cols.is_empty() => Some( + cols.iter() + .map(|c| quote_ident(c)) + .collect::>() + .join(", "), + ), + _ => None, + }; + + let introspect_qs = match &columns_sql { + Some(cols) => format!("SELECT {} FROM {} WHERE false", cols, full_table_name), + None => format!("SELECT * FROM {} WHERE false", full_table_name), + }; + + let read_conn_g = db_client.read().await; + + let stmt = read_conn_g.prepare(&introspect_qs, false).await?; + let column_types: Vec = + stmt.columns().iter().map(|c| c.type_().clone()).collect(); + + if column_types.is_empty() { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Cannot introspect column types from target table".into(), + )); + } + + let typed_rows: Vec> = + Python::with_gil(|gil| -> PSQLPyResult>> { + let mut typed: Vec> = Vec::with_capacity(raw_records.len()); + for (row_idx, row) in raw_records.iter().enumerate() { + if row.len() != column_types.len() { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + format!( + "Record at index {} has {} fields, expected {}", + row_idx, + row.len(), + column_types.len() + ), + )); + } + let mut row_dto: Vec = Vec::with_capacity(row.len()); + for (cell, ty) in row.iter().zip(column_types.iter()) { + row_dto.push(from_python_typed(cell.bind(gil), ty)?); + } + typed.push(row_dto); + } + Ok(typed) + })?; + + let copy_qs = match &columns_sql { + Some(cols) => format!( + "COPY {}({}) FROM STDIN (FORMAT binary)", + full_table_name, cols + ), + None => format!("COPY {} FROM STDIN (FORMAT binary)", full_table_name), + }; + + let sink = read_conn_g.copy_in(©_qs).await?; + let writer = BinaryCopyInWriter::new(sink, &column_types); + pin_mut!(writer); + + for row in &typed_rows { + let row_refs: Vec<&(dyn ToSql + Sync)> = + row.iter().map(|v| v as &(dyn ToSql + Sync)).collect(); + writer.as_mut().write(&row_refs).await?; + } + + let rows_created = writer.as_mut().finish().await?; + + Ok(rows_created) + } + } + }; +} + +impl_copy_records_method!(Connection); +impl_copy_records_method!(Transaction);