diff --git a/.github/actions/ccache-action/action.yml b/.github/actions/ccache-action/action.yml new file mode 100644 index 00000000..e3af01f4 --- /dev/null +++ b/.github/actions/ccache-action/action.yml @@ -0,0 +1,23 @@ +name: 'Setup Ccache' +inputs: + key: + description: 'Cache key (defaults to github.job)' + required: false + default: '' +runs: + using: "composite" + steps: + - name: Setup Ccache + uses: hendrikmuhs/ccache-action@main + with: + key: ${{ inputs.key || github.job }} + save: ${{ github.repository != 'duckdb/duckdb-python' || contains('["refs/heads/main", "refs/heads/v1.4-andium", "refs/heads/v1.5-variegata"]', github.ref) }} + # Dump verbose ccache statistics report at end of CI job. + verbose: 1 + # Increase per-directory limit: 5*1024 MB / 16 = 320 MB. + # Note: `layout=subdirs` computes the size limit divided by 16 dirs. + # See also: https://ccache.dev/manual/4.9.html#_cache_size_management + max-size: 1500MB + # Evicts all cache files that were not touched during the job run. + # Removing cache files from previous runs avoids creating huge caches. + evict-old-files: 'job' diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml index 99b7884c..575f6f5b 100644 --- a/.github/workflows/code_quality.yml +++ b/.github/workflows/code_quality.yml @@ -32,7 +32,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: "0.9.0" - python-version: 3.9 + python-version: "3.12" - name: pre-commit (cache) uses: actions/cache@v4 diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index fd62b6c1..7832e314 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -71,7 +71,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: "0.9.0" - python-version: 3.9 + python-version: 3.12 enable-cache: true cache-suffix: -${{ github.workflow }} @@ -111,6 +111,7 @@ jobs: mkdir coverage-cpp uv run gcovr \ --gcov-ignore-errors all \ + --gcov-ignore-parse-errors=negative_hits.warn_once_per_file \ --root "$PWD" \ --filter "${PWD}/src/duckdb_py" \ --exclude '.*/\.cache/.*' \ diff --git a/.github/workflows/packaging_sdist.yml b/.github/workflows/packaging_sdist.yml index b6558744..fb45b366 100644 --- a/.github/workflows/packaging_sdist.yml +++ b/.github/workflows/packaging_sdist.yml @@ -59,7 +59,7 @@ jobs: uses: astral-sh/setup-uv@v7 with: version: "0.9.0" - python-version: 3.11 + python-version: 3.12 - name: Build sdist run: uv build --sdist diff --git a/.github/workflows/packaging_wheels.yml b/.github/workflows/packaging_wheels.yml index 23a16af7..82ed76a8 100644 --- a/.github/workflows/packaging_wheels.yml +++ b/.github/workflows/packaging_wheels.yml @@ -25,15 +25,112 @@ on: type: string jobs: + seed_wheels: + name: 'Seed: cp314-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }}' + strategy: + fail-fast: false + matrix: + python: [ cp314 ] + platform: + - { os: windows-2025, arch: amd64, cibw_system: win } + - { os: windows-11-arm, arch: ARM64, cibw_system: win } + - { os: ubuntu-24.04, arch: x86_64, cibw_system: manylinux } + - { os: ubuntu-24.04-arm, arch: aarch64, cibw_system: manylinux } + - { os: macos-15, arch: arm64, cibw_system: macosx } + - { os: macos-15, arch: universal2, cibw_system: macosx } + - { os: macos-15-intel, arch: x86_64, cibw_system: macosx } + minimal: + - ${{ inputs.minimal }} + exclude: + - { minimal: true, platform: { arch: universal2 } } + runs-on: ${{ matrix.platform.os }} + env: + CCACHE_DIR: ${{ github.workspace }}/.ccache + ### cibuildwheel configuration + # + # This is somewhat brittle, so be careful with changes. Some notes for our future selves (and others): + # - cibw will change its cwd to a temp dir and create a separate venv for testing. It then installs the wheel it + # built into that venv, and run the CIBW_TEST_COMMAND. We have to install all dependencies ourselves, and make + # sure that the pytest config in pyproject.toml is available. + # - CIBW_BEFORE_TEST installs the test dependencies by exporting them into a pylock.toml. At the time of writing, + # `uv sync --no-install-project` had problems correctly resolving dependencies using resolution environments + # across all platforms we build for. This might be solved in newer uv versions. + # - CIBW_TEST_COMMAND specifies pytest conf from pyproject.toml. --confcutdir is needed to prevent pytest from + # traversing the full filesystem, which produces an error on Windows. + # - CIBW_TEST_SKIP we always skip tests for *-macosx_universal2 builds, because we run tests for arm64 and x86_64. + CIBW_TEST_SKIP: ${{ inputs.testsuite == 'none' && '*' || '*-macosx_universal2' }} + CIBW_TEST_SOURCES: tests + CIBW_BEFORE_TEST: > + uv export --only-group test --no-emit-project --quiet --output-file pylock.toml --directory {project} && + uv pip install -r pylock.toml + CIBW_TEST_COMMAND: > + uv run -v pytest --confcutdir=. --rootdir . -c {project}/pyproject.toml ${{ inputs.testsuite == 'fast' && './tests/fast' || './tests' }} + + steps: + - name: Checkout DuckDB Python + uses: actions/checkout@v4 + with: + ref: ${{ inputs.duckdb-python-sha }} + fetch-depth: 0 + submodules: true + + - name: Checkout DuckDB + shell: bash + if: ${{ inputs.duckdb-sha }} + run: | + cd external/duckdb + git fetch origin + git checkout ${{ inputs.duckdb-sha }} + + - name: Set CIBW_ENVIRONMENT + shell: bash + run: | + cibw_env="" + if [[ "${{ matrix.platform.cibw_system }}" == "manylinux" ]]; then + cibw_env="CCACHE_DIR=/host${{ github.workspace }}/.ccache" + fi + if [[ -n "${{ inputs.set-version }}" ]]; then + cibw_env="${cibw_env:+$cibw_env }OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" + fi + if [[ -n "$cibw_env" ]]; then + echo "CIBW_ENVIRONMENT=${cibw_env}" >> $GITHUB_ENV + fi + + - name: Setup Ccache + uses: ./.github/actions/ccache-action + with: + key: ${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + + # Install Astral UV, which will be used as build-frontend for cibuildwheel + - uses: astral-sh/setup-uv@v7 + with: + version: "0.9.0" + enable-cache: false + cache-suffix: -${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + + - name: Build${{ inputs.testsuite != 'none' && ' and test ' || ' ' }}wheels + uses: pypa/cibuildwheel@v3.2 + env: + CIBW_ARCHS: ${{ matrix.platform.arch == 'amd64' && 'AMD64' || matrix.platform.arch }} + CIBW_BUILD: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + + - name: Upload wheel + uses: actions/upload-artifact@v4 + with: + name: wheel-${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} + path: wheelhouse/*.whl + compression-level: 0 + build_wheels: name: 'Wheel: ${{ matrix.python }}-${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }}' + needs: seed_wheels strategy: fail-fast: false matrix: - python: [ cp39, cp310, cp311, cp312, cp313, cp314 ] + python: [ cp310, cp311, cp312, cp313 ] platform: - { os: windows-2025, arch: amd64, cibw_system: win } - - { os: windows-11-arm, arch: ARM64, cibw_system: win } # cibw requires ARM64 to be uppercase + - { os: windows-11-arm, arch: ARM64, cibw_system: win } - { os: ubuntu-24.04, arch: x86_64, cibw_system: manylinux } - { os: ubuntu-24.04-arm, arch: aarch64, cibw_system: manylinux } - { os: macos-15, arch: arm64, cibw_system: macosx } @@ -42,15 +139,14 @@ jobs: minimal: - ${{ inputs.minimal }} exclude: - - { minimal: true, python: cp310 } - { minimal: true, python: cp311 } - { minimal: true, python: cp312 } - { minimal: true, python: cp313 } - { minimal: true, platform: { arch: universal2 } } - - { python: cp39, platform: { os: windows-11-arm, arch: ARM64 } } # too many dependency problems for win arm64 - - { python: cp310, platform: { os: windows-11-arm, arch: ARM64 } } # too many dependency problems for win arm64 + - { python: cp310, platform: { os: windows-11-arm, arch: ARM64 } } runs-on: ${{ matrix.platform.os }} env: + CCACHE_DIR: ${{ github.workspace }}/.ccache ### cibuildwheel configuration # # This is somewhat brittle, so be careful with changes. Some notes for our future selves (and others): @@ -87,11 +183,24 @@ jobs: git fetch origin git checkout ${{ inputs.duckdb-sha }} - # Make sure that OVERRIDE_GIT_DESCRIBE is propagated to cibuildwhel's env, also when it's running linux builds - - name: Set OVERRIDE_GIT_DESCRIBE + - name: Set CIBW_ENVIRONMENT shell: bash - if: ${{ inputs.set-version != '' }} - run: echo "CIBW_ENVIRONMENT=OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" >> $GITHUB_ENV + run: | + cibw_env="" + if [[ "${{ matrix.platform.cibw_system }}" == "manylinux" ]]; then + cibw_env="CCACHE_DIR=/host${{ github.workspace }}/.ccache" + fi + if [[ -n "${{ inputs.set-version }}" ]]; then + cibw_env="${cibw_env:+$cibw_env }OVERRIDE_GIT_DESCRIBE=${{ inputs.set-version }}" + fi + if [[ -n "$cibw_env" ]]; then + echo "CIBW_ENVIRONMENT=${cibw_env}" >> $GITHUB_ENV + fi + + - name: Setup Ccache + uses: ./.github/actions/ccache-action + with: + key: ${{ matrix.platform.cibw_system }}_${{ matrix.platform.arch }} # Install Astral UV, which will be used as build-frontend for cibuildwheel - uses: astral-sh/setup-uv@v7 diff --git a/.github/workflows/targeted_test.yml b/.github/workflows/targeted_test.yml index 13ae9566..d1a828de 100644 --- a/.github/workflows/targeted_test.yml +++ b/.github/workflows/targeted_test.yml @@ -19,7 +19,6 @@ on: required: true type: choice options: - - '3.9' - '3.10' - '3.11' - '3.12' diff --git a/.gitignore b/.gitignore index dce09a74..a42c13b0 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ .sw? #OS X specific files. .DS_store +#VSCode specifics +.vscode/ #==============================================================================# # Build artifacts @@ -45,6 +47,7 @@ cmake-build-release cmake-build-relwithdebinfo duckdb_packaging/duckdb_version.txt test.db +tmp/ #==============================================================================# # Python diff --git a/LICENSE b/LICENSE index 4e1fbb76..2719c9a2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2018-2025 Stichting DuckDB Foundation +Copyright 2018-2026 Stichting DuckDB Foundation Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: diff --git a/_duckdb-stubs/__init__.pyi b/_duckdb-stubs/__init__.pyi index 124a5d5a..81d69be7 100644 --- a/_duckdb-stubs/__init__.pyi +++ b/_duckdb-stubs/__init__.pyi @@ -1,23 +1,41 @@ +import datetime +import decimal import os import pathlib -import typing as pytyping +import typing +import uuid from typing_extensions import Self -if pytyping.TYPE_CHECKING: +if typing.TYPE_CHECKING: import fsspec import numpy as np import polars import pandas import pyarrow.lib - import torch as pytorch - import tensorflow - from collections.abc import Callable, Sequence, Mapping + from collections.abc import Callable, Iterable, Sequence, Mapping from duckdb import sqltypes, func + from builtins import list as lst # needed to avoid mypy error on DuckDBPyRelation.list method shadowing # the field_ids argument to to_parquet and write_parquet has a recursive structure - ParquetFieldIdsType = Mapping[str, pytyping.Union[int, "ParquetFieldIdsType"]] + ParquetFieldIdsType = Mapping[str, int | "ParquetFieldIdsType"] -__all__: list[str] = [ +_ExpressionLike: typing.TypeAlias = ( + "Expression" + | str + | int + | float + | bool + | bytes + | None + | datetime.date + | datetime.datetime + | datetime.time + | datetime.timedelta + | decimal.Decimal + | uuid.UUID +) + +__all__: lst[str] = [ "BinderException", "CSVLineTerminator", "CaseExpression", @@ -86,13 +104,17 @@ __all__: list[str] = [ "default_connection", "description", "df", + "disable_profiling", "distinct", "dtype", "duplicate", + "enable_profiling", "enum_type", "execute", "executemany", "extract_statements", + "to_arrow_reader", + "to_arrow_table", "fetch_arrow_table", "fetch_df", "fetch_df_chunk", @@ -109,6 +131,7 @@ __all__: list[str] = [ "from_df", "from_parquet", "from_query", + "get_profiling_information", "get_table_names", "install_extension", "interrupt", @@ -157,21 +180,21 @@ __all__: list[str] = [ class BinderException(ProgrammingError): ... class CSVLineTerminator: - CARRIAGE_RETURN_LINE_FEED: pytyping.ClassVar[ + CARRIAGE_RETURN_LINE_FEED: typing.ClassVar[ CSVLineTerminator ] # value = - LINE_FEED: pytyping.ClassVar[CSVLineTerminator] # value = - __members__: pytyping.ClassVar[ + LINE_FEED: typing.ClassVar[CSVLineTerminator] # value = + __members__: typing.ClassVar[ dict[str, CSVLineTerminator] ] # value = {'LINE_FEED': , 'CARRIAGE_RETURN_LINE_FEED': } # noqa: E501 def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property @@ -190,8 +213,12 @@ class DuckDBPyConnection: def __enter__(self) -> Self: ... def __exit__(self, exc_type: object, exc: object, traceback: object) -> None: ... def append(self, table_name: str, df: pandas.DataFrame, *, by_name: bool = False) -> DuckDBPyConnection: ... - def array_type(self, type: sqltypes.DuckDBPyType, size: pytyping.SupportsInt) -> sqltypes.DuckDBPyType: ... - def arrow(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def array_type(self, type: sqltypes.DuckDBPyType, size: typing.SupportsInt) -> sqltypes.DuckDBPyType: ... + def arrow(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: + """Alias of to_arrow_reader(). We recommend using to_arrow_reader() instead.""" + ... + def to_arrow_reader(self, batch_size: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def to_arrow_table(self, batch_size: typing.SupportsInt = 1000000) -> pyarrow.lib.Table: ... def begin(self) -> DuckDBPyConnection: ... def checkpoint(self) -> DuckDBPyConnection: ... def close(self) -> None: ... @@ -199,8 +226,8 @@ class DuckDBPyConnection: def create_function( self, name: str, - function: Callable[..., pytyping.Any], - parameters: list[sqltypes.DuckDBPyType] | None = None, + function: Callable[..., typing.Any], + parameters: lst[sqltypes.DuckDBPyType] | None = None, return_type: sqltypes.DuckDBPyType | None = None, *, type: func.PythonUDFType = ..., @@ -209,32 +236,34 @@ class DuckDBPyConnection: side_effects: bool = False, ) -> DuckDBPyConnection: ... def cursor(self) -> DuckDBPyConnection: ... - def decimal_type(self, width: pytyping.SupportsInt, scale: pytyping.SupportsInt) -> sqltypes.DuckDBPyType: ... + def decimal_type(self, width: typing.SupportsInt, scale: typing.SupportsInt) -> sqltypes.DuckDBPyType: ... def df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... def dtype(self, type_str: str) -> sqltypes.DuckDBPyType: ... def duplicate(self) -> DuckDBPyConnection: ... - def enum_type( - self, name: str, type: sqltypes.DuckDBPyType, values: list[pytyping.Any] - ) -> sqltypes.DuckDBPyType: ... + def enum_type(self, name: str, type: sqltypes.DuckDBPyType, values: lst[typing.Any]) -> sqltypes.DuckDBPyType: ... def execute(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... def executemany(self, query: Statement | str, parameters: object = None) -> DuckDBPyConnection: ... - def extract_statements(self, query: str) -> list[Statement]: ... - def fetch_arrow_table(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.Table: ... + def extract_statements(self, query: str) -> lst[Statement]: ... + def fetch_arrow_table(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.Table: + """Deprecated: use to_arrow_table() instead.""" + ... def fetch_df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... def fetch_df_chunk( - self, vectors_per_chunk: pytyping.SupportsInt = 1, *, date_as_object: bool = False + self, vectors_per_chunk: typing.SupportsInt = 1, *, date_as_object: bool = False ) -> pandas.DataFrame: ... - def fetch_record_batch(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... - def fetchall(self) -> list[tuple[pytyping.Any, ...]]: ... + def fetch_record_batch(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: + """Deprecated: use to_arrow_reader() instead.""" + ... + def fetchall(self) -> lst[tuple[typing.Any, ...]]: ... def fetchdf(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def fetchmany(self, size: pytyping.SupportsInt = 1) -> list[tuple[pytyping.Any, ...]]: ... - def fetchnumpy(self) -> dict[str, np.typing.NDArray[pytyping.Any] | pandas.Categorical]: ... - def fetchone(self) -> tuple[pytyping.Any, ...] | None: ... + def fetchmany(self, size: typing.SupportsInt = 1) -> lst[tuple[typing.Any, ...]]: ... + def fetchnumpy(self) -> dict[str, np.typing.NDArray[typing.Any] | pandas.Categorical]: ... + def fetchone(self) -> tuple[typing.Any, ...] | None: ... def filesystem_is_registered(self, name: str) -> bool: ... def from_arrow(self, arrow_object: object) -> DuckDBPyRelation: ... def from_csv_auto( self, - path_or_buffer: str | bytes | os.PathLike[str] | os.PathLike[bytes], + path_or_buffer: str | bytes | os.PathLike[str] | os.PathLike[bytes] | typing.IO[bytes], header: bool | int | None = None, compression: str | None = None, sep: str | None = None, @@ -242,8 +271,8 @@ class DuckDBPyConnection: files_to_sniff: int | None = None, comment: str | None = None, thousands: str | None = None, - dtype: dict[str, str] | list[str] | None = None, - na_values: str | list[str] | None = None, + dtype: dict[str, str] | lst[str] | None = None, + na_values: str | lst[str] | None = None, skiprows: int | None = None, quotechar: str | None = None, escapechar: str | None = None, @@ -256,17 +285,17 @@ class DuckDBPyConnection: all_varchar: bool | None = None, normalize_names: bool | None = None, null_padding: bool | None = None, - names: list[str] | None = None, + names: lst[str] | None = None, lineterminator: str | None = None, columns: dict[str, str] | None = None, - auto_type_candidates: list[str] | None = None, + auto_type_candidates: lst[str] | None = None, max_line_size: int | None = None, ignore_errors: bool | None = None, store_rejects: bool | None = None, rejects_table: str | None = None, rejects_scan: str | None = None, rejects_limit: int | None = None, - force_not_null: list[str] | None = None, + force_not_null: lst[str] | None = None, buffer_size: int | None = None, decimal: str | None = None, allow_quoted_nulls: bool | None = None, @@ -278,7 +307,7 @@ class DuckDBPyConnection: strict_mode: bool | None = None, ) -> DuckDBPyRelation: ... def from_df(self, df: pandas.DataFrame) -> DuckDBPyRelation: ... - @pytyping.overload + @typing.overload def from_parquet( self, file_glob: str, @@ -290,7 +319,7 @@ class DuckDBPyConnection: union_by_name: bool = False, compression: str | None = None, ) -> DuckDBPyRelation: ... - @pytyping.overload + @typing.overload def from_parquet( self, file_globs: Sequence[str], @@ -313,28 +342,29 @@ class DuckDBPyConnection: repository_url: str | None = None, version: str | None = None, ) -> None: ... + def get_profiling_information(self, format: str = "json") -> str: ... + def enable_profiling(self) -> None: ... + def disable_profiling(self) -> None: ... def interrupt(self) -> None: ... - def list_filesystems(self) -> list[str]: ... + def list_filesystems(self) -> lst[str]: ... def list_type(self, type: sqltypes.DuckDBPyType) -> sqltypes.DuckDBPyType: ... def load_extension(self, extension: str) -> None: ... def map_type(self, key: sqltypes.DuckDBPyType, value: sqltypes.DuckDBPyType) -> sqltypes.DuckDBPyType: ... - @pytyping.overload + @typing.overload def pl( - self, rows_per_batch: pytyping.SupportsInt = 1000000, *, lazy: pytyping.Literal[False] = ... + self, rows_per_batch: typing.SupportsInt = 1000000, *, lazy: typing.Literal[False] = ... ) -> polars.DataFrame: ... - @pytyping.overload + @typing.overload + def pl(self, rows_per_batch: typing.SupportsInt = 1000000, *, lazy: typing.Literal[True]) -> polars.LazyFrame: ... + @typing.overload def pl( - self, rows_per_batch: pytyping.SupportsInt = 1000000, *, lazy: pytyping.Literal[True] - ) -> polars.LazyFrame: ... - @pytyping.overload - def pl( - self, rows_per_batch: pytyping.SupportsInt = 1000000, *, lazy: bool = False - ) -> pytyping.Union[polars.DataFrame, polars.LazyFrame]: ... + self, rows_per_batch: typing.SupportsInt = 1000000, *, lazy: bool = False + ) -> polars.DataFrame | polars.LazyFrame: ... def query(self, query: str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... def query_progress(self) -> float: ... def read_csv( self, - path_or_buffer: str | bytes | os.PathLike[str], + path_or_buffer: str | bytes | os.PathLike[str] | os.PathLike[bytes] | typing.IO[bytes], header: bool | int | None = None, compression: str | None = None, sep: str | None = None, @@ -342,8 +372,8 @@ class DuckDBPyConnection: files_to_sniff: int | None = None, comment: str | None = None, thousands: str | None = None, - dtype: dict[str, str] | list[str] | None = None, - na_values: str | list[str] | None = None, + dtype: dict[str, str] | lst[str] | None = None, + na_values: str | lst[str] | None = None, skiprows: int | None = None, quotechar: str | None = None, escapechar: str | None = None, @@ -356,17 +386,17 @@ class DuckDBPyConnection: all_varchar: bool | None = None, normalize_names: bool | None = None, null_padding: bool | None = None, - names: list[str] | None = None, + names: lst[str] | None = None, lineterminator: str | None = None, columns: dict[str, str] | None = None, - auto_type_candidates: list[str] | None = None, + auto_type_candidates: lst[str] | None = None, max_line_size: int | None = None, ignore_errors: bool | None = None, store_rejects: bool | None = None, rejects_table: str | None = None, rejects_scan: str | None = None, rejects_limit: int | None = None, - force_not_null: list[str] | None = None, + force_not_null: lst[str] | None = None, buffer_size: int | None = None, decimal: str | None = None, allow_quoted_nulls: bool | None = None, @@ -401,7 +431,7 @@ class DuckDBPyConnection: hive_types: dict[str, str] | None = None, hive_types_autocast: bool | None = None, ) -> DuckDBPyRelation: ... - @pytyping.overload + @typing.overload def read_parquet( self, file_glob: str, @@ -413,7 +443,7 @@ class DuckDBPyConnection: union_by_name: bool = False, compression: str | None = None, ) -> DuckDBPyRelation: ... - @pytyping.overload + @typing.overload def read_parquet( self, file_globs: Sequence[str], @@ -423,49 +453,49 @@ class DuckDBPyConnection: filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, - compression: pytyping.Any = None, + compression: typing.Any = None, ) -> DuckDBPyRelation: ... def register(self, view_name: str, python_object: object) -> DuckDBPyConnection: ... def register_filesystem(self, filesystem: fsspec.AbstractFileSystem) -> None: ... def remove_function(self, name: str) -> DuckDBPyConnection: ... def rollback(self) -> DuckDBPyConnection: ... def row_type( - self, fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType] + self, fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType] ) -> sqltypes.DuckDBPyType: ... def sql(self, query: Statement | str, *, alias: str = "", params: object = None) -> DuckDBPyRelation: ... def sqltype(self, type_str: str) -> sqltypes.DuckDBPyType: ... def string_type(self, collation: str = "") -> sqltypes.DuckDBPyType: ... def struct_type( - self, fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType] + self, fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType] ) -> sqltypes.DuckDBPyType: ... def table(self, table_name: str) -> DuckDBPyRelation: ... def table_function(self, name: str, parameters: object = None) -> DuckDBPyRelation: ... - def tf(self) -> dict[str, tensorflow.Tensor]: ... - def torch(self) -> dict[str, pytorch.Tensor]: ... + def tf(self) -> dict[str, typing.Any]: ... + def torch(self) -> dict[str, typing.Any]: ... def type(self, type_str: str) -> sqltypes.DuckDBPyType: ... def union_type( - self, members: list[sqltypes.DuckDBPyType] | dict[str, sqltypes.DuckDBPyType] + self, members: lst[sqltypes.DuckDBPyType] | dict[str, sqltypes.DuckDBPyType] ) -> sqltypes.DuckDBPyType: ... def unregister(self, view_name: str) -> DuckDBPyConnection: ... def unregister_filesystem(self, name: str) -> None: ... - def values(self, *args: list[pytyping.Any] | tuple[Expression, ...] | Expression) -> DuckDBPyRelation: ... + def values(self, *args: lst[typing.Any] | tuple[Expression, ...] | Expression) -> DuckDBPyRelation: ... def view(self, view_name: str) -> DuckDBPyRelation: ... @property - def description(self) -> list[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]]: ... + def description(self) -> lst[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]]: ... @property def rowcount(self) -> int: ... class DuckDBPyRelation: - def __arrow_c_stream__(self, requested_schema: object | None = None) -> pytyping.Any: ... + def __arrow_c_stream__(self, requested_schema: object | None = None) -> typing.Any: ... def __contains__(self, name: str) -> bool: ... def __getattr__(self, name: str) -> DuckDBPyRelation: ... def __getitem__(self, name: str) -> DuckDBPyRelation: ... def __len__(self) -> int: ... def aggregate( - self, aggr_expr: Expression | str | list[Expression], group_expr: Expression | str = "" + self, aggr_expr: str | Iterable[_ExpressionLike], group_expr: _ExpressionLike = "" ) -> DuckDBPyRelation: ... def any_value( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def apply( self, @@ -481,22 +511,26 @@ class DuckDBPyRelation: def arg_min( self, arg_column: str, value_column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... - def arrow(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def arrow(self, batch_size: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: + """Alias of to_arrow_reader(). We recommend using to_arrow_reader() instead.""" + ... + def to_arrow_reader(self, batch_size: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... + def to_arrow_table(self, batch_size: typing.SupportsInt = 1000000) -> pyarrow.lib.Table: ... def avg( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def bit_and( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def bit_or( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def bit_xor( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def bitstring_agg( self, - column: str, + expression: str, min: int | None = None, max: int | None = None, groups: str = "", @@ -504,14 +538,14 @@ class DuckDBPyRelation: projected_columns: str = "", ) -> DuckDBPyRelation: ... def bool_and( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def bool_or( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def close(self) -> None: ... def count( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def create(self, table_name: str) -> None: ... def create_view(self, view_name: str, replace: bool = True) -> DuckDBPyRelation: ... @@ -525,124 +559,133 @@ class DuckDBPyRelation: def execute(self) -> DuckDBPyRelation: ... def explain(self, type: ExplainType = ExplainType.STANDARD) -> str: ... def favg( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... - def fetch_arrow_reader(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... - def fetch_arrow_table(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.Table: ... + def fetch_arrow_reader(self, batch_size: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: + """Deprecated: use to_arrow_reader() instead.""" + ... + def fetch_arrow_table(self, batch_size: typing.SupportsInt = 1000000) -> pyarrow.lib.Table: + """Deprecated: use to_arrow_table() instead.""" + ... def fetch_df_chunk( - self, vectors_per_chunk: pytyping.SupportsInt = 1, *, date_as_object: bool = False + self, vectors_per_chunk: typing.SupportsInt = 1, *, date_as_object: bool = False ) -> pandas.DataFrame: ... - def fetch_record_batch(self, rows_per_batch: pytyping.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: ... - def fetchall(self) -> list[tuple[pytyping.Any, ...]]: ... + def fetch_record_batch(self, rows_per_batch: typing.SupportsInt = 1000000) -> pyarrow.lib.RecordBatchReader: + """Deprecated: use to_arrow_reader() instead.""" + ... + def fetchall(self) -> lst[tuple[typing.Any, ...]]: ... def fetchdf(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... - def fetchmany(self, size: pytyping.SupportsInt = 1) -> list[tuple[pytyping.Any, ...]]: ... - def fetchnumpy(self) -> dict[str, np.typing.NDArray[pytyping.Any] | pandas.Categorical]: ... - def fetchone(self) -> tuple[pytyping.Any, ...] | None: ... + def fetchmany(self, size: typing.SupportsInt = 1) -> lst[tuple[typing.Any, ...]]: ... + def fetchnumpy(self) -> dict[str, np.typing.NDArray[typing.Any] | pandas.Categorical]: ... + def fetchone(self) -> tuple[typing.Any, ...] | None: ... def filter(self, filter_expr: Expression | str) -> DuckDBPyRelation: ... - def first(self, column: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... - def first_value(self, column: str, window_spec: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def first(self, expression: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def first_value(self, expression: str, window_spec: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... def fsum( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... - def geomean(self, column: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def geomean(self, expression: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... def histogram( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... - def insert(self, values: pytyping.List[object]) -> None: ... + def insert(self, values: lst[object]) -> None: ... def insert_into(self, table_name: str) -> None: ... def intersect(self, other_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def join( - self, other_rel: DuckDBPyRelation, condition: Expression | str, how: str = "inner" + self, + other_rel: DuckDBPyRelation, + condition: Expression | str, + how: typing.Literal["inner", "left", "right", "outer", "semi", "anti"] = "inner", ) -> DuckDBPyRelation: ... def lag( self, - column: str, + expression: str, window_spec: str, - offset: pytyping.SupportsInt = 1, + offset: typing.SupportsInt = 1, default_value: str = "NULL", ignore_nulls: bool = False, projected_columns: str = "", ) -> DuckDBPyRelation: ... - def last(self, column: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... - def last_value(self, column: str, window_spec: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def last(self, expression: str, groups: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... + def last_value(self, expression: str, window_spec: str = "", projected_columns: str = "") -> DuckDBPyRelation: ... def lead( self, - column: str, + expression: str, window_spec: str, - offset: pytyping.SupportsInt = 1, + offset: typing.SupportsInt = 1, default_value: str = "NULL", ignore_nulls: bool = False, projected_columns: str = "", ) -> DuckDBPyRelation: ... - def limit(self, n: pytyping.SupportsInt, offset: pytyping.SupportsInt = 0) -> DuckDBPyRelation: ... + def limit(self, n: typing.SupportsInt, offset: typing.SupportsInt = 0) -> DuckDBPyRelation: ... def list( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def map( - self, map_function: Callable[..., pytyping.Any], *, schema: dict[str, sqltypes.DuckDBPyType] | None = None + self, map_function: Callable[..., typing.Any], *, schema: dict[str, sqltypes.DuckDBPyType] | None = None ) -> DuckDBPyRelation: ... def max( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def mean( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def median( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def min( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def mode( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def n_tile( - self, window_spec: str, num_buckets: pytyping.SupportsInt, projected_columns: str = "" + self, window_spec: str, num_buckets: typing.SupportsInt, projected_columns: str = "" ) -> DuckDBPyRelation: ... def nth_value( self, - column: str, + expression: str, window_spec: str, - offset: pytyping.SupportsInt, + offset: typing.SupportsInt, ignore_nulls: bool = False, projected_columns: str = "", ) -> DuckDBPyRelation: ... def order(self, order_expr: str) -> DuckDBPyRelation: ... def percent_rank(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... - @pytyping.overload + @typing.overload def pl( - self, batch_size: pytyping.SupportsInt = 1000000, *, lazy: pytyping.Literal[False] = ... + self, batch_size: typing.SupportsInt = 1000000, *, lazy: typing.Literal[False] = ... ) -> polars.DataFrame: ... - @pytyping.overload - def pl(self, batch_size: pytyping.SupportsInt = 1000000, *, lazy: pytyping.Literal[True]) -> polars.LazyFrame: ... - @pytyping.overload + @typing.overload + def pl(self, batch_size: typing.SupportsInt = 1000000, *, lazy: typing.Literal[True]) -> polars.LazyFrame: ... + @typing.overload def pl( - self, batch_size: pytyping.SupportsInt = 1000000, *, lazy: bool = False - ) -> pytyping.Union[polars.DataFrame, polars.LazyFrame]: ... + self, batch_size: typing.SupportsInt = 1000000, *, lazy: bool = False + ) -> polars.DataFrame | polars.LazyFrame: ... def product( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... - def project(self, *args: str | Expression, groups: str = "") -> DuckDBPyRelation: ... + def project(self, *args: _ExpressionLike, groups: str = "") -> DuckDBPyRelation: ... def quantile( self, - column: str, - q: float | pytyping.List[float] = 0.5, + expression: str, + q: float | lst[float] = 0.5, groups: str = "", window_spec: str = "", projected_columns: str = "", ) -> DuckDBPyRelation: ... def quantile_cont( self, - column: str, - q: float | pytyping.List[float] = 0.5, + expression: str, + q: float | lst[float] = 0.5, groups: str = "", window_spec: str = "", projected_columns: str = "", ) -> DuckDBPyRelation: ... def quantile_disc( self, - column: str, - q: float | pytyping.List[float] = 0.5, + expression: str, + q: float | lst[float] = 0.5, groups: str = "", window_spec: str = "", projected_columns: str = "", @@ -650,43 +693,41 @@ class DuckDBPyRelation: def query(self, virtual_table_name: str, sql_query: str) -> DuckDBPyRelation: ... def rank(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... def rank_dense(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... - def record_batch(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.RecordBatchReader: ... def row_number(self, window_spec: str, projected_columns: str = "") -> DuckDBPyRelation: ... - def select(self, *args: str | Expression, groups: str = "") -> DuckDBPyRelation: ... - def select_dtypes(self, types: pytyping.List[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ... - def select_types(self, types: pytyping.List[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ... + def select(self, *args: _ExpressionLike, groups: str = "") -> DuckDBPyRelation: ... + def select_dtypes(self, types: lst[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ... + def select_types(self, types: lst[sqltypes.DuckDBPyType | str]) -> DuckDBPyRelation: ... def set_alias(self, alias: str) -> DuckDBPyRelation: ... def show( self, *, - max_width: pytyping.SupportsInt | None = None, - max_rows: pytyping.SupportsInt | None = None, - max_col_width: pytyping.SupportsInt | None = None, + max_width: typing.SupportsInt | None = None, + max_rows: typing.SupportsInt | None = None, + max_col_width: typing.SupportsInt | None = None, null_value: str | None = None, render_mode: RenderMode | None = None, ) -> None: ... - def sort(self, *args: Expression) -> DuckDBPyRelation: ... + def sort(self, *args: _ExpressionLike) -> DuckDBPyRelation: ... def sql_query(self) -> str: ... def std( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def stddev( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def stddev_pop( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def stddev_samp( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def string_agg( - self, column: str, sep: str = ",", groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, sep: str = ",", groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def sum( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... - def tf(self) -> dict[str, tensorflow.Tensor]: ... - def to_arrow_table(self, batch_size: pytyping.SupportsInt = 1000000) -> pyarrow.lib.Table: ... + def tf(self) -> dict[str, typing.Any]: ... def to_csv( self, file_name: str, @@ -704,7 +745,7 @@ class DuckDBPyRelation: overwrite: bool | None = None, per_thread_output: bool | None = None, use_tmp_file: bool | None = None, - partition_by: pytyping.List[str] | None = None, + partition_by: lst[str] | None = None, write_partition_columns: bool | None = None, ) -> None: ... def to_df(self, *, date_as_object: bool = False) -> pandas.DataFrame: ... @@ -713,13 +754,13 @@ class DuckDBPyRelation: file_name: str, *, compression: str | None = None, - field_ids: ParquetFieldIdsType | pytyping.Literal["auto"] | None = None, + field_ids: ParquetFieldIdsType | typing.Literal["auto"] | None = None, row_group_size_bytes: int | str | None = None, row_group_size: int | None = None, overwrite: bool | None = None, per_thread_output: bool | None = None, use_tmp_file: bool | None = None, - partition_by: pytyping.List[str] | None = None, + partition_by: lst[str] | None = None, write_partition_columns: bool | None = None, append: bool | None = None, filename_pattern: str | None = None, @@ -727,26 +768,27 @@ class DuckDBPyRelation: ) -> None: ... def to_table(self, table_name: str) -> None: ... def to_view(self, view_name: str, replace: bool = True) -> DuckDBPyRelation: ... - def torch(self) -> dict[str, pytorch.Tensor]: ... + def torch(self) -> dict[str, typing.Any]: ... def union(self, union_rel: DuckDBPyRelation) -> DuckDBPyRelation: ... def unique(self, unique_aggr: str) -> DuckDBPyRelation: ... - def update(self, set: Expression | str, *, condition: Expression | str | None = None) -> None: ... - def value_counts(self, column: str, groups: str = "") -> DuckDBPyRelation: ... + def update(self, set: dict[str, _ExpressionLike], *, condition: _ExpressionLike | None = None) -> None: ... + def value_counts(self, expression: str, groups: str = "") -> DuckDBPyRelation: ... def var( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def var_pop( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def var_samp( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def variance( - self, column: str, groups: str = "", window_spec: str = "", projected_columns: str = "" + self, expression: str, groups: str = "", window_spec: str = "", projected_columns: str = "" ) -> DuckDBPyRelation: ... def write_csv( self, file_name: str, + *, sep: str | None = None, na_rep: str | None = None, header: bool | None = None, @@ -760,20 +802,21 @@ class DuckDBPyRelation: overwrite: bool | None = None, per_thread_output: bool | None = None, use_tmp_file: bool | None = None, - partition_by: pytyping.List[str] | None = None, + partition_by: lst[str] | None = None, write_partition_columns: bool | None = None, ) -> None: ... def write_parquet( self, file_name: str, + *, compression: str | None = None, - field_ids: ParquetFieldIdsType | pytyping.Literal["auto"] | None = None, + field_ids: ParquetFieldIdsType | typing.Literal["auto"] | None = None, row_group_size_bytes: str | int | None = None, row_group_size: int | None = None, overwrite: bool | None = None, per_thread_output: bool | None = None, use_tmp_file: bool | None = None, - partition_by: pytyping.List[str] | None = None, + partition_by: lst[str] | None = None, write_partition_columns: bool | None = None, append: bool | None = None, filename_pattern: str | None = None, @@ -782,108 +825,108 @@ class DuckDBPyRelation: @property def alias(self) -> str: ... @property - def columns(self) -> pytyping.List[str]: ... + def columns(self) -> lst[str]: ... @property - def description(self) -> pytyping.List[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]]: ... + def description(self) -> lst[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]]: ... @property - def dtypes(self) -> pytyping.List[str]: ... + def dtypes(self) -> lst[sqltypes.DuckDBPyType]: ... @property def shape(self) -> tuple[int, int]: ... @property def type(self) -> str: ... @property - def types(self) -> pytyping.List[sqltypes.DuckDBPyType]: ... + def types(self) -> lst[sqltypes.DuckDBPyType]: ... class Error(Exception): ... class ExpectedResultType: - CHANGED_ROWS: pytyping.ClassVar[ExpectedResultType] # value = - NOTHING: pytyping.ClassVar[ExpectedResultType] # value = - QUERY_RESULT: pytyping.ClassVar[ExpectedResultType] # value = - __members__: pytyping.ClassVar[ + CHANGED_ROWS: typing.ClassVar[ExpectedResultType] # value = + NOTHING: typing.ClassVar[ExpectedResultType] # value = + QUERY_RESULT: typing.ClassVar[ExpectedResultType] # value = + __members__: typing.ClassVar[ dict[str, ExpectedResultType] ] # value = {'QUERY_RESULT': , 'CHANGED_ROWS': , 'NOTHING': } # noqa: E501 def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class ExplainType: - ANALYZE: pytyping.ClassVar[ExplainType] # value = - STANDARD: pytyping.ClassVar[ExplainType] # value = - __members__: pytyping.ClassVar[ + ANALYZE: typing.ClassVar[ExplainType] # value = + STANDARD: typing.ClassVar[ExplainType] # value = + __members__: typing.ClassVar[ dict[str, ExplainType] ] # value = {'STANDARD': , 'ANALYZE': } def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class Expression: - def __add__(self, other: Expression) -> Expression: ... - def __and__(self, other: Expression) -> Expression: ... - def __div__(self, other: Expression) -> Expression: ... - def __eq__(self, other: Expression) -> Expression: ... # type: ignore[override] - def __floordiv__(self, other: Expression) -> Expression: ... - def __ge__(self, other: Expression) -> Expression: ... - def __gt__(self, other: Expression) -> Expression: ... - @pytyping.overload + def __add__(self, other: _ExpressionLike) -> Expression: ... + def __and__(self, other: _ExpressionLike) -> Expression: ... + def __div__(self, other: _ExpressionLike) -> Expression: ... + def __eq__(self, other: _ExpressionLike) -> Expression: ... # type: ignore[override] + def __floordiv__(self, other: _ExpressionLike) -> Expression: ... + def __ge__(self, other: _ExpressionLike) -> Expression: ... + def __gt__(self, other: _ExpressionLike) -> Expression: ... + @typing.overload def __init__(self, arg0: str) -> None: ... - @pytyping.overload - def __init__(self, arg0: pytyping.Any) -> None: ... + @typing.overload + def __init__(self, arg0: typing.Any) -> None: ... def __invert__(self) -> Expression: ... - def __le__(self, other: Expression) -> Expression: ... - def __lt__(self, other: Expression) -> Expression: ... - def __mod__(self, other: Expression) -> Expression: ... - def __mul__(self, other: Expression) -> Expression: ... - def __ne__(self, other: Expression) -> Expression: ... # type: ignore[override] + def __le__(self, other: _ExpressionLike) -> Expression: ... + def __lt__(self, other: _ExpressionLike) -> Expression: ... + def __mod__(self, other: _ExpressionLike) -> Expression: ... + def __mul__(self, other: _ExpressionLike) -> Expression: ... + def __ne__(self, other: _ExpressionLike) -> Expression: ... # type: ignore[override] def __neg__(self) -> Expression: ... - def __or__(self, other: Expression) -> Expression: ... - def __pow__(self, other: Expression) -> Expression: ... - def __radd__(self, other: Expression) -> Expression: ... - def __rand__(self, other: Expression) -> Expression: ... - def __rdiv__(self, other: Expression) -> Expression: ... - def __rfloordiv__(self, other: Expression) -> Expression: ... - def __rmod__(self, other: Expression) -> Expression: ... - def __rmul__(self, other: Expression) -> Expression: ... - def __ror__(self, other: Expression) -> Expression: ... - def __rpow__(self, other: Expression) -> Expression: ... - def __rsub__(self, other: Expression) -> Expression: ... - def __rtruediv__(self, other: Expression) -> Expression: ... - def __sub__(self, other: Expression) -> Expression: ... - def __truediv__(self, other: Expression) -> Expression: ... + def __or__(self, other: _ExpressionLike) -> Expression: ... + def __pow__(self, other: _ExpressionLike) -> Expression: ... + def __radd__(self, other: _ExpressionLike) -> Expression: ... + def __rand__(self, other: _ExpressionLike) -> Expression: ... + def __rdiv__(self, other: _ExpressionLike) -> Expression: ... + def __rfloordiv__(self, other: _ExpressionLike) -> Expression: ... + def __rmod__(self, other: _ExpressionLike) -> Expression: ... + def __rmul__(self, other: _ExpressionLike) -> Expression: ... + def __ror__(self, other: _ExpressionLike) -> Expression: ... + def __rpow__(self, other: _ExpressionLike) -> Expression: ... + def __rsub__(self, other: _ExpressionLike) -> Expression: ... + def __rtruediv__(self, other: _ExpressionLike) -> Expression: ... + def __sub__(self, other: _ExpressionLike) -> Expression: ... + def __truediv__(self, other: _ExpressionLike) -> Expression: ... def alias(self, name: str) -> Expression: ... def asc(self) -> Expression: ... - def between(self, lower: Expression, upper: Expression) -> Expression: ... + def between(self, lower: _ExpressionLike, upper: _ExpressionLike) -> Expression: ... def cast(self, type: sqltypes.DuckDBPyType) -> Expression: ... def collate(self, collation: str) -> Expression: ... def desc(self) -> Expression: ... def get_name(self) -> str: ... - def isin(self, *args: Expression) -> Expression: ... - def isnotin(self, *args: Expression) -> Expression: ... + def isin(self, *args: _ExpressionLike) -> Expression: ... + def isnotin(self, *args: _ExpressionLike) -> Expression: ... def isnotnull(self) -> Expression: ... def isnull(self) -> Expression: ... def nulls_first(self) -> Expression: ... def nulls_last(self) -> Expression: ... - def otherwise(self, value: Expression) -> Expression: ... + def otherwise(self, value: _ExpressionLike) -> Expression: ... def show(self) -> None: ... - def when(self, condition: Expression, value: Expression) -> Expression: ... + def when(self, condition: _ExpressionLike, value: _ExpressionLike) -> Expression: ... class FatalException(DatabaseError): ... @@ -910,38 +953,38 @@ class PermissionException(DatabaseError): ... class ProgrammingError(DatabaseError): ... class PythonExceptionHandling: - DEFAULT: pytyping.ClassVar[PythonExceptionHandling] # value = - RETURN_NULL: pytyping.ClassVar[PythonExceptionHandling] # value = - __members__: pytyping.ClassVar[ + DEFAULT: typing.ClassVar[PythonExceptionHandling] # value = + RETURN_NULL: typing.ClassVar[PythonExceptionHandling] # value = + __members__: typing.ClassVar[ dict[str, PythonExceptionHandling] ] # value = {'DEFAULT': , 'RETURN_NULL': } # noqa: E501 def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class RenderMode: - COLUMNS: pytyping.ClassVar[RenderMode] # value = - ROWS: pytyping.ClassVar[RenderMode] # value = - __members__: pytyping.ClassVar[ + COLUMNS: typing.ClassVar[RenderMode] # value = + ROWS: typing.ClassVar[RenderMode] # value = + __members__: typing.ClassVar[ dict[str, RenderMode] ] # value = {'ROWS': , 'COLUMNS': } def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property @@ -952,7 +995,7 @@ class SerializationException(OperationalError): ... class Statement: @property - def expected_result_type(self) -> list[StatementType]: ... + def expected_result_type(self) -> lst[StatementType]: ... @property def named_parameters(self) -> set[str]: ... @property @@ -961,47 +1004,47 @@ class Statement: def type(self) -> StatementType: ... class StatementType: - ALTER_STATEMENT: pytyping.ClassVar[StatementType] # value = - ANALYZE_STATEMENT: pytyping.ClassVar[StatementType] # value = - ATTACH_STATEMENT: pytyping.ClassVar[StatementType] # value = - CALL_STATEMENT: pytyping.ClassVar[StatementType] # value = - COPY_DATABASE_STATEMENT: pytyping.ClassVar[StatementType] # value = - COPY_STATEMENT: pytyping.ClassVar[StatementType] # value = - CREATE_FUNC_STATEMENT: pytyping.ClassVar[StatementType] # value = - CREATE_STATEMENT: pytyping.ClassVar[StatementType] # value = - DELETE_STATEMENT: pytyping.ClassVar[StatementType] # value = - DETACH_STATEMENT: pytyping.ClassVar[StatementType] # value = - DROP_STATEMENT: pytyping.ClassVar[StatementType] # value = - EXECUTE_STATEMENT: pytyping.ClassVar[StatementType] # value = - EXPLAIN_STATEMENT: pytyping.ClassVar[StatementType] # value = - EXPORT_STATEMENT: pytyping.ClassVar[StatementType] # value = - EXTENSION_STATEMENT: pytyping.ClassVar[StatementType] # value = - INSERT_STATEMENT: pytyping.ClassVar[StatementType] # value = - INVALID_STATEMENT: pytyping.ClassVar[StatementType] # value = - LOAD_STATEMENT: pytyping.ClassVar[StatementType] # value = - LOGICAL_PLAN_STATEMENT: pytyping.ClassVar[StatementType] # value = - MERGE_INTO_STATEMENT: pytyping.ClassVar[StatementType] # value = - MULTI_STATEMENT: pytyping.ClassVar[StatementType] # value = - PRAGMA_STATEMENT: pytyping.ClassVar[StatementType] # value = - PREPARE_STATEMENT: pytyping.ClassVar[StatementType] # value = - RELATION_STATEMENT: pytyping.ClassVar[StatementType] # value = - SELECT_STATEMENT: pytyping.ClassVar[StatementType] # value = - SET_STATEMENT: pytyping.ClassVar[StatementType] # value = - TRANSACTION_STATEMENT: pytyping.ClassVar[StatementType] # value = - UPDATE_STATEMENT: pytyping.ClassVar[StatementType] # value = - VACUUM_STATEMENT: pytyping.ClassVar[StatementType] # value = - VARIABLE_SET_STATEMENT: pytyping.ClassVar[StatementType] # value = - __members__: pytyping.ClassVar[ + ALTER_STATEMENT: typing.ClassVar[StatementType] # value = + ANALYZE_STATEMENT: typing.ClassVar[StatementType] # value = + ATTACH_STATEMENT: typing.ClassVar[StatementType] # value = + CALL_STATEMENT: typing.ClassVar[StatementType] # value = + COPY_DATABASE_STATEMENT: typing.ClassVar[StatementType] # value = + COPY_STATEMENT: typing.ClassVar[StatementType] # value = + CREATE_FUNC_STATEMENT: typing.ClassVar[StatementType] # value = + CREATE_STATEMENT: typing.ClassVar[StatementType] # value = + DELETE_STATEMENT: typing.ClassVar[StatementType] # value = + DETACH_STATEMENT: typing.ClassVar[StatementType] # value = + DROP_STATEMENT: typing.ClassVar[StatementType] # value = + EXECUTE_STATEMENT: typing.ClassVar[StatementType] # value = + EXPLAIN_STATEMENT: typing.ClassVar[StatementType] # value = + EXPORT_STATEMENT: typing.ClassVar[StatementType] # value = + EXTENSION_STATEMENT: typing.ClassVar[StatementType] # value = + INSERT_STATEMENT: typing.ClassVar[StatementType] # value = + INVALID_STATEMENT: typing.ClassVar[StatementType] # value = + LOAD_STATEMENT: typing.ClassVar[StatementType] # value = + LOGICAL_PLAN_STATEMENT: typing.ClassVar[StatementType] # value = + MERGE_INTO_STATEMENT: typing.ClassVar[StatementType] # value = + MULTI_STATEMENT: typing.ClassVar[StatementType] # value = + PRAGMA_STATEMENT: typing.ClassVar[StatementType] # value = + PREPARE_STATEMENT: typing.ClassVar[StatementType] # value = + RELATION_STATEMENT: typing.ClassVar[StatementType] # value = + SELECT_STATEMENT: typing.ClassVar[StatementType] # value = + SET_STATEMENT: typing.ClassVar[StatementType] # value = + TRANSACTION_STATEMENT: typing.ClassVar[StatementType] # value = + UPDATE_STATEMENT: typing.ClassVar[StatementType] # value = + VACUUM_STATEMENT: typing.ClassVar[StatementType] # value = + VARIABLE_SET_STATEMENT: typing.ClassVar[StatementType] # value = + __members__: typing.ClassVar[ dict[str, StatementType] ] # value = {'INVALID_STATEMENT': , 'SELECT_STATEMENT': , 'INSERT_STATEMENT': , 'UPDATE_STATEMENT': , 'CREATE_STATEMENT': , 'DELETE_STATEMENT': , 'PREPARE_STATEMENT': , 'EXECUTE_STATEMENT': , 'ALTER_STATEMENT': , 'TRANSACTION_STATEMENT': , 'COPY_STATEMENT': , 'ANALYZE_STATEMENT': , 'VARIABLE_SET_STATEMENT': , 'CREATE_FUNC_STATEMENT': , 'EXPLAIN_STATEMENT': , 'DROP_STATEMENT': , 'EXPORT_STATEMENT': , 'PRAGMA_STATEMENT': , 'VACUUM_STATEMENT': , 'CALL_STATEMENT': , 'SET_STATEMENT': , 'LOAD_STATEMENT': , 'RELATION_STATEMENT': , 'EXTENSION_STATEMENT': , 'LOGICAL_PLAN_STATEMENT': , 'ATTACH_STATEMENT': , 'DETACH_STATEMENT': , 'MULTI_STATEMENT': , 'COPY_DATABASE_STATEMENT': , 'MERGE_INTO_STATEMENT': } # noqa: E501 def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property @@ -1013,40 +1056,40 @@ class TypeMismatchException(DataError): ... class Warning(Exception): ... class token_type: - __members__: pytyping.ClassVar[ + __members__: typing.ClassVar[ dict[str, token_type] ] # value = {'identifier': , 'numeric_const': , 'string_const': , 'operator': , 'keyword': , 'comment': } # noqa: E501 - comment: pytyping.ClassVar[token_type] # value = - identifier: pytyping.ClassVar[token_type] # value = - keyword: pytyping.ClassVar[token_type] # value = - numeric_const: pytyping.ClassVar[token_type] # value = - operator: pytyping.ClassVar[token_type] # value = - string_const: pytyping.ClassVar[token_type] # value = + comment: typing.ClassVar[token_type] # value = + identifier: typing.ClassVar[token_type] # value = + keyword: typing.ClassVar[token_type] # value = + numeric_const: typing.ClassVar[token_type] # value = + operator: typing.ClassVar[token_type] # value = + string_const: typing.ClassVar[token_type] # value = def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... -def CaseExpression(condition: Expression, value: Expression) -> Expression: ... -def CoalesceOperator(*args: Expression) -> Expression: ... +def CaseExpression(condition: _ExpressionLike, value: _ExpressionLike) -> Expression: ... +def CoalesceOperator(*args: _ExpressionLike) -> Expression: ... def ColumnExpression(*args: str) -> Expression: ... -def ConstantExpression(value: pytyping.Any) -> Expression: ... +def ConstantExpression(value: typing.Any) -> Expression: ... def DefaultExpression() -> Expression: ... -def FunctionExpression(function_name: str, *args: Expression) -> Expression: ... -def LambdaExpression(lhs: pytyping.Any, rhs: Expression) -> Expression: ... +def FunctionExpression(function_name: str, *args: _ExpressionLike) -> Expression: ... +def LambdaExpression(lhs: typing.Any, rhs: _ExpressionLike) -> Expression: ... def SQLExpression(expression: str) -> Expression: ... -def StarExpression(*, exclude: pytyping.Any = None) -> Expression: ... +def StarExpression(*, exclude: Iterable[str | Expression] | None = None) -> Expression: ... def aggregate( df: pandas.DataFrame, - aggr_expr: Expression | list[Expression] | str | list[str], + aggr_expr: str | Iterable[_ExpressionLike], group_expr: str = "", *, connection: DuckDBPyConnection | None = None, @@ -1056,14 +1099,23 @@ def append( table_name: str, df: pandas.DataFrame, *, by_name: bool = False, connection: DuckDBPyConnection | None = None ) -> DuckDBPyConnection: ... def array_type( - type: sqltypes.DuckDBPyType, size: pytyping.SupportsInt, *, connection: DuckDBPyConnection | None = None + type: sqltypes.DuckDBPyType, size: typing.SupportsInt, *, connection: DuckDBPyConnection | None = None ) -> sqltypes.DuckDBPyType: ... -@pytyping.overload +@typing.overload def arrow( - rows_per_batch: pytyping.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None + rows_per_batch: typing.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None +) -> pyarrow.lib.RecordBatchReader: + """Alias of to_arrow_reader(). We recommend using to_arrow_reader() instead.""" + ... + +@typing.overload +def arrow(arrow_object: typing.Any, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def to_arrow_reader( + batch_size: typing.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None ) -> pyarrow.lib.RecordBatchReader: ... -@pytyping.overload -def arrow(arrow_object: pytyping.Any, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... +def to_arrow_table( + batch_size: typing.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None +) -> pyarrow.lib.Table: ... def begin(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... def checkpoint(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... def close(*, connection: DuckDBPyConnection | None = None) -> None: ... @@ -1071,12 +1123,12 @@ def commit(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnectio def connect( database: str | pathlib.Path = ":memory:", read_only: bool = False, - config: dict[str, str | bool | int | float | list[str]] | None = None, + config: dict[str, str | bool | int | float | lst[str]] | None = None, ) -> DuckDBPyConnection: ... def create_function( name: str, - function: Callable[..., pytyping.Any], - parameters: list[sqltypes.DuckDBPyType] | None = None, + function: Callable[..., typing.Any], + parameters: lst[sqltypes.DuckDBPyType] | None = None, return_type: sqltypes.DuckDBPyType | None = None, *, type: func.PythonUDFType = ..., @@ -1087,15 +1139,15 @@ def create_function( ) -> DuckDBPyConnection: ... def cursor(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... def decimal_type( - width: pytyping.SupportsInt, scale: pytyping.SupportsInt, *, connection: DuckDBPyConnection | None = None + width: typing.SupportsInt, scale: typing.SupportsInt, *, connection: DuckDBPyConnection | None = None ) -> sqltypes.DuckDBPyType: ... def default_connection() -> DuckDBPyConnection: ... def description( *, connection: DuckDBPyConnection | None = None -) -> list[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]] | None: ... -@pytyping.overload +) -> lst[tuple[str, sqltypes.DuckDBPyType, None, None, None, None, None]] | None: ... +@typing.overload def df(*, date_as_object: bool = False, connection: DuckDBPyConnection | None = None) -> pandas.DataFrame: ... -@pytyping.overload +@typing.overload def df(df: pandas.DataFrame, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... def distinct(df: pandas.DataFrame, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... def dtype(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... @@ -1103,7 +1155,7 @@ def duplicate(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnec def enum_type( name: str, type: sqltypes.DuckDBPyType, - values: list[pytyping.Any], + values: lst[typing.Any], *, connection: DuckDBPyConnection | None = None, ) -> sqltypes.DuckDBPyType: ... @@ -1119,29 +1171,35 @@ def executemany( *, connection: DuckDBPyConnection | None = None, ) -> DuckDBPyConnection: ... -def extract_statements(query: str, *, connection: DuckDBPyConnection | None = None) -> list[Statement]: ... +def extract_statements(query: str, *, connection: DuckDBPyConnection | None = None) -> lst[Statement]: ... def fetch_arrow_table( - rows_per_batch: pytyping.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None -) -> pyarrow.lib.Table: ... + rows_per_batch: typing.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None +) -> pyarrow.lib.Table: + """Deprecated: use to_arrow_table() instead.""" + ... + def fetch_df(*, date_as_object: bool = False, connection: DuckDBPyConnection | None = None) -> pandas.DataFrame: ... def fetch_df_chunk( - vectors_per_chunk: pytyping.SupportsInt = 1, + vectors_per_chunk: typing.SupportsInt = 1, *, date_as_object: bool = False, connection: DuckDBPyConnection | None = None, ) -> pandas.DataFrame: ... def fetch_record_batch( - rows_per_batch: pytyping.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None -) -> pyarrow.lib.RecordBatchReader: ... -def fetchall(*, connection: DuckDBPyConnection | None = None) -> list[tuple[pytyping.Any, ...]]: ... + rows_per_batch: typing.SupportsInt = 1000000, *, connection: DuckDBPyConnection | None = None +) -> pyarrow.lib.RecordBatchReader: + """Deprecated: use to_arrow_reader() instead.""" + ... + +def fetchall(*, connection: DuckDBPyConnection | None = None) -> lst[tuple[typing.Any, ...]]: ... def fetchdf(*, date_as_object: bool = False, connection: DuckDBPyConnection | None = None) -> pandas.DataFrame: ... def fetchmany( - size: pytyping.SupportsInt = 1, *, connection: DuckDBPyConnection | None = None -) -> list[tuple[pytyping.Any, ...]]: ... + size: typing.SupportsInt = 1, *, connection: DuckDBPyConnection | None = None +) -> lst[tuple[typing.Any, ...]]: ... def fetchnumpy( *, connection: DuckDBPyConnection | None = None -) -> dict[str, np.typing.NDArray[pytyping.Any] | pandas.Categorical]: ... -def fetchone(*, connection: DuckDBPyConnection | None = None) -> tuple[pytyping.Any, ...] | None: ... +) -> dict[str, np.typing.NDArray[typing.Any] | pandas.Categorical]: ... +def fetchone(*, connection: DuckDBPyConnection | None = None) -> tuple[typing.Any, ...] | None: ... def filesystem_is_registered(name: str, *, connection: DuckDBPyConnection | None = None) -> bool: ... def filter( df: pandas.DataFrame, @@ -1155,7 +1213,7 @@ def from_arrow( connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... def from_csv_auto( - path_or_buffer: str | bytes | os.PathLike[str], + path_or_buffer: str | bytes | os.PathLike[str] | os.PathLike[bytes] | typing.IO[bytes], header: bool | int | None = None, compression: str | None = None, sep: str | None = None, @@ -1163,8 +1221,8 @@ def from_csv_auto( files_to_sniff: int | None = None, comment: str | None = None, thousands: str | None = None, - dtype: dict[str, str] | list[str] | None = None, - na_values: str | list[str] | None = None, + dtype: dict[str, str] | lst[str] | None = None, + na_values: str | lst[str] | None = None, skiprows: int | None = None, quotechar: str | None = None, escapechar: str | None = None, @@ -1177,17 +1235,17 @@ def from_csv_auto( all_varchar: bool | None = None, normalize_names: bool | None = None, null_padding: bool | None = None, - names: list[str] | None = None, + names: lst[str] | None = None, lineterminator: str | None = None, columns: dict[str, str] | None = None, - auto_type_candidates: list[str] | None = None, + auto_type_candidates: lst[str] | None = None, max_line_size: int | None = None, ignore_errors: bool | None = None, store_rejects: bool | None = None, rejects_table: str | None = None, rejects_scan: str | None = None, rejects_limit: int | None = None, - force_not_null: list[str] | None = None, + force_not_null: lst[str] | None = None, buffer_size: int | None = None, decimal: str | None = None, allow_quoted_nulls: bool | None = None, @@ -1199,7 +1257,7 @@ def from_csv_auto( strict_mode: bool | None = None, ) -> DuckDBPyRelation: ... def from_df(df: pandas.DataFrame, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... -@pytyping.overload +@typing.overload def from_parquet( file_glob: str, binary_as_string: bool = False, @@ -1211,7 +1269,7 @@ def from_parquet( compression: str | None = None, connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... -@pytyping.overload +@typing.overload def from_parquet( file_globs: Sequence[str], binary_as_string: bool = False, @@ -1220,7 +1278,7 @@ def from_parquet( filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, - compression: pytyping.Any = None, + compression: typing.Any = None, connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... def from_query( @@ -1245,12 +1303,15 @@ def install_extension( def interrupt(*, connection: DuckDBPyConnection | None = None) -> None: ... def limit( df: pandas.DataFrame, - n: pytyping.SupportsInt, - offset: pytyping.SupportsInt = 0, + n: typing.SupportsInt, + offset: typing.SupportsInt = 0, *, connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... -def list_filesystems(*, connection: DuckDBPyConnection | None = None) -> list[str]: ... +def get_profiling_information(*, connection: DuckDBPyConnection | None = None, format: str = "json") -> str: ... +def enable_profiling(*, connection: DuckDBPyConnection | None = None) -> None: ... +def disable_profiling(*, connection: DuckDBPyConnection | None = None) -> None: ... +def list_filesystems(*, connection: DuckDBPyConnection | None = None) -> lst[str]: ... def list_type( type: sqltypes.DuckDBPyType, *, connection: DuckDBPyConnection | None = None ) -> sqltypes.DuckDBPyType: ... @@ -1264,29 +1325,29 @@ def map_type( def order( df: pandas.DataFrame, order_expr: str, *, connection: DuckDBPyConnection | None = None ) -> DuckDBPyRelation: ... -@pytyping.overload +@typing.overload def pl( - rows_per_batch: pytyping.SupportsInt = 1000000, + rows_per_batch: typing.SupportsInt = 1000000, *, - lazy: pytyping.Literal[False] = ..., + lazy: typing.Literal[False] = ..., connection: DuckDBPyConnection | None = None, ) -> polars.DataFrame: ... -@pytyping.overload +@typing.overload def pl( - rows_per_batch: pytyping.SupportsInt = 1000000, + rows_per_batch: typing.SupportsInt = 1000000, *, - lazy: pytyping.Literal[True], + lazy: typing.Literal[True], connection: DuckDBPyConnection | None = None, ) -> polars.LazyFrame: ... -@pytyping.overload +@typing.overload def pl( - rows_per_batch: pytyping.SupportsInt = 1000000, + rows_per_batch: typing.SupportsInt = 1000000, *, lazy: bool = False, connection: DuckDBPyConnection | None = None, -) -> pytyping.Union[polars.DataFrame, polars.LazyFrame]: ... +) -> polars.DataFrame | polars.LazyFrame: ... def project( - df: pandas.DataFrame, *args: str | Expression, groups: str = "", connection: DuckDBPyConnection | None = None + df: pandas.DataFrame, *args: _ExpressionLike, groups: str = "", connection: DuckDBPyConnection | None = None ) -> DuckDBPyRelation: ... def query( query: Statement | str, @@ -1304,7 +1365,7 @@ def query_df( ) -> DuckDBPyRelation: ... def query_progress(*, connection: DuckDBPyConnection | None = None) -> float: ... def read_csv( - path_or_buffer: str | bytes | os.PathLike[str], + path_or_buffer: str | bytes | os.PathLike[str] | os.PathLike[bytes] | typing.IO[bytes], header: bool | int | None = None, compression: str | None = None, sep: str | None = None, @@ -1312,8 +1373,8 @@ def read_csv( files_to_sniff: int | None = None, comment: str | None = None, thousands: str | None = None, - dtype: dict[str, str] | list[str] | None = None, - na_values: str | list[str] | None = None, + dtype: dict[str, str] | lst[str] | None = None, + na_values: str | lst[str] | None = None, skiprows: int | None = None, quotechar: str | None = None, escapechar: str | None = None, @@ -1326,17 +1387,17 @@ def read_csv( all_varchar: bool | None = None, normalize_names: bool | None = None, null_padding: bool | None = None, - names: list[str] | None = None, + names: lst[str] | None = None, lineterminator: str | None = None, columns: dict[str, str] | None = None, - auto_type_candidates: list[str] | None = None, + auto_type_candidates: lst[str] | None = None, max_line_size: int | None = None, ignore_errors: bool | None = None, store_rejects: bool | None = None, rejects_table: str | None = None, rejects_scan: str | None = None, rejects_limit: int | None = None, - force_not_null: list[str] | None = None, + force_not_null: lst[str] | None = None, buffer_size: int | None = None, decimal: str | None = None, allow_quoted_nulls: bool | None = None, @@ -1370,7 +1431,7 @@ def read_json( hive_types: dict[str, str] | None = None, hive_types_autocast: bool | None = None, ) -> DuckDBPyRelation: ... -@pytyping.overload +@typing.overload def read_parquet( file_glob: str, binary_as_string: bool = False, @@ -1382,7 +1443,7 @@ def read_parquet( compression: str | None = None, connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... -@pytyping.overload +@typing.overload def read_parquet( file_globs: Sequence[str], binary_as_string: bool = False, @@ -1391,7 +1452,7 @@ def read_parquet( filename: bool = False, hive_partitioning: bool = False, union_by_name: bool = False, - compression: pytyping.Any = None, + compression: typing.Any = None, connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... def register( @@ -1406,7 +1467,7 @@ def register_filesystem( def remove_function(name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... def rollback(*, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... def row_type( - fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType], + fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType], *, connection: DuckDBPyConnection | None = None, ) -> sqltypes.DuckDBPyType: ... @@ -1422,7 +1483,7 @@ def sql( def sqltype(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... def string_type(collation: str = "", *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... def struct_type( - fields: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType], + fields: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType], *, connection: DuckDBPyConnection | None = None, ) -> sqltypes.DuckDBPyType: ... @@ -1433,19 +1494,19 @@ def table_function( *, connection: DuckDBPyConnection | None = None, ) -> DuckDBPyRelation: ... -def tf(*, connection: DuckDBPyConnection | None = None) -> dict[str, tensorflow.Tensor]: ... -def tokenize(query: str) -> list[tuple[int, token_type]]: ... -def torch(*, connection: DuckDBPyConnection | None = None) -> dict[str, pytorch.Tensor]: ... +def tf(*, connection: DuckDBPyConnection | None = None) -> dict[str, typing.Any]: ... +def tokenize(query: str) -> lst[tuple[int, token_type]]: ... +def torch(*, connection: DuckDBPyConnection | None = None) -> dict[str, typing.Any]: ... def type(type_str: str, *, connection: DuckDBPyConnection | None = None) -> sqltypes.DuckDBPyType: ... def union_type( - members: dict[str, sqltypes.DuckDBPyType] | list[sqltypes.DuckDBPyType], + members: dict[str, sqltypes.DuckDBPyType] | lst[sqltypes.DuckDBPyType], *, connection: DuckDBPyConnection | None = None, ) -> sqltypes.DuckDBPyType: ... def unregister(view_name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyConnection: ... def unregister_filesystem(name: str, *, connection: DuckDBPyConnection | None = None) -> None: ... def values( - *args: list[pytyping.Any] | tuple[Expression, ...] | Expression, connection: DuckDBPyConnection | None = None + *args: lst[typing.Any] | tuple[Expression, ...] | Expression, connection: DuckDBPyConnection | None = None ) -> DuckDBPyRelation: ... def view(view_name: str, *, connection: DuckDBPyConnection | None = None) -> DuckDBPyRelation: ... def write_csv( @@ -1465,7 +1526,7 @@ def write_csv( overwrite: bool | None = None, per_thread_output: bool | None = None, use_tmp_file: bool | None = None, - partition_by: list[str] | None = None, + partition_by: lst[str] | None = None, write_partition_columns: bool | None = None, ) -> None: ... @@ -1475,7 +1536,7 @@ __interactive__: bool __jupyter__: bool __standard_vector_size__: int __version__: str -_clean_default_connection: pytyping.Any # value = +_clean_default_connection: typing.Any # value = apilevel: str paramstyle: str threadsafety: int diff --git a/_duckdb-stubs/_func.pyi b/_duckdb-stubs/_func.pyi index 68484499..5330ed04 100644 --- a/_duckdb-stubs/_func.pyi +++ b/_duckdb-stubs/_func.pyi @@ -1,40 +1,40 @@ -import typing as pytyping +import typing __all__: list[str] = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] class FunctionNullHandling: - DEFAULT: pytyping.ClassVar[FunctionNullHandling] # value = - SPECIAL: pytyping.ClassVar[FunctionNullHandling] # value = - __members__: pytyping.ClassVar[ + DEFAULT: typing.ClassVar[FunctionNullHandling] # value = + SPECIAL: typing.ClassVar[FunctionNullHandling] # value = + __members__: typing.ClassVar[ dict[str, FunctionNullHandling] ] # value = {'DEFAULT': , 'SPECIAL': } def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property def value(self) -> int: ... class PythonUDFType: - ARROW: pytyping.ClassVar[PythonUDFType] # value = - NATIVE: pytyping.ClassVar[PythonUDFType] # value = - __members__: pytyping.ClassVar[ + ARROW: typing.ClassVar[PythonUDFType] # value = + NATIVE: typing.ClassVar[PythonUDFType] # value = + __members__: typing.ClassVar[ dict[str, PythonUDFType] ] # value = {'NATIVE': , 'ARROW': } def __eq__(self, other: object) -> bool: ... def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... - def __init__(self, value: pytyping.SupportsInt) -> None: ... + def __init__(self, value: typing.SupportsInt) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: pytyping.SupportsInt) -> None: ... + def __setstate__(self, state: typing.SupportsInt) -> None: ... @property def name(self) -> str: ... @property diff --git a/_duckdb-stubs/_sqltypes.pyi b/_duckdb-stubs/_sqltypes.pyi index 88abb977..82e768eb 100644 --- a/_duckdb-stubs/_sqltypes.pyi +++ b/_duckdb-stubs/_sqltypes.pyi @@ -1,5 +1,5 @@ import duckdb -import typing as pytyping +import typing __all__: list[str] = [ "BIGINT", @@ -29,6 +29,7 @@ __all__: list[str] = [ "UTINYINT", "UUID", "VARCHAR", + "VARIANT", "DuckDBPyType", ] @@ -37,12 +38,12 @@ class DuckDBPyType: def __getattr__(self, name: str) -> DuckDBPyType: ... def __getitem__(self, name: str) -> DuckDBPyType: ... def __hash__(self) -> int: ... - @pytyping.overload + @typing.overload def __init__(self, type_str: str, connection: duckdb.DuckDBPyConnection) -> None: ... - @pytyping.overload + @typing.overload def __init__(self, obj: object) -> None: ... @property - def children(self) -> list[tuple[str, object]]: ... + def children(self) -> list[tuple[str, DuckDBPyType | int | list[str]]]: ... @property def id(self) -> str: ... @@ -59,6 +60,7 @@ INTERVAL: DuckDBPyType # value = INTERVAL SMALLINT: DuckDBPyType # value = SMALLINT SQLNULL: DuckDBPyType # value = "NULL" TIME: DuckDBPyType # value = TIME +TIME_NS: DuckDBPyType # value = TIME_NS TIMESTAMP: DuckDBPyType # value = TIMESTAMP TIMESTAMP_MS: DuckDBPyType # value = TIMESTAMP_MS TIMESTAMP_NS: DuckDBPyType # value = TIMESTAMP_NS @@ -73,3 +75,4 @@ USMALLINT: DuckDBPyType # value = USMALLINT UTINYINT: DuckDBPyType # value = UTINYINT UUID: DuckDBPyType # value = UUID VARCHAR: DuckDBPyType # value = VARCHAR +VARIANT: DuckDBPyType # value = VARIANT diff --git a/adbc_driver_duckdb/__init__.py b/adbc_driver_duckdb/__init__.py index f925ea9e..c2777c90 100644 --- a/adbc_driver_duckdb/__init__.py +++ b/adbc_driver_duckdb/__init__.py @@ -20,7 +20,6 @@ import enum import functools import importlib.util -import typing import adbc_driver_manager @@ -32,7 +31,7 @@ class StatementOptions(enum.Enum): BATCH_ROWS = "adbc.duckdb.query.batch_rows" -def connect(path: typing.Optional[str] = None) -> adbc_driver_manager.AdbcDatabase: +def connect(path: str | None = None) -> adbc_driver_manager.AdbcDatabase: """Create a low level ADBC connection to DuckDB.""" if path is None: return adbc_driver_manager.AdbcDatabase(driver=driver_path(), entrypoint="duckdb_adbc_init") diff --git a/adbc_driver_duckdb/dbapi.py b/adbc_driver_duckdb/dbapi.py index 5d0a8702..377f86a0 100644 --- a/adbc_driver_duckdb/dbapi.py +++ b/adbc_driver_duckdb/dbapi.py @@ -17,8 +17,6 @@ """DBAPI 2.0-compatible facade for the ADBC DuckDB driver.""" -import typing - import adbc_driver_manager import adbc_driver_manager.dbapi @@ -91,7 +89,7 @@ # Functions -def connect(path: typing.Optional[str] = None, **kwargs) -> "Connection": +def connect(path: str | None = None, **kwargs) -> "Connection": """Connect to DuckDB via ADBC.""" db = None conn = None diff --git a/cmake/duckdb_loader.cmake b/cmake/duckdb_loader.cmake index 2b8667ec..9c21aba0 100644 --- a/cmake/duckdb_loader.cmake +++ b/cmake/duckdb_loader.cmake @@ -250,9 +250,13 @@ function(duckdb_add_library target_name) endfunction() function(duckdb_link_extensions target_name) - # Link to the DuckDB static library and extensions - target_link_libraries(${target_name} - PRIVATE duckdb_generated_extension_loader) + # Link to the DuckDB static library and extensions We use WHOLE_ARCHIVE + # because duckdb_static calls LoadAllExtensions which is defined in the + # extension loader. Without this, linkers (especially on Linux with + # --as-needed) may drop the extension loader before seeing the reference. + target_link_libraries( + ${target_name} + PRIVATE "$") if(BUILD_EXTENSIONS) message(STATUS "Linking DuckDB extensions:") foreach(ext IN LISTS BUILD_EXTENSIONS) diff --git a/duckdb/__init__.py b/duckdb/__init__.py index a7370083..d17c530f 100644 --- a/duckdb/__init__.py +++ b/duckdb/__init__.py @@ -84,9 +84,11 @@ default_connection, description, df, + disable_profiling, distinct, dtype, duplicate, + enable_profiling, enum_type, execute, executemany, @@ -107,6 +109,7 @@ from_df, from_parquet, from_query, + get_profiling_information, get_table_names, install_extension, interrupt, @@ -140,6 +143,8 @@ table_function, tf, threadsafety, + to_arrow_reader, + to_arrow_table, token_type, tokenize, torch, @@ -310,9 +315,11 @@ "default_connection", "description", "df", + "disable_profiling", "distinct", "dtype", "duplicate", + "enable_profiling", "enum_type", "execute", "executemany", @@ -333,6 +340,7 @@ "from_df", "from_parquet", "from_query", + "get_profiling_information", "get_table_names", "install_extension", "interrupt", @@ -368,6 +376,8 @@ "tf", "threadsafety", "threadsafety", + "to_arrow_reader", + "to_arrow_table", "token_type", "tokenize", "torch", diff --git a/duckdb/bytes_io_wrapper.py b/duckdb/bytes_io_wrapper.py index 722c7cb4..d0ef78bf 100644 --- a/duckdb/bytes_io_wrapper.py +++ b/duckdb/bytes_io_wrapper.py @@ -34,7 +34,7 @@ """ from io import StringIO, TextIOBase -from typing import Any, Union +from typing import Any class BytesIOWrapper: @@ -43,7 +43,7 @@ class BytesIOWrapper: Created for compat with pyarrow read_csv. """ - def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") -> None: # noqa: D107 + def __init__(self, buffer: StringIO | TextIOBase, encoding: str = "utf-8") -> None: # noqa: D107 self.buffer = buffer self.encoding = encoding # Because a character can be represented by more than 1 byte, @@ -55,7 +55,7 @@ def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8") def __getattr__(self, attr: str) -> Any: # noqa: D105, ANN401 return getattr(self.buffer, attr) - def read(self, n: Union[int, None] = -1) -> bytes: # noqa: D102 + def read(self, n: int | None = -1) -> bytes: # noqa: D102 assert self.buffer is not None bytestring = self.buffer.read(n).encode(self.encoding) # When n=-1/n greater than remaining bytes: Read entire file/rest of file diff --git a/duckdb/experimental/spark/_globals.py b/duckdb/experimental/spark/_globals.py index 0625a140..23a6f171 100644 --- a/duckdb/experimental/spark/_globals.py +++ b/duckdb/experimental/spark/_globals.py @@ -33,6 +33,7 @@ def foo(arg=pyducdkb.spark._NoValue): __ALL__ = ["_NoValue"] +from typing_extensions import Self # Disallow reloading this module so as to preserve the identities of the # classes defined here. @@ -54,7 +55,7 @@ class _NoValueType: __instance = None - def __new__(cls) -> "_NoValueType": + def __new__(cls) -> Self: # ensure that only one instance exists if not cls.__instance: cls.__instance = super().__new__(cls) diff --git a/duckdb/experimental/spark/_typing.py b/duckdb/experimental/spark/_typing.py index 1ed78ea8..de7f2fff 100644 --- a/duckdb/experimental/spark/_typing.py +++ b/duckdb/experimental/spark/_typing.py @@ -16,16 +16,16 @@ # specific language governing permissions and limitations # under the License. -from collections.abc import Iterable, Sized -from typing import Callable, TypeVar, Union +from collections.abc import Callable, Iterable, Sized +from typing import Literal, TypeVar from numpy import float32, float64, int32, int64, ndarray -from typing_extensions import Literal, Protocol, Self +from typing_extensions import Protocol, Self F = TypeVar("F", bound=Callable) T_co = TypeVar("T_co", covariant=True) -PrimitiveType = Union[bool, float, int, str] +PrimitiveType = bool | float | int | str NonUDFType = Literal[0] diff --git a/duckdb/experimental/spark/conf.py b/duckdb/experimental/spark/conf.py index 974115d6..9b2cc0eb 100644 --- a/duckdb/experimental/spark/conf.py +++ b/duckdb/experimental/spark/conf.py @@ -1,5 +1,3 @@ -from typing import Optional # noqa: D100 - from duckdb.experimental.spark.exception import ContributionsAcceptedError @@ -10,7 +8,7 @@ def __init__(self) -> None: # noqa: D107 def contains(self, key: str) -> bool: # noqa: D102 raise ContributionsAcceptedError - def get(self, key: str, defaultValue: Optional[str] = None) -> Optional[str]: # noqa: D102 + def get(self, key: str, defaultValue: str | None = None) -> str | None: # noqa: D102 raise ContributionsAcceptedError def getAll(self) -> list[tuple[str, str]]: # noqa: D102 @@ -26,7 +24,7 @@ def setAppName(self, value: str) -> "SparkConf": # noqa: D102 raise ContributionsAcceptedError def setExecutorEnv( # noqa: D102 - self, key: Optional[str] = None, value: Optional[str] = None, pairs: Optional[list[tuple[str, str]]] = None + self, key: str | None = None, value: str | None = None, pairs: list[tuple[str, str]] | None = None ) -> "SparkConf": raise ContributionsAcceptedError diff --git a/duckdb/experimental/spark/context.py b/duckdb/experimental/spark/context.py index c78bde65..311153b2 100644 --- a/duckdb/experimental/spark/context.py +++ b/duckdb/experimental/spark/context.py @@ -1,5 +1,3 @@ -from typing import Optional # noqa: D100 - import duckdb from duckdb import DuckDBPyConnection from duckdb.experimental.spark.conf import SparkConf @@ -20,7 +18,7 @@ def stop(self) -> None: # noqa: D102 self._connection.close() @classmethod - def getOrCreate(cls, conf: Optional[SparkConf] = None) -> "SparkContext": # noqa: D102 + def getOrCreate(cls, conf: SparkConf | None = None) -> "SparkContext": # noqa: D102 raise ContributionsAcceptedError @classmethod @@ -93,13 +91,13 @@ def dump_profiles(self, path: str) -> None: # noqa: D102 # def emptyRDD(self) -> duckdb.experimental.spark.rdd.RDD[typing.Any]: # pass - def getCheckpointDir(self) -> Optional[str]: # noqa: D102 + def getCheckpointDir(self) -> str | None: # noqa: D102 raise ContributionsAcceptedError def getConf(self) -> SparkConf: # noqa: D102 raise ContributionsAcceptedError - def getLocalProperty(self, key: str) -> Optional[str]: # noqa: D102 + def getLocalProperty(self, key: str) -> str | None: # noqa: D102 raise ContributionsAcceptedError # def hadoopFile(self, path: str, inputFormatClass: str, keyClass: str, valueClass: str, diff --git a/duckdb/experimental/spark/errors/error_classes.py b/duckdb/experimental/spark/errors/error_classes.py index 22055cbf..c43a5f18 100644 --- a/duckdb/experimental/spark/errors/error_classes.py +++ b/duckdb/experimental/spark/errors/error_classes.py @@ -1,4 +1,4 @@ -# ruff: noqa: D100, E501 +# ruff: noqa: E501 # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/errors/exceptions/base.py b/duckdb/experimental/spark/errors/exceptions/base.py index 2eae2a19..9a60512f 100644 --- a/duckdb/experimental/spark/errors/exceptions/base.py +++ b/duckdb/experimental/spark/errors/exceptions/base.py @@ -1,4 +1,4 @@ -from typing import Optional, cast # noqa: D100 +from typing import cast from ..utils import ErrorClassesReader @@ -8,11 +8,11 @@ class PySparkException(Exception): def __init__( # noqa: D107 self, - message: Optional[str] = None, + message: str | None = None, # The error class, decides the message format, must be one of the valid options listed in 'error_classes.py' - error_class: Optional[str] = None, + error_class: str | None = None, # The dictionary listing the arguments specified in the message (or the error_class) - message_parameters: Optional[dict[str, str]] = None, + message_parameters: dict[str, str] | None = None, ) -> None: # `message` vs `error_class` & `message_parameters` are mutually exclusive. assert (message is not None and (error_class is None and message_parameters is None)) or ( @@ -31,7 +31,7 @@ def __init__( # noqa: D107 self.error_class = error_class self.message_parameters = message_parameters - def getErrorClass(self) -> Optional[str]: + def getErrorClass(self) -> str | None: """Returns an error class as a string. .. versionadded:: 3.4.0 @@ -43,7 +43,7 @@ def getErrorClass(self) -> Optional[str]: """ return self.error_class - def getMessageParameters(self) -> Optional[dict[str, str]]: + def getMessageParameters(self) -> dict[str, str] | None: """Returns a message parameters as a dictionary. .. versionadded:: 3.4.0 diff --git a/duckdb/experimental/spark/errors/utils.py b/duckdb/experimental/spark/errors/utils.py index 8a71f3b0..f2962fc8 100644 --- a/duckdb/experimental/spark/errors/utils.py +++ b/duckdb/experimental/spark/errors/utils.py @@ -1,4 +1,4 @@ -# # noqa: D100 +# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. diff --git a/duckdb/experimental/spark/exception.py b/duckdb/experimental/spark/exception.py index c3a7c1b6..440b7819 100644 --- a/duckdb/experimental/spark/exception.py +++ b/duckdb/experimental/spark/exception.py @@ -1,14 +1,10 @@ -# ruff: noqa: D100 -from typing import Optional - - class ContributionsAcceptedError(NotImplementedError): """This method is not planned to be implemented, if you would like to implement this method or show your interest in this method to other members of the community, feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb. """ # noqa: D205 - def __init__(self, message: Optional[str] = None) -> None: # noqa: D107 + def __init__(self, message: str | None = None) -> None: # noqa: D107 doc = self.__class__.__doc__ if message: doc = message + "\n" + doc diff --git a/duckdb/experimental/spark/sql/_typing.py b/duckdb/experimental/spark/sql/_typing.py index caf0058c..cf0b15e1 100644 --- a/duckdb/experimental/spark/sql/_typing.py +++ b/duckdb/experimental/spark/sql/_typing.py @@ -16,18 +16,18 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Callable from typing import ( Any, - Callable, - Optional, TypeVar, - Union, ) try: from typing import Literal, Protocol except ImportError: - from typing_extensions import Literal, Protocol + from typing import Literal + + from typing_extensions import Protocol import datetime import decimal @@ -36,14 +36,14 @@ from . import types from .column import Column -ColumnOrName = Union[Column, str] +ColumnOrName = Column | str ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName) DecimalLiteral = decimal.Decimal -DateTimeLiteral = Union[datetime.datetime, datetime.date] +DateTimeLiteral = datetime.datetime | datetime.date LiteralType = PrimitiveType -AtomicDataTypeOrString = Union[types.AtomicType, str] -DataTypeOrString = Union[types.DataType, str] -OptionalPrimitiveType = Optional[PrimitiveType] +AtomicDataTypeOrString = types.AtomicType | str +DataTypeOrString = types.DataType | str +OptionalPrimitiveType = PrimitiveType | None AtomicValue = TypeVar( "AtomicValue", diff --git a/duckdb/experimental/spark/sql/catalog.py b/duckdb/experimental/spark/sql/catalog.py index 70fc7b18..f43bab59 100644 --- a/duckdb/experimental/spark/sql/catalog.py +++ b/duckdb/experimental/spark/sql/catalog.py @@ -1,25 +1,25 @@ -from typing import NamedTuple, Optional, Union # noqa: D100 +from typing import NamedTuple from .session import SparkSession class Database(NamedTuple): # noqa: D101 name: str - description: Optional[str] + description: str | None locationUri: str class Table(NamedTuple): # noqa: D101 name: str - database: Optional[str] - description: Optional[str] + database: str | None + description: str | None tableType: str isTemporary: bool class Column(NamedTuple): # noqa: D101 name: str - description: Optional[str] + description: str | None dataType: str nullable: bool isPartition: bool @@ -28,7 +28,7 @@ class Column(NamedTuple): # noqa: D101 class Function(NamedTuple): # noqa: D101 name: str - description: Optional[str] + description: str | None className: str isTemporary: bool @@ -55,7 +55,7 @@ def transform_to_table(x: list[str]) -> Table: tables = [transform_to_table(x) for x in res] return tables - def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Column]: # noqa: D102 + def listColumns(self, tableName: str, dbName: str | None = None) -> list[Column]: # noqa: D102 query = f""" select column_name, data_type, is_nullable from duckdb_columns() where table_name = '{tableName}' """ @@ -63,13 +63,13 @@ def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Colu query += f" and database_name = '{dbName}'" res = self._session.conn.sql(query).fetchall() - def transform_to_column(x: list[Union[str, bool]]) -> Column: + def transform_to_column(x: list[str | bool]) -> Column: return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False) columns = [transform_to_column(x) for x in res] return columns - def listFunctions(self, dbName: Optional[str] = None) -> list[Function]: # noqa: D102 + def listFunctions(self, dbName: str | None = None) -> list[Function]: # noqa: D102 raise NotImplementedError def setCurrentDatabase(self, dbName: str) -> None: # noqa: D102 diff --git a/duckdb/experimental/spark/sql/column.py b/duckdb/experimental/spark/sql/column.py index 661e4da7..e013a56d 100644 --- a/duckdb/experimental/spark/sql/column.py +++ b/duckdb/experimental/spark/sql/column.py @@ -1,5 +1,8 @@ -from collections.abc import Iterable # noqa: D100 -from typing import TYPE_CHECKING, Any, Callable, Union, cast +from collections.abc import ( + Callable, + Iterable, +) +from typing import TYPE_CHECKING, Any, Union, cast from ..exception import ContributionsAcceptedError from .types import DataType @@ -222,11 +225,11 @@ def otherwise(self, value: Union["Column", str]) -> "Column": # noqa: D102 expr = self.expr.otherwise(v) return Column(expr) - def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102 + def cast(self, dataType: DataType | str) -> "Column": # noqa: D102 internal_type = DuckDBPyType(dataType) if isinstance(dataType, str) else dataType.duckdb_type return Column(self.expr.cast(internal_type)) - def isin(self, *cols: Union[Iterable[Union["Column", str]], Union["Column", str]]) -> "Column": # noqa: D102 + def isin(self, *cols: Iterable[Union["Column", str]] | Union["Column", str]) -> "Column": # noqa: D102 if len(cols) == 1 and isinstance(cols[0], (list, set)): # Only one argument supplied, it's a list cols = cast("tuple", cols[0]) diff --git a/duckdb/experimental/spark/sql/conf.py b/duckdb/experimental/spark/sql/conf.py index e44f2566..75a77899 100644 --- a/duckdb/experimental/spark/sql/conf.py +++ b/duckdb/experimental/spark/sql/conf.py @@ -1,5 +1,3 @@ -from typing import Optional, Union # noqa: D100 - from duckdb import DuckDBPyConnection from duckdb.experimental.spark._globals import _NoValue, _NoValueType @@ -17,7 +15,7 @@ def isModifiable(self, key: str) -> bool: # noqa: D102 def unset(self, key: str) -> None: # noqa: D102 raise NotImplementedError - def get(self, key: str, default: Union[Optional[str], _NoValueType] = _NoValue) -> str: # noqa: D102 + def get(self, key: str, default: str | None | _NoValueType = _NoValue) -> str: # noqa: D102 raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/dataframe.py b/duckdb/experimental/spark/sql/dataframe.py index e7519e81..83b2dd09 100644 --- a/duckdb/experimental/spark/sql/dataframe.py +++ b/duckdb/experimental/spark/sql/dataframe.py @@ -1,11 +1,10 @@ -import uuid # noqa: D100 +import uuid +from collections.abc import Callable from functools import reduce from keyword import iskeyword from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, Union, cast, overload, @@ -206,7 +205,7 @@ def withColumns(self, *colsMap: dict[str, Column]) -> "DataFrame": # In case anything is remaining, these are new columns # that we need to add to the DataFrame - for col_name, col in zip(column_names, columns): + for col_name, col in zip(column_names, columns, strict=False): cols.append(col.expr.alias(col_name)) rel = self.relation.select(*cols) @@ -341,7 +340,7 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) ) return result - def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any) -> "DataFrame": # noqa: ANN401 + def sort(self, *cols: str | Column | list[str | Column], **kwargs: Any) -> "DataFrame": # noqa: ANN401 """Returns a new :class:`DataFrame` sorted by the specified column(s). Parameters @@ -458,7 +457,7 @@ def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: An if not ascending: columns = [c.desc() for c in columns] elif isinstance(ascending, list): - columns = [c if asc else c.desc() for asc, c in zip(ascending, columns)] + columns = [c if asc else c.desc() for asc, c in zip(ascending, columns, strict=False)] else: raise PySparkTypeError( error_class="NOT_BOOL_OR_LIST", @@ -471,7 +470,7 @@ def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: An orderBy = sort - def head(self, n: Optional[int] = None) -> Union[Optional[Row], list[Row]]: # noqa: D102 + def head(self, n: int | None = None) -> Row | None | list[Row]: # noqa: D102 if n is None: rs = self.head(1) return rs[0] if rs else None @@ -597,8 +596,8 @@ def __dir__(self) -> list[str]: # noqa: D105 def join( self, other: "DataFrame", - on: Optional[Union[str, list[str], Column, list[Column]]] = None, - how: Optional[str] = None, + on: str | list[str] | Column | list[Column] | None = None, + how: str | None = None, ) -> "DataFrame": """Joins with another :class:`DataFrame`, using the given join expression. @@ -871,12 +870,12 @@ def schema(self) -> StructType: return self._schema @overload - def __getitem__(self, item: Union[int, str]) -> Column: ... + def __getitem__(self, item: int | str) -> Column: ... @overload - def __getitem__(self, item: Union[Column, list, tuple]) -> "DataFrame": ... + def __getitem__(self, item: Column | list | tuple) -> "DataFrame": ... - def __getitem__(self, item: Union[int, str, Column, list, tuple]) -> Union[Column, "DataFrame"]: + def __getitem__(self, item: int | str | Column | list | tuple) -> Union[Column, "DataFrame"]: """Returns the column as a :class:`Column`. Examples: @@ -919,7 +918,7 @@ def __getattr__(self, name: str) -> Column: def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": ... @overload - def groupBy(self, __cols: Union[list[Column], list[str]]) -> "GroupedData": ... # noqa: PYI063 + def groupBy(self, __cols: list[Column] | list[str]) -> "GroupedData": ... # noqa: PYI063 def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] """Groups the :class:`DataFrame` using the specified columns, @@ -997,7 +996,7 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] def write(self) -> DataFrameWriter: # noqa: D102 return DataFrameWriter(self) - def printSchema(self, level: Optional[int] = None) -> None: + def printSchema(self, level: int | None = None) -> None: """Prints out the schema in the tree format. Parameters @@ -1067,8 +1066,7 @@ def union(self, other: "DataFrame") -> "DataFrame": unionAll = union def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> "DataFrame": - """Returns a new :class:`DataFrame` containing union of rows in this and another - :class:`DataFrame`. + """Returns a new :class:`DataFrame` containing union of rows in this and another :class:`DataFrame`. This is different from both `UNION ALL` and `UNION DISTINCT` in SQL. To do a SQL-style set union (that does deduplication of elements), use this function followed by :func:`distinct`. @@ -1121,15 +1119,27 @@ def unionByName(self, other: "DataFrame", allowMissingColumns: bool = False) -> | 1| 2| 3|NULL| |NULL| 4| 5| 6| +----+----+----+----+ - """ # noqa: D205 + """ if allowMissingColumns: - cols = [] - for col in self.relation.columns: - if col in other.relation.columns: - cols.append(col) - else: - cols.append(spark_sql_functions.lit(None)) - other = other.select(*cols) + df1 = self.select( + *self.relation.columns, + *[ + spark_sql_functions.lit(None).alias(c) + for c in other.relation.columns + if c not in self.relation.columns + ], + ) + + df2 = other.select( + *[ + spark_sql_functions.col(c) + if c in other.relation.columns + else spark_sql_functions.lit(None).alias(c) + for c in df1.relation.columns + ] + ) + + return df1.unionByName(df2, allowMissingColumns=False) else: other = other.select(*self.relation.columns) @@ -1251,7 +1261,7 @@ def exceptAll(self, other: "DataFrame") -> "DataFrame": """ # noqa: D205 return DataFrame(self.relation.except_(other.relation), self.session) - def dropDuplicates(self, subset: Optional[list[str]] = None) -> "DataFrame": + def dropDuplicates(self, subset: list[str] | None = None) -> "DataFrame": """Return a new :class:`DataFrame` with duplicate rows removed, optionally only considering certain columns. @@ -1360,7 +1370,8 @@ def _cast_types(self, *types) -> "DataFrame": assert types_count == len(existing_columns) cast_expressions = [ - f"{existing}::{target_type} as {existing}" for existing, target_type in zip(existing_columns, types) + f"{existing}::{target_type} as {existing}" + for existing, target_type in zip(existing_columns, types, strict=False) ] cast_expressions = ", ".join(cast_expressions) new_rel = self.relation.project(cast_expressions) @@ -1373,7 +1384,7 @@ def toDF(self, *cols) -> "DataFrame": # noqa: D102 raise PySparkValueError(message="Provided column names and number of columns in the DataFrame don't match") existing_columns = [ColumnExpression(x) for x in existing_columns] - projections = [existing.alias(new) for existing, new in zip(existing_columns, cols)] + projections = [existing.alias(new) for existing, new in zip(existing_columns, cols, strict=False)] new_rel = self.relation.project(*projections) return DataFrame(new_rel, self.session) diff --git a/duckdb/experimental/spark/sql/functions.py b/duckdb/experimental/spark/sql/functions.py index 49c475a4..71ff8c59 100644 --- a/duckdb/experimental/spark/sql/functions.py +++ b/duckdb/experimental/spark/sql/functions.py @@ -1,5 +1,6 @@ -import warnings # noqa: D100 -from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload +import warnings +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional, Union, overload from duckdb import ( CaseExpression, @@ -109,7 +110,7 @@ def ucase(str: "ColumnOrName") -> Column: return upper(str) -def when(condition: "Column", value: Union[Column, str]) -> Column: # noqa: D103 +def when(condition: "Column", value: Column | str) -> Column: # noqa: D103 if not isinstance(condition, Column): msg = "condition should be a Column" raise TypeError(msg) @@ -118,7 +119,7 @@ def when(condition: "Column", value: Union[Column, str]) -> Column: # noqa: D10 return Column(expr) -def _inner_expr_or_val(val: Union[Column, str]) -> Union[Column, str]: +def _inner_expr_or_val(val: Column | str) -> Column | str: return val.expr if isinstance(val, Column) else val @@ -126,7 +127,7 @@ def struct(*cols: Column) -> Column: # noqa: D103 return Column(FunctionExpression("struct_pack", *[_inner_expr_or_val(x) for x in cols])) -def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["ColumnOrName", ...]]]) -> Column: +def array(*cols: Union["ColumnOrName", list["ColumnOrName"] | tuple["ColumnOrName", ...]]) -> Column: r"""Creates a new array column. .. versionadded:: 1.4.0 @@ -449,7 +450,7 @@ def right(str: "ColumnOrName", len: "ColumnOrName") -> Column: ) -def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None) -> Column: +def levenshtein(left: "ColumnOrName", right: "ColumnOrName", threshold: int | None = None) -> Column: """Computes the Levenshtein distance of the two given strings. .. versionadded:: 1.5.0 @@ -766,7 +767,7 @@ def collect_list(col: "ColumnOrName") -> Column: return array_agg(col) -def array_append(col: "ColumnOrName", value: Union[Column, str]) -> Column: +def array_append(col: "ColumnOrName", value: Column | str) -> Column: """Collection function: returns an array of the elements in col1 along with the added element in col2 at the last of the array. @@ -800,7 +801,7 @@ def array_append(col: "ColumnOrName", value: Union[Column, str]) -> Column: return _invoke_function("list_append", _to_column_expr(col), _get_expr(value)) -def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Union[Column, str]) -> Column: +def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Column | str) -> Column: """Collection function: adds an item into a given array at a specified array index. Array indices start at 1, or start from the end if index is negative. Index above array size appends the array, or prepends the array if index is negative, @@ -893,7 +894,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Un ) -def array_contains(col: "ColumnOrName", value: Union[Column, str]) -> Column: +def array_contains(col: "ColumnOrName", value: Column | str) -> Column: """Collection function: returns null if the array is null, true if the array contains the given value, and false otherwise. @@ -1373,7 +1374,7 @@ def count(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("count", col) -def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: +def approx_count_distinct(col: "ColumnOrName", rsd: float | None = None) -> Column: """Aggregate function: returns a new :class:`~pyspark.sql.Column` for approximate distinct count of column `col`. @@ -1410,7 +1411,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C return _invoke_function_over_columns("approx_count_distinct", col) -def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: +def approxCountDistinct(col: "ColumnOrName", rsd: float | None = None) -> Column: """.. versionadded:: 1.3.0. .. versionchanged:: 3.4.0 @@ -1433,7 +1434,7 @@ def transform(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Col def transform( col: "ColumnOrName", - f: Union[Callable[[Column], Column], Callable[[Column, Column], Column]], + f: Callable[[Column], Column] | Callable[[Column, Column], Column], ) -> Column: """Returns an array of elements after applying a transformation to each element in the input array. @@ -2255,7 +2256,7 @@ def product(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("product", col) -def rand(seed: Optional[int] = None) -> Column: +def rand(seed: int | None = None) -> Column: """Generates a random column with independent and identically distributed (i.i.d.) samples uniformly distributed in [0.0, 1.0). @@ -2419,7 +2420,7 @@ def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: ) -def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: Optional[Union[int, Column]] = None) -> Column: +def regexp_extract_all(str: "ColumnOrName", regexp: "ColumnOrName", idx: int | Column | None = None) -> Column: r"""Extract all strings in the `str` that match the Java regex `regexp` and corresponding to the regex group index. @@ -4968,7 +4969,7 @@ def add_months(start: "ColumnOrName", months: Union["ColumnOrName", int]) -> Col return _invoke_function("date_add", _to_column_expr(start), FunctionExpression("to_months", months)).cast("date") -def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[str] = None) -> Column: +def array_join(col: "ColumnOrName", delimiter: str, null_replacement: str | None = None) -> Column: """Concatenates the elements of `column` using the `delimiter`. Null values are replaced with `null_replacement` if set, otherwise they are ignored. @@ -5136,7 +5137,7 @@ def array_size(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("len", col) -def array_sort(col: "ColumnOrName", comparator: Optional[Callable[[Column, Column], Column]] = None) -> Column: +def array_sort(col: "ColumnOrName", comparator: Callable[[Column, Column], Column] | None = None) -> Column: """Collection function: sorts the input array in ascending order. The elements of the input array must be orderable. Null elements will be placed at the end of the returned array. @@ -5592,7 +5593,7 @@ def zeroifnull(col: "ColumnOrName") -> Column: return coalesce(col, lit(0)) -def _to_date_or_timestamp(col: "ColumnOrName", spark_datatype: _types.DataType, format: Optional[str] = None) -> Column: +def _to_date_or_timestamp(col: "ColumnOrName", spark_datatype: _types.DataType, format: str | None = None) -> Column: if format is not None: raise ContributionsAcceptedError( "format is not yet supported as DuckDB and PySpark use a different way of specifying them." @@ -5601,7 +5602,7 @@ def _to_date_or_timestamp(col: "ColumnOrName", spark_datatype: _types.DataType, return Column(_to_column_expr(col)).cast(spark_datatype) -def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: +def to_date(col: "ColumnOrName", format: str | None = None) -> Column: """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.DateType` using the optionally specified format. Specify formats according to `datetime pattern`_. By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format @@ -5639,7 +5640,7 @@ def to_date(col: "ColumnOrName", format: Optional[str] = None) -> Column: return _to_date_or_timestamp(col, _types.DateType(), format) -def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: +def to_timestamp(col: "ColumnOrName", format: str | None = None) -> Column: """Converts a :class:`~pyspark.sql.Column` into :class:`pyspark.sql.types.TimestampType` using the optionally specified format. Specify formats according to `datetime pattern`_. By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format diff --git a/duckdb/experimental/spark/sql/group.py b/duckdb/experimental/spark/sql/group.py index aa3e56d6..5f784453 100644 --- a/duckdb/experimental/spark/sql/group.py +++ b/duckdb/experimental/spark/sql/group.py @@ -1,4 +1,4 @@ -# # noqa: D100 +# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -15,7 +15,8 @@ # limitations under the License. # -from typing import TYPE_CHECKING, Callable, Union, overload +from collections.abc import Callable +from typing import TYPE_CHECKING, overload from ..exception import ContributionsAcceptedError from .column import Column @@ -319,7 +320,7 @@ def agg(self, *exprs: Column) -> DataFrame: ... @overload def agg(self, __exprs: dict[str, str]) -> DataFrame: ... # noqa: PYI063 - def agg(self, *exprs: Union[Column, dict[str, str]]) -> DataFrame: + def agg(self, *exprs: Column | dict[str, str]) -> DataFrame: """Compute aggregates and returns the result as a :class:`DataFrame`. The available aggregate functions can be: diff --git a/duckdb/experimental/spark/sql/readwriter.py b/duckdb/experimental/spark/sql/readwriter.py index eef99043..230d5d2a 100644 --- a/duckdb/experimental/spark/sql/readwriter.py +++ b/duckdb/experimental/spark/sql/readwriter.py @@ -1,11 +1,11 @@ -from typing import TYPE_CHECKING, Optional, Union, cast # noqa: D100 +from typing import TYPE_CHECKING, cast from ..errors import PySparkNotImplementedError, PySparkTypeError from ..exception import ContributionsAcceptedError from .types import StructType -PrimitiveType = Union[bool, float, int, str] -OptionalPrimitiveType = Optional[PrimitiveType] +PrimitiveType = bool | float | int | str +OptionalPrimitiveType = PrimitiveType | None if TYPE_CHECKING: from duckdb.experimental.spark.sql.dataframe import DataFrame @@ -23,9 +23,9 @@ def saveAsTable(self, table_name: str) -> None: # noqa: D102 def parquet( # noqa: D102 self, path: str, - mode: Optional[str] = None, - partitionBy: Union[str, list[str], None] = None, - compression: Optional[str] = None, + mode: str | None = None, + partitionBy: str | list[str] | None = None, + compression: str | None = None, ) -> None: relation = self.dataframe.relation if mode: @@ -38,23 +38,23 @@ def parquet( # noqa: D102 def csv( # noqa: D102 self, path: str, - mode: Optional[str] = None, - compression: Optional[str] = None, - sep: Optional[str] = None, - quote: Optional[str] = None, - escape: Optional[str] = None, - header: Optional[Union[bool, str]] = None, - nullValue: Optional[str] = None, - escapeQuotes: Optional[Union[bool, str]] = None, - quoteAll: Optional[Union[bool, str]] = None, - dateFormat: Optional[str] = None, - timestampFormat: Optional[str] = None, - ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None, - ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None, - charToEscapeQuoteEscaping: Optional[str] = None, - encoding: Optional[str] = None, - emptyValue: Optional[str] = None, - lineSep: Optional[str] = None, + mode: str | None = None, + compression: str | None = None, + sep: str | None = None, + quote: str | None = None, + escape: str | None = None, + header: bool | str | None = None, + nullValue: str | None = None, + escapeQuotes: bool | str | None = None, + quoteAll: bool | str | None = None, + dateFormat: str | None = None, + timestampFormat: str | None = None, + ignoreLeadingWhiteSpace: bool | str | None = None, + ignoreTrailingWhiteSpace: bool | str | None = None, + charToEscapeQuoteEscaping: str | None = None, + encoding: str | None = None, + emptyValue: str | None = None, + lineSep: str | None = None, ) -> None: if mode not in (None, "overwrite"): raise NotImplementedError @@ -92,9 +92,9 @@ def __init__(self, session: "SparkSession") -> None: # noqa: D107 def load( # noqa: D102 self, - path: Optional[Union[str, list[str]]] = None, - format: Optional[str] = None, - schema: Optional[Union[StructType, str]] = None, + path: str | list[str] | None = None, + format: str | None = None, + schema: StructType | str | None = None, **options: OptionalPrimitiveType, ) -> "DataFrame": from duckdb.experimental.spark.sql.dataframe import DataFrame @@ -129,40 +129,40 @@ def load( # noqa: D102 def csv( # noqa: D102 self, - path: Union[str, list[str]], - schema: Optional[Union[StructType, str]] = None, - sep: Optional[str] = None, - encoding: Optional[str] = None, - quote: Optional[str] = None, - escape: Optional[str] = None, - comment: Optional[str] = None, - header: Optional[Union[bool, str]] = None, - inferSchema: Optional[Union[bool, str]] = None, - ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = None, - ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = None, - nullValue: Optional[str] = None, - nanValue: Optional[str] = None, - positiveInf: Optional[str] = None, - negativeInf: Optional[str] = None, - dateFormat: Optional[str] = None, - timestampFormat: Optional[str] = None, - maxColumns: Optional[Union[int, str]] = None, - maxCharsPerColumn: Optional[Union[int, str]] = None, - maxMalformedLogPerPartition: Optional[Union[int, str]] = None, - mode: Optional[str] = None, - columnNameOfCorruptRecord: Optional[str] = None, - multiLine: Optional[Union[bool, str]] = None, - charToEscapeQuoteEscaping: Optional[str] = None, - samplingRatio: Optional[Union[float, str]] = None, - enforceSchema: Optional[Union[bool, str]] = None, - emptyValue: Optional[str] = None, - locale: Optional[str] = None, - lineSep: Optional[str] = None, - pathGlobFilter: Optional[Union[bool, str]] = None, - recursiveFileLookup: Optional[Union[bool, str]] = None, - modifiedBefore: Optional[Union[bool, str]] = None, - modifiedAfter: Optional[Union[bool, str]] = None, - unescapedQuoteHandling: Optional[str] = None, + path: str | list[str], + schema: StructType | str | None = None, + sep: str | None = None, + encoding: str | None = None, + quote: str | None = None, + escape: str | None = None, + comment: str | None = None, + header: bool | str | None = None, + inferSchema: bool | str | None = None, + ignoreLeadingWhiteSpace: bool | str | None = None, + ignoreTrailingWhiteSpace: bool | str | None = None, + nullValue: str | None = None, + nanValue: str | None = None, + positiveInf: str | None = None, + negativeInf: str | None = None, + dateFormat: str | None = None, + timestampFormat: str | None = None, + maxColumns: int | str | None = None, + maxCharsPerColumn: int | str | None = None, + maxMalformedLogPerPartition: int | str | None = None, + mode: str | None = None, + columnNameOfCorruptRecord: str | None = None, + multiLine: bool | str | None = None, + charToEscapeQuoteEscaping: str | None = None, + samplingRatio: float | str | None = None, + enforceSchema: bool | str | None = None, + emptyValue: str | None = None, + locale: str | None = None, + lineSep: str | None = None, + pathGlobFilter: bool | str | None = None, + recursiveFileLookup: bool | str | None = None, + modifiedBefore: bool | str | None = None, + modifiedAfter: bool | str | None = None, + unescapedQuoteHandling: str | None = None, ) -> "DataFrame": if not isinstance(path, str): raise NotImplementedError @@ -263,31 +263,31 @@ def parquet(self, *paths: str, **options: "OptionalPrimitiveType") -> "DataFrame def json( self, - path: Union[str, list[str]], - schema: Optional[Union[StructType, str]] = None, - primitivesAsString: Optional[Union[bool, str]] = None, - prefersDecimal: Optional[Union[bool, str]] = None, - allowComments: Optional[Union[bool, str]] = None, - allowUnquotedFieldNames: Optional[Union[bool, str]] = None, - allowSingleQuotes: Optional[Union[bool, str]] = None, - allowNumericLeadingZero: Optional[Union[bool, str]] = None, - allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = None, - mode: Optional[str] = None, - columnNameOfCorruptRecord: Optional[str] = None, - dateFormat: Optional[str] = None, - timestampFormat: Optional[str] = None, - multiLine: Optional[Union[bool, str]] = None, - allowUnquotedControlChars: Optional[Union[bool, str]] = None, - lineSep: Optional[str] = None, - samplingRatio: Optional[Union[float, str]] = None, - dropFieldIfAllNull: Optional[Union[bool, str]] = None, - encoding: Optional[str] = None, - locale: Optional[str] = None, - pathGlobFilter: Optional[Union[bool, str]] = None, - recursiveFileLookup: Optional[Union[bool, str]] = None, - modifiedBefore: Optional[Union[bool, str]] = None, - modifiedAfter: Optional[Union[bool, str]] = None, - allowNonNumericNumbers: Optional[Union[bool, str]] = None, + path: str | list[str], + schema: StructType | str | None = None, + primitivesAsString: bool | str | None = None, + prefersDecimal: bool | str | None = None, + allowComments: bool | str | None = None, + allowUnquotedFieldNames: bool | str | None = None, + allowSingleQuotes: bool | str | None = None, + allowNumericLeadingZero: bool | str | None = None, + allowBackslashEscapingAnyCharacter: bool | str | None = None, + mode: str | None = None, + columnNameOfCorruptRecord: str | None = None, + dateFormat: str | None = None, + timestampFormat: str | None = None, + multiLine: bool | str | None = None, + allowUnquotedControlChars: bool | str | None = None, + lineSep: str | None = None, + samplingRatio: float | str | None = None, + dropFieldIfAllNull: bool | str | None = None, + encoding: str | None = None, + locale: str | None = None, + pathGlobFilter: bool | str | None = None, + recursiveFileLookup: bool | str | None = None, + modifiedBefore: bool | str | None = None, + modifiedAfter: bool | str | None = None, + allowNonNumericNumbers: bool | str | None = None, ) -> "DataFrame": """Loads JSON files and returns the results as a :class:`DataFrame`. diff --git a/duckdb/experimental/spark/sql/session.py b/duckdb/experimental/spark/sql/session.py index b05b9705..c407a9f1 100644 --- a/duckdb/experimental/spark/sql/session.py +++ b/duckdb/experimental/spark/sql/session.py @@ -1,6 +1,6 @@ -import uuid # noqa: D100 +import uuid from collections.abc import Iterable, Sized -from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union +from typing import TYPE_CHECKING, Any, NoReturn, Union import duckdb @@ -38,7 +38,7 @@ def _combine_data_and_schema(data: Iterable[Any], schema: StructType) -> list[du new_data = [] for row in data: - new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema])] + new_row = [Value(x, dtype.duckdb_type) for x, dtype in zip(row, [y.dataType for y in schema], strict=False)] new_data.append(new_row) return new_data @@ -113,7 +113,7 @@ def construct_parameters(tuples: Iterable) -> list[list]: return DataFrame(rel, self) def _createDataFrameFromPandas( - self, data: "PandasDataFrame", types: Union[list[str], None], names: Union[list[str], None] + self, data: "PandasDataFrame", types: list[str] | None, names: list[str] | None ) -> DataFrame: df = self._create_dataframe(data) @@ -128,8 +128,8 @@ def _createDataFrameFromPandas( def createDataFrame( # noqa: D102 self, data: Union["PandasDataFrame", Iterable[Any]], - schema: Optional[Union[StructType, list[str]]] = None, - samplingRatio: Optional[float] = None, + schema: StructType | list[str] | None = None, + samplingRatio: float | None = None, verifySchema: bool = True, ) -> DataFrame: if samplingRatio: @@ -194,9 +194,9 @@ def newSession(self) -> "SparkSession": # noqa: D102 def range( # noqa: D102 self, start: int, - end: Optional[int] = None, + end: int | None = None, step: int = 1, - numPartitions: Optional[int] = None, + numPartitions: int | None = None, ) -> "DataFrame": if numPartitions: raise ContributionsAcceptedError @@ -281,9 +281,9 @@ def getOrCreate(self) -> "SparkSession": # noqa: D102 def config( # noqa: D102 self, - key: Optional[str] = None, - value: Optional[Any] = None, # noqa: ANN401 - conf: Optional[SparkConf] = None, + key: str | None = None, + value: Any | None = None, # noqa: ANN401 + conf: SparkConf | None = None, ) -> "SparkSession.Builder": return self diff --git a/duckdb/experimental/spark/sql/streaming.py b/duckdb/experimental/spark/sql/streaming.py index 08b7cc30..e40bfbf4 100644 --- a/duckdb/experimental/spark/sql/streaming.py +++ b/duckdb/experimental/spark/sql/streaming.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, Union # noqa: D100 +from typing import TYPE_CHECKING from .types import StructType @@ -6,8 +6,8 @@ from .dataframe import DataFrame from .session import SparkSession -PrimitiveType = Union[bool, float, int, str] -OptionalPrimitiveType = Optional[PrimitiveType] +PrimitiveType = bool | float | int | str +OptionalPrimitiveType = PrimitiveType | None class DataStreamWriter: # noqa: D101 @@ -25,9 +25,9 @@ def __init__(self, session: "SparkSession") -> None: # noqa: D107 def load( # noqa: D102 self, - path: Optional[str] = None, - format: Optional[str] = None, - schema: Union[StructType, str, None] = None, + path: str | None = None, + format: str | None = None, + schema: StructType | str | None = None, **options: OptionalPrimitiveType, ) -> "DataFrame": raise NotImplementedError diff --git a/duckdb/experimental/spark/sql/type_utils.py b/duckdb/experimental/spark/sql/type_utils.py index 15217788..2e15e38b 100644 --- a/duckdb/experimental/spark/sql/type_utils.py +++ b/duckdb/experimental/spark/sql/type_utils.py @@ -1,4 +1,4 @@ -from typing import cast # noqa: D100 +from typing import cast from duckdb.sqltypes import DuckDBPyType @@ -23,6 +23,7 @@ StringType, StructField, StructType, + TimeNSType, TimeNTZType, TimestampMillisecondNTZType, TimestampNanosecondNTZType, @@ -36,6 +37,7 @@ UnsignedLongType, UnsignedShortType, UUIDType, + VariantType, ) _sqltype_to_spark_class = { @@ -56,6 +58,7 @@ "uuid": UUIDType, "date": DateType, "time": TimeNTZType, + "time_ns": TimeNSType, "time with time zone": TimeType, "timestamp": TimestampNTZType, "timestamp with time zone": TimestampType, @@ -72,6 +75,7 @@ "float": FloatType, "double": DoubleType, "decimal": DecimalType, + "variant": VariantType, } @@ -109,5 +113,5 @@ def convert_type(dtype: DuckDBPyType) -> DataType: # noqa: D103 def duckdb_to_spark_schema(names: list[str], types: list[DuckDBPyType]) -> StructType: # noqa: D103 - fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types])] + fields = [StructField(name, dtype) for name, dtype in zip(names, [convert_type(x) for x in types], strict=False)] return StructType(fields) diff --git a/duckdb/experimental/spark/sql/types.py b/duckdb/experimental/spark/sql/types.py index 2e87f9ff..5bfff09f 100644 --- a/duckdb/experimental/spark/sql/types.py +++ b/duckdb/experimental/spark/sql/types.py @@ -1,4 +1,3 @@ -# ruff: noqa: D100 # This code is based on code from Apache Spark under the license found in the LICENSE # file located in the 'spark' folder. @@ -10,16 +9,9 @@ from builtins import tuple from collections.abc import Iterator, Mapping from types import MappingProxyType -from typing import ( - Any, - ClassVar, - NoReturn, - Optional, - TypeVar, - Union, - cast, - overload, -) +from typing import Any, ClassVar, NoReturn, TypeVar, Union, cast, overload + +from typing_extensions import Self import duckdb from duckdb.sqltypes import DuckDBPyType @@ -51,6 +43,7 @@ "StringType", "StructField", "StructType", + "TimeNSType", "TimeNTZType", "TimeType", "TimestampMillisecondNTZType", @@ -64,6 +57,7 @@ "UnsignedIntegerType", "UnsignedLongType", "UnsignedShortType", + "VariantType", ] @@ -92,7 +86,7 @@ def typeName(cls) -> str: # noqa: D102 def simpleString(self) -> str: # noqa: D102 return self.typeName() - def jsonValue(self) -> Union[str, dict[str, Any]]: # noqa: D102 + def jsonValue(self) -> str | dict[str, Any]: # noqa: D102 raise ContributionsAcceptedError def json(self) -> str: # noqa: D102 @@ -194,6 +188,13 @@ def __init__(self) -> None: # noqa: D107 super().__init__(DuckDBPyType("BOOLEAN")) +class VariantType(AtomicType, metaclass=DataTypeSingleton): + """Variant (semi-structured) data type.""" + + def __init__(self) -> None: # noqa: D107 + super().__init__(DuckDBPyType("VARIANT")) + + class DateType(AtomicType, metaclass=DataTypeSingleton): """Date (datetime.date) data type.""" @@ -505,6 +506,16 @@ def simpleString(self) -> str: # noqa: D102 return "time" +class TimeNSType(IntegralType): + """Time NS (datetime.time) data type without timezone information.""" + + def __init__(self) -> None: # noqa: D107 + super().__init__(DuckDBPyType("TIME_NS")) + + def simpleString(self) -> str: # noqa: D102 + return "time_ns" + + class DayTimeIntervalType(AtomicType): """DayTimeIntervalType (datetime.timedelta).""" @@ -522,9 +533,9 @@ class DayTimeIntervalType(AtomicType): } ) - _inverted_fields: Mapping[int, str] = MappingProxyType(dict(zip(_fields.values(), _fields.keys()))) + _inverted_fields: Mapping[int, str] = MappingProxyType(dict(zip(_fields.values(), _fields.keys(), strict=False))) - def __init__(self, startField: Optional[int] = None, endField: Optional[int] = None) -> None: # noqa: D107 + def __init__(self, startField: int | None = None, endField: int | None = None) -> None: # noqa: D107 super().__init__(DuckDBPyType("INTERVAL")) if startField is None and endField is None: # Default matched to scala side. @@ -557,11 +568,11 @@ def __repr__(self) -> str: # noqa: D105 def needConversion(self) -> bool: # noqa: D102 return True - def toInternal(self, dt: datetime.timedelta) -> Optional[int]: # noqa: D102 + def toInternal(self, dt: datetime.timedelta) -> int | None: # noqa: D102 if dt is not None: return (math.floor(dt.total_seconds()) * 1000000) + dt.microseconds - def fromInternal(self, micros: int) -> Optional[datetime.timedelta]: # noqa: D102 + def fromInternal(self, micros: int) -> datetime.timedelta | None: # noqa: D102 if micros is not None: return datetime.timedelta(microseconds=micros) @@ -599,12 +610,12 @@ def __repr__(self) -> str: # noqa: D105 def needConversion(self) -> bool: # noqa: D102 return self.elementType.needConversion() - def toInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 + def toInternal(self, obj: list[T | None]) -> list[T | None]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.toInternal(v) for v in obj] - def fromInternal(self, obj: list[Optional[T]]) -> list[Optional[T]]: # noqa: D102 + def fromInternal(self, obj: list[T | None]) -> list[T | None]: # noqa: D102 if not self.needConversion(): return obj return obj and [self.elementType.fromInternal(v) for v in obj] @@ -651,12 +662,12 @@ def __repr__(self) -> str: # noqa: D105 def needConversion(self) -> bool: # noqa: D102 return self.keyType.needConversion() or self.valueType.needConversion() - def toInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 + def toInternal(self, obj: dict[T, U | None]) -> dict[T, U | None]: # noqa: D102 if not self.needConversion(): return obj return obj and {self.keyType.toInternal(k): self.valueType.toInternal(v) for k, v in obj.items()} - def fromInternal(self, obj: dict[T, Optional[U]]) -> dict[T, Optional[U]]: # noqa: D102 + def fromInternal(self, obj: dict[T, U | None]) -> dict[T, U | None]: # noqa: D102 if not self.needConversion(): return obj return obj and {self.keyType.fromInternal(k): self.valueType.fromInternal(v) for k, v in obj.items()} @@ -689,7 +700,7 @@ def __init__( # noqa: D107 name: str, dataType: DataType, nullable: bool = True, - metadata: Optional[dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> None: super().__init__(dataType.duckdb_type) assert isinstance(dataType, DataType), f"dataType {dataType} should be an instance of {DataType}" @@ -748,9 +759,9 @@ class StructType(DataType): """ def _update_internal_duckdb_type(self) -> None: - self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields]))) + self.duckdb_type = duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields], strict=False))) - def __init__(self, fields: Optional[list[StructField]] = None) -> None: # noqa: D107 + def __init__(self, fields: list[StructField] | None = None) -> None: # noqa: D107 if not fields: self.fields = [] self.names = [] @@ -761,15 +772,15 @@ def __init__(self, fields: Optional[list[StructField]] = None) -> None: # noqa: # Precalculated list of fields that need conversion with fromInternal/toInternal functions self._needConversion = [f.needConversion() for f in self] self._needSerializeAnyField = any(self._needConversion) - super().__init__(duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields])))) + super().__init__(duckdb.struct_type(dict(zip(self.names, [x.duckdb_type for x in self.fields], strict=False)))) @overload def add( self, field: str, - data_type: Union[str, DataType], + data_type: str | DataType, nullable: bool = True, - metadata: Optional[dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> "StructType": ... @overload @@ -777,10 +788,10 @@ def add(self, field: StructField) -> "StructType": ... def add( self, - field: Union[str, StructField], - data_type: Optional[Union[str, DataType]] = None, + field: str | StructField, + data_type: str | DataType | None = None, nullable: bool = True, - metadata: Optional[dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> "StructType": r"""Construct a :class:`StructType` by adding new elements to it, to define the schema. The method accepts either: @@ -846,7 +857,7 @@ def __len__(self) -> int: """Return the number of fields.""" return len(self.fields) - def __getitem__(self, key: Union[str, int]) -> StructField: + def __getitem__(self, key: str | int) -> StructField: """Access fields by name or slice.""" if isinstance(key, str): for field in self: @@ -894,7 +905,7 @@ def fieldNames(self) -> list[str]: """ return list(self.names) - def treeString(self, level: Optional[int] = None) -> str: + def treeString(self, level: int | None = None) -> str: """Returns a string representation of the schema in tree format. Parameters @@ -915,7 +926,7 @@ def treeString(self, level: Optional[int] = None) -> str: |-- age: integer (nullable = true) """ - def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]: + def _tree_string(schema: "StructType", depth: int = 0, max_depth: int | None = None) -> list[str]: """Recursively build tree string lines.""" lines = [] if depth == 0: @@ -978,15 +989,17 @@ def toInternal(self, obj: tuple) -> tuple: # noqa: D102 if isinstance(obj, dict): return tuple( f.toInternal(obj.get(n)) if c else obj.get(n) - for n, f, c in zip(self.names, self.fields, self._needConversion) + for n, f, c in zip(self.names, self.fields, self._needConversion, strict=False) ) elif isinstance(obj, (tuple, list)): - return tuple(f.toInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion)) + return tuple( + f.toInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion, strict=False) + ) elif hasattr(obj, "__dict__"): d = obj.__dict__ return tuple( f.toInternal(d.get(n)) if c else d.get(n) - for n, f, c in zip(self.names, self.fields, self._needConversion) + for n, f, c in zip(self.names, self.fields, self._needConversion, strict=False) ) else: msg = f"Unexpected tuple {obj!r} with StructType" @@ -1010,10 +1023,12 @@ def fromInternal(self, obj: tuple) -> "Row": # noqa: D102 # it's already converted by pickler return obj - values: Union[tuple, list] + values: tuple | list if self._needSerializeAnyField: # Only calling fromInternal function for fields that need conversion - values = [f.fromInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion)] + values = [ + f.fromInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion, strict=False) + ] else: values = obj return _create_row(self.names, values) @@ -1110,19 +1125,19 @@ def __eq__(self, other: object) -> bool: ] _all_atomic_types: dict[str, type[DataType]] = {t.typeName(): t for t in _atomic_types} -_complex_types: list[type[Union[ArrayType, MapType, StructType]]] = [ +_complex_types: list[type[ArrayType | MapType | StructType]] = [ ArrayType, MapType, StructType, ] -_all_complex_types: dict[str, type[Union[ArrayType, MapType, StructType]]] = {v.typeName(): v for v in _complex_types} +_all_complex_types: dict[str, type[ArrayType | MapType | StructType]] = {v.typeName(): v for v in _complex_types} _FIXED_DECIMAL = re.compile(r"decimal\(\s*(\d+)\s*,\s*(-?\d+)\s*\)") _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") -def _create_row(fields: Union["Row", list[str]], values: Union[tuple[Any, ...], list[Any]]) -> "Row": +def _create_row(fields: Union["Row", list[str]], values: tuple[Any, ...] | list[Any]) -> "Row": row = Row(*values) row.__fields__ = fields return row @@ -1188,7 +1203,7 @@ def __new__(cls, *args: str) -> "Row": ... @overload def __new__(cls, **kwargs: Any) -> "Row": ... # noqa: ANN401 - def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # noqa: D102 + def __new__(cls, *args: str | None, **kwargs: Any | None) -> Self: # noqa: D102 if args and kwargs: msg = "Can not use both args and kwargs to create Row" raise ValueError(msg) @@ -1197,9 +1212,8 @@ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # noq row = tuple.__new__(cls, list(kwargs.values())) row.__fields__ = list(kwargs.keys()) return row - else: - # create row class or objects - return tuple.__new__(cls, args) + # create row class or objects + return tuple.__new__(cls, args) def asDict(self, recursive: bool = False) -> dict[str, Any]: """Return as a dict. @@ -1233,7 +1247,7 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]: if recursive: - def conv(obj: Union[Row, list, dict, object]) -> Union[list, dict, object]: + def conv(obj: Row | list | dict | object) -> list | dict | object: if isinstance(obj, Row): return obj.asDict(True) elif isinstance(obj, list): @@ -1243,9 +1257,9 @@ def conv(obj: Union[Row, list, dict, object]) -> Union[list, dict, object]: else: return obj - return dict(zip(self.__fields__, (conv(o) for o in self))) + return dict(zip(self.__fields__, (conv(o) for o in self), strict=False)) else: - return dict(zip(self.__fields__, self)) + return dict(zip(self.__fields__, self, strict=False)) def __contains__(self, item: Any) -> bool: # noqa: D105, ANN401 if hasattr(self, "__fields__"): @@ -1295,7 +1309,7 @@ def __setattr__(self, key: Any, value: Any) -> None: # noqa: D105, ANN401 def __reduce__( self, - ) -> Union[str, tuple[Any, ...]]: + ) -> str | tuple[Any, ...]: """Returns a tuple so Python knows how to pickle Row.""" if hasattr(self, "__fields__"): return (_create_row, (self.__fields__, tuple(self))) @@ -1305,6 +1319,6 @@ def __reduce__( def __repr__(self) -> str: """Printable representation of Row used in Python REPL.""" if hasattr(self, "__fields__"): - return "Row({})".format(", ".join(f"{k}={v!r}" for k, v in zip(self.__fields__, tuple(self)))) + return "Row({})".format(", ".join(f"{k}={v!r}" for k, v in zip(self.__fields__, tuple(self), strict=False))) else: return "".format(", ".join(f"{field!r}" for field in self)) diff --git a/duckdb/experimental/spark/sql/udf.py b/duckdb/experimental/spark/sql/udf.py index 7437ed6b..c22f6be9 100644 --- a/duckdb/experimental/spark/sql/udf.py +++ b/duckdb/experimental/spark/sql/udf.py @@ -1,12 +1,15 @@ -# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ # noqa: D100 -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +# https://sparkbyexamples.com/pyspark/pyspark-udf-user-defined-function/ +from typing import TYPE_CHECKING, Any, Optional, TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable from .types import DataType if TYPE_CHECKING: from .session import SparkSession -DataTypeOrString = Union[DataType, str] +DataTypeOrString = DataType | str UserDefinedFunctionLike = TypeVar("UserDefinedFunctionLike") @@ -17,7 +20,7 @@ def __init__(self, sparkSession: "SparkSession") -> None: # noqa: D107 def register( # noqa: D102 self, name: str, - f: Union[Callable[..., Any], "UserDefinedFunctionLike"], + f: "Callable[..., Any] | UserDefinedFunctionLike", returnType: Optional["DataTypeOrString"] = None, ) -> "UserDefinedFunctionLike": self.sparkSession.conn.create_function(name, f, return_type=returnType) diff --git a/duckdb/functional/__init__.py b/duckdb/functional/__init__.py deleted file mode 100644 index 5114629b..00000000 --- a/duckdb/functional/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""DuckDB function constants and types. DEPRECATED: please use `duckdb.func` instead.""" - -import warnings - -from duckdb.func import ARROW, DEFAULT, NATIVE, SPECIAL, FunctionNullHandling, PythonUDFType - -__all__ = ["ARROW", "DEFAULT", "NATIVE", "SPECIAL", "FunctionNullHandling", "PythonUDFType"] - -warnings.warn( - "`duckdb.functional` is deprecated and will be removed in a future version. Please use `duckdb.func` instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/duckdb/polars_io.py b/duckdb/polars_io.py index a7ed84ff..b5bb6a4e 100644 --- a/duckdb/polars_io.py +++ b/duckdb/polars_io.py @@ -130,6 +130,14 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: ) return str(int_literal) + if node_type == "Float": + # Direct float literals + float_literal = tree[node_type] + assert isinstance(float_literal, (float, int, str)), ( + f"The value of a Float should be a float, int or str but got {type(float_literal)}" + ) + return str(float_literal) + if node_type == "Function": # Handle boolean functions like IsNull, IsNotNull func_tree = tree[node_type] @@ -159,6 +167,25 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: msg = f"Unsupported function type: {func_dict}" raise NotImplementedError(msg) + if node_type == "Cast": + cast_tree = tree[node_type] + assert isinstance(cast_tree, dict), f"A {node_type} should be a dict but got {type(cast_tree)}" + options = cast_tree.get("options") + if options == "Strict": + # Strict casts on literals (e.g. pl.lit(1, dtype=pl.Int8)) are safe to unwrap — + # the value is known at expression creation time. Strict casts on columns + # (e.g. pl.col("a").cast(pl.Int64)) are semantically meaningful and must not be dropped. + cast_expr = cast_tree.get("expr", {}) + if not isinstance(cast_expr, dict) or "Literal" not in cast_expr: + msg = "Strict cast on non-literal expression cannot be pushed down" + raise NotImplementedError(msg) + elif options != "NonStrict": + msg = f"Only NonStrict/Strict casts can be safely unwrapped, got {options!r}" + raise NotImplementedError(msg) + cast_expr = cast_tree["expr"] + assert isinstance(cast_expr, dict), f"A {node_type} should be a dict but got {type(cast_expr)}" + return _pl_tree_to_sql(cast_expr) + if node_type == "Scalar": # Detect format: old style (dtype/value) or new style (direct type key) scalar_tree = tree[node_type] @@ -197,10 +224,12 @@ def _pl_tree_to_sql(tree: _ExpressionTree) -> str: "Int16", "Int32", "Int64", + "Int128", "UInt8", "UInt16", "UInt32", "UInt64", + "UInt128", "Float32", "Float64", "Boolean", @@ -270,10 +299,7 @@ def source_generator( # Try to pushdown filter, if one exists if duck_predicate is not None: relation_final = relation_final.filter(duck_predicate) - if batch_size is None: - results = relation_final.fetch_arrow_reader() - else: - results = relation_final.fetch_arrow_reader(batch_size) + results = relation_final.to_arrow_reader() if batch_size is None else relation_final.to_arrow_reader(batch_size) for record_batch in iter(results.read_next_batch, None): if predicate is not None and duck_predicate is None: diff --git a/duckdb/query_graph/__init__.py b/duckdb/query_graph/__init__.py new file mode 100644 index 00000000..340dd8d3 --- /dev/null +++ b/duckdb/query_graph/__init__.py @@ -0,0 +1,3 @@ +from .__main__ import ProfilingInfo # noqa: D104 + +__all__ = ["ProfilingInfo"] diff --git a/duckdb/query_graph/__main__.py b/duckdb/query_graph/__main__.py index d4851694..dd4ff959 100644 --- a/duckdb/query_graph/__main__.py +++ b/duckdb/query_graph/__main__.py @@ -5,80 +5,275 @@ from functools import reduce from pathlib import Path +from duckdb import DuckDBPyConnection + qgraph_css = """ -.styled-table { - border-collapse: collapse; - margin: 25px 0; - font-size: 0.9em; - font-family: sans-serif; - min-width: 400px; - box-shadow: 0 0 20px rgba(0, 0, 0, 0.15); +:root { + --text-primary-color: #0d0d0d; + --text-secondary-color: #444; + --doc-codebox-border-color: #e6e6e6; + --doc-codebox-background-color: #f7f7f7; + --doc-scrollbar-bg: #e6e6e6; + --doc-scrollbar-slider: #ccc; + --duckdb-accent: #009982; + --duckdb-accent-light: #00b89a; + --card-bg: #fff; + --border-radius: 8px; + --shadow: 0 4px 14px rgba(0,0,0,0.05); } -.styled-table thead tr { - background-color: #009879; - color: #ffffff; - text-align: left; + +html, body { + margin: 0; + padding: 0; + font-family: Inter, system-ui, -apple-system, "Segoe UI", Roboto, sans-serif; + color: var(--text-primary-color); + background: #fafafa; + line-height: 1.55; } -.styled-table th, -.styled-table td { - padding: 12px 15px; + +.container { + max-width: 1000px; + margin: 40px auto; + padding: 0 20px; } -.styled-table tbody tr { - border-bottom: 1px solid #dddddd; + +header { + display: flex; + align-items: center; + gap: 10px; + margin-bottom: 5px; } -.styled-table tbody tr:nth-of-type(even) { - background-color: #f3f3f3; +header img { + width: 100px; + height: 100px; } -.styled-table tbody tr:last-of-type { - border-bottom: 2px solid #009879; +header h1 { + font-size: 1.5rem; + font-weight: 600; + margin: 0; + color: var(--text-primary-color); } -.node-body { - font-size:15px; +/* === Table Styling (DuckDB documentation style, flat header) === */ +table { + border-collapse: collapse; + width: 100%; + margin-bottom: 20px; + text-align: left; + font-variant-numeric: tabular-nums; + border: 1px solid var(--doc-codebox-border-color); + border-radius: var(--border-radius); + overflow: hidden; + box-shadow: var(--shadow); + background: var(--card-bg); +} + +thead { + background-color: var(--duckdb-accent); + color: white; +} + +th, td { + padding: 10px 12px; + font-size: 14px; + vertical-align: top; +} + +th { + font-weight: 700; +} + +tbody tr { + border-bottom: 1px solid var(--doc-codebox-border-color); +} + +tbody tr:last-child td { + border-bottom: none; +} + +tbody tr:hover { + background: var(--doc-codebox-border-color); +} + +tbody tr.phase-details-row { + border-bottom: none; } + +tbody tr.phase-details-row:hover { + background: transparent; +} + +tbody tr.phase-details-row details summary { + font-size: 12px; + padding: 4px 0; +} + +tbody tr.phase-details-row details[open] summary { + margin-bottom: 4px; +} + +/* === Chart/Card Section === */ +.chart { + padding: 20px; + border: 1px solid var(--doc-codebox-border-color); + border-radius: var(--border-radius); + background: var(--card-bg); + box-shadow: var(--shadow); + overflow: visible; +} + +/* === Tree Layout Styling === */ +.tf-tree { + overflow-x: visible; + overflow-y: visible; + padding-top: 20px; +} + .tf-nc { - position: relative; - width: 180px; - text-align: center; - background-color: #fff100; + background: var(--card-bg); + border: 1px solid var(--doc-codebox-border-color); + border-radius: var(--border-radius); + padding: 6px; + display: inline-block; +} + +.node-body { + font-size: 13px; + text-align: left; + padding: 10px; + white-space: nowrap; } -.custom-tooltip { - position: relative; + +.node-body p { + margin: 2px 0; +} + +.node-details { + white-space: nowrap; + overflow: visible; display: inline-block; } -.tooltip-text { - visibility: hidden; - background-color: #333; - color: #fff; +/* === Metric Boxes === */ +.chart .metrics-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(150px, 1fr)); + gap: 16px; + margin-bottom: 20px; +} + +.chart .metric-box { + background: var(--card-bg); + border: 1px solid var(--doc-codebox-border-color); + border-radius: var(--border-radius); + box-shadow: var(--shadow); + padding: 12px 16px; text-align: center; - padding: 0px; - border-radius: 1px; + transition: transform 0.2s ease, box-shadow 0.2s ease; +} + +.chart .metric-box:hover { + transform: translateY(-2px); + box-shadow: 0 6px 18px rgba(0, 0, 0, 0.08); +} + +.chart .metric-title { + font-size: 13px; + color: var(--text-secondary-color); + margin-bottom: 4px; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +.chart .metric-value { + font-size: 18px; + font-weight: 600; + color: var(--duckdb-accent); +} - /* Positioning */ - position: absolute; - z-index: 1; - bottom: 100%; - left: 50%; - transform: translateX(-50%); - margin-bottom: 8px; - /* Tooltip Arrow */ - width: 400px; +/* === SQL Query Block === */ +.chart.sql-block { + background: var(--doc-codebox-background-color); + border: 1px solid var(--doc-codebox-border-color); + border-radius: var(--border-radius); + box-shadow: var(--shadow); + padding: 16px; + overflow-x: auto; + margin-top: 20px; +} + +.chart.sql-block pre { + margin: 0; + font-family: "JetBrains Mono", "Fira Code", Consolas, monospace; + font-size: 13.5px; + line-height: 1.5; + color: var(--text-primary-color); + white-space: pre; +} + +.chart.sql-block code { + color: var(--duckdb-accent); + font-weight: 500; +} + + +/* === Links, Typography, and Consistency === */ +a { + color: var(--duckdb-accent); + text-decoration: underline; + transition: color 0.3s; +} + +a:hover { + color: black; +} + +strong { + font-weight: 600; } -.custom-tooltip:hover .tooltip-text { - visibility: visible; +/* === Dark Mode Support === */ +@media (prefers-color-scheme: dark) { + :root { + --text-primary-color: #e6e6e6; + --text-secondary-color: #b3b3b3; + --doc-codebox-border-color: #2a2a2a; + --doc-codebox-background-color: #1e1e1e; + --card-bg: #111; + } + body { + background: #0b0b0b; + } + thead { + background-color: var(--duckdb-accent); + } + tbody tr:hover { + background: #222; + } + + /* Fix tree node text visibility in dark mode */ + .tf-nc .node-body, + .tf-nc .node-body p, + .tf-nc .node-details { + color: #1a1a1a !important; + } + + /* Fix metric title visibility in dark mode */ + .chart .metric-title { + color: #b3b3b3; + } } -""" +""" # noqa: W293 class NodeTiming: # noqa: D101 - def __init__(self, phase: str, time: float) -> None: # noqa: D107 + def __init__(self, phase: str, time: float, depth: int) -> None: # noqa: D107 self.phase = phase self.time = time + self.depth = depth # percentage is determined later. self.percentage = 0 @@ -88,7 +283,7 @@ def calculate_percentage(self, total_time: float) -> None: # noqa: D102 def combine_timing(self, r: "NodeTiming") -> "NodeTiming": # noqa: D102 # TODO: can only add timings for same-phase nodes # noqa: TD002, TD003 total_time = self.time + r.time - return NodeTiming(self.phase, total_time) + return NodeTiming(self.phase, total_time, self.depth) class AllTimings: # noqa: D101 @@ -124,200 +319,319 @@ def open_utf8(fpath: str, flags: str) -> object: # noqa: D103 return Path(fpath).open(mode=flags, encoding="utf8") -def get_child_timings(top_node: object, query_timings: object) -> str: # noqa: D103 - node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"])) - query_timings.add_node_timing(node_timing) - for child in top_node["children"]: - get_child_timings(child, query_timings) - - -def get_pink_shade_hex(fraction: float) -> str: # noqa: D103 - fraction = max(0, min(1, fraction)) - - # Define the RGB values for very light pink (almost white) and dark pink - light_pink = (255, 250, 250) # Very light pink - dark_pink = (255, 20, 147) # Dark pink - - # Calculate the RGB values for the given fraction - r = int(light_pink[0] + (dark_pink[0] - light_pink[0]) * fraction) - g = int(light_pink[1] + (dark_pink[1] - light_pink[1]) * fraction) - b = int(light_pink[2] + (dark_pink[2] - light_pink[2]) * fraction) - - # Return as hexadecimal color code - return f"#{r:02x}{g:02x}{b:02x}" - - -def get_node_body(name: str, result: str, cpu_time: float, card: int, est: int, width: int, extra_info: str) -> str: # noqa: D103 - node_style = f"background-color: {get_pink_shade_hex(float(result) / cpu_time)};" - - body = f'' - body += '
' - new_name = "BRIDGE" if (name == "INVALID") else name.replace("_", " ") - formatted_num = f"{float(result):.4f}" - body += f"

{new_name}

time: {formatted_num} seconds

" - body += f' {extra_info} ' - if width > 0: - body += f"

cardinality: {card}

" - body += f"

estimate: {est}

" - body += f"

width: {width} bytes

" - # TODO: Expand on timing. Usually available from a detailed profiling # noqa: TD002, TD003 - body += "
" - body += "
" - return body - - -def generate_tree_recursive(json_graph: object, cpu_time: float) -> str: # noqa: D103 - node_prefix_html = "
  • " - node_suffix_html = "
  • " - - extra_info = "" - estimate = 0 - for key in json_graph["extra_info"]: - value = json_graph["extra_info"][key] - if key == "Estimated Cardinality": - estimate = int(value) +class ProfilingInfo: # noqa: D101 + def __init__(self, conn: DuckDBPyConnection | None = None, from_file: str | None = None) -> None: # noqa: D107 + self.conn = conn + self.from_file = from_file + + def to_json(self) -> str: # noqa: D102 + if self.from_file is not None: + with open_utf8(self.from_file, "r") as f: + return f.read() + + return self.conn.get_profiling_information(format="json") + + def to_pydict(self) -> dict: # noqa: D102 + return json.loads(self.to_json()) + + def to_html(self, output_file: str = "profile.html") -> str: # noqa: D102 + profiling_info_text = self.to_json() + html_output = self._translate_json_to_html(input_text=profiling_info_text, output_file=output_file) + return html_output + + def _get_child_timings(self, top_node: object, query_timings: object, depth: int = 0) -> str: + node_timing = NodeTiming(top_node["operator_type"], float(top_node["operator_timing"]), depth) + query_timings.add_node_timing(node_timing) + for child in top_node["children"]: + self._get_child_timings(child, query_timings, depth + 1) + + @staticmethod + def _get_f7fff0_shade_hex(fraction: float) -> str: + """Returns a shade between very light (#f7fff0) and a slightly darker green-yellow, + depending on the fraction (0..1). + """ # noqa: D205 + fraction = max(0, min(1, fraction)) + + # Define RGB for light and dark end + light_color = (247, 255, 240) # #f7fff0 + dark_color = (200, 255, 150) # slightly darker/more saturated green-yellow + + # Interpolate RGB channels + r = int(light_color[0] + (dark_color[0] - light_color[0]) * fraction) + g = int(light_color[1] + (dark_color[1] - light_color[1]) * fraction) + b = int(light_color[2] + (dark_color[2] - light_color[2]) * fraction) + + return f"#{r:02x}{g:02x}{b:02x}" + + def _get_node_body( + self, name: str, result: str, cpu_time: float, card: int, est: int, result_size: int, extra_info: str + ) -> str: + """Generate the HTML body for a single node in the tree.""" + node_style = f"background-color: {self._get_f7fff0_shade_hex(float(result) / cpu_time)};" + new_name = "BRIDGE" if (name == "INVALID") else name.replace("_", " ") + formatted_num = f"{float(result):.4f}" + + body = f'' + body += '
    ' + body += f"

    {new_name}

    " + if result_size > 0: + body += f"

    time: {formatted_num}s

    " + body += f"

    cardinality: {card}

    " + body += f"

    estimate: {est}

    " + body += f"

    result size: {result_size} bytes

    " + body += "
    " + body += "Extra info" + body += '
    ' + body += f"

    {extra_info}

    " + # TODO: Expand on timing. Usually available from a detailed profiling # noqa: TD002, TD003 + body += "
    " + body += "
    " + body += "
    " + body += "
    " + return body + + def _generate_tree_recursive(self, json_graph: object, cpu_time: float) -> str: + node_prefix_html = "
  • " + node_suffix_html = "
  • " + + extra_info = "" + estimate = 0 + for key in json_graph["extra_info"]: + value = json_graph["extra_info"][key] + if key == "Estimated Cardinality": + estimate = int(value) + else: + extra_info += f"{key}: {value}
    " + + # get rid of some typically long names + extra_info = re.sub(r"__internal_\s*", "__", extra_info) + extra_info = re.sub(r"compress_integral\s*", "compress", extra_info) + + node_body = self._get_node_body( + json_graph["operator_type"], + json_graph["operator_timing"], + cpu_time, + json_graph["operator_cardinality"], + estimate, + json_graph["result_set_size"], + re.sub(r",\s*", ", ", extra_info), + ) + + children_html = "" + if len(json_graph["children"]) >= 1: + children_html += "
      " + for child in json_graph["children"]: + children_html += self._generate_tree_recursive(child, cpu_time) + children_html += "
    " + return node_prefix_html + node_body + children_html + node_suffix_html + + # For generating the table in the top left with expandable phases + def _generate_timing_html(self, graph_json: object, query_timings: object) -> object: + """Generates timing HTML table with expandable phases.""" + json_graph = json.loads(graph_json) + self._gather_timing_information(json_graph, query_timings) + table_head = """ + + + + + + + + """ + + table_body = "" + table_end = "
    PhaseTime (s)Percentage
    " + + execution_time = query_timings.get_sum_of_all_timings() + + all_phases = query_timings.get_phases() + query_timings.add_node_timing(NodeTiming("Execution Time (CPU)", execution_time, None)) + all_phases = ["Execution Time (CPU)", *all_phases] + + for phase in all_phases: + summarized_phase = query_timings.get_summary_phase_timings(phase) + summarized_phase.calculate_percentage(execution_time) + phase_column = f"{phase}" if phase == "Execution Time (CPU)" else phase + + # Main phase row + table_body += f""" + + {phase_column} + {round(summarized_phase.time, 8)} + {str(summarized_phase.percentage * 100)[:6]}% + + """ + + # Add expandable details for individual nodes (except for Execution Time) + if phase != "Execution Time (CPU)": + phase_timings = query_timings.get_phase_timings(phase) + if len(phase_timings) > 1: # Only show details if there are multiple nodes + table_body += f""" + + +
    + + Show {len(phase_timings)} nodes + + + + """ + for node_timing in sorted(phase_timings, key=lambda x: x.time, reverse=True): + node_timing.calculate_percentage(execution_time) + depth_indent = " " * (node_timing.depth * 4) + table_body += f""" + + + + + + """ # noqa: E501 + table_body += """ + +
    {depth_indent}↳ Depth {node_timing.depth}{round(node_timing.time, 8)}{str(node_timing.percentage * 100)[:6]}%
    +
    + + + """ + + table_body += table_end + return table_head + table_body + + @staticmethod + def _generate_metric_grid_html(graph_json: str) -> str: + json_graph = json.loads(graph_json) + metrics = { + "Execution Time (s)": f"{float(json_graph.get('latency', 'N/A')):.4f}", + "Total GB Read": f"{float(json_graph.get('total_bytes_read', 'N/A')) / (1024**3):.4f}" + if json_graph.get("total_bytes_read", "N/A") != "N/A" + else "N/A", + "Total GB Written": f"{float(json_graph.get('total_bytes_written', 'N/A')) / (1024**3):.4f}" + if json_graph.get("total_bytes_written", "N/A") != "N/A" + else "N/A", + "Peak Memory (GB)": f"{float(json_graph.get('system_peak_buffer_memory', 'N/A')) / (1024**3):.4f}" + if json_graph.get("system_peak_buffer_memory", "N/A") != "N/A" + else "N/A", + "Rows Scanned": f"{json_graph.get('cumulative_rows_scanned', 'N/A'):,}" + if json_graph.get("cumulative_rows_scanned", "N/A") != "N/A" + else "N/A", + } + metric_grid_html = """
    """ + for key in metrics: + metric_grid_html += f""" +
    +
    {key}
    +
    {metrics[key]}
    +
    + """ + metric_grid_html += "
    " + return metric_grid_html + + @staticmethod + def _generate_sql_query_html(graph_json: str) -> str: + json_graph = json.loads(graph_json) + sql_query = json_graph.get("query_name", "N/A") + sql_html = f""" +
    SQL Query +
    +
    
    +    {sql_query}
    +            
    +
    +

    + """ + return sql_html + + def _generate_tree_html(self, graph_json: object) -> str: + json_graph = json.loads(graph_json) + cpu_time = float(json_graph["cpu_time"]) + tree_prefix = '
    \n
      ' + tree_suffix = "
    " + # first level of json is general overview + # TODO: make sure json output first level always has only 1 level # noqa: TD002, TD003 + tree_body = self._generate_tree_recursive(json_graph["children"][0], cpu_time) + return tree_prefix + tree_body + tree_suffix + + def _generate_ipython(self, json_input: str) -> str: + from IPython.core.display import HTML + + html_output = self._generate_html(json_input, False) + + return HTML( + ( + '\n ${CSS}\n ${LIBRARIES}\n
    \n ${CHART_SCRIPT}\n ' + ) + .replace("${CSS}", html_output["css"]) + .replace("${CHART_SCRIPT}", html_output["chart_script"]) + .replace("${LIBRARIES}", html_output["libraries"]) + ) + + @staticmethod + def _generate_style_html(graph_json: str, include_meta_info: bool) -> None: # noqa: FBT001 + treeflex_css = '\n' + libraries = '\n' # noqa: E501 + return {"treeflex_css": treeflex_css, "duckdb_css": qgraph_css, "libraries": libraries, "chart_script": ""} + + def _gather_timing_information(self, json: str, query_timings: object) -> None: + # add up all of the times + # measure each time as a percentage of the total time. + # then you can return a list of [phase, time, percentage] + self._get_child_timings(json["children"][0], query_timings) + + def _translate_json_to_html( + self, input_file: str | None = None, input_text: str | None = None, output_file: str = "profile.html" + ) -> None: + query_timings = AllTimings() + if input_text is not None: + text = input_text + elif input_file is not None: + with open_utf8(input_file, "r") as f: + text = f.read() else: - extra_info += f"{key}: {value}
    " - cardinality = json_graph["operator_cardinality"] - width = int(json_graph["result_set_size"] / max(1, cardinality)) - - # get rid of some typically long names - extra_info = re.sub(r"__internal_\s*", "__", extra_info) - extra_info = re.sub(r"compress_integral\s*", "compress", extra_info) - - node_body = get_node_body( - json_graph["operator_type"], - json_graph["operator_timing"], - cpu_time, - cardinality, - estimate, - width, - re.sub(r",\s*", ", ", extra_info), - ) - - children_html = "" - if len(json_graph["children"]) >= 1: - children_html += "
      " - for child in json_graph["children"]: - children_html += generate_tree_recursive(child, cpu_time) - children_html += "
    " - return node_prefix_html + node_body + children_html + node_suffix_html - - -# For generating the table in the top left. -def generate_timing_html(graph_json: object, query_timings: object) -> object: # noqa: D103 - json_graph = json.loads(graph_json) - gather_timing_information(json_graph, query_timings) - total_time = float(json_graph.get("operator_timing") or json_graph.get("latency")) - table_head = """ - - - - - - - - """ - - table_body = "" - table_end = "
    PhaseTimePercentage
    " - - execution_time = query_timings.get_sum_of_all_timings() - - all_phases = query_timings.get_phases() - query_timings.add_node_timing(NodeTiming("TOTAL TIME", total_time)) - query_timings.add_node_timing(NodeTiming("Execution Time", execution_time)) - all_phases = ["TOTAL TIME", "Execution Time", *all_phases] - for phase in all_phases: - summarized_phase = query_timings.get_summary_phase_timings(phase) - summarized_phase.calculate_percentage(total_time) - phase_column = f"{phase}" if phase == "TOTAL TIME" or phase == "Execution Time" else phase - table_body += f""" - - {phase_column} - {summarized_phase.time} - {str(summarized_phase.percentage * 100)[:6]}% - -""" - table_body += table_end - return table_head + table_body - - -def generate_tree_html(graph_json: object) -> str: # noqa: D103 - json_graph = json.loads(graph_json) - cpu_time = float(json_graph["cpu_time"]) - tree_prefix = '
    \n
      ' - tree_suffix = "
    " - # first level of json is general overview - # TODO: make sure json output first level always has only 1 level # noqa: TD002, TD003 - tree_body = generate_tree_recursive(json_graph["children"][0], cpu_time) - return tree_prefix + tree_body + tree_suffix - - -def generate_ipython(json_input: str) -> str: # noqa: D103 - from IPython.core.display import HTML - - html_output = generate_html(json_input, False) # noqa: F821 - - return HTML( - ('\n ${CSS}\n ${LIBRARIES}\n
    \n ${CHART_SCRIPT}\n ') - .replace("${CSS}", html_output["css"]) - .replace("${CHART_SCRIPT}", html_output["chart_script"]) - .replace("${LIBRARIES}", html_output["libraries"]) - ) - - -def generate_style_html(graph_json: str, include_meta_info: bool) -> None: # noqa: D103, FBT001 - treeflex_css = '\n' - css = "\n" - return {"treeflex_css": treeflex_css, "duckdb_css": css, "libraries": "", "chart_script": ""} - - -def gather_timing_information(json: str, query_timings: object) -> None: # noqa: D103 - # add up all of the times - # measure each time as a percentage of the total time. - # then you can return a list of [phase, time, percentage] - get_child_timings(json["children"][0], query_timings) - - -def translate_json_to_html(input_file: str, output_file: str) -> None: # noqa: D103 - query_timings = AllTimings() - with open_utf8(input_file, "r") as f: - text = f.read() - - html_output = generate_style_html(text, True) - timing_table = generate_timing_html(text, query_timings) - tree_output = generate_tree_html(text) - - # finally create and write the html - with open_utf8(output_file, "w+") as f: - html = """ - - - - - Query Profile Graph for Query - ${TREEFLEX_CSS} - - - -
    -
    - ${TIMING_TABLE} -
    - ${TREE} - - -""" - html = html.replace("${TREEFLEX_CSS}", html_output["treeflex_css"]) - html = html.replace("${DUCKDB_CSS}", html_output["duckdb_css"]) - html = html.replace("${TIMING_TABLE}", timing_table) - html = html.replace("${TREE}", tree_output) - f.write(html) + print("please provide either input file or input text") + exit(1) + html_output = self._generate_style_html(text, True) + highlight_metric_grid = self._generate_metric_grid_html(text) + timing_table = self._generate_timing_html(text, query_timings) + tree_output = self._generate_tree_html(text) + sql_query_html = self._generate_sql_query_html(text) + # finally create and write the html + with open_utf8(output_file, "w+") as f: + html = """ + + + + + Query Profile Graph for Query + ${TREEFLEX_CSS} + + + +
    +
    + DuckDB Logo +

    Query Profile Graph

    +
    +
    + ${METRIC_GRID} +
    +
    + ${SQL_QUERY} + ${TIMING_TABLE} +
    + ${TREE} + + + """ # noqa: E501 + html = html.replace("${TREEFLEX_CSS}", html_output["treeflex_css"]) + html = html.replace("${DUCKDB_CSS}", html_output["duckdb_css"]) + html = html.replace("${METRIC_GRID}", highlight_metric_grid) + html = html.replace("${SQL_QUERY}", sql_query_html) + html = html.replace("${TIMING_TABLE}", timing_table) + html = html.replace("${TREE}", tree_output) + f.write(html) def main() -> None: # noqa: D103 @@ -326,7 +640,7 @@ def main() -> None: # noqa: D103 description="""Given a json profile output, generate a html file showing the query graph and timings of operators""", ) - parser.add_argument("profile_input", help="profile input in json") + parser.add_argument("--profile_input", help="profile input in json") parser.add_argument("--out", required=False, default=False) parser.add_argument("--open", required=False, action="store_true", default=True) args = parser.parse_args() @@ -347,8 +661,8 @@ def main() -> None: # noqa: D103 exit(1) open_output = args.open - - translate_json_to_html(input, output) + profiling_info = ProfilingInfo(from_file=input) + profiling_info.to_html(output_file=output) if open_output: webbrowser.open(f"file://{Path(output).resolve()}", new=2) diff --git a/duckdb/sqltypes/__init__.py b/duckdb/sqltypes/__init__.py index 38917ce3..4a742e74 100644 --- a/duckdb/sqltypes/__init__.py +++ b/duckdb/sqltypes/__init__.py @@ -14,6 +14,7 @@ SMALLINT, SQLNULL, TIME, + TIME_NS, TIME_TZ, TIMESTAMP, TIMESTAMP_MS, @@ -28,6 +29,7 @@ UTINYINT, UUID, VARCHAR, + VARIANT, DuckDBPyType, ) @@ -50,6 +52,7 @@ "TIMESTAMP_NS", "TIMESTAMP_S", "TIMESTAMP_TZ", + "TIME_NS", "TIME_TZ", "TINYINT", "UBIGINT", @@ -59,5 +62,6 @@ "UTINYINT", "UUID", "VARCHAR", + "VARIANT", "DuckDBPyType", ] diff --git a/duckdb/typing/__init__.py b/duckdb/typing/__init__.py deleted file mode 100644 index 4c29047b..00000000 --- a/duckdb/typing/__init__.py +++ /dev/null @@ -1,71 +0,0 @@ -"""DuckDB's SQL types. DEPRECATED. Please use `duckdb.sqltypes` instead.""" - -import warnings - -from duckdb.sqltypes import ( - BIGINT, - BIT, - BLOB, - BOOLEAN, - DATE, - DOUBLE, - FLOAT, - HUGEINT, - INTEGER, - INTERVAL, - SMALLINT, - SQLNULL, - TIME, - TIME_TZ, - TIMESTAMP, - TIMESTAMP_MS, - TIMESTAMP_NS, - TIMESTAMP_S, - TIMESTAMP_TZ, - TINYINT, - UBIGINT, - UHUGEINT, - UINTEGER, - USMALLINT, - UTINYINT, - UUID, - VARCHAR, - DuckDBPyType, -) - -__all__ = [ - "BIGINT", - "BIT", - "BLOB", - "BOOLEAN", - "DATE", - "DOUBLE", - "FLOAT", - "HUGEINT", - "INTEGER", - "INTERVAL", - "SMALLINT", - "SQLNULL", - "TIME", - "TIMESTAMP", - "TIMESTAMP_MS", - "TIMESTAMP_NS", - "TIMESTAMP_S", - "TIMESTAMP_TZ", - "TIME_TZ", - "TINYINT", - "UBIGINT", - "UHUGEINT", - "UINTEGER", - "USMALLINT", - "UTINYINT", - "UUID", - "VARCHAR", - "DuckDBPyType", -] - -warnings.warn( - "`duckdb.typing` is deprecated and will be removed in a future version. Please use `duckdb.sqltypes` instead.", - DeprecationWarning, - stacklevel=2, -) diff --git a/duckdb_packaging/_versioning.py b/duckdb_packaging/_versioning.py index 0a5eb66b..0ec984f3 100644 --- a/duckdb_packaging/_versioning.py +++ b/duckdb_packaging/_versioning.py @@ -9,7 +9,6 @@ import pathlib import re import subprocess -from typing import Optional VERSION_RE = re.compile( r"^(?P[0-9]+)\.(?P[0-9]+)\.(?P[0-9]+)(?:rc(?P[0-9]+)|\.post(?P[0-9]+))?$" @@ -100,7 +99,7 @@ def pep440_to_git_tag(version: str) -> str: return f"v{version}" -def get_current_version() -> Optional[str]: +def get_current_version() -> str | None: """Get the current version from git tags. Returns: @@ -115,7 +114,7 @@ def get_current_version() -> Optional[str]: return None -def create_git_tag(version: str, message: Optional[str] = None, repo_path: Optional[pathlib.Path] = None) -> None: +def create_git_tag(version: str, message: str | None = None, repo_path: pathlib.Path | None = None) -> None: """Create a git tag for the given version. Args: @@ -148,10 +147,10 @@ def strip_post_from_version(version: str) -> str: def get_git_describe( - repo_path: Optional[pathlib.Path] = None, + repo_path: pathlib.Path | None = None, since_major: bool = False, # noqa: FBT001 since_minor: bool = False, # noqa: FBT001 -) -> Optional[str]: +) -> str | None: """Get git describe output for version determination. Returns: diff --git a/duckdb_packaging/build_backend.py b/duckdb_packaging/build_backend.py index 799a43c9..114b81f3 100644 --- a/duckdb_packaging/build_backend.py +++ b/duckdb_packaging/build_backend.py @@ -16,7 +16,6 @@ import subprocess import sys from pathlib import Path -from typing import Optional, Union from scikit_build_core.build import ( build_editable, @@ -132,7 +131,7 @@ def _read_duckdb_long_version() -> str: return _version_file_path().read_text(encoding="utf-8").strip() -def _skbuild_config_add(key: str, value: Union[list, str], config_settings: dict[str, Union[list[str], str]]) -> None: +def _skbuild_config_add(key: str, value: list | str, config_settings: dict[str, list[str] | str]) -> None: """Add or modify a configuration setting for scikit-build-core. This function handles adding values to scikit-build-core configuration settings, @@ -179,7 +178,7 @@ def _skbuild_config_add(key: str, value: Union[list, str], config_settings: dict raise RuntimeError(msg) -def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[list[str], str]]] = None) -> str: +def build_sdist(sdist_directory: str, config_settings: dict[str, list[str] | str] | None = None) -> str: """Build a source distribution using the DuckDB submodule. This function extracts the DuckDB version from either the git submodule and saves it @@ -210,8 +209,8 @@ def build_sdist(sdist_directory: str, config_settings: Optional[dict[str, Union[ def build_wheel( wheel_directory: str, - config_settings: Optional[dict[str, Union[list[str], str]]] = None, - metadata_directory: Optional[str] = None, + config_settings: dict[str, list[str] | str] | None = None, + metadata_directory: str | None = None, ) -> str: """Build a wheel from either git submodule or extracted sdist sources. diff --git a/duckdb_packaging/pypi_cleanup.py b/duckdb_packaging/pypi_cleanup.py index 094df741..53990cdb 100644 --- a/duckdb_packaging/pypi_cleanup.py +++ b/duckdb_packaging/pypi_cleanup.py @@ -19,7 +19,6 @@ from collections.abc import Generator from enum import Enum from html.parser import HTMLParser -from typing import Optional, Union from urllib.parse import urlparse import pyotp @@ -173,7 +172,7 @@ def session_with_retries() -> Generator[Session, None, None]: yield session -def load_credentials() -> tuple[Optional[str], Optional[str]]: +def load_credentials() -> tuple[str | None, str | None]: """Load credentials from environment variables.""" password = os.getenv("PYPI_CLEANUP_PASSWORD") otp = os.getenv("PYPI_CLEANUP_OTP") @@ -200,7 +199,7 @@ def __init__(self, target: str) -> None: # noqa: D107 self.csrf = None # Result value from all forms on page self._in_form = False # Currently parsing a form with an action we're interested in - def handle_starttag(self, tag: str, attrs: list[tuple[str, Union[str, None]]]) -> None: # noqa: D102 + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: # noqa: D102 if not self.csrf: if tag == "form": attrs = dict(attrs) @@ -232,9 +231,9 @@ def __init__( # noqa: D107 index_url: str, mode: CleanMode, max_dev_releases: int = _DEFAULT_MAX_NIGHTLIES, - username: Optional[str] = None, - password: Optional[str] = None, - otp: Optional[str] = None, + username: str | None = None, + password: str | None = None, + otp: str | None = None, ) -> None: parsed_url = urlparse(index_url) self._index_url = parsed_url.geturl().rstrip("/") diff --git a/external/duckdb b/external/duckdb index 2e305aac..8c56048c 160000 --- a/external/duckdb +++ b/external/duckdb @@ -1 +1 @@ -Subproject commit 2e305aac809ef22e818024be1e86a1d0ee0d2863 +Subproject commit 8c56048c85fdece49952fb749d1964bee522420a diff --git a/pyproject.toml b/pyproject.toml index 987f2c07..12c853dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ dynamic = ["version"] description = "DuckDB in-process database" readme = "README.md" keywords = ["DuckDB", "Database", "SQL", "OLAP"] -requires-python = ">=3.9.0" +requires-python = ">=3.10.0" classifiers = [ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", @@ -25,7 +25,6 @@ classifiers = [ "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -176,12 +175,12 @@ exclude = [ # # This section has dependency groups for testing and development. Tread carefully, the current setup makes sure that # test dependencies can be installed on as many platforms we build wheel for. Especially picky are: -# - tensorflow: we can only run tests on cp39-cp311, for osx there is no tensorflow-cpu, for windows we need +# - tensorflow: we can only run tests on cp310-cp311, for osx there is no tensorflow-cpu, for windows we need # tensorflow-cpu-aws and there is no distribution availalbe for Linux aarch64. # - torch: since we can't use gpu acceleration, we need to rely on torch-cpu, which isn't available on pypi. We use # `tool.uv.index` and `tool.uv.sources` to make sure the official pytorch index is used. Even there, we don't # have a wheel available for x86_64 OSX + cp313. -# - numpy: tensorflow doesn't play nice with numpy>2 so for every platform that can run tensorflow (cp39-cp311) we use +# - numpy: tensorflow doesn't play nice with numpy>2 so for every platform that can run tensorflow (cp310-cp311) we use # numpy<2. numpy<2 has no wheels for cp31[2|3], meaning an sdist will be used. However, on Windows amd64 + # cp313 this results in a segfault / access violation. To get around this, we install numpy>=2 on all >=cp312 # platforms. Then for windows arm64, for which there is no tensorflow, we only allow numpy>=2.3 because that @@ -195,20 +194,20 @@ default-groups = ["dev"] # build wheels for. # See https://docs.astral.sh/uv/concepts/resolution/#universal-resolution environments = [ # no need to resolve packages beyond these platforms with uv... - "python_version >= '3.9' and sys_platform == 'darwin' and platform_machine == 'arm64'", - "python_version >= '3.9' and sys_platform == 'darwin' and platform_machine == 'x86_64'", - "python_version >= '3.9' and sys_platform == 'win32' and platform_machine == 'AMD64'", + "python_version >= '3.10' and sys_platform == 'darwin' and platform_machine == 'arm64'", + "python_version >= '3.10' and sys_platform == 'darwin' and platform_machine == 'x86_64'", + "python_version >= '3.10' and sys_platform == 'win32' and platform_machine == 'AMD64'", "python_version >= '3.11' and sys_platform == 'win32' and platform_machine == 'ARM64'", - "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'x86_64'", - "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'aarch64'", + "python_version >= '3.10' and sys_platform == 'linux' and platform_machine == 'x86_64'", + "python_version >= '3.10' and sys_platform == 'linux' and platform_machine == 'aarch64'", ] required-environments = [ # ... but do always resolve for all of them - "python_version >= '3.9' and sys_platform == 'darwin' and platform_machine == 'arm64'", - "python_version >= '3.9' and sys_platform == 'darwin' and platform_machine == 'x86_64'", - "python_version >= '3.9' and sys_platform == 'win32' and platform_machine == 'AMD64'", + "python_version >= '3.10' and sys_platform == 'darwin' and platform_machine == 'arm64'", + "python_version >= '3.10' and sys_platform == 'darwin' and platform_machine == 'x86_64'", + "python_version >= '3.10' and sys_platform == 'win32' and platform_machine == 'AMD64'", "python_version >= '3.11' and sys_platform == 'win32' and platform_machine == 'ARM64'", - "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'x86_64'", - "python_version >= '3.9' and sys_platform == 'linux' and platform_machine == 'aarch64'", + "python_version >= '3.10' and sys_platform == 'linux' and platform_machine == 'x86_64'", + "python_version >= '3.10' and sys_platform == 'linux' and platform_machine == 'aarch64'", ] # We just need pytorch for tests, wihtout GPU acceleration. PyPI doesn't host a cpu-only version for Linux, so we have @@ -234,7 +233,8 @@ stubdeps = [ # dependencies used for typehints in the stubs "typing-extensions", ] test = [ # dependencies used for running tests - "adbc-driver-manager; sys_platform != 'win32' or platform_machine != 'ARM64'", + "adbc-driver-manager>=1.10.0; python_version >= '3.10' and (sys_platform != 'win32' or platform_machine != 'ARM64')", + "adbc-driver-manager>=1.7.0; python_version < '3.10' and (sys_platform != 'win32' or platform_machine != 'ARM64')", "pytest", "pytest-reraise", "pytest-timeout", @@ -252,8 +252,10 @@ test = [ # dependencies used for running tests "requests", "urllib3", "fsspec>=2022.11.0; sys_platform != 'win32' or platform_machine != 'ARM64'", - "pandas>=2.0.0", - "pyarrow>=18.0.0; sys_platform != 'win32' or platform_machine != 'ARM64'", + "pandas>=3.0.0; python_version > '3.10'", + "pandas<3.0.0; python_version < '3.11'", + "pyarrow>=23.0.0; python_version >= '3.10' and (sys_platform != 'win32' or platform_machine != 'ARM64')", + "pyarrow>=18.0.0; python_version < '3.10' and (sys_platform != 'win32' or platform_machine != 'ARM64')", "torch>=2.2.2; python_version < '3.14' and ( sys_platform != 'darwin' or platform_machine != 'x86_64' or python_version < '3.13' ) and ( sys_platform != 'win32' or platform_machine != 'ARM64' or python_version > '3.11' )", "tensorflow==2.14.0; sys_platform == 'darwin' and python_version < '3.12'", "tensorflow-cpu>=2.14.0; sys_platform == 'linux' and platform_machine != 'aarch64' and python_version < '3.12'", @@ -327,7 +329,7 @@ packages = ["duckdb", "_duckdb"] strict = true warn_unreachable = true pretty = true -python_version = "3.9" +python_version = "3.10" exclude = [ "duckdb/experimental/", # not checking the pyspark API "duckdb/query_graph/", # old and unmaintained (should probably remove) @@ -360,7 +362,7 @@ source = ["duckdb"] [tool.ruff] line-length = 120 indent-width = 4 -target-version = "py39" +target-version = "py310" fix = true exclude = ['external/duckdb', 'sqllogic'] @@ -409,6 +411,8 @@ strict = true [tool.ruff.lint.per-file-ignores] "duckdb/experimental/spark/**.py" = [ + # No need for moduledocstrings for spark + 'D100', # Ignore boolean positional args in the Spark API 'FBT001' ] diff --git a/scripts/cache_data.json b/scripts/cache_data.json index b1f1df3e..fea6034d 100644 --- a/scripts/cache_data.json +++ b/scripts/cache_data.json @@ -20,7 +20,8 @@ "pyarrow.decimal32", "pyarrow.decimal64", "pyarrow.decimal128" - ] + ], + "required": false }, "pyarrow.dataset": { "type": "module", @@ -29,7 +30,8 @@ "children": [ "pyarrow.dataset.Scanner", "pyarrow.dataset.Dataset" - ] + ], + "required": false }, "pyarrow.dataset.Scanner": { "type": "attribute", @@ -528,7 +530,9 @@ "name": "polars", "children": [ "polars.DataFrame", - "polars.LazyFrame" + "polars.LazyFrame", + "polars.col", + "polars.lit" ], "required": false }, @@ -806,5 +810,17 @@ "full_path": "typing.Union", "name": "Union", "children": [] + }, + "polars.col": { + "type": "attribute", + "full_path": "polars.col", + "name": "col", + "children": [] + }, + "polars.lit": { + "type": "attribute", + "full_path": "polars.lit", + "name": "lit", + "children": [] } } \ No newline at end of file diff --git a/scripts/connection_methods.json b/scripts/connection_methods.json index a87b992f..56398af0 100644 --- a/scripts/connection_methods.json +++ b/scripts/connection_methods.json @@ -395,7 +395,7 @@ "return": "polars.DataFrame" }, { - "name": "fetch_arrow_table", + "name": "arrow_table", "function": "FetchArrow", "docs": "Fetch a result as Arrow table following execute()", "args": [ @@ -1093,5 +1093,30 @@ } ], "return": "None" + }, + { + "name": "get_profiling_information", + "function": "GetProfilingInformation", + "docs": "Get profiling information for a query", + "args": [ + { + "name": "format", + "default": "JSON", + "type": "Optional[str]" + } + ], + "return": "str" + }, + { + "name": "enable_profiling", + "function": "EnableProfiling", + "docs": "Enable profiling for a connection", + "return": "None" + }, + { + "name": "disable_profiling", + "function": "DisableProfiling", + "docs": "Disable profiling for a connection", + "return": "None" } ] diff --git a/scripts/generate_connection_stubs.py b/scripts/generate_connection_stubs.py index d542a047..76c19b36 100644 --- a/scripts/generate_connection_stubs.py +++ b/scripts/generate_connection_stubs.py @@ -5,7 +5,7 @@ os.chdir(Path(__file__).parent) JSON_PATH = "connection_methods.json" -DUCKDB_STUBS_FILE = Path("..") / "duckdb" / "__init__.pyi" +DUCKDB_STUBS_FILE = Path("..") / "_duckdb-stubs" / "__init__.pyi" START_MARKER = " # START OF CONNECTION METHODS" END_MARKER = " # END OF CONNECTION METHODS" diff --git a/scripts/generate_import_cache_json.py b/scripts/generate_import_cache_json.py index dd8c3d5c..389866c5 100644 --- a/scripts/generate_import_cache_json.py +++ b/scripts/generate_import_cache_json.py @@ -46,7 +46,7 @@ def __init__(self, full_path) -> None: self.type = "module" self.name = parts[-1] self.full_path = full_path - self.items: dict[str, Union[ImportCacheAttribute, ImportCacheModule]] = {} + self.items: dict[str, ImportCacheAttribute | ImportCacheModule] = {} def add_item(self, item: Union[ImportCacheAttribute, "ImportCacheModule"]): assert self.full_path != item.full_path @@ -111,7 +111,7 @@ def get_module(self, module_name: str) -> ImportCacheModule: raise ValueError(msg) return self.modules[module_name] - def get_item(self, item_name: str) -> Union[ImportCacheModule, ImportCacheAttribute]: + def get_item(self, item_name: str) -> ImportCacheModule | ImportCacheAttribute: parts = item_name.split(".") if len(parts) == 1: return self.get_module(item_name) diff --git a/scripts/get_cpp_methods.py b/scripts/get_cpp_methods.py index a86b609e..0a77192a 100644 --- a/scripts/get_cpp_methods.py +++ b/scripts/get_cpp_methods.py @@ -1,7 +1,7 @@ # Requires `python3 -m pip install cxxheaderparser pcpp` +from collections.abc import Callable from enum import Enum from pathlib import Path -from typing import Callable import cxxheaderparser.parser import cxxheaderparser.preprocessor diff --git a/scripts/imports.py b/scripts/imports.py index ee4c4597..26e0394b 100644 --- a/scripts/imports.py +++ b/scripts/imports.py @@ -109,6 +109,8 @@ polars.DataFrame polars.LazyFrame +polars.col +polars.lit import duckdb import duckdb.filesystem diff --git a/src/duckdb_py/arrow/CMakeLists.txt b/src/duckdb_py/arrow/CMakeLists.txt index 9a9188b8..2f92f09b 100644 --- a/src/duckdb_py/arrow/CMakeLists.txt +++ b/src/duckdb_py/arrow/CMakeLists.txt @@ -1,5 +1,6 @@ # this is used for clang-tidy checks -add_library(python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp - pyarrow_filter_pushdown.cpp) +add_library( + python_arrow OBJECT arrow_array_stream.cpp arrow_export_utils.cpp + polars_filter_pushdown.cpp pyarrow_filter_pushdown.cpp) target_link_libraries(python_arrow PRIVATE _duckdb_dependencies) diff --git a/src/duckdb_py/arrow/arrow_array_stream.cpp b/src/duckdb_py/arrow/arrow_array_stream.cpp index f9cfd1bb..4f438dec 100644 --- a/src/duckdb_py/arrow/arrow_array_stream.cpp +++ b/src/duckdb_py/arrow/arrow_array_stream.cpp @@ -1,4 +1,5 @@ #include "duckdb_python/arrow/arrow_array_stream.hpp" +#include "duckdb_python/arrow/polars_filter_pushdown.hpp" #include "duckdb_python/arrow/pyarrow_filter_pushdown.hpp" #include "duckdb_python/pyconnection/pyconnection.hpp" @@ -27,15 +28,15 @@ void VerifyArrowDatasetLoaded() { } } -py::object PythonTableArrowArrayStreamFactory::ProduceScanner(DBConfig &config, py::object &arrow_scanner, - py::handle &arrow_obj_handle, +py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle, ArrowStreamParameters ¶meters, const ClientProperties &client_properties) { D_ASSERT(!py::isinstance(arrow_obj_handle)); ArrowSchemaWrapper schema; PythonTableArrowArrayStreamFactory::GetSchemaInternal(arrow_obj_handle, schema); ArrowTableSchema arrow_table; - ArrowTableFunction::PopulateArrowTableSchema(config, arrow_table, schema.arrow_schema); + ArrowTableFunction::PopulateArrowTableSchema(*client_properties.client_context.get_mutable(), arrow_table, + schema.arrow_schema); auto filters = parameters.filters; auto &column_list = parameters.projected_columns.columns; @@ -64,52 +65,128 @@ unique_ptr PythonTableArrowArrayStreamFactory::Produce( auto factory = static_cast(reinterpret_cast(factory_ptr)); // NOLINT D_ASSERT(factory->arrow_object); py::handle arrow_obj_handle(factory->arrow_object); - auto arrow_object_type = DuckDBPyConnection::GetArrowType(arrow_obj_handle); + auto arrow_object_type = factory->cached_arrow_type; + + if (arrow_object_type == PyArrowObjectType::PolarsLazyFrame) { + py::object lf = py::reinterpret_borrow(arrow_obj_handle); + + auto filters = parameters.filters; + bool filters_pushed = false; + + // Translate DuckDB filters to Polars expressions and push into the lazy plan + if (filters && !filters->filters.empty()) { + try { + auto filter_expr = PolarsFilterPushdown::TransformFilter( + *filters, parameters.projected_columns.projection_map, parameters.projected_columns.filter_to_col, + factory->client_properties); + if (!filter_expr.is(py::none())) { + lf = lf.attr("filter")(filter_expr); + filters_pushed = true; + } + } catch (...) { + // Fallback: DuckDB handles filtering post-scan + } + } + + // If no filters were pushed and we have a cached Arrow table, reuse it. This avoids re-reading from source and + // re-converting on repeated unfiltered scans. + py::object arrow_table; + if (!filters_pushed && factory->cached_arrow_table.ptr() != nullptr) { + arrow_table = factory->cached_arrow_table; + } else { + arrow_table = lf.attr("collect")().attr("to_arrow")(); + // Cache only unfiltered results (filtered results are partial) + if (!filters_pushed) { + factory->cached_arrow_table = arrow_table; + } + } + + // Apply column projection + auto &column_list = parameters.projected_columns.columns; + if (!column_list.empty()) { + arrow_table = arrow_table.attr("select")(py::cast(column_list)); + } + + auto capsule_obj = arrow_table.attr("__arrow_c_stream__")(); + auto capsule = py::reinterpret_borrow(capsule_obj); + auto stream = capsule.get_pointer(); + auto res = make_uniq(); + res->arrow_array_stream = *stream; + stream->release = nullptr; + return res; + } + + if (arrow_object_type == PyArrowObjectType::PyCapsuleInterface || arrow_object_type == PyArrowObjectType::Table) { + py::object capsule_obj = arrow_obj_handle.attr("__arrow_c_stream__")(); + auto capsule = py::reinterpret_borrow(capsule_obj); + auto stream = capsule.get_pointer(); + if (!stream->release) { + throw InvalidInputException( + "The __arrow_c_stream__() method returned a released stream. " + "If this object is single-use, implement __arrow_c_schema__() or expose a .schema attribute " + "with _export_to_c() so that DuckDB can extract the schema without consuming the stream."); + } + + auto &import_cache_check = *DuckDBPyConnection::ImportCache(); + if (import_cache_check.pyarrow.dataset()) { + // Tier A: full pushdown via pyarrow.dataset + // Import as RecordBatchReader, feed through Scanner.from_batches for projection/filter pushdown. + auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib"); + auto import_func = pyarrow_lib_module.attr("RecordBatchReader").attr("_import_from_c"); + py::object reader = import_func(reinterpret_cast(stream)); + // _import_from_c takes ownership of the stream; null out to prevent capsule double-free + stream->release = nullptr; + auto &import_cache = *DuckDBPyConnection::ImportCache(); + py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches"); + py::handle reader_handle = reader; + auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_properties); + auto record_batches = scanner.attr("to_reader")(); + auto res = make_uniq(); + auto export_to_c = record_batches.attr("_export_to_c"); + export_to_c(reinterpret_cast(&res->arrow_array_stream)); + return res; + } else { + // Tier B: no pyarrow.dataset, return raw stream (no pushdown) + // DuckDB applies projection/filter post-scan via arrow_scan_dumb + auto res = make_uniq(); + res->arrow_array_stream = *stream; + stream->release = nullptr; + return res; + } + } if (arrow_object_type == PyArrowObjectType::PyCapsule) { auto res = make_uniq(); auto capsule = py::reinterpret_borrow(arrow_obj_handle); auto stream = capsule.get_pointer(); if (!stream->release) { - throw InternalException("ArrowArrayStream was released by another thread/library"); + throw InvalidInputException("This ArrowArrayStream has already been consumed and cannot be scanned again."); } res->arrow_array_stream = *stream; stream->release = nullptr; return res; } + // Scanner and Dataset: require pyarrow.dataset for pushdown + VerifyArrowDatasetLoaded(); auto &import_cache = *DuckDBPyConnection::ImportCache(); py::object scanner; py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches"); switch (arrow_object_type) { - case PyArrowObjectType::Table: { - auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset"); - auto dataset = arrow_dataset(arrow_obj_handle); - py::object arrow_scanner = dataset.attr("__class__").attr("scanner"); - scanner = ProduceScanner(factory->config, arrow_scanner, dataset, parameters, factory->client_properties); - break; - } - case PyArrowObjectType::RecordBatchReader: { - scanner = ProduceScanner(factory->config, arrow_batch_scanner, arrow_obj_handle, parameters, - factory->client_properties); - break; - } case PyArrowObjectType::Scanner: { // If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack // scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain auto record_batches = arrow_obj_handle.attr("to_reader")(); - scanner = ProduceScanner(factory->config, arrow_batch_scanner, record_batches, parameters, - factory->client_properties); + scanner = ProduceScanner(arrow_batch_scanner, record_batches, parameters, factory->client_properties); break; } case PyArrowObjectType::Dataset: { py::object arrow_scanner = arrow_obj_handle.attr("__class__").attr("scanner"); - scanner = - ProduceScanner(factory->config, arrow_scanner, arrow_obj_handle, parameters, factory->client_properties); + scanner = ProduceScanner(arrow_scanner, arrow_obj_handle, parameters, factory->client_properties); break; } default: { - auto py_object_type = string(py::str(arrow_obj_handle.get_type().attr("__name__"))); + auto py_object_type = string(py::str(py::type::of(arrow_obj_handle).attr("__name__"))); throw InvalidInputException("Object of type '%s' is not a recognized Arrow object", py_object_type); } } @@ -122,46 +199,100 @@ unique_ptr PythonTableArrowArrayStreamFactory::Produce( } void PythonTableArrowArrayStreamFactory::GetSchemaInternal(py::handle arrow_obj_handle, ArrowSchemaWrapper &schema) { + // PyCapsule (from bare capsule Produce path) if (py::isinstance(arrow_obj_handle)) { auto capsule = py::reinterpret_borrow(arrow_obj_handle); auto stream = capsule.get_pointer(); if (!stream->release) { - throw InternalException("ArrowArrayStream was released by another thread/library"); + throw InvalidInputException("This ArrowArrayStream has already been consumed and cannot be scanned again."); + } + if (stream->get_schema(stream, &schema.arrow_schema)) { + throw InvalidInputException("Failed to get Arrow schema from stream: %s", + stream->get_last_error ? stream->get_last_error(stream) : "unknown error"); } - stream->get_schema(stream, &schema.arrow_schema); - return; - } - - auto table_class = py::module::import("pyarrow").attr("Table"); - if (py::isinstance(arrow_obj_handle, table_class)) { - auto obj_schema = arrow_obj_handle.attr("schema"); - auto export_to_c = obj_schema.attr("_export_to_c"); - export_to_c(reinterpret_cast(&schema.arrow_schema)); return; } + // Scanner: use projected_schema; everything else (RecordBatchReader, Dataset): use .schema VerifyArrowDatasetLoaded(); - auto &import_cache = *DuckDBPyConnection::ImportCache(); - auto scanner_class = import_cache.pyarrow.dataset.Scanner(); - - if (py::isinstance(arrow_obj_handle, scanner_class)) { + if (py::isinstance(arrow_obj_handle, import_cache.pyarrow.dataset.Scanner())) { auto obj_schema = arrow_obj_handle.attr("projected_schema"); - auto export_to_c = obj_schema.attr("_export_to_c"); - export_to_c(reinterpret_cast(&schema)); + obj_schema.attr("_export_to_c")(reinterpret_cast(&schema.arrow_schema)); } else { auto obj_schema = arrow_obj_handle.attr("schema"); - auto export_to_c = obj_schema.attr("_export_to_c"); - export_to_c(reinterpret_cast(&schema)); + obj_schema.attr("_export_to_c")(reinterpret_cast(&schema.arrow_schema)); } } void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowSchemaWrapper &schema) { - py::gil_scoped_acquire acquire; auto factory = static_cast(reinterpret_cast(factory_ptr)); // NOLINT + + // Fast path: return cached schema without GIL or Python calls + if (factory->schema_cached) { + schema.arrow_schema = factory->cached_schema; // struct copy + schema.arrow_schema.release = nullptr; // non-owning copy + return; + } + + py::gil_scoped_acquire acquire; D_ASSERT(factory->arrow_object); py::handle arrow_obj_handle(factory->arrow_object); + + auto type = factory->cached_arrow_type; + if (type == PyArrowObjectType::PolarsLazyFrame) { + // head(0).collect().to_arrow() gives the Arrow-exported schema (e.g. large_string) without materializing data. + // collect_schema() would give Polars-native types (e.g. string_view) that don't match the actual export. + const auto empty_arrow = arrow_obj_handle.attr("head")(0).attr("collect")().attr("to_arrow")(); + const auto schema_capsule = empty_arrow.attr("schema").attr("__arrow_c_schema__")(); + const auto capsule = py::reinterpret_borrow(schema_capsule); + const auto arrow_schema = capsule.get_pointer(); + factory->cached_schema = *arrow_schema; + arrow_schema->release = nullptr; + factory->schema_cached = true; + schema.arrow_schema = factory->cached_schema; + schema.arrow_schema.release = nullptr; + return; + } + if (type == PyArrowObjectType::PyCapsuleInterface || type == PyArrowObjectType::Table) { + // Get __arrow_c_schema__ if it exists + if (py::hasattr(arrow_obj_handle, "__arrow_c_schema__")) { + auto schema_capsule = arrow_obj_handle.attr("__arrow_c_schema__")(); + auto capsule = py::reinterpret_borrow(schema_capsule); + auto arrow_schema = capsule.get_pointer(); + factory->cached_schema = *arrow_schema; // factory takes ownership + arrow_schema->release = nullptr; + factory->schema_cached = true; + schema.arrow_schema = factory->cached_schema; // non-owning copy + schema.arrow_schema.release = nullptr; + return; + } + // Otherwise try to use .schema with _export_to_c + if (py::hasattr(arrow_obj_handle, "schema")) { + auto obj_schema = arrow_obj_handle.attr("schema"); + if (py::hasattr(obj_schema, "_export_to_c")) { + obj_schema.attr("_export_to_c")(reinterpret_cast(&schema.arrow_schema)); + return; + } + } + // Fallback: create a temporary stream just for the schema (consumes single-use streams!) + auto stream_capsule = arrow_obj_handle.attr("__arrow_c_stream__")(); + auto capsule = py::reinterpret_borrow(stream_capsule); + auto stream = capsule.get_pointer(); + if (stream->get_schema(stream, &schema.arrow_schema)) { + throw InvalidInputException("Failed to get Arrow schema from stream: %s", + stream->get_last_error ? stream->get_last_error(stream) : "unknown error"); + } + return; // stream_capsule goes out of scope, stream released by capsule destructor + } GetSchemaInternal(arrow_obj_handle, schema); + + // Cache for Table and Dataset (immutable schema) + if (type == PyArrowObjectType::Table || type == PyArrowObjectType::Dataset) { + factory->cached_schema = schema.arrow_schema; // factory takes ownership + schema.arrow_schema.release = nullptr; // caller gets non-owning copy + factory->schema_cached = true; + } } } // namespace duckdb diff --git a/src/duckdb_py/arrow/polars_filter_pushdown.cpp b/src/duckdb_py/arrow/polars_filter_pushdown.cpp new file mode 100644 index 00000000..493189a3 --- /dev/null +++ b/src/duckdb_py/arrow/polars_filter_pushdown.cpp @@ -0,0 +1,161 @@ +#include "duckdb_python/arrow/polars_filter_pushdown.hpp" + +#include "duckdb/planner/filter/in_filter.hpp" +#include "duckdb/planner/filter/optional_filter.hpp" +#include "duckdb/planner/filter/conjunction_filter.hpp" +#include "duckdb/planner/filter/constant_filter.hpp" +#include "duckdb/planner/filter/struct_filter.hpp" +#include "duckdb/planner/table_filter.hpp" + +#include "duckdb_python/pyconnection/pyconnection.hpp" +#include "duckdb_python/python_objects.hpp" + +namespace duckdb { + +static py::object TransformFilterRecursive(TableFilter &filter, py::object col_expr, + const ClientProperties &client_properties) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + + switch (filter.filter_type) { + case TableFilterType::CONSTANT_COMPARISON: { + auto &constant_filter = filter.Cast(); + auto &constant = constant_filter.constant; + auto &constant_type = constant.type(); + + // Check for NaN + bool is_nan = false; + if (constant_type.id() == LogicalTypeId::FLOAT) { + is_nan = Value::IsNan(constant.GetValue()); + } else if (constant_type.id() == LogicalTypeId::DOUBLE) { + is_nan = Value::IsNan(constant.GetValue()); + } + + if (is_nan) { + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return col_expr.attr("is_nan")(); + case ExpressionType::COMPARE_LESSTHAN: + case ExpressionType::COMPARE_NOTEQUAL: + return col_expr.attr("is_nan")().attr("__invert__")(); + case ExpressionType::COMPARE_GREATERTHAN: + return import_cache.polars.lit()(false); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return import_cache.polars.lit()(true); + default: + return py::none(); + } + } + + // Convert DuckDB Value to Python object + auto py_value = PythonObject::FromValue(constant, constant_type, client_properties); + + switch (constant_filter.comparison_type) { + case ExpressionType::COMPARE_EQUAL: + return col_expr.attr("__eq__")(py_value); + case ExpressionType::COMPARE_LESSTHAN: + return col_expr.attr("__lt__")(py_value); + case ExpressionType::COMPARE_GREATERTHAN: + return col_expr.attr("__gt__")(py_value); + case ExpressionType::COMPARE_LESSTHANOREQUALTO: + return col_expr.attr("__le__")(py_value); + case ExpressionType::COMPARE_GREATERTHANOREQUALTO: + return col_expr.attr("__ge__")(py_value); + case ExpressionType::COMPARE_NOTEQUAL: + return col_expr.attr("__ne__")(py_value); + default: + return py::none(); + } + } + case TableFilterType::IS_NULL: { + return col_expr.attr("is_null")(); + } + case TableFilterType::IS_NOT_NULL: { + return col_expr.attr("is_not_null")(); + } + case TableFilterType::CONJUNCTION_AND: { + auto &and_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < and_filter.child_filters.size(); i++) { + auto child_expression = TransformFilterRecursive(*and_filter.child_filters[i], col_expr, client_properties); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; + } + case TableFilterType::CONJUNCTION_OR: { + auto &or_filter = filter.Cast(); + py::object expression = py::none(); + for (idx_t i = 0; i < or_filter.child_filters.size(); i++) { + auto child_expression = TransformFilterRecursive(*or_filter.child_filters[i], col_expr, client_properties); + if (child_expression.is(py::none())) { + // Can't skip children in OR + return py::none(); + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__or__")(child_expression); + } + } + return expression; + } + case TableFilterType::STRUCT_EXTRACT: { + auto &struct_filter = filter.Cast(); + auto child_col = col_expr.attr("struct").attr("field")(struct_filter.child_name); + return TransformFilterRecursive(*struct_filter.child_filter, child_col, client_properties); + } + case TableFilterType::IN_FILTER: { + auto &in_filter = filter.Cast(); + py::list py_values; + for (const auto &value : in_filter.values) { + py_values.append(PythonObject::FromValue(value, value.type(), client_properties)); + } + return col_expr.attr("is_in")(py_values); + } + case TableFilterType::OPTIONAL_FILTER: { + auto &optional_filter = filter.Cast(); + if (!optional_filter.child_filter) { + return py::none(); + } + return TransformFilterRecursive(*optional_filter.child_filter, col_expr, client_properties); + } + default: + // We skip DYNAMIC_FILTER, EXPRESSION_FILTER, BLOOM_FILTER + return py::none(); + } +} + +py::object PolarsFilterPushdown::TransformFilter(const TableFilterSet &filter_collection, + unordered_map &columns, + const unordered_map &filter_to_col, + const ClientProperties &client_properties) { + auto &import_cache = *DuckDBPyConnection::ImportCache(); + auto &filters_map = filter_collection.filters; + + py::object expression = py::none(); + for (auto &it : filters_map) { + auto column_idx = it.first; + auto &column_name = columns[column_idx]; + auto col_expr = import_cache.polars.col()(column_name); + + auto child_expression = TransformFilterRecursive(*it.second, col_expr, client_properties); + if (child_expression.is(py::none())) { + continue; + } + if (expression.is(py::none())) { + expression = std::move(child_expression); + } else { + expression = expression.attr("__and__")(child_expression); + } + } + return expression; +} + +} // namespace duckdb diff --git a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp index 66a6e3fa..af05789a 100644 --- a/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp +++ b/src/duckdb_py/arrow/pyarrow_filter_pushdown.cpp @@ -160,6 +160,15 @@ py::object GetScalar(Value &constant, const string &timezone_config, const Arrow } } +static py::list TransformInList(const InFilter &in) { + py::list res; + ClientProperties default_properties; + for (auto &val : in.values) { + res.append(PythonObject::FromValue(val, val.type(), default_properties)); + } + return res; +} + py::object TransformFilterRecursive(TableFilter &filter, vector column_ref, const string &timezone_config, const ArrowType &type) { auto &import_cache = *DuckDBPyConnection::ImportCache(); @@ -282,17 +291,9 @@ py::object TransformFilterRecursive(TableFilter &filter, vector column_r } case TableFilterType::IN_FILTER: { auto &in_filter = filter.Cast(); - ConjunctionOrFilter or_filter; - value_set_t unique_values; - for (const auto &value : in_filter.values) { - if (unique_values.find(value) == unique_values.end()) { - unique_values.insert(value); - } - } - for (const auto &value : unique_values) { - or_filter.child_filters.push_back(make_uniq(ExpressionType::COMPARE_EQUAL, value)); - } - return TransformFilterRecursive(or_filter, column_ref, timezone_config, type); + auto constant_field = field(py::tuple(py::cast(column_ref))); + auto in_list = TransformInList(in_filter); + return constant_field.attr("isin")(std::move(in_list)); } case TableFilterType::DYNAMIC_FILTER: { //! Ignore dynamic filters for now, not necessary for correctness diff --git a/src/duckdb_py/common/exceptions.cpp b/src/duckdb_py/common/exceptions.cpp index 51de2bdf..5bf744f1 100644 --- a/src/duckdb_py/common/exceptions.cpp +++ b/src/duckdb_py/common/exceptions.cpp @@ -350,13 +350,19 @@ void RegisterExceptions(const py::module &m) { auto io_exception = py::register_exception(m, "IOException", operational_error).ptr(); py::register_exception(m, "SerializationException", operational_error); - static py::exception HTTP_EXCEPTION(m, "HTTPException", io_exception); - const auto string_type = py::type::of(py::str()); - const auto Dict = py::module_::import("typing").attr("Dict"); - HTTP_EXCEPTION.attr("__annotations__") = - py::dict(py::arg("status_code") = py::type::of(py::int_()), py::arg("body") = string_type, - py::arg("reason") = string_type, py::arg("headers") = Dict[py::make_tuple(string_type, string_type)]); - HTTP_EXCEPTION.doc() = "Thrown when an error occurs in the httpfs extension, or whilst downloading an extension."; + // Use a raw pointer to avoid destructor running after Python finalization. + // The module holds a reference to the exception type, keeping it alive. + static PyObject *HTTP_EXCEPTION = nullptr; + { + auto http_exc = py::register_exception(m, "HTTPException", io_exception); + HTTP_EXCEPTION = http_exc.ptr(); + const auto string_type = py::type::of(py::str()); + const auto Dict = py::module_::import("typing").attr("Dict"); + http_exc.attr("__annotations__") = py::dict( + py::arg("status_code") = py::type::of(py::int_()), py::arg("body") = string_type, + py::arg("reason") = string_type, py::arg("headers") = Dict[py::make_tuple(string_type, string_type)]); + http_exc.doc() = "Thrown when an error occurs in the httpfs extension, or whilst downloading an extension."; + } // IntegrityError auto integrity_error = py::register_exception(m, "IntegrityError", db_error).ptr(); @@ -388,7 +394,7 @@ void RegisterExceptions(const py::module &m) { } catch (const duckdb::Exception &ex) { duckdb::ErrorData error(ex); UnsetPythonException(); - PyThrowException(error, HTTP_EXCEPTION.ptr()); + PyThrowException(error, HTTP_EXCEPTION); } catch (const py::builtin_exception &ex) { // These represent Python exceptions, we don't want to catch these throw; @@ -399,7 +405,7 @@ void RegisterExceptions(const py::module &m) { throw; } UnsetPythonException(); - PyThrowException(error, HTTP_EXCEPTION.ptr()); + PyThrowException(error, HTTP_EXCEPTION); } }); } diff --git a/src/duckdb_py/duckdb_python.cpp b/src/duckdb_py/duckdb_python.cpp index 1dd3ba17..eea21519 100644 --- a/src/duckdb_py/duckdb_python.cpp +++ b/src/duckdb_py/duckdb_python.cpp @@ -124,6 +124,34 @@ static void InitializeConnectionMethods(py::module_ &m) { }, "Check if a filesystem with the provided name is currently registered", py::arg("name"), py::kw_only(), py::arg("connection") = py::none()); + m.def( + "get_profiling_information", + [](const py::str &format, shared_ptr conn = nullptr) { + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); + } + return conn->GetProfilingInformation(format); + }, + "Get profiling information from a query", py::kw_only(), py::arg("format") = "json", + py::arg("connection") = py::none()); + m.def( + "enable_profiling", + [](shared_ptr conn = nullptr) { + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); + } + return conn->EnableProfiling(); + }, + "Enable profiling for the current connection", py::kw_only(), py::arg("connection") = py::none()); + m.def( + "disable_profiling", + [](shared_ptr conn = nullptr) { + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); + } + return conn->DisableProfiling(); + }, + "Disable profiling for the current connection", py::kw_only(), py::arg("connection") = py::none()); m.def( "create_function", [](const string &name, const py::function &udf, const py::object &arguments = py::none(), @@ -418,31 +446,45 @@ static void InitializeConnectionMethods(py::module_ &m) { "Fetch a result as Polars DataFrame following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), py::arg("lazy") = false, py::arg("connection") = py::none()); m.def( - "fetch_arrow_table", - [](idx_t rows_per_batch, shared_ptr conn = nullptr) { + "to_arrow_table", + [](idx_t batch_size, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } - return conn->FetchArrow(rows_per_batch); + return conn->FetchArrow(batch_size); }, - "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), + "Fetch a result as Arrow table following execute()", py::arg("batch_size") = 1000000, py::kw_only(), py::arg("connection") = py::none()); m.def( - "fetch_record_batch", - [](const idx_t rows_per_batch, shared_ptr conn = nullptr) { + "to_arrow_reader", + [](idx_t batch_size, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } - return conn->FetchRecordBatchReader(rows_per_batch); + return conn->FetchRecordBatchReader(batch_size); }, - "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), + "Fetch an Arrow RecordBatchReader following execute()", py::arg("batch_size") = 1000000, py::kw_only(), py::arg("connection") = py::none()); m.def( - "arrow", + "fetch_arrow_table", + [](idx_t rows_per_batch, shared_ptr conn = nullptr) { + if (!conn) { + conn = DuckDBPyConnection::DefaultConnection(); + } + PyErr_WarnEx(PyExc_DeprecationWarning, "fetch_arrow_table() is deprecated, use to_arrow_table() instead.", + 0); + return conn->FetchArrow(rows_per_batch); + }, + "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), + py::arg("connection") = py::none()); + m.def( + "fetch_record_batch", [](const idx_t rows_per_batch, shared_ptr conn = nullptr) { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } + PyErr_WarnEx(PyExc_DeprecationWarning, "fetch_record_batch() is deprecated, use to_arrow_reader() instead.", + 0); return conn->FetchRecordBatchReader(rows_per_batch); }, "Fetch an Arrow RecordBatchReader following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), @@ -929,14 +971,14 @@ static void InitializeConnectionMethods(py::module_ &m) { // We define these "wrapper" methods manually because they are overloaded m.def( "arrow", - [](idx_t rows_per_batch, shared_ptr conn) -> duckdb::pyarrow::Table { + [](idx_t rows_per_batch, shared_ptr conn) -> duckdb::pyarrow::RecordBatchReader { if (!conn) { conn = DuckDBPyConnection::DefaultConnection(); } - return conn->FetchArrow(rows_per_batch); + return conn->FetchRecordBatchReader(rows_per_batch); }, - "Fetch a result as Arrow table following execute()", py::arg("rows_per_batch") = 1000000, py::kw_only(), - py::arg("connection") = py::none()); + "Alias of to_arrow_reader(). We recommend using to_arrow_reader() instead.", + py::arg("rows_per_batch") = 1000000, py::kw_only(), py::arg("connection") = py::none()); m.def( "arrow", [](py::object &arrow_object, shared_ptr conn) -> unique_ptr { diff --git a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp index a5895b4a..90974ff7 100644 --- a/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp +++ b/src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp @@ -54,12 +54,12 @@ class Table : public py::object { enum class PyArrowObjectType { Invalid, Table, - RecordBatchReader, Scanner, Dataset, PyCapsule, PyCapsuleInterface, - MessageReader + MessageReader, + PolarsLazyFrame }; void TransformDuckToArrowChunk(ArrowSchema &arrow_schema, ArrowArray &data, py::list &batches); @@ -69,8 +69,20 @@ PyArrowObjectType GetArrowType(const py::handle &obj); class PythonTableArrowArrayStreamFactory { public: explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p, - DBConfig &config) - : arrow_object(arrow_table), client_properties(client_properties_p), config(config) {}; + PyArrowObjectType arrow_type_p) + : arrow_object(arrow_table), client_properties(client_properties_p), cached_arrow_type(arrow_type_p) { + cached_schema.release = nullptr; + } + + ~PythonTableArrowArrayStreamFactory() { + if (cached_arrow_table.ptr() != nullptr) { + py::gil_scoped_acquire acquire; + cached_arrow_table = py::object(); + } + if (cached_schema.release) { + cached_schema.release(&cached_schema); + } + } //! Produces an Arrow Scanner, should be only called once when initializing Scan States static unique_ptr Produce(uintptr_t factory, ArrowStreamParameters ¶meters); @@ -83,10 +95,17 @@ class PythonTableArrowArrayStreamFactory { PyObject *arrow_object; const ClientProperties client_properties; - DBConfig &config; + const PyArrowObjectType cached_arrow_type; + + //! Cached Arrow table from an unfiltered .collect().to_arrow() on a LazyFrame. + //! Avoids re-reading from source and re-converting on repeated scans without filters. + py::object cached_arrow_table; private: - static py::object ProduceScanner(DBConfig &config, py::object &arrow_scanner, py::handle &arrow_obj_handle, + ArrowSchema cached_schema; + bool schema_cached = false; + + static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle, ArrowStreamParameters ¶meters, const ClientProperties &client_properties); }; } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/arrow/polars_filter_pushdown.hpp b/src/duckdb_py/include/duckdb_python/arrow/polars_filter_pushdown.hpp new file mode 100644 index 00000000..adf485c9 --- /dev/null +++ b/src/duckdb_py/include/duckdb_python/arrow/polars_filter_pushdown.hpp @@ -0,0 +1,24 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb_python/arrow/polars_filter_pushdown.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/unordered_map.hpp" +#include "duckdb/planner/table_filter.hpp" +#include "duckdb/main/client_properties.hpp" +#include "duckdb_python/pybind11/pybind_wrapper.hpp" + +namespace duckdb { + +struct PolarsFilterPushdown { + static py::object TransformFilter(const TableFilterSet &filter_collection, unordered_map &columns, + const unordered_map &filter_to_col, + const ClientProperties &client_properties); +}; + +} // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp b/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp index 39dd252f..43c0c5c3 100644 --- a/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp +++ b/src/duckdb_py/include/duckdb_python/expression/pyexpression.hpp @@ -12,6 +12,7 @@ #include "duckdb.hpp" #include "duckdb/common/string.hpp" #include "duckdb/parser/parsed_expression.hpp" +#include "duckdb/parser/expression/case_expression.hpp" #include "duckdb/parser/expression/constant_expression.hpp" #include "duckdb/parser/expression/columnref_expression.hpp" #include "duckdb/parser/expression/function_expression.hpp" diff --git a/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp b/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp index c22173b4..17f746fb 100644 --- a/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp +++ b/src/duckdb_py/include/duckdb_python/import_cache/modules/polars_module.hpp @@ -26,13 +26,17 @@ struct PolarsCacheItem : public PythonImportCacheItem { static constexpr const char *Name = "polars"; public: - PolarsCacheItem() : PythonImportCacheItem("polars"), DataFrame("DataFrame", this), LazyFrame("LazyFrame", this) { + PolarsCacheItem() + : PythonImportCacheItem("polars"), DataFrame("DataFrame", this), LazyFrame("LazyFrame", this), col("col", this), + lit("lit", this) { } ~PolarsCacheItem() override { } PythonImportCacheItem DataFrame; PythonImportCacheItem LazyFrame; + PythonImportCacheItem col; + PythonImportCacheItem lit; protected: bool IsRequired() const override final { diff --git a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp index ff9d9ebc..642262e8 100644 --- a/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp +++ b/src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp @@ -46,6 +46,11 @@ struct PyarrowDatasetCacheItem : public PythonImportCacheItem { PythonImportCacheItem Scanner; PythonImportCacheItem Dataset; + +protected: + bool IsRequired() const override final { + return false; + } }; struct PyarrowCacheItem : public PythonImportCacheItem { @@ -80,6 +85,11 @@ struct PyarrowCacheItem : public PythonImportCacheItem { PythonImportCacheItem decimal32; PythonImportCacheItem decimal64; PythonImportCacheItem decimal128; + +protected: + bool IsRequired() const override final { + return false; + } }; } // namespace duckdb diff --git a/src/duckdb_py/include/duckdb_python/numpy/numpy_type.hpp b/src/duckdb_py/include/duckdb_python/numpy/numpy_type.hpp index 982f00ec..d58bc139 100644 --- a/src/duckdb_py/include/duckdb_python/numpy/numpy_type.hpp +++ b/src/duckdb_py/include/duckdb_python/numpy/numpy_type.hpp @@ -18,25 +18,28 @@ namespace duckdb { // Pandas Specific Types (e.g., categorical, datetime_tz,...) enum class NumpyNullableType : uint8_t { //! NumPy dtypes - BOOL, //! bool_, bool8 - INT_8, //! byte, int8 - UINT_8, //! ubyte, uint8 - INT_16, //! int16, short - UINT_16, //! uint16, ushort - INT_32, //! int32, intc - UINT_32, //! uint32, uintc, - INT_64, //! int64, int0, int_, intp, matrix - UINT_64, //! uint64, uint, uint0, uintp - FLOAT_16, //! float16, half - FLOAT_32, //! float32, single - FLOAT_64, //! float64, float_, double - OBJECT, //! object - UNICODE, //! { py::list ListFilesystems(); bool FileSystemIsRegistered(const string &name); + // Profiling info + py::str GetProfilingInformation(const py::str &format = "json"); + void EnableProfiling(); + void DisableProfiling(); + //! Default connection to an in-memory database static DefaultConnectionHolder default_connection; //! Caches and provides an interface to get frequently used modules+subtypes @@ -350,6 +355,8 @@ struct DuckDBPyConnection : public enable_shared_from_this { static unique_ptr CompletePendingQuery(PendingQueryResult &pending_query); private: + unique_ptr CreateRelation(shared_ptr rel); + unique_ptr CreateRelation(shared_ptr result); PathLike GetPathLike(const py::object &object); ScalarFunction CreateScalarUDF(const string &name, const py::function &udf, const py::object ¶meters, const shared_ptr &return_type, bool vectorized, diff --git a/src/duckdb_py/include/duckdb_python/pyrelation.hpp b/src/duckdb_py/include/duckdb_python/pyrelation.hpp index b1975e7f..50f39b5f 100644 --- a/src/duckdb_py/include/duckdb_python/pyrelation.hpp +++ b/src/duckdb_py/include/duckdb_python/pyrelation.hpp @@ -262,6 +262,10 @@ struct DuckDBPyRelation { bool ContainsColumnByName(const string &name) const; + void SetConnectionOwner(py::object owner); + unique_ptr DeriveRelation(shared_ptr new_rel); + unique_ptr DeriveRelation(shared_ptr result); + private: string ToStringInternal(const BoxRendererConfig &config, bool invalidate_cache = false); string GenerateExpressionList(const string &function_name, const string &aggregated_columns, @@ -284,6 +288,9 @@ struct DuckDBPyRelation { unique_ptr ExecuteInternal(bool stream_result = false); private: + //! Prevents GC of the parent DuckDBPyConnection. + //! Declared first so it is destroyed last (reverse declaration order). + py::object connection_owner; //! Whether the relation has been executed at least once bool executed; shared_ptr rel; diff --git a/src/duckdb_py/include/duckdb_python/python_conversion.hpp b/src/duckdb_py/include/duckdb_python/python_conversion.hpp index d3bfadba..bad518ef 100644 --- a/src/duckdb_py/include/duckdb_python/python_conversion.hpp +++ b/src/duckdb_py/include/duckdb_python/python_conversion.hpp @@ -45,7 +45,7 @@ enum class PythonObjectType { PythonObjectType GetPythonObjectType(py::handle &ele); -bool TryTransformPythonNumeric(Value &res, py::handle ele, const LogicalType &target_type = LogicalType::UNKNOWN); +LogicalType SniffPythonIntegerType(py::handle ele); bool DictionaryHasMapFormat(const PyDictionary &dict); void TransformPythonObject(py::handle ele, Vector &vector, idx_t result_offset, bool nan_as_null = true); Value TransformPythonValue(py::handle ele, const LogicalType &target_type = LogicalType::UNKNOWN, diff --git a/src/duckdb_py/native/python_conversion.cpp b/src/duckdb_py/native/python_conversion.cpp index caa409c5..a56ea73f 100644 --- a/src/duckdb_py/native/python_conversion.cpp +++ b/src/duckdb_py/native/python_conversion.cpp @@ -13,6 +13,36 @@ namespace duckdb { +// Like DefaultCastAs, but handles UNION targets by finding the first compatible member. DefaultCastAs raises a +// Conversion Error when multiple UNION members have the same type (e.g. UNION(u1 DOUBLE, u2 DOUBLE)), so for UNION +// targets we resolve the member ourselves. +static Value CastToTarget(Value val, const LogicalType &target_type) { + if (target_type.id() != LogicalTypeId::UNION) { + return val.DefaultCastAs(target_type); + } + + auto member_count = UnionType::GetMemberCount(target_type); + auto &source_type = val.type(); + + // First pass: if there's an exact type match we use that + for (idx_t i = 0; i < member_count; i++) { + if (UnionType::GetMemberType(target_type, i) == source_type) { + return Value::UNION(UnionType::CopyMemberTypes(target_type), NumericCast(i), std::move(val)); + } + } + + // Second pass: if there's a type we can implicitly cast to, we do that + for (idx_t i = 0; i < member_count; i++) { + auto member_type = UnionType::GetMemberType(target_type, i); + Value candidate = val; + if (candidate.DefaultTryCastAs(member_type)) { + return Value::UNION(UnionType::CopyMemberTypes(target_type), NumericCast(i), std::move(candidate)); + } + } + throw ConversionException("Could not convert value of type %s to %s", source_type.ToString(), + target_type.ToString()); +} + static Value EmptyMapValue() { auto map_type = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL); return Value::MAP(ListType::GetChildType(map_type), vector()); @@ -92,7 +122,7 @@ Value TransformDictionaryToStruct(const PyDictionary &dict, const LogicalType &t child_list_t struct_values; for (idx_t i = 0; i < dict.len; i++) { auto &key = struct_target ? StructType::GetChildName(target_type, i) : struct_keys[i]; - auto value_index = key_mapping[key]; + auto value_index = struct_target ? key_mapping[key] : i; auto &child_type = struct_target ? StructType::GetChildType(target_type, i) : LogicalType::UNKNOWN; auto val = TransformPythonValue(dict.values.attr("__getitem__")(value_index), child_type); struct_values.emplace_back(make_pair(std::move(key), std::move(val))); @@ -230,150 +260,108 @@ Value TransformTupleToStruct(py::handle ele, const LogicalType &target_type = Lo return result; } -bool TryTransformPythonIntegerToDouble(Value &res, py::handle ele) { - double number = PyLong_AsDouble(ele.ptr()); - if (number == -1.0 && PyErr_Occurred()) { +// Tries to convert a Python integer that overflows int64/uint64 into a HUGEINT or UHUGEINT Value +// by decomposing it into upper and lower 64-bit components. Tries HUGEINT first; falls back to +// UHUGEINT for large positive values. Returns false if the value doesn't fit in 128 bits. +static bool TryTransformPythonLongToHugeInt(py::handle ele, const LogicalType &target_type, Value &result) { + auto ptr = ele.ptr(); + + // Extract lower 64 bits (two's complement, works for negative values too) + uint64_t lower = PyLong_AsUnsignedLongLongMask(ptr); + if (lower == static_cast(-1) && PyErr_Occurred()) { + PyErr_Clear(); + return false; + } + + // Extract upper bits by right-shifting by 64 + py::int_ shift_amount(64); + py::object upper_obj = py::reinterpret_steal(PyNumber_Rshift(ptr, shift_amount.ptr())); + + // Try signed 128-bit (hugeint) first + int overflow; + int64_t upper_signed = PyLong_AsLongLongAndOverflow(upper_obj.ptr(), &overflow); + if (overflow == 0 && !(upper_signed == -1 && PyErr_Occurred())) { + auto val = Value::HUGEINT(hugeint_t {upper_signed, lower}); + if (target_type.id() == LogicalTypeId::UNKNOWN || target_type.id() == LogicalTypeId::HUGEINT) { + result = val; + } else { + result = CastToTarget(std::move(val), target_type); + } + return true; + } + PyErr_Clear(); + + // Try unsigned 128-bit (uhugeint) + uint64_t upper_unsigned = PyLong_AsUnsignedLongLong(upper_obj.ptr()); + if (PyErr_Occurred()) { PyErr_Clear(); return false; } - res = Value::DOUBLE(number); + + auto val = Value::UHUGEINT(uhugeint_t {upper_unsigned, lower}); + if (target_type.id() == LogicalTypeId::UNKNOWN || target_type.id() == LogicalTypeId::UHUGEINT) { + result = val; + } else { + result = CastToTarget(std::move(val), target_type); + } return true; } -void TransformPythonUnsigned(uint64_t value, Value &res) { - if (value > (uint64_t)std::numeric_limits::max()) { - res = Value::UBIGINT(value); - } else if (value > (int64_t)std::numeric_limits::max()) { - res = Value::UINTEGER(value); - } else if (value > (int64_t)std::numeric_limits::max()) { - res = Value::USMALLINT(value); - } else { - res = Value::UTINYINT(value); +// Throwing wrapper for contexts that require a result (e.g. prepared statement parameters). +static Value TransformPythonLongToHugeInt(py::handle ele, const LogicalType &target_type) { + Value result; + if (!TryTransformPythonLongToHugeInt(ele, target_type, result)) { + throw InvalidInputException("Python integer too large for 128-bit integer type: %s", std::string(py::str(ele))); } + return result; } -bool TrySniffPythonNumeric(Value &res, int64_t value) { +// Picks the tightest DuckDB integer type (>=INT32) for an int64 value when no target type is specified. +static Value SniffIntegerValue(int64_t value) { if (value < (int64_t)std::numeric_limits::min() || value > (int64_t)std::numeric_limits::max()) { - res = Value::BIGINT(value); - } else { - // To match default duckdb behavior, numeric values without a specified type should not become a smaller type - // than INT32 - res = Value::INTEGER(value); + return Value::BIGINT(value); } - return true; + return Value::INTEGER(value); } -// TODO: add support for HUGEINT -bool TryTransformPythonNumeric(Value &res, py::handle ele, const LogicalType &target_type) { +// Sniffs the tightest DuckDB integer type for a Python integer. +// Progressively widens: int64 → uint64 → hugeint → uhugeint. +// Returns SQLNULL if the value doesn't fit in any DuckDB integer type (> 128-bit). +LogicalType SniffPythonIntegerType(py::handle ele) { auto ptr = ele.ptr(); + // Step 1: Try int64 int overflow; - int64_t value = PyLong_AsLongLongAndOverflow(ptr, &overflow); - if (overflow == -1) { - PyErr_Clear(); - if (target_type.id() == LogicalTypeId::BIGINT) { - throw InvalidInputException(StringUtil::Format("Failed to cast value: Python value '%s' to INT64", - std::string(pybind11::str(ele)))); - } - auto cast_as = target_type.id() == LogicalTypeId::UNKNOWN ? LogicalType::HUGEINT : target_type; - auto numeric_string = std::string(py::str(ele)); - res = Value(numeric_string).DefaultCastAs(cast_as); - return true; - } else if (overflow == 1) { - if (target_type.InternalType() == PhysicalType::INT64) { - throw InvalidInputException(StringUtil::Format("Failed to cast value: Python value '%s' to INT64", - std::string(pybind11::str(ele)))); - } - uint64_t unsigned_value = PyLong_AsUnsignedLongLong(ptr); - if (PyErr_Occurred()) { - PyErr_Clear(); - return TryTransformPythonIntegerToDouble(res, ele); - } else { - TransformPythonUnsigned(unsigned_value, res); - } - PyErr_Clear(); - return true; - } else if (value == -1 && PyErr_Occurred()) { - return false; - } + const int64_t value = PyLong_AsLongLongAndOverflow(ptr, &overflow); - // The value is int64_t or smaller - - switch (target_type.id()) { - case LogicalTypeId::UNKNOWN: - return TrySniffPythonNumeric(res, value); - case LogicalTypeId::HUGEINT: { - res = Value::HUGEINT(value); - return true; - } - case LogicalTypeId::UHUGEINT: { - if (value < 0) { - return false; - } - res = Value::UHUGEINT(value); - return true; - } - case LogicalTypeId::BIGINT: { - res = Value::BIGINT(value); - return true; - } - case LogicalTypeId::INTEGER: { - if (value < NumericLimits::Minimum() || value > NumericLimits::Maximum()) { - return false; - } - res = Value::INTEGER(value); - return true; - } - case LogicalTypeId::SMALLINT: { - if (value < NumericLimits::Minimum() || value > NumericLimits::Maximum()) { - return false; - } - res = Value::SMALLINT(value); - return true; - } - case LogicalTypeId::TINYINT: { - if (value < NumericLimits::Minimum() || value > NumericLimits::Maximum()) { - return false; - } - res = Value::TINYINT(value); - return true; - } - case LogicalTypeId::UBIGINT: { - if (value < 0) { - return false; - } - res = Value::UBIGINT(value); - return true; - } - case LogicalTypeId::UINTEGER: { - if (value < 0 || value > (int64_t)NumericLimits::Maximum()) { - return false; - } - res = Value::UINTEGER(value); - return true; - } - case LogicalTypeId::USMALLINT: { - if (value < 0 || value > (int64_t)NumericLimits::Maximum()) { - return false; + if (overflow == 0) { + if (value == -1 && PyErr_Occurred()) { + PyErr_Clear(); + return LogicalType::SQLNULL; } - res = Value::USMALLINT(value); - return true; - } - case LogicalTypeId::UTINYINT: { - if (value < 0 || value > (int64_t)NumericLimits::Maximum()) { - return false; + if (value < static_cast(std::numeric_limits::min()) || + value > static_cast(std::numeric_limits::max())) { + return LogicalType::BIGINT; } - res = Value::UTINYINT(value); - return true; + return LogicalType::INTEGER; } - default: { - if (!TrySniffPythonNumeric(res, value)) { - return false; + PyErr_Clear(); + + // Step 2: For positive overflow, try uint64 + if (overflow == 1) { + (void)PyLong_AsUnsignedLongLong(ptr); + if (!PyErr_Occurred()) { + return LogicalType::UBIGINT; } - res = res.DefaultCastAs(target_type, true); - return true; + PyErr_Clear(); } + + // Step 3: Try 128-bit (hugeint/uhugeint) + Value res; + if (!TryTransformPythonLongToHugeInt(ele, LogicalType::UNKNOWN, res)) { + return LogicalType::SQLNULL; } + return res.type(); } Value TransformDictionary(const PyDictionary &dict) { @@ -476,33 +464,22 @@ struct PythonValueConversion { target_type.ToString()); } default: - throw ConversionException("Could not convert 'float' to type %s", target_type.ToString()); + result = CastToTarget(Value::DOUBLE(val), target_type); + break; } } - static void HandleLongAsDouble(Value &result, const LogicalType &target_type, double val) { - auto cast_as = target_type.id() == LogicalTypeId::UNKNOWN ? LogicalType::DOUBLE : target_type; - result = Value::DOUBLE(val).DefaultCastAs(cast_as); + static void HandleLongOverflow(Value &result, const LogicalType &target_type, py::handle ele) { + result = TransformPythonLongToHugeInt(ele, target_type); } static void HandleUnsignedBigint(Value &result, const LogicalType &target_type, uint64_t val) { auto cast_as = target_type.id() == LogicalTypeId::UNKNOWN ? LogicalType::UBIGINT : target_type; - result = Value::UBIGINT(val).DefaultCastAs(cast_as); + result = CastToTarget(Value::UBIGINT(val), cast_as); } static void HandleBigint(Value &res, const LogicalType &target_type, int64_t value) { - switch (target_type.id()) { - case LogicalTypeId::UNKNOWN: { - if (value < (int64_t)std::numeric_limits::min() || - value > (int64_t)std::numeric_limits::max()) { - res = Value::BIGINT(value); - } else { - // To match default duckdb behavior, numeric values without a specified type should not become a smaller - // type than INT32 - res = Value::INTEGER(value); - } - break; - } - default: - res = Value::BIGINT(value).DefaultCastAs(target_type); - break; + if (target_type.id() == LogicalTypeId::UNKNOWN) { + res = SniffIntegerValue(value); + } else { + res = CastToTarget(SniffIntegerValue(value), target_type); } } @@ -511,7 +488,7 @@ struct PythonValueConversion { (target_type.id() == LogicalTypeId::VARCHAR && !target_type.HasAlias())) { result = Value(value); } else { - result = Value(value).DefaultCastAs(target_type); + result = CastToTarget(Value(value), target_type); } } @@ -606,7 +583,7 @@ struct PythonValueConversion { auto type = ele.attr("type"); shared_ptr internal_type; if (!py::try_cast>(type, internal_type)) { - string actual_type = py::str(type.get_type()); + string actual_type = py::str(py::type::of(type)); throw InvalidInputException("The 'type' of a Value should be of type DuckDBPyType, not '%s'", actual_type); } @@ -648,13 +625,13 @@ struct PythonVectorConversion { break; } default: - throw TypeMismatchException( - LogicalType::DOUBLE, result.GetType(), - "Python Conversion Failure: Expected a value of type %s, but got a value of type double"); + FallbackValueConversion(result, result_offset, CastToTarget(Value::DOUBLE(val), result.GetType())); + break; } } - static void HandleLongAsDouble(Vector &result, const idx_t &result_offset, double val) { - FallbackValueConversion(result, result_offset, Value::DOUBLE(val)); + static void HandleLongOverflow(Vector &result, const idx_t &result_offset, py::handle ele) { + Value result_val = TransformPythonLongToHugeInt(ele, result.GetType()); + FallbackValueConversion(result, result_offset, std::move(result_val)); } static void HandleUnsignedBigint(Vector &result, const idx_t &result_offset, uint64_t value) { // this code path is only called for values in the range of [INT64_MAX...UINT64_MAX] @@ -669,7 +646,7 @@ struct PythonVectorConversion { FlatVector::GetData(result)[result_offset] = value; break; default: - FallbackValueConversion(result, result_offset, Value::UBIGINT(value)); + FallbackValueConversion(result, result_offset, CastToTarget(Value::UBIGINT(value), result.GetType())); break; } } @@ -740,7 +717,7 @@ struct PythonVectorConversion { break; } default: - FallbackValueConversion(result, result_offset, Value::BIGINT(value)); + FallbackValueConversion(result, result_offset, CastToTarget(Value::BIGINT(value), result.GetType())); break; } } @@ -953,25 +930,20 @@ void TransformPythonObjectInternal(py::handle ele, A &result, const B ¶m, bo default: break; } - if (overflow == 1) { + if (overflow == 1) { // value is > LLONG_MAX uint64_t unsigned_value = PyLong_AsUnsignedLongLong(ptr); if (!PyErr_Occurred()) { // value does not fit within an int64, but it fits within a uint64 OP::HandleUnsignedBigint(result, param, unsigned_value); break; } + PyErr_Clear(); if (conversion_target.id() == LogicalTypeId::UBIGINT) { throw InvalidInputException("Python Conversion Failure: Value out of range for type %s", conversion_target); } - PyErr_Clear(); - } - double number = PyLong_AsDouble(ele.ptr()); - if (number == -1.0 && PyErr_Occurred()) { - PyErr_Clear(); - throw InvalidInputException("An error occurred attempting to convert a python integer"); } - OP::HandleLongAsDouble(result, param, number); + OP::HandleLongOverflow(result, param, ele); } else if (value == -1 && PyErr_Occurred()) { throw InvalidInputException("An error occurred attempting to convert a python integer"); } else { @@ -1062,7 +1034,7 @@ void TransformPythonObjectInternal(py::handle ele, A &result, const B ¶m, bo } case PythonObjectType::Other: throw NotImplementedException("Unable to transform python value of type '%s' to DuckDB LogicalType", - py::str(ele.get_type()).cast()); + py::str(py::type::of(ele)).cast()); default: throw InternalException("Object type recognized but not implemented!"); } diff --git a/src/duckdb_py/native/python_objects.cpp b/src/duckdb_py/native/python_objects.cpp index 21aa281f..55990f7c 100644 --- a/src/duckdb_py/native/python_objects.cpp +++ b/src/duckdb_py/native/python_objects.cpp @@ -13,6 +13,9 @@ #include "datetime.h" // Python datetime initialize #1 +#include +#include + namespace duckdb { PyDictionary::PyDictionary(py::object dict) { @@ -445,6 +448,7 @@ static bool KeyIsHashable(const LogicalType &type) { case LogicalTypeId::LIST: case LogicalTypeId::ARRAY: case LogicalTypeId::MAP: + case LogicalTypeId::VARIANT: return false; case LogicalTypeId::UNION: { idx_t count = UnionType::GetMemberCount(type); @@ -585,13 +589,20 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, auto tmp_datetime_with_tz = import_cache.datetime.datetime.combine()(tmp_datetime, py_time, timezone_offset); return tmp_datetime_with_tz.attr("timetz")(); } - case LogicalTypeId::TIME: { + case LogicalTypeId::TIME: + case LogicalTypeId::TIME_NS: { D_ASSERT(type.InternalType() == PhysicalType::INT64); - int32_t hour, min, sec, microsec; - auto time = val.GetValueUnsafe(); - duckdb::Time::Convert(time, hour, min, sec, microsec); + int32_t hour, min, sec, usec; + dtime_t time; + if (type.id() == LogicalTypeId::TIME) { + time = val.GetValueUnsafe(); + } else { + // Python's datetime doesn't support nanoseconds, we convert to micros. + time = val.GetValueUnsafe().time(); + } + duckdb::Time::Convert(time, hour, min, sec, usec); try { - auto pytime = PyTime_FromTime(hour, min, sec, microsec); + auto pytime = PyTime_FromTime(hour, min, sec, usec); if (!pytime) { throw py::error_already_set(); } @@ -693,6 +704,14 @@ py::object PythonObject::FromValue(const Value &val, const LogicalType &type, return import_cache.datetime.timedelta()(py::arg("days") = days, py::arg("microseconds") = interval_value.micros); } + case LogicalTypeId::VARIANT: { + Vector tmp(val); + RecursiveUnifiedVectorFormat format; + Vector::RecursiveToUnifiedFormat(tmp, 1, format); + UnifiedVariantVectorData vector_data(format); + auto variant_val = VariantUtils::ConvertVariantToValue(vector_data, 0, 0); + return FromValue(variant_val, variant_val.type(), client_properties); + } default: throw NotImplementedException("Unsupported type: \"%s\"", type.ToString()); diff --git a/src/duckdb_py/numpy/array_wrapper.cpp b/src/duckdb_py/numpy/array_wrapper.cpp index baf1cd70..ac3f7991 100644 --- a/src/duckdb_py/numpy/array_wrapper.cpp +++ b/src/duckdb_py/numpy/array_wrapper.cpp @@ -11,6 +11,8 @@ #include "duckdb_python/pyresult.hpp" #include "duckdb/common/types/uuid.hpp" +#include + namespace duckdb { namespace duckdb_py_convert { @@ -112,7 +114,7 @@ struct IntervalConvert { template static int64_t ConvertValue(interval_t val, NumpyAppendData &append_data) { (void)append_data; - return Interval::GetNanoseconds(val); + return Interval::GetMicro(val); } template @@ -140,6 +142,24 @@ struct TimeConvert { } }; +struct TimeNSConvert { + template + static PyObject *ConvertValue(dtime_ns_t val, NumpyAppendData &append_data) { + auto &client_properties = append_data.client_properties; + auto value = Value::TIME_NS(val); + auto py_obj = PythonObject::FromValue(value, LogicalType::TIME_NS, client_properties); + // Release ownership of the PyObject* without decreasing refcount + // this returns a handle, of which we take the ptr to get the PyObject* + return py_obj.release().ptr(); + } + + template + static NUMPY_T NullValue(bool &set_mask) { + set_mask = true; + return nullptr; + } +}; + struct StringConvert { template static PyObject *ConvertValue(string_t val, NumpyAppendData &append_data) { @@ -284,6 +304,19 @@ struct UnionConvert { } }; +struct VariantConvert { + static py::object ConvertValue(Vector &input, idx_t chunk_offset, NumpyAppendData &append_data) { + auto &client_properties = append_data.client_properties; + auto val = input.GetValue(chunk_offset); + Vector tmp(val); + RecursiveUnifiedVectorFormat format; + Vector::RecursiveToUnifiedFormat(tmp, 1, format); + UnifiedVariantVectorData vector_data(format); + auto variant_val = VariantUtils::ConvertVariantToValue(vector_data, 0, 0); + return PythonObject::FromValue(variant_val, variant_val.type(), client_properties); + } +}; + struct MapConvert { static py::dict ConvertValue(Vector &input, idx_t chunk_offset, NumpyAppendData &append_data) { auto &client_properties = append_data.client_properties; @@ -639,6 +672,9 @@ void ArrayWrapper::Append(idx_t current_offset, Vector &input, idx_t source_size case LogicalTypeId::TIME: may_have_null = ConvertColumn(append_data); break; + case LogicalTypeId::TIME_NS: + may_have_null = ConvertColumn(append_data); + break; case LogicalTypeId::INTERVAL: may_have_null = ConvertColumn(append_data); break; @@ -666,6 +702,9 @@ void ArrayWrapper::Append(idx_t current_offset, Vector &input, idx_t source_size case LogicalTypeId::STRUCT: may_have_null = ConvertNested(append_data); break; + case LogicalTypeId::VARIANT: + may_have_null = ConvertNested(append_data); + break; case LogicalTypeId::UUID: may_have_null = ConvertColumn(append_data); break; diff --git a/src/duckdb_py/numpy/numpy_scan.cpp b/src/duckdb_py/numpy/numpy_scan.cpp index 0117eaae..b1cd6e60 100644 --- a/src/duckdb_py/numpy/numpy_scan.cpp +++ b/src/duckdb_py/numpy/numpy_scan.cpp @@ -302,7 +302,10 @@ void NumpyScan::Scan(PandasColumnBindData &bind_data, idx_t count, idx_t offset, } break; } - case NumpyNullableType::TIMEDELTA: { + case NumpyNullableType::TIMEDELTA_NS: + case NumpyNullableType::TIMEDELTA_US: + case NumpyNullableType::TIMEDELTA_MS: + case NumpyNullableType::TIMEDELTA_S: { auto src_ptr = reinterpret_cast(array.data()); auto tgt_ptr = FlatVector::GetData(out); auto &mask = FlatVector::Validity(out); @@ -314,7 +317,25 @@ void NumpyScan::Scan(PandasColumnBindData &bind_data, idx_t count, idx_t offset, mask.SetInvalid(row); continue; } - int64_t micro = src_ptr[source_idx] / 1000; + + int64_t micro; + switch (bind_data.numpy_type.type) { + case NumpyNullableType::TIMEDELTA_NS: + micro = src_ptr[source_idx] / 1000; // ns -> us + break; + case NumpyNullableType::TIMEDELTA_US: + micro = src_ptr[source_idx]; // already us + break; + case NumpyNullableType::TIMEDELTA_MS: + micro = src_ptr[source_idx] * 1000; // ms -> us + break; + case NumpyNullableType::TIMEDELTA_S: + micro = src_ptr[source_idx] * 1000000; // s -> us + break; + default: + throw InternalException("Unexpected timedelta type"); + } + int64_t days = micro / Interval::MICROS_PER_DAY; micro = micro % Interval::MICROS_PER_DAY; int64_t months = days / Interval::DAYS_PER_MONTH; diff --git a/src/duckdb_py/numpy/raw_array_wrapper.cpp b/src/duckdb_py/numpy/raw_array_wrapper.cpp index 5d73685b..0d8553de 100644 --- a/src/duckdb_py/numpy/raw_array_wrapper.cpp +++ b/src/duckdb_py/numpy/raw_array_wrapper.cpp @@ -48,6 +48,7 @@ static idx_t GetNumpyTypeWidth(const LogicalType &type) { case LogicalTypeId::TIMESTAMP_TZ: return sizeof(int64_t); case LogicalTypeId::TIME: + case LogicalTypeId::TIME_NS: case LogicalTypeId::TIME_TZ: case LogicalTypeId::VARCHAR: case LogicalTypeId::BIT: @@ -59,6 +60,7 @@ static idx_t GetNumpyTypeWidth(const LogicalType &type) { case LogicalTypeId::UNION: case LogicalTypeId::UUID: case LogicalTypeId::ARRAY: + case LogicalTypeId::VARIANT: return sizeof(PyObject *); default: throw NotImplementedException("Unsupported type \"%s\" for DuckDB -> NumPy conversion", type.ToString()); @@ -108,8 +110,9 @@ string RawArrayWrapper::DuckDBToNumpyDtype(const LogicalType &type) { case LogicalTypeId::DATE: return "datetime64[us]"; case LogicalTypeId::INTERVAL: - return "timedelta64[ns]"; + return "timedelta64[us]"; case LogicalTypeId::TIME: + case LogicalTypeId::TIME_NS: case LogicalTypeId::TIME_TZ: case LogicalTypeId::VARCHAR: case LogicalTypeId::BIT: @@ -120,6 +123,7 @@ string RawArrayWrapper::DuckDBToNumpyDtype(const LogicalType &type) { case LogicalTypeId::UNION: case LogicalTypeId::UUID: case LogicalTypeId::ARRAY: + case LogicalTypeId::VARIANT: return "object"; case LogicalTypeId::ENUM: { auto size = EnumType::GetSize(type); diff --git a/src/duckdb_py/numpy/type.cpp b/src/duckdb_py/numpy/type.cpp index 92ac4785..3d8d9096 100644 --- a/src/duckdb_py/numpy/type.cpp +++ b/src/duckdb_py/numpy/type.cpp @@ -58,11 +58,23 @@ static NumpyNullableType ConvertNumpyTypeInternal(const string &col_type_str) { if (col_type_str == "string") { return NumpyNullableType::STRING; } + if (col_type_str == "str") { + return NumpyNullableType::STRING; + } if (col_type_str == "object") { return NumpyNullableType::OBJECT; } if (col_type_str == "timedelta64[ns]") { - return NumpyNullableType::TIMEDELTA; + return NumpyNullableType::TIMEDELTA_NS; + } + if (col_type_str == "timedelta64[us]") { + return NumpyNullableType::TIMEDELTA_US; + } + if (col_type_str == "timedelta64[ms]") { + return NumpyNullableType::TIMEDELTA_MS; + } + if (col_type_str == "timedelta64[s]") { + return NumpyNullableType::TIMEDELTA_S; } // We use 'StartsWith' because it might have ', tz' at the end, indicating timezone if (StringUtil::StartsWith(col_type_str, "datetime64[ns")) { @@ -140,7 +152,10 @@ LogicalType NumpyToLogicalType(const NumpyType &col_type) { return LogicalType::VARCHAR; case NumpyNullableType::OBJECT: return LogicalType::VARCHAR; - case NumpyNullableType::TIMEDELTA: + case NumpyNullableType::TIMEDELTA_NS: + case NumpyNullableType::TIMEDELTA_US: + case NumpyNullableType::TIMEDELTA_MS: + case NumpyNullableType::TIMEDELTA_S: return LogicalType::INTERVAL; case NumpyNullableType::DATETIME_MS: { if (col_type.has_timezone) { diff --git a/src/duckdb_py/pandas/analyzer.cpp b/src/duckdb_py/pandas/analyzer.cpp index ee264524..a91bff51 100644 --- a/src/duckdb_py/pandas/analyzer.cpp +++ b/src/duckdb_py/pandas/analyzer.cpp @@ -363,12 +363,11 @@ LogicalType PandasAnalyzer::GetItemType(py::object ele, bool &can_convert) { case PythonObjectType::Bool: return LogicalType::BOOLEAN; case PythonObjectType::Integer: { - Value integer; - if (!TryTransformPythonNumeric(integer, ele)) { + auto type = SniffPythonIntegerType(ele); + if (type.id() == LogicalTypeId::SQLNULL) { can_convert = false; - return LogicalType::SQLNULL; } - return integer.type(); + return type; } case PythonObjectType::Float: if (std::isnan(PyFloat_AsDouble(ele.ptr()))) { diff --git a/src/duckdb_py/pyconnection.cpp b/src/duckdb_py/pyconnection.cpp index 11a7ea9d..6883ba45 100644 --- a/src/duckdb_py/pyconnection.cpp +++ b/src/duckdb_py/pyconnection.cpp @@ -3,6 +3,7 @@ #include "duckdb/catalog/default/default_types.hpp" #include "duckdb/common/arrow/arrow.hpp" #include "duckdb/common/enums/file_compression_type.hpp" +#include "duckdb/common/enums/profiler_format.hpp" #include "duckdb/common/printer.hpp" #include "duckdb/common/types.hpp" #include "duckdb/common/types/vector.hpp" @@ -82,6 +83,20 @@ DuckDBPyConnection::~DuckDBPyConnection() { } } +unique_ptr DuckDBPyConnection::CreateRelation(shared_ptr rel) { + auto py_rel = make_uniq(std::move(rel)); + py::gil_scoped_acquire gil; + py_rel->SetConnectionOwner(py::cast(shared_from_this())); + return py_rel; +} + +unique_ptr DuckDBPyConnection::CreateRelation(shared_ptr result) { + auto py_rel = make_uniq(std::move(result)); + py::gil_scoped_acquire gil; + py_rel->SetConnectionOwner(py::cast(shared_from_this())); + return py_rel; +} + void DuckDBPyConnection::DetectEnvironment() { // Get the formatted Python version py::module_ sys = py::module_::import("sys"); @@ -202,11 +217,28 @@ static void InitializeConnectionMethods(py::class_ DuckDBPyConnection::ExecuteMany(const py::object } // Set the internal 'result' object if (query_result) { - auto py_result = make_uniq(std::move(query_result)); - con.SetResult(make_uniq(std::move(py_result))); + // Don't use CreateRelation here — the result is stored inside the connection, + // so setting connection_owner would create a ref cycle (connection → result → connection). + con.SetResult(make_uniq(make_shared_ptr(std::move(query_result)))); } return shared_from_this(); @@ -656,8 +728,9 @@ shared_ptr DuckDBPyConnection::Execute(const py::object &que // Set the internal 'result' object if (res) { - auto py_result = make_uniq(std::move(res)); - con.SetResult(make_uniq(std::move(py_result))); + // Don't use CreateRelation here — the result is stored inside the connection, + // so setting connection_owner would create a ref cycle (connection → result → connection). + con.SetResult(make_uniq(make_shared_ptr(std::move(res)))); } return shared_from_this(); } @@ -715,7 +788,7 @@ static void ParseMultiFileOptions(named_parameter_map_t &options, const Optional if (!py::none().is(hive_partitioning)) { if (!py::isinstance(hive_partitioning)) { - string actual_type = py::str(hive_partitioning.get_type()); + string actual_type = py::str(py::type::of(hive_partitioning)); throw BinderException("read_json only accepts 'hive_partitioning' as a boolean, not '%s'", actual_type); } auto val = TransformPythonValue(hive_partitioning, LogicalTypeId::BOOLEAN); @@ -724,7 +797,7 @@ static void ParseMultiFileOptions(named_parameter_map_t &options, const Optional if (!py::none().is(union_by_name)) { if (!py::isinstance(union_by_name)) { - string actual_type = py::str(union_by_name.get_type()); + string actual_type = py::str(py::type::of(union_by_name)); throw BinderException("read_json only accepts 'union_by_name' as a boolean, not '%s'", actual_type); } auto val = TransformPythonValue(union_by_name, LogicalTypeId::BOOLEAN); @@ -733,7 +806,7 @@ static void ParseMultiFileOptions(named_parameter_map_t &options, const Optional if (!py::none().is(hive_types_autocast)) { if (!py::isinstance(hive_types_autocast)) { - string actual_type = py::str(hive_types_autocast.get_type()); + string actual_type = py::str(py::type::of(hive_types_autocast)); throw BinderException("read_json only accepts 'hive_types_autocast' as a boolean, not '%s'", actual_type); } auto val = TransformPythonValue(hive_types_autocast, LogicalTypeId::BOOLEAN); @@ -772,11 +845,11 @@ unique_ptr DuckDBPyConnection::ReadJSON( auto &column_name = kv.first; auto &type = kv.second; if (!py::isinstance(column_name)) { - string actual_type = py::str(column_name.get_type()); + string actual_type = py::str(py::type::of(column_name)); throw BinderException("The provided column name must be a str, not of type '%s'", actual_type); } if (!py::isinstance(type)) { - string actual_type = py::str(column_name.get_type()); + string actual_type = py::str(py::type::of(column_name)); throw BinderException("The provided column type must be a str, not of type '%s'", actual_type); } struct_fields.emplace_back(py::str(column_name), Value(py::str(type))); @@ -787,7 +860,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(records)) { if (!py::isinstance(records)) { - string actual_type = py::str(records.get_type()); + string actual_type = py::str(py::type::of(records)); throw BinderException("read_json only accepts 'records' as a string, not '%s'", actual_type); } auto records_s = py::reinterpret_borrow(records); @@ -797,7 +870,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(format)) { if (!py::isinstance(format)) { - string actual_type = py::str(format.get_type()); + string actual_type = py::str(py::type::of(format)); throw BinderException("read_json only accepts 'format' as a string, not '%s'", actual_type); } auto format_s = py::reinterpret_borrow(format); @@ -807,7 +880,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(date_format)) { if (!py::isinstance(date_format)) { - string actual_type = py::str(date_format.get_type()); + string actual_type = py::str(py::type::of(date_format)); throw BinderException("read_json only accepts 'date_format' as a string, not '%s'", actual_type); } auto date_format_s = py::reinterpret_borrow(date_format); @@ -817,7 +890,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(timestamp_format)) { if (!py::isinstance(timestamp_format)) { - string actual_type = py::str(timestamp_format.get_type()); + string actual_type = py::str(py::type::of(timestamp_format)); throw BinderException("read_json only accepts 'timestamp_format' as a string, not '%s'", actual_type); } auto timestamp_format_s = py::reinterpret_borrow(timestamp_format); @@ -827,7 +900,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(compression)) { if (!py::isinstance(compression)) { - string actual_type = py::str(compression.get_type()); + string actual_type = py::str(py::type::of(compression)); throw BinderException("read_json only accepts 'compression' as a string, not '%s'", actual_type); } auto compression_s = py::reinterpret_borrow(compression); @@ -837,7 +910,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(sample_size)) { if (!py::isinstance(sample_size)) { - string actual_type = py::str(sample_size.get_type()); + string actual_type = py::str(py::type::of(sample_size)); throw BinderException("read_json only accepts 'sample_size' as an integer, not '%s'", actual_type); } options["sample_size"] = Value::INTEGER(py::int_(sample_size)); @@ -845,7 +918,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(maximum_depth)) { if (!py::isinstance(maximum_depth)) { - string actual_type = py::str(maximum_depth.get_type()); + string actual_type = py::str(py::type::of(maximum_depth)); throw BinderException("read_json only accepts 'maximum_depth' as an integer, not '%s'", actual_type); } options["maximum_depth"] = Value::INTEGER(py::int_(maximum_depth)); @@ -853,7 +926,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(maximum_object_size)) { if (!py::isinstance(maximum_object_size)) { - string actual_type = py::str(maximum_object_size.get_type()); + string actual_type = py::str(py::type::of(maximum_object_size)); throw BinderException("read_json only accepts 'maximum_object_size' as an unsigned integer, not '%s'", actual_type); } @@ -863,7 +936,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(ignore_errors)) { if (!py::isinstance(ignore_errors)) { - string actual_type = py::str(ignore_errors.get_type()); + string actual_type = py::str(py::type::of(ignore_errors)); throw BinderException("read_json only accepts 'ignore_errors' as a boolean, not '%s'", actual_type); } auto val = TransformPythonValue(ignore_errors, LogicalTypeId::BOOLEAN); @@ -872,7 +945,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(convert_strings_to_integers)) { if (!py::isinstance(convert_strings_to_integers)) { - string actual_type = py::str(convert_strings_to_integers.get_type()); + string actual_type = py::str(py::type::of(convert_strings_to_integers)); throw BinderException("read_json only accepts 'convert_strings_to_integers' as a boolean, not '%s'", actual_type); } @@ -882,7 +955,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(field_appearance_threshold)) { if (!py::isinstance(field_appearance_threshold)) { - string actual_type = py::str(field_appearance_threshold.get_type()); + string actual_type = py::str(py::type::of(field_appearance_threshold)); throw BinderException("read_json only accepts 'field_appearance_threshold' as a float, not '%s'", actual_type); } @@ -892,7 +965,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(map_inference_threshold)) { if (!py::isinstance(map_inference_threshold)) { - string actual_type = py::str(map_inference_threshold.get_type()); + string actual_type = py::str(py::type::of(map_inference_threshold)); throw BinderException("read_json only accepts 'map_inference_threshold' as an integer, not '%s'", actual_type); } @@ -902,7 +975,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (!py::none().is(maximum_sample_files)) { if (!py::isinstance(maximum_sample_files)) { - string actual_type = py::str(maximum_sample_files.get_type()); + string actual_type = py::str(py::type::of(maximum_sample_files)); throw BinderException("read_json only accepts 'maximum_sample_files' as an integer, not '%s'", actual_type); } auto val = TransformPythonValue(maximum_sample_files, LogicalTypeId::BIGINT); @@ -925,7 +998,7 @@ unique_ptr DuckDBPyConnection::ReadJSON( if (file_like_object_wrapper) { read_json_relation->AddExternalDependency(std::move(file_like_object_wrapper)); } - return make_uniq(std::move(read_json_relation)); + return CreateRelation(std::move(read_json_relation)); } PathLike DuckDBPyConnection::GetPathLike(const py::object &object) { @@ -1355,7 +1428,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(lineterminator)) { PythonCSVLineTerminator::Type new_line_type; if (!py::try_cast(lineterminator, new_line_type)) { - string actual_type = py::str(lineterminator.get_type()); + string actual_type = py::str(py::type::of(lineterminator)); throw BinderException("read_csv only accepts 'lineterminator' as a string or CSVLineTerminator, not '%s'", actual_type); } @@ -1364,7 +1437,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(max_line_size)) { if (!py::isinstance(max_line_size) && !py::isinstance(max_line_size)) { - string actual_type = py::str(max_line_size.get_type()); + string actual_type = py::str(py::type::of(max_line_size)); throw BinderException("read_csv only accepts 'max_line_size' as a string or an integer, not '%s'", actual_type); } @@ -1374,7 +1447,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(auto_type_candidates)) { if (!py::isinstance(auto_type_candidates)) { - string actual_type = py::str(auto_type_candidates.get_type()); + string actual_type = py::str(py::type::of(auto_type_candidates)); throw BinderException("read_csv only accepts 'auto_type_candidates' as a list[str], not '%s'", actual_type); } auto val = TransformPythonValue(auto_type_candidates, LogicalType::LIST(LogicalTypeId::VARCHAR)); @@ -1383,7 +1456,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(ignore_errors)) { if (!py::isinstance(ignore_errors)) { - string actual_type = py::str(ignore_errors.get_type()); + string actual_type = py::str(py::type::of(ignore_errors)); throw BinderException("read_csv only accepts 'ignore_errors' as a bool, not '%s'", actual_type); } auto val = TransformPythonValue(ignore_errors, LogicalTypeId::BOOLEAN); @@ -1392,7 +1465,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(store_rejects)) { if (!py::isinstance(store_rejects)) { - string actual_type = py::str(store_rejects.get_type()); + string actual_type = py::str(py::type::of(store_rejects)); throw BinderException("read_csv only accepts 'store_rejects' as a bool, not '%s'", actual_type); } auto val = TransformPythonValue(store_rejects, LogicalTypeId::BOOLEAN); @@ -1401,7 +1474,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(rejects_table)) { if (!py::isinstance(rejects_table)) { - string actual_type = py::str(rejects_table.get_type()); + string actual_type = py::str(py::type::of(rejects_table)); throw BinderException("read_csv only accepts 'rejects_table' as a string, not '%s'", actual_type); } auto val = TransformPythonValue(rejects_table, LogicalTypeId::VARCHAR); @@ -1410,7 +1483,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(rejects_scan)) { if (!py::isinstance(rejects_scan)) { - string actual_type = py::str(rejects_scan.get_type()); + string actual_type = py::str(py::type::of(rejects_scan)); throw BinderException("read_csv only accepts 'rejects_scan' as a string, not '%s'", actual_type); } auto val = TransformPythonValue(rejects_scan, LogicalTypeId::VARCHAR); @@ -1419,7 +1492,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(rejects_limit)) { if (!py::isinstance(rejects_limit)) { - string actual_type = py::str(rejects_limit.get_type()); + string actual_type = py::str(py::type::of(rejects_limit)); throw BinderException("read_csv only accepts 'rejects_limit' as an int, not '%s'", actual_type); } auto val = TransformPythonValue(rejects_limit, LogicalTypeId::BIGINT); @@ -1428,7 +1501,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(force_not_null)) { if (!py::isinstance(force_not_null)) { - string actual_type = py::str(force_not_null.get_type()); + string actual_type = py::str(py::type::of(force_not_null)); throw BinderException("read_csv only accepts 'force_not_null' as a list[str], not '%s'", actual_type); } auto val = TransformPythonValue(force_not_null, LogicalType::LIST(LogicalTypeId::VARCHAR)); @@ -1437,7 +1510,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(buffer_size)) { if (!py::isinstance(buffer_size)) { - string actual_type = py::str(buffer_size.get_type()); + string actual_type = py::str(py::type::of(buffer_size)); throw BinderException("read_csv only accepts 'buffer_size' as a list[str], not '%s'", actual_type); } auto val = TransformPythonValue(buffer_size, LogicalTypeId::UBIGINT); @@ -1446,7 +1519,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(decimal)) { if (!py::isinstance(decimal)) { - string actual_type = py::str(decimal.get_type()); + string actual_type = py::str(py::type::of(decimal)); throw BinderException("read_csv only accepts 'decimal' as a string, not '%s'", actual_type); } auto val = TransformPythonValue(decimal, LogicalTypeId::VARCHAR); @@ -1455,7 +1528,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ if (!py::none().is(allow_quoted_nulls)) { if (!py::isinstance(allow_quoted_nulls)) { - string actual_type = py::str(allow_quoted_nulls.get_type()); + string actual_type = py::str(py::type::of(allow_quoted_nulls)); throw BinderException("read_csv only accepts 'allow_quoted_nulls' as a bool, not '%s'", actual_type); } auto val = TransformPythonValue(allow_quoted_nulls, LogicalTypeId::BOOLEAN); @@ -1473,11 +1546,11 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ auto &column_name = kv.first; auto &type = kv.second; if (!py::isinstance(column_name)) { - string actual_type = py::str(column_name.get_type()); + string actual_type = py::str(py::type::of(column_name)); throw BinderException("The provided column name must be a str, not of type '%s'", actual_type); } if (!py::isinstance(type)) { - string actual_type = py::str(column_name.get_type()); + string actual_type = py::str(py::type::of(column_name)); throw BinderException("The provided column type must be a str, not of type '%s'", actual_type); } struct_fields.emplace_back(py::str(column_name), Value(py::str(type))); @@ -1496,7 +1569,7 @@ unique_ptr DuckDBPyConnection::ReadCSV(const py::object &name_ read_csv.AddExternalDependency(std::move(file_like_object_wrapper)); } - return make_uniq(read_csv_p->Alias(read_csv.alias)); + return CreateRelation(read_csv_p->Alias(read_csv.alias)); } void DuckDBPyConnection::ExecuteImmediately(vector> statements) { @@ -1543,8 +1616,9 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer // Attempt to create a Relation for lazy execution if possible shared_ptr relation; - if (py::none().is(params)) { - // FIXME: currently we can't create relations with prepared parameters + bool has_params = !py::none().is(params) && py::len(params) > 0; + if (!has_params) { + // No params (or empty params) — use lazy QueryRelation path { D_ASSERT(py::gil_check()); py::gil_scoped_release gil; @@ -1581,7 +1655,7 @@ unique_ptr DuckDBPyConnection::RunQuery(const py::object &quer relation = make_shared_ptr(connection.context, materialized_result.TakeCollection(), res->names, alias); } - return make_uniq(std::move(relation)); + return CreateRelation(std::move(relation)); } unique_ptr DuckDBPyConnection::Table(const string &tname) { @@ -1591,8 +1665,7 @@ unique_ptr DuckDBPyConnection::Table(const string &tname) { qualified_name.schema = DEFAULT_SCHEMA; } try { - return make_uniq( - connection.Table(qualified_name.catalog, qualified_name.schema, qualified_name.name)); + return CreateRelation(connection.Table(qualified_name.catalog, qualified_name.schema, qualified_name.name)); } catch (const CatalogException &) { // CatalogException will be of the type '... is not a table' // Not a table in the database, make a query relation that can perform replacement scans @@ -1631,7 +1704,7 @@ static vector>> ValueListsFromTuples(const p for (idx_t i = 0; i < arg_count; i++) { py::handle arg = tuples[i]; if (!py::isinstance(arg)) { - string actual_type = py::str(arg.get_type()); + string actual_type = py::str(py::type::of(arg)); throw InvalidInputException("Expected objects of type tuple, not %s", actual_type); } auto expressions = py::cast(arg); @@ -1658,7 +1731,7 @@ unique_ptr DuckDBPyConnection::Values(const py::args &args) { py::handle first_arg = args[0]; if (arg_count == 1 && py::isinstance(first_arg)) { vector> values {DuckDBPyConnection::TransformPythonParamList(first_arg)}; - return make_uniq(connection.Values(values)); + return CreateRelation(connection.Values(values)); } else { vector>> expressions; if (py::isinstance(first_arg)) { @@ -1667,13 +1740,13 @@ unique_ptr DuckDBPyConnection::Values(const py::args &args) { auto values = ValueListFromExpressions(args); expressions.push_back(std::move(values)); } - return make_uniq(connection.Values(std::move(expressions))); + return CreateRelation(connection.Values(std::move(expressions))); } } unique_ptr DuckDBPyConnection::View(const string &vname) { auto &connection = con.GetConnection(); - return make_uniq(connection.View(vname)); + return CreateRelation(connection.View(vname)); } unique_ptr DuckDBPyConnection::TableFunction(const string &fname, py::object params) { @@ -1685,8 +1758,7 @@ unique_ptr DuckDBPyConnection::TableFunction(const string &fna throw InvalidInputException("'params' has to be a list of parameters"); } - return make_uniq( - connection.TableFunction(fname, DuckDBPyConnection::TransformPythonParamList(params))); + return CreateRelation(connection.TableFunction(fname, DuckDBPyConnection::TransformPythonParamList(params))); } unique_ptr DuckDBPyConnection::FromDF(const PandasDataFrame &value) { @@ -1699,7 +1771,7 @@ unique_ptr DuckDBPyConnection::FromDF(const PandasDataFrame &v auto tableref = PythonReplacementScan::ReplacementObject(value, name, *connection.context); D_ASSERT(tableref); auto rel = make_shared_ptr(connection.context, std::move(tableref), name); - return make_uniq(std::move(rel)); + return CreateRelation(std::move(rel)); } unique_ptr DuckDBPyConnection::FromParquetInternal(Value &&file_param, bool binary_as_string, @@ -1724,7 +1796,7 @@ unique_ptr DuckDBPyConnection::FromParquetInternal(Value &&fil } D_ASSERT(py::gil_check()); py::gil_scoped_release gil; - return make_uniq(connection.TableFunction("parquet_scan", params, named_parameters)->Alias(name)); + return CreateRelation(connection.TableFunction("parquet_scan", params, named_parameters)->Alias(name)); } unique_ptr DuckDBPyConnection::FromParquet(const string &file_glob, bool binary_as_string, @@ -1754,13 +1826,13 @@ unique_ptr DuckDBPyConnection::FromArrow(py::object &arrow_obj auto &connection = con.GetConnection(); string name = "arrow_object_" + StringUtil::GenerateRandomName(); if (!IsAcceptedArrowObject(arrow_object)) { - auto py_object_type = string(py::str(arrow_object.get_type().attr("__name__"))); + auto py_object_type = string(py::str(py::type::of(arrow_object).attr("__name__"))); throw InvalidInputException("Python Object Type %s is not an accepted Arrow Object.", py_object_type); } auto tableref = PythonReplacementScan::ReplacementObject(arrow_object, name, *connection.context, true); D_ASSERT(tableref); auto rel = make_shared_ptr(connection.context, std::move(tableref), name); - return make_uniq(std::move(rel)); + return CreateRelation(std::move(rel)); } unordered_set DuckDBPyConnection::GetTableNames(const string &query, bool qualified) { @@ -2140,7 +2212,7 @@ static string GetPathString(const py::object &path) { if (is_path || py::isinstance(path)) { return std::string(py::str(path)); } - string actual_type = py::str(path.get_type()); + string actual_type = py::str(py::type::of(path)); throw InvalidInputException("Please provide either a str or a pathlib.Path, not %s", actual_type); } @@ -2326,26 +2398,16 @@ PyArrowObjectType DuckDBPyConnection::GetArrowType(const py::handle &obj) { if (ModuleIsLoaded()) { auto &import_cache = *DuckDBPyConnection::ImportCache(); - // First Verify Lib Types - auto table_class = import_cache.pyarrow.Table(); - auto record_batch_reader_class = import_cache.pyarrow.RecordBatchReader(); - auto message_reader_class = import_cache.pyarrow.ipc.MessageReader(); - if (py::isinstance(obj, table_class)) { - return PyArrowObjectType::Table; - } else if (py::isinstance(obj, record_batch_reader_class)) { - return PyArrowObjectType::RecordBatchReader; - } else if (py::isinstance(obj, message_reader_class)) { + // MessageReader requires nanoarrow, separate scan function + if (py::isinstance(obj, import_cache.pyarrow.ipc.MessageReader())) { return PyArrowObjectType::MessageReader; } if (ModuleIsLoaded()) { - // Then Verify dataset types - auto dataset_class = import_cache.pyarrow.dataset.Dataset(); - auto scanner_class = import_cache.pyarrow.dataset.Scanner(); - - if (py::isinstance(obj, scanner_class)) { + // Scanner/Dataset don't have __arrow_c_stream__, need dedicated handling + if (py::isinstance(obj, import_cache.pyarrow.dataset.Scanner())) { return PyArrowObjectType::Scanner; - } else if (py::isinstance(obj, dataset_class)) { + } else if (py::isinstance(obj, import_cache.pyarrow.dataset.Dataset())) { return PyArrowObjectType::Dataset; } } diff --git a/src/duckdb_py/pyconnection/type_creation.cpp b/src/duckdb_py/pyconnection/type_creation.cpp index f1839fee..71e6c610 100644 --- a/src/duckdb_py/pyconnection/type_creation.cpp +++ b/src/duckdb_py/pyconnection/type_creation.cpp @@ -26,7 +26,7 @@ static child_list_t GetChildList(const py::object &container) { for (auto &item : fields) { shared_ptr pytype; if (!py::try_cast>(item, pytype)) { - string actual_type = py::str(item.get_type()); + string actual_type = py::str(py::type::of(item)); throw InvalidInputException("object has to be a list of DuckDBPyType's, not '%s'", actual_type); } types.push_back(std::make_pair(StringUtil::Format("v%d", i++), pytype->Type())); @@ -40,14 +40,14 @@ static child_list_t GetChildList(const py::object &container) { string name = py::str(name_p); shared_ptr pytype; if (!py::try_cast>(type_p, pytype)) { - string actual_type = py::str(type_p.get_type()); + string actual_type = py::str(py::type::of(type_p)); throw InvalidInputException("object has to be a list of DuckDBPyType's, not '%s'", actual_type); } types.push_back(std::make_pair(name, pytype->Type())); } return types; } else { - string actual_type = py::str(container.get_type()); + string actual_type = py::str(py::type::of(container)); throw InvalidInputException( "Can not construct a child list from object of type '%s', only dict/list is supported", actual_type); } diff --git a/src/duckdb_py/pyexpression.cpp b/src/duckdb_py/pyexpression.cpp index 5dd551a1..0703389b 100644 --- a/src/duckdb_py/pyexpression.cpp +++ b/src/duckdb_py/pyexpression.cpp @@ -500,7 +500,7 @@ shared_ptr DuckDBPyExpression::FunctionExpression(const stri for (auto arg : args) { shared_ptr py_expr; if (!py::try_cast>(arg, py_expr)) { - string actual_type = py::str(arg.get_type()); + string actual_type = py::str(py::type::of(arg)); throw InvalidInputException("Expected argument of type Expression, received '%s' instead", actual_type); } auto expr = py_expr->GetExpression().Copy(); diff --git a/src/duckdb_py/pyfilesystem.cpp b/src/duckdb_py/pyfilesystem.cpp index d9821779..4b7112eb 100644 --- a/src/duckdb_py/pyfilesystem.cpp +++ b/src/duckdb_py/pyfilesystem.cpp @@ -98,9 +98,11 @@ int64_t PythonFilesystem::Write(FileHandle &handle, void *buffer, int64_t nr_byt return py::int_(write(data)); } void PythonFilesystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - Seek(handle, location); - - Write(handle, buffer, nr_bytes); + PythonGILWrapper gil; + auto &py_handle = PythonFileHandle::GetHandle(handle); + py_handle.attr("seek")(location); + auto data = py::bytes(std::string(const_char_ptr_cast(buffer), nr_bytes)); + py_handle.attr("write")(data); } int64_t PythonFilesystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { @@ -116,9 +118,11 @@ int64_t PythonFilesystem::Read(FileHandle &handle, void *buffer, int64_t nr_byte } void PythonFilesystem::Read(duckdb::FileHandle &handle, void *buffer, int64_t nr_bytes, uint64_t location) { - Seek(handle, location); - - Read(handle, buffer, nr_bytes); + PythonGILWrapper gil; + auto &py_handle = PythonFileHandle::GetHandle(handle); + py_handle.attr("seek")(location); + string data = py::bytes(py_handle.attr("read")(nr_bytes)); + memcpy(buffer, data.c_str(), data.size()); } bool PythonFilesystem::FileExists(const string &filename, optional_ptr opener) { return Exists(filename, "isfile"); @@ -219,14 +223,12 @@ void PythonFilesystem::CreateDirectory(const string &directory, optional_ptr &callback, FileOpener *opener) { - static py::str DIRECTORY("directory"); - D_ASSERT(!py::gil_check()); PythonGILWrapper gil; bool nonempty = false; for (auto item : filesystem.attr("ls")(py::str(directory))) { - bool is_dir = DIRECTORY.equal(item["type"]); + bool is_dir = py::cast(item["type"]) == "directory"; callback(py::str(item["name"]), is_dir); nonempty = true; } diff --git a/src/duckdb_py/pyrelation.cpp b/src/duckdb_py/pyrelation.cpp index 9b605128..8a70f0d2 100644 --- a/src/duckdb_py/pyrelation.cpp +++ b/src/duckdb_py/pyrelation.cpp @@ -76,7 +76,7 @@ DuckDBPyRelation::DuckDBPyRelation(shared_ptr result_p) : rel(nu } unique_ptr DuckDBPyRelation::ProjectFromExpression(const string &expression) { - auto projected_relation = make_uniq(rel->Project(expression)); + auto projected_relation = DeriveRelation(rel->Project(expression)); for (auto &dep : this->rel->external_dependencies) { projected_relation->rel->AddExternalDependency(dep); } @@ -108,9 +108,9 @@ unique_ptr DuckDBPyRelation::Project(const py::args &args, con vector empty_aliases; if (groups.empty()) { // No groups provided - return make_uniq(rel->Project(std::move(expressions), empty_aliases)); + return DeriveRelation(rel->Project(std::move(expressions), empty_aliases)); } - return make_uniq(rel->Aggregate(std::move(expressions), groups)); + return DeriveRelation(rel->Aggregate(std::move(expressions), groups)); } } @@ -128,12 +128,13 @@ unique_ptr DuckDBPyRelation::ProjectFromTypes(const py::object LogicalType type; if (py::isinstance(item)) { string type_str = py::str(item); - type = TransformStringToLogicalType(type_str, *rel->context->GetContext()); + rel->context->GetContext()->RunFunctionInTransaction( + [&]() { type = TransformStringToLogicalType(type_str, *rel->context->GetContext().get()); }); } else if (py::isinstance(item)) { auto *type_p = item.cast(); type = type_p->Type(); } else { - string actual_type = py::str(item.get_type()); + string actual_type = py::str(py::type::of(item)); throw InvalidInputException("Can only project on objects of type DuckDBPyType or str, not '%s'", actual_type); } @@ -179,7 +180,7 @@ unique_ptr DuckDBPyRelation::EmptyResult(const shared_ptr DuckDBPyRelation::SetAlias(const string &expr) { - return make_uniq(rel->Alias(expr)); + return DeriveRelation(rel->Alias(expr)); } py::str DuckDBPyRelation::GetAlias() { @@ -196,19 +197,19 @@ unique_ptr DuckDBPyRelation::Filter(const py::object &expr) { throw InvalidInputException("Please provide either a string or a DuckDBPyExpression object to 'filter'"); } auto expr_p = expression->GetExpression().Copy(); - return make_uniq(rel->Filter(std::move(expr_p))); + return DeriveRelation(rel->Filter(std::move(expr_p))); } unique_ptr DuckDBPyRelation::FilterFromExpression(const string &expr) { - return make_uniq(rel->Filter(expr)); + return DeriveRelation(rel->Filter(expr)); } unique_ptr DuckDBPyRelation::Limit(int64_t n, int64_t offset) { - return make_uniq(rel->Limit(n, offset)); + return DeriveRelation(rel->Limit(n, offset)); } unique_ptr DuckDBPyRelation::Order(const string &expr) { - return make_uniq(rel->Order(expr)); + return DeriveRelation(rel->Order(expr)); } unique_ptr DuckDBPyRelation::Sort(const py::args &args) { @@ -218,7 +219,7 @@ unique_ptr DuckDBPyRelation::Sort(const py::args &args) { for (auto arg : args) { shared_ptr py_expr; if (!py::try_cast>(arg, py_expr)) { - string actual_type = py::str(arg.get_type()); + string actual_type = py::str(py::type::of(arg)); throw InvalidInputException("Expected argument of type Expression, received '%s' instead", actual_type); } auto expr = py_expr->GetExpression().Copy(); @@ -227,7 +228,7 @@ unique_ptr DuckDBPyRelation::Sort(const py::args &args) { if (order_nodes.empty()) { throw InvalidInputException("Please provide at least one expression to sort on"); } - return make_uniq(rel->Order(std::move(order_nodes))); + return DeriveRelation(rel->Order(std::move(order_nodes))); } vector> GetExpressions(ClientContext &context, const py::object &expr) { @@ -247,7 +248,8 @@ vector> GetExpressions(ClientContext &context, cons auto aggregate_list = std::string(py::str(expr)); return Parser::ParseExpressionList(aggregate_list, context.GetParserOptions()); } else { - string actual_type = py::str(expr.get_type()); + // A single Expression could be supported here by wrapping it in a vector + string actual_type = py::str(py::type::of(expr)); throw InvalidInputException("Please provide either a string or list of Expression objects, not %s", actual_type); } @@ -257,9 +259,9 @@ unique_ptr DuckDBPyRelation::Aggregate(const py::object &expr, AssertRelation(); auto expressions = GetExpressions(*rel->context->GetContext(), expr); if (!groups.empty()) { - return make_uniq(rel->Aggregate(std::move(expressions), groups)); + return DeriveRelation(rel->Aggregate(std::move(expressions), groups)); } - return make_uniq(rel->Aggregate(std::move(expressions))); + return DeriveRelation(rel->Aggregate(std::move(expressions))); } void DuckDBPyRelation::AssertResult() const { @@ -352,7 +354,7 @@ unique_ptr DuckDBPyRelation::Describe() { DescribeAggregateInfo("stddev", true), DescribeAggregateInfo("min"), DescribeAggregateInfo("max"), DescribeAggregateInfo("median", true)}; auto expressions = CreateExpressionList(columns, aggregates); - return make_uniq(rel->Aggregate(expressions)); + return DeriveRelation(rel->Aggregate(expressions)); } string DuckDBPyRelation::ToSQL() { @@ -395,10 +397,36 @@ string DuckDBPyRelation::GenerateExpressionList(const string &function_name, vec function_name + "(" + function_parameter + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; } for (idx_t i = 0; i < input.size(); i++) { + // We parse the input as an expression to validate it. + auto trimmed_input = input[i]; + StringUtil::Trim(trimmed_input); + + unique_ptr expression; + try { + auto expressions = Parser::ParseExpressionList(trimmed_input); + if (expressions.size() == 1) { + expression = std::move(expressions[0]); + } + } catch (const ParserException &) { + // First attempt at parsing failed, the input might be a column name that needs quoting. + auto quoted_input = KeywordHelper::WriteQuoted(trimmed_input, '"'); + auto expressions = Parser::ParseExpressionList(quoted_input); + if (expressions.size() == 1 && expressions[0]->GetExpressionClass() == ExpressionClass::COLUMN_REF) { + expression = std::move(expressions[0]); + } + } + + if (!expression) { + throw ParserException("Invalid column expression: %s", trimmed_input); + } + + // ToString() handles escaping for all expression types + auto escaped_input = expression->ToString(); + if (function_parameter.empty()) { - expr += function_name + "(" + input[i] + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; + expr += function_name + "(" + escaped_input + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; } else { - expr += function_name + "(" + input[i] + "," + function_parameter + + expr += function_name + "(" + escaped_input + "," + function_parameter + ((ignore_nulls) ? " ignore nulls) " : ") ") + window_spec; } @@ -428,7 +456,7 @@ DuckDBPyRelation::GenericWindowFunction(const string &function_name, const strin const string &projected_columns) { auto expr = GenerateExpressionList(function_name, aggr_columns, "", function_parameters, ignore_nulls, projected_columns, window_spec); - return make_uniq(rel->Project(expr)); + return DeriveRelation(rel->Project(expr)); } unique_ptr DuckDBPyRelation::ApplyAggOrWin(const string &function_name, const string &agg_columns, @@ -587,7 +615,7 @@ unique_ptr DuckDBPyRelation::Product(const std::string &column unique_ptr DuckDBPyRelation::StringAgg(const std::string &column, const std::string &sep, const std::string &groups, const std::string &window_spec, const std::string &projected_columns) { - auto string_agg_params = "\'" + sep + "\'"; + auto string_agg_params = KeywordHelper::WriteOptionallyQuoted(sep, '\''); return ApplyAggOrWin("string_agg", column, string_agg_params, groups, window_spec, projected_columns); } @@ -694,7 +722,7 @@ py::tuple DuckDBPyRelation::Shape() { } unique_ptr DuckDBPyRelation::Unique(const string &std_columns) { - return make_uniq(rel->Project(std_columns)->Distinct()); + return DeriveRelation(rel->Project(std_columns)->Distinct()); } /* General-purpose window functions */ @@ -768,7 +796,7 @@ unique_ptr DuckDBPyRelation::NthValue(const string &column, co } unique_ptr DuckDBPyRelation::Distinct() { - return make_uniq(rel->Distinct()); + return DeriveRelation(rel->Distinct()); } duckdb::pyarrow::RecordBatchReader DuckDBPyRelation::FetchRecordBatchReader(idx_t rows_per_batch) { @@ -964,6 +992,19 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema) if (!rel) { return py::none(); } + // The PyCapsule protocol doesn't allow custom parameters, so we use the same + // default batch size as fetch_arrow_table / fetch_record_batch. + idx_t batch_size = 1000000; + auto &config = ClientConfig::GetConfig(*rel->context->GetContext()); + ScopedConfigSetting scoped_setting( + config, + [&batch_size](ClientConfig &config) { + config.get_result_collector = [&batch_size](ClientContext &context, + PreparedStatementData &data) -> PhysicalOperator & { + return PhysicalArrowCollector::Create(context, data, batch_size); + }; + }, + [](ClientConfig &config) { config.get_result_collector = nullptr; }); ExecuteOrThrow(); } AssertResultOpen(); @@ -975,7 +1016,8 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema) PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) { if (!lazy) { auto arrow = ToArrowTableInternal(batch_size, true); - return py::cast(pybind11::module_::import("polars").attr("DataFrame")(arrow)); + return py::cast( + pybind11::module_::import("polars").attr("from_arrow")(arrow, py::arg("rechunk") = false)); } auto &import_cache = *DuckDBPyConnection::ImportCache(); auto lazy_frame_produce = import_cache.duckdb.polars_io.duckdb_source(); @@ -1036,6 +1078,22 @@ bool DuckDBPyRelation::ContainsColumnByName(const string &name) const { [&](const string &item) { return StringUtil::CIEquals(name, item); }) != names.end(); } +void DuckDBPyRelation::SetConnectionOwner(py::object owner) { + connection_owner = std::move(owner); +} + +unique_ptr DuckDBPyRelation::DeriveRelation(shared_ptr new_rel) { + auto result = make_uniq(std::move(new_rel)); + result->connection_owner = connection_owner; + return result; +} + +unique_ptr DuckDBPyRelation::DeriveRelation(shared_ptr result_p) { + auto result = make_uniq(std::move(result_p)); + result->connection_owner = connection_owner; + return result; +} + static bool ContainsStructFieldByName(LogicalType &type, const string &name) { if (type.id() != LogicalTypeId::STRUCT) { return false; @@ -1076,19 +1134,19 @@ unique_ptr DuckDBPyRelation::GetAttribute(const string &name) expressions.push_back(std::move(make_uniq(column_names))); vector aliases; aliases.push_back(name); - return make_uniq(rel->Project(std::move(expressions), aliases)); + return DeriveRelation(rel->Project(std::move(expressions), aliases)); } unique_ptr DuckDBPyRelation::Union(DuckDBPyRelation *other) { - return make_uniq(rel->Union(other->rel)); + return DeriveRelation(rel->Union(other->rel)); } unique_ptr DuckDBPyRelation::Except(DuckDBPyRelation *other) { - return make_uniq(rel->Except(other->rel)); + return DeriveRelation(rel->Except(other->rel)); } unique_ptr DuckDBPyRelation::Intersect(DuckDBPyRelation *other) { - return make_uniq(rel->Intersect(other->rel)); + return DeriveRelation(rel->Intersect(other->rel)); } namespace { @@ -1149,14 +1207,14 @@ unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other, con } if (py::isinstance(condition)) { auto condition_string = std::string(py::cast(condition)); - return make_uniq(rel->Join(other->rel, condition_string, join_type)); + return DeriveRelation(rel->Join(other->rel, condition_string, join_type)); } vector using_list; if (py::is_list_like(condition)) { auto using_list_p = py::list(condition); for (auto &item : using_list_p) { if (!py::isinstance(item)) { - string actual_type = py::str(item.get_type()); + string actual_type = py::str(py::type::of(item)); throw InvalidInputException("Using clause should be a list of strings, not %s", actual_type); } using_list.push_back(std::string(py::str(item))); @@ -1165,7 +1223,7 @@ unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other, con throw InvalidInputException("Please provide at least one string in the condition to create a USING clause"); } auto join_relation = make_shared_ptr(rel, other->rel, std::move(using_list), join_type); - return make_uniq(std::move(join_relation)); + return DeriveRelation(std::move(join_relation)); } shared_ptr condition_expr; if (!py::try_cast(condition, condition_expr)) { @@ -1174,11 +1232,11 @@ unique_ptr DuckDBPyRelation::Join(DuckDBPyRelation *other, con } vector> conditions; conditions.push_back(condition_expr->GetExpression().Copy()); - return make_uniq(rel->Join(other->rel, std::move(conditions), join_type)); + return DeriveRelation(rel->Join(other->rel, std::move(conditions), join_type)); } unique_ptr DuckDBPyRelation::Cross(DuckDBPyRelation *other) { - return make_uniq(rel->CrossProduct(other->rel)); + return DeriveRelation(rel->CrossProduct(other->rel)); } static Value NestedDictToStruct(const py::object &dictionary) { @@ -1474,7 +1532,7 @@ void DuckDBPyRelation::ToCSV(const string &filename, const py::object &sep, cons // should this return a rel with the new view? unique_ptr DuckDBPyRelation::CreateView(const string &view_name, bool replace) { rel->CreateView(view_name, replace); - return make_uniq(rel); + return DeriveRelation(rel); } static bool IsDescribeStatement(SQLStatement &statement) { @@ -1502,7 +1560,7 @@ unique_ptr DuckDBPyRelation::Query(const string &view_name, co auto select_statement = unique_ptr_cast(std::move(parser.statements[0])); auto query_relation = make_shared_ptr(rel->context->GetContext(), std::move(select_statement), sql_query, "query_relation"); - return make_uniq(std::move(query_relation)); + return DeriveRelation(std::move(query_relation)); } else if (IsDescribeStatement(statement)) { auto query = PragmaShow(view_name); return Query(view_name, query); @@ -1567,7 +1625,7 @@ void DuckDBPyRelation::Update(const py::object &set_p, const py::object &where) } shared_ptr py_expr; if (!py::try_cast>(item_value, py_expr)) { - string actual_type = py::str(item_value.get_type()); + string actual_type = py::str(py::type::of(item_value)); throw InvalidInputException("Please provide an object of type Expression as the value, not %s", actual_type); } @@ -1602,7 +1660,7 @@ unique_ptr DuckDBPyRelation::Map(py::function fun, Optional params; params.emplace_back(Value::POINTER(CastPointerToValue(fun.ptr()))); params.emplace_back(Value::POINTER(CastPointerToValue(schema.ptr()))); - auto relation = make_uniq(rel->TableFunction("python_map_function", params)); + auto relation = DeriveRelation(rel->TableFunction("python_map_function", params)); auto rel_dependency = make_uniq(); rel_dependency->AddDependency("map", PythonDependencyItem::Create(std::move(fun))); rel_dependency->AddDependency("schema", PythonDependencyItem::Create(std::move(schema))); diff --git a/src/duckdb_py/pyrelation/initialize.cpp b/src/duckdb_py/pyrelation/initialize.cpp index 35eeff40..4393889a 100644 --- a/src/duckdb_py/pyrelation/initialize.cpp +++ b/src/duckdb_py/pyrelation/initialize.cpp @@ -62,12 +62,21 @@ static void InitializeConsumers(py::class_ &m) { py::arg("date_as_object") = false) .def("fetch_df_chunk", &DuckDBPyRelation::FetchDFChunk, "Execute and fetch a chunk of the rows", py::arg("vectors_per_chunk") = 1, py::kw_only(), py::arg("date_as_object") = false) - .def("arrow", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) - .def("fetch_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", - py::arg("batch_size") = 1000000) .def("to_arrow_table", &DuckDBPyRelation::ToArrowTable, "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) + .def("to_arrow_reader", &DuckDBPyRelation::ToRecordBatch, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) + .def("arrow", &DuckDBPyRelation::ToRecordBatch, + "Alias of to_arrow_reader(). We recommend using to_arrow_reader() instead.", + py::arg("batch_size") = 1000000) + .def( + "fetch_arrow_table", + [](pybind11::object &self, idx_t batch_size) { + PyErr_WarnEx(PyExc_DeprecationWarning, + "fetch_arrow_table() is deprecated, use to_arrow_table() instead.", 0); + return self.attr("to_arrow_table")(batch_size); + }, + "Execute and fetch all rows as an Arrow Table", py::arg("batch_size") = 1000000) .def("pl", &DuckDBPyRelation::ToPolars, "Execute and fetch all rows as a Polars DataFrame", py::arg("batch_size") = 1000000, py::kw_only(), py::arg("lazy") = false) .def("torch", &DuckDBPyRelation::FetchPyTorch, "Fetch a result as dict of PyTorch Tensors") @@ -79,24 +88,31 @@ static void InitializeConsumers(py::class_ &m) { )"; m.def("__arrow_c_stream__", &DuckDBPyRelation::ToArrowCapsule, capsule_docs, py::arg("requested_schema") = py::none()); - m.def("fetch_record_batch", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) - .def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch, - "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000) + m.def( + "fetch_record_batch", + [](pybind11::object &self, idx_t rows_per_batch) { + PyErr_WarnEx(PyExc_DeprecationWarning, + "fetch_record_batch() is deprecated, use to_arrow_reader() instead.", 0); + return self.attr("to_arrow_reader")(rows_per_batch); + }, + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("rows_per_batch") = 1000000) .def( - "record_batch", - [](pybind11::object &self, idx_t rows_per_batch) { + "fetch_arrow_reader", + [](pybind11::object &self, idx_t batch_size) { PyErr_WarnEx(PyExc_DeprecationWarning, - "record_batch() is deprecated, use fetch_record_batch() instead.", 0); - return self.attr("fetch_record_batch")(rows_per_batch); + "fetch_arrow_reader() is deprecated, use to_arrow_reader() instead.", 0); + if (PyErr_Occurred()) { + throw py::error_already_set(); + } + return self.attr("to_arrow_reader")(batch_size); }, - py::arg("batch_size") = 1000000); + "Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000); } static void InitializeAggregates(py::class_ &m) { /* General aggregate functions */ - m.def("any_value", &DuckDBPyRelation::AnyValue, "Returns the first non-null value from a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + m.def("any_value", &DuckDBPyRelation::AnyValue, "Returns the first non-null value from a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") .def("arg_max", &DuckDBPyRelation::ArgMax, "Finds the row with the maximum value for a value column and returns the value of that row for an " "argument column", @@ -107,82 +123,98 @@ static void InitializeAggregates(py::class_ &m) { "argument column", py::arg("arg_column"), py::arg("value_column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); - DefineMethod({"avg", "mean"}, m, &DuckDBPyRelation::Avg, "Computes the average on a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", + DefineMethod({"avg", "mean"}, m, &DuckDBPyRelation::Avg, "Computes the average of a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); - m.def("bit_and", &DuckDBPyRelation::BitAnd, "Computes the bitwise AND of all bits present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("bit_or", &DuckDBPyRelation::BitOr, "Computes the bitwise OR of all bits present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("bit_xor", &DuckDBPyRelation::BitXor, "Computes the bitwise XOR of all bits present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + m.def("bit_and", &DuckDBPyRelation::BitAnd, "Computes the bitwise AND of all bits present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("bit_or", &DuckDBPyRelation::BitOr, "Computes the bitwise OR of all bits present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") + .def("bit_xor", &DuckDBPyRelation::BitXor, "Computes the bitwise XOR of all bits present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") .def("bitstring_agg", &DuckDBPyRelation::BitStringAgg, - "Computes a bitstring with bits set for each distinct value in a given column", py::arg("column"), + "Computes a bitstring with bits set for each distinct value in a given expression", py::arg("expression"), py::arg("min") = py::none(), py::arg("max") = py::none(), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("bool_and", &DuckDBPyRelation::BoolAnd, "Computes the logical AND of all values present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("bool_or", &DuckDBPyRelation::BoolOr, "Computes the logical OR of all values present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("count", &DuckDBPyRelation::Count, "Computes the number of elements present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("bool_and", &DuckDBPyRelation::BoolAnd, + "Computes the logical AND of all values present in a given expression", py::arg("expression"), + py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("bool_or", &DuckDBPyRelation::BoolOr, + "Computes the logical OR of all values present in a given expression", py::arg("expression"), + py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("count", &DuckDBPyRelation::Count, "Computes the number of elements present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") .def("value_counts", &DuckDBPyRelation::ValueCounts, - "Computes the number of elements present in a given column, also projecting the original column", - py::arg("column"), py::arg("groups") = "") + "Computes the number of elements present in a given expression, also projecting the original expression", + py::arg("expression"), py::arg("groups") = "") .def("favg", &DuckDBPyRelation::FAvg, - "Computes the average of all values present in a given column using a more accurate floating point " + "Computes the average of all values present in a given expression using a more accurate floating point " "summation (Kahan Sum)", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("first", &DuckDBPyRelation::First, "Returns the first value of a given column", py::arg("column"), + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") + .def("first", &DuckDBPyRelation::First, "Returns the first value of a given expression", py::arg("expression"), py::arg("groups") = "", py::arg("projected_columns") = "") .def("fsum", &DuckDBPyRelation::FSum, - "Computes the sum of all values present in a given column using a more accurate floating point " + "Computes the sum of all values present in a given expression using a more accurate floating point " "summation (Kahan Sum)", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") .def("geomean", &DuckDBPyRelation::GeoMean, - "Computes the geometric mean over all values present in a given column", py::arg("column"), + "Computes the geometric mean over all values present in a given expression", py::arg("expression"), py::arg("groups") = "", py::arg("projected_columns") = "") .def("histogram", &DuckDBPyRelation::Histogram, - "Computes the histogram over all values present in a given column", py::arg("column"), + "Computes the histogram over all values present in a given expression", py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("list", &DuckDBPyRelation::List, "Returns a list containing all values present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("last", &DuckDBPyRelation::Last, "Returns the last value of a given column", py::arg("column"), + .def("list", &DuckDBPyRelation::List, "Returns a list containing all values present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") + .def("last", &DuckDBPyRelation::Last, "Returns the last value of a given expression", py::arg("expression"), py::arg("groups") = "", py::arg("projected_columns") = "") - .def("max", &DuckDBPyRelation::Max, "Returns the maximum value present in a given column", py::arg("column"), - py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("min", &DuckDBPyRelation::Min, "Returns the minimum value present in a given column", py::arg("column"), - py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("product", &DuckDBPyRelation::Product, "Returns the product of all values present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("max", &DuckDBPyRelation::Max, "Returns the maximum value present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") + .def("min", &DuckDBPyRelation::Min, "Returns the minimum value present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") + .def("product", &DuckDBPyRelation::Product, "Returns the product of all values present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") .def("string_agg", &DuckDBPyRelation::StringAgg, - "Concatenates the values present in a given column with a separator", py::arg("column"), + "Concatenates the values present in a given expression with a separator", py::arg("expression"), py::arg("sep") = ",", py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("sum", &DuckDBPyRelation::Sum, "Computes the sum of all values present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("sum", &DuckDBPyRelation::Sum, "Computes the sum of all values present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") .def("unique", &DuckDBPyRelation::Unique, "Returns the distinct values in a column.", py::arg("unique_aggr")); /* TODO: Approximate aggregate functions */ /* TODO: Statistical aggregate functions */ - m.def("median", &DuckDBPyRelation::Median, "Computes the median over all values present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("mode", &DuckDBPyRelation::Mode, "Computes the mode over all values present in a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + m.def("median", &DuckDBPyRelation::Median, "Computes the median over all values present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("mode", &DuckDBPyRelation::Mode, "Computes the mode over all values present in a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = "") .def("quantile_cont", &DuckDBPyRelation::QuantileCont, - "Computes the interpolated quantile value for a given column", py::arg("column"), py::arg("q") = 0.5, - py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); + "Computes the interpolated quantile value for a given expression", py::arg("expression"), + py::arg("q") = 0.5, py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = ""); DefineMethod({"quantile_disc", "quantile"}, m, &DuckDBPyRelation::QuantileDisc, - "Computes the exact quantile value for a given column", py::arg("column"), py::arg("q") = 0.5, + "Computes the exact quantile value for a given expression", py::arg("expression"), py::arg("q") = 0.5, py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); - m.def("stddev_pop", &DuckDBPyRelation::StdPop, "Computes the population standard deviation for a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); + m.def("stddev_pop", &DuckDBPyRelation::StdPop, "Computes the population standard deviation for a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = ""); DefineMethod({"stddev_samp", "stddev", "std"}, m, &DuckDBPyRelation::StdSamp, - "Computes the sample standard deviation for a given column", py::arg("column"), py::arg("groups") = "", - py::arg("window_spec") = "", py::arg("projected_columns") = ""); - m.def("var_pop", &DuckDBPyRelation::VarPop, "Computes the population variance for a given column", - py::arg("column"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); + "Computes the sample standard deviation for a given expression", py::arg("expression"), + py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); + m.def("var_pop", &DuckDBPyRelation::VarPop, "Computes the population variance for a given expression", + py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", + py::arg("projected_columns") = ""); DefineMethod({"var_samp", "variance", "var"}, m, &DuckDBPyRelation::VarSamp, - "Computes the sample variance for a given column", py::arg("column"), py::arg("groups") = "", + "Computes the sample variance for a given expression", py::arg("expression"), py::arg("groups") = "", py::arg("window_spec") = "", py::arg("projected_columns") = ""); } @@ -200,19 +232,19 @@ static void InitializeWindowOperators(py::class_ &m) { .def("cume_dist", &DuckDBPyRelation::CumeDist, "Computes the cumulative distribution within the partition", py::arg("window_spec"), py::arg("projected_columns") = "") .def("first_value", &DuckDBPyRelation::FirstValue, "Computes the first value within the group or partition", - py::arg("column"), py::arg("window_spec") = "", py::arg("projected_columns") = "") + py::arg("expression"), py::arg("window_spec") = "", py::arg("projected_columns") = "") .def("n_tile", &DuckDBPyRelation::NTile, "Divides the partition as equally as possible into num_buckets", py::arg("window_spec"), py::arg("num_buckets"), py::arg("projected_columns") = "") - .def("lag", &DuckDBPyRelation::Lag, "Computes the lag within the partition", py::arg("column"), + .def("lag", &DuckDBPyRelation::Lag, "Computes the lag within the partition", py::arg("expression"), py::arg("window_spec"), py::arg("offset") = 1, py::arg("default_value") = "NULL", py::arg("ignore_nulls") = false, py::arg("projected_columns") = "") .def("last_value", &DuckDBPyRelation::LastValue, "Computes the last value within the group or partition", - py::arg("column"), py::arg("window_spec") = "", py::arg("projected_columns") = "") - .def("lead", &DuckDBPyRelation::Lead, "Computes the lead within the partition", py::arg("column"), + py::arg("expression"), py::arg("window_spec") = "", py::arg("projected_columns") = "") + .def("lead", &DuckDBPyRelation::Lead, "Computes the lead within the partition", py::arg("expression"), py::arg("window_spec"), py::arg("offset") = 1, py::arg("default_value") = "NULL", py::arg("ignore_nulls") = false, py::arg("projected_columns") = "") - .def("nth_value", &DuckDBPyRelation::NthValue, "Computes the nth value within the partition", py::arg("column"), - py::arg("window_spec"), py::arg("offset"), py::arg("ignore_nulls") = false, + .def("nth_value", &DuckDBPyRelation::NthValue, "Computes the nth value within the partition", + py::arg("expression"), py::arg("window_spec"), py::arg("offset"), py::arg("ignore_nulls") = false, py::arg("projected_columns") = ""); } diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index e92f6abe..34b5f6ff 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -117,16 +117,13 @@ unique_ptr DuckDBPyResult::FetchNextRaw(QueryResult &query_result) { } Optional DuckDBPyResult::Fetchone() { - { - D_ASSERT(py::gil_check()); + if (!result) { + throw InvalidInputException("result closed"); + } + if (!current_chunk || chunk_offset >= current_chunk->size()) { py::gil_scoped_release release; - if (!result) { - throw InvalidInputException("result closed"); - } - if (!current_chunk || chunk_offset >= current_chunk->size()) { - current_chunk = FetchNext(*result); - chunk_offset = 0; - } + current_chunk = FetchNext(*result); + chunk_offset = 0; } if (!current_chunk || current_chunk->size() == 0) { @@ -304,7 +301,7 @@ void DuckDBPyResult::ConvertDateTimeTypes(PandasDataFrame &df, bool date_as_obje // We need to create the column anew because the exact dt changed to a new timezone ReplaceDFColumn(df, names[i].c_str(), i, new_value); } else if (date_as_object && result->types[i] == LogicalType::DATE) { - auto new_value = df[names[i].c_str()].attr("dt").attr("date"); + py::object new_value = df[names[i].c_str()].attr("dt").attr("date"); ReplaceDFColumn(df, names[i].c_str(), i, new_value); } } @@ -499,6 +496,81 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyResult::FetchRecordBatchReader(idx_t return py::cast(record_batch_reader); } +// Wraps pre-built Arrow arrays from an ArrowQueryResult into an ArrowArrayStream. +// This avoids the double-materialization that happens when using ResultArrowArrayStreamWrapper +// with an ArrowQueryResult (which throws NotImplementedException from FetchInternal). +struct ArrowQueryResultStreamWrapper { + ArrowQueryResultStreamWrapper(unique_ptr result_p) : result(std::move(result_p)), index(0) { + auto &arrow_result = result->Cast(); + arrays = arrow_result.ConsumeArrays(); + types = result->types; + names = result->names; + client_properties = result->client_properties; + + stream.private_data = this; + stream.get_schema = GetSchema; + stream.get_next = GetNext; + stream.release = Release; + stream.get_last_error = GetLastError; + } + + static int GetSchema(ArrowArrayStream *stream, ArrowSchema *out) { + if (!stream->release) { + return -1; + } + auto self = reinterpret_cast(stream->private_data); + out->release = nullptr; + try { + ArrowConverter::ToArrowSchema(out, self->types, self->names, self->client_properties); + } catch (std::runtime_error &e) { + self->last_error = e.what(); + return -1; + } + return 0; + } + + static int GetNext(ArrowArrayStream *stream, ArrowArray *out) { + if (!stream->release) { + return -1; + } + auto self = reinterpret_cast(stream->private_data); + if (self->index >= self->arrays.size()) { + out->release = nullptr; + return 0; + } + *out = self->arrays[self->index]->arrow_array; + self->arrays[self->index]->arrow_array.release = nullptr; + self->index++; + return 0; + } + + static void Release(ArrowArrayStream *stream) { + if (!stream || !stream->release) { + return; + } + stream->release = nullptr; + delete reinterpret_cast(stream->private_data); + } + + static const char *GetLastError(ArrowArrayStream *stream) { + if (!stream->release) { + return "stream was released"; + } + auto self = reinterpret_cast(stream->private_data); + return self->last_error.c_str(); + } + + ArrowArrayStream stream; + unique_ptr result; + vector> arrays; + vector types; + vector names; + ClientProperties client_properties; + idx_t index; + string last_error; +}; + +// Destructor for capsules that own a heap-allocated ArrowArrayStream (slow path). static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) { auto data = PyCapsule_GetPointer(object, "arrow_array_stream"); if (!data) { @@ -512,6 +584,18 @@ static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) { } py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) { + if (result && result->type == QueryResultType::ARROW_RESULT) { + // Fast path: yield pre-built Arrow arrays directly. + // The wrapper is heap-allocated; Release() deletes it via private_data. + // We heap-allocate a separate ArrowArrayStream for the capsule so that the capsule + // holds a stable pointer even after the wrapper is consumed and deleted by a scan. + auto wrapper = new ArrowQueryResultStreamWrapper(std::move(result)); + auto stream = new ArrowArrayStream(); + *stream = wrapper->stream; + wrapper->stream.release = nullptr; + return py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor); + } + // Existing slow path for MaterializedQueryResult / StreamQueryResult auto stream_p = FetchArrowArrayStream(rows_per_batch); auto stream = new ArrowArrayStream(); *stream = stream_p; diff --git a/src/duckdb_py/python_replacement_scan.cpp b/src/duckdb_py/python_replacement_scan.cpp index 843545e8..8bff9e8f 100644 --- a/src/duckdb_py/python_replacement_scan.cpp +++ b/src/duckdb_py/python_replacement_scan.cpp @@ -1,7 +1,5 @@ #include "duckdb_python/python_replacement_scan.hpp" - #include "duckdb/main/db_instance_cache.hpp" - #include "duckdb_python/pybind11/pybind_wrapper.hpp" #include "duckdb/main/client_properties.hpp" #include "duckdb_python/numpy/numpy_type.hpp" @@ -14,12 +12,13 @@ #include "duckdb_python/pandas/pandas_scan.hpp" #include "duckdb/parser/tableref/subqueryref.hpp" #include "duckdb_python/pyrelation.hpp" +#include namespace duckdb { static void CreateArrowScan(const string &name, py::object entry, TableFunctionRef &table_function, vector> &children, ClientProperties &client_properties, - PyArrowObjectType type, DBConfig &config, DatabaseInstance &db) { + PyArrowObjectType type, DatabaseInstance &db) { shared_ptr external_dependency = make_shared_ptr(); if (type == PyArrowObjectType::MessageReader) { if (!db.ExtensionIsLoaded("nanoarrow")) { @@ -52,12 +51,7 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR auto dependency_item = PythonDependencyItem::Create(stream_messages); external_dependency->AddDependency("replacement_cache", std::move(dependency_item)); } else { - if (type == PyArrowObjectType::PyCapsuleInterface) { - entry = entry.attr("__arrow_c_stream__")(); - type = PyArrowObjectType::PyCapsule; - } - - auto stream_factory = make_uniq(entry.ptr(), client_properties, config); + auto stream_factory = make_uniq(entry.ptr(), client_properties, type); auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce; auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema; @@ -67,8 +61,17 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR make_uniq(Value::POINTER(CastPointerToValue(stream_factory_get_schema)))); if (type == PyArrowObjectType::PyCapsule) { - // Disable projection+filter pushdown + // Disable projection+filter pushdown for bare capsules (single-use, no PyArrow wrapper) table_function.function = make_uniq("arrow_scan_dumb", std::move(children)); + } else if (type == PyArrowObjectType::PyCapsuleInterface) { + // Try to load pyarrow.dataset for pushdown support + auto &cache = *DuckDBPyConnection::ImportCache(); + if (!cache.pyarrow.dataset()) { + // No pyarrow.dataset: scan without pushdown, DuckDB handles projection/filter post-scan + table_function.function = make_uniq("arrow_scan_dumb", std::move(children)); + } else { + table_function.function = make_uniq("arrow_scan", std::move(children)); + } } else { table_function.function = make_uniq("arrow_scan", std::move(children)); } @@ -81,7 +84,7 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR static void ThrowScanFailureError(const py::object &entry, const string &name, const string &location = "") { string error; - auto py_object_type = string(py::str(entry.get_type().attr("__name__"))); + auto py_object_type = string(py::str(py::type::of(entry).attr("__name__"))); error += StringUtil::Format("Python Object \"%s\" of type \"%s\"", name, py_object_type); if (!location.empty()) { error += StringUtil::Format(" found on line \"%s\"", location); @@ -114,7 +117,7 @@ unique_ptr PythonReplacementScan::TryReplacementObject(const py::objec if (PandasDataFrame::IsPyArrowBacked(entry)) { auto table = PandasDataFrame::ToArrowTable(entry); CreateArrowScan(name, table, *table_function, children, client_properties, PyArrowObjectType::Table, - DBConfig::GetConfig(context), *context.db); + *context.db); } else { string name = "df_" + StringUtil::GenerateRandomName(); auto new_df = PandasScanFunction::PandasReplaceCopiedNames(entry); @@ -142,19 +145,18 @@ unique_ptr PythonReplacementScan::TryReplacementObject(const py::objec subquery->external_dependency = std::move(dependency); return std::move(subquery); } else if (PolarsDataFrame::IsDataFrame(entry)) { + // Polars DataFrames always go through one-time .to_arrow() materialization. + // Polars's __arrow_c_stream__() serializes from its internal layout on every call, + // which is expensive for repeated scans. The .to_arrow() path converts once. auto arrow_dataset = entry.attr("to_arrow")(); CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table, - DBConfig::GetConfig(context), *context.db); + *context.db); } else if (PolarsDataFrame::IsLazyFrame(entry)) { - auto materialized = entry.attr("collect")(); - auto arrow_dataset = materialized.attr("to_arrow")(); - CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table, - DBConfig::GetConfig(context), *context.db); - } else if (DuckDBPyConnection::GetArrowType(entry) != PyArrowObjectType::Invalid && - !(DuckDBPyConnection::GetArrowType(entry) == PyArrowObjectType::MessageReader && !relation)) { - arrow_type = DuckDBPyConnection::GetArrowType(entry); - CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type, - DBConfig::GetConfig(context), *context.db); + CreateArrowScan(name, entry, *table_function, children, client_properties, PyArrowObjectType::PolarsLazyFrame, + *context.db); + } else if ((arrow_type = DuckDBPyConnection::GetArrowType(entry)) != PyArrowObjectType::Invalid && + !(arrow_type == PyArrowObjectType::MessageReader && !relation)) { + CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type, *context.db); } else if (DuckDBPyConnection::IsAcceptedNumpyObject(entry) != NumpyObjectType::INVALID) { numpytype = DuckDBPyConnection::IsAcceptedNumpyObject(entry); string np_name = "np_" + StringUtil::GenerateRandomName(); @@ -299,7 +301,7 @@ unique_ptr PythonReplacementScan::Replace(ClientContext &context, Repl optional_ptr data) { auto &table_name = input.table_name; auto &config = DBConfig::GetConfig(context); - if (!config.options.enable_external_access) { + if (!Settings::Get(config)) { return nullptr; } diff --git a/src/duckdb_py/python_udf.cpp b/src/duckdb_py/python_udf.cpp index fd6775e0..a62004d4 100644 --- a/src/duckdb_py/python_udf.cpp +++ b/src/duckdb_py/python_udf.cpp @@ -75,7 +75,7 @@ static void ConvertArrowTableToVector(const py::object &table, Vector &out, Clie py::gil_scoped_release gil; auto stream_factory = - make_uniq(ptr, context.GetClientProperties(), DBConfig::GetConfig(context)); + make_uniq(ptr, context.GetClientProperties(), PyArrowObjectType::Table); auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce; auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema; @@ -307,40 +307,48 @@ static scalar_function_t CreateNativeFunction(PyObject *function, PythonExceptio for (idx_t row = 0; row < input.size(); row++) { - auto bundled_parameters = py::tuple((int)input.ColumnCount()); - bool contains_null = false; - for (idx_t i = 0; i < input.ColumnCount(); i++) { - // Fill the tuple with the arguments for this row - auto &column = input.data[i]; - auto value = column.GetValue(row); - if (value.IsNull() && default_null_handling) { - contains_null = true; - break; + py::object ret; + if (input.ColumnCount() > 0) { + auto bundled_parameters = py::tuple((int)input.ColumnCount()); + bool contains_null = false; + for (idx_t i = 0; i < input.ColumnCount(); i++) { + // Fill the tuple with the arguments for this row + auto &column = input.data[i]; + auto value = column.GetValue(row); + if (value.IsNull() && default_null_handling) { + contains_null = true; + break; + } + bundled_parameters[i] = PythonObject::FromValue(value, column.GetType(), client_properties); } - bundled_parameters[i] = PythonObject::FromValue(value, column.GetType(), client_properties); - } - if (contains_null) { - // Immediately insert None, no need to call the function - FlatVector::SetNull(result, row, true); - continue; - } - - // Call the function - auto ret = py::reinterpret_steal(PyObject_CallObject(function, bundled_parameters.ptr())); - if (ret == nullptr && PyErr_Occurred()) { - if (exception_handling == PythonExceptionHandling::FORWARD_ERROR) { - auto exception = py::error_already_set(); - throw InvalidInputException("Python exception occurred while executing the UDF: %s", - exception.what()); - } else if (exception_handling == PythonExceptionHandling::RETURN_NULL) { - PyErr_Clear(); + if (contains_null) { + // Immediately insert None, no need to call the function FlatVector::SetNull(result, row, true); continue; - } else { + } + // Call the function + ret = py::reinterpret_steal(PyObject_CallObject(function, bundled_parameters.ptr())); + } else { + ret = py::reinterpret_steal(PyObject_CallObject(function, nullptr)); + } + + if (!ret || ret.is_none()) { + if (PyErr_Occurred()) { + if (exception_handling == PythonExceptionHandling::FORWARD_ERROR) { + auto exception = py::error_already_set(); + throw InvalidInputException("Python exception occurred while executing the UDF: %s", + exception.what()); + } + if (exception_handling == PythonExceptionHandling::RETURN_NULL) { + PyErr_Clear(); + FlatVector::SetNull(result, row, true); + continue; + } throw NotImplementedException("Exception handling type not implemented"); } - } else if ((!ret || ret == Py_None) && default_null_handling) { - throw InvalidInputException(NullHandlingError()); + if (default_null_handling) { + throw InvalidInputException(NullHandlingError()); + } } TransformPythonObject(ret, result, row); } diff --git a/src/duckdb_py/typing/pytype.cpp b/src/duckdb_py/typing/pytype.cpp index e7e31a18..5087de50 100644 --- a/src/duckdb_py/typing/pytype.cpp +++ b/src/duckdb_py/typing/pytype.cpp @@ -122,7 +122,11 @@ static LogicalType FromString(const string &type_str, shared_ptrcon.GetConnection(); - return TransformStringToLogicalType(type_str, *connection.context); + + LogicalType type; + connection.context->RunFunctionInTransaction( + [&]() { type = TransformStringToLogicalType(type_str, *connection.context); }); + return type; } static bool FromNumpyType(const py::object &type, LogicalType &result) { @@ -211,7 +215,7 @@ static py::tuple FilterNones(const py::tuple &args) { for (const auto &arg : args) { py::object object = py::reinterpret_borrow(arg); - if (object.is(py::none().get_type())) { + if (object.is(py::type::of(py::none()))) { continue; } result.append(object); @@ -309,13 +313,13 @@ static LogicalType FromObject(const py::object &object) { case PythonTypeObject::TYPE: { shared_ptr type_object; if (!py::try_cast>(object, type_object)) { - string actual_type = py::str(object.get_type()); + string actual_type = py::str(py::type::of(object)); throw InvalidInputException("Expected argument of type DuckDBPyType, received '%s' instead", actual_type); } return type_object->Type(); } default: { - string actual_type = py::str(object.get_type()); + string actual_type = py::str(py::type::of(object)); throw NotImplementedException("Could not convert from object of type '%s' to DuckDBPyType", actual_type); } } diff --git a/src/duckdb_py/typing/typing.cpp b/src/duckdb_py/typing/typing.cpp index fe990de1..c86f3712 100644 --- a/src/duckdb_py/typing/typing.cpp +++ b/src/duckdb_py/typing/typing.cpp @@ -27,6 +27,7 @@ static void DefineBaseTypes(py::handle &m) { m.attr("TIMESTAMP_S") = make_shared_ptr(LogicalType::TIMESTAMP_S); m.attr("TIME") = make_shared_ptr(LogicalType::TIME); + m.attr("TIME_NS") = make_shared_ptr(LogicalType::TIME_NS); m.attr("TIME_TZ") = make_shared_ptr(LogicalType::TIME_TZ); m.attr("TIMESTAMP_TZ") = make_shared_ptr(LogicalType::TIMESTAMP_TZ); @@ -36,6 +37,7 @@ static void DefineBaseTypes(py::handle &m) { m.attr("BLOB") = make_shared_ptr(LogicalType::BLOB); m.attr("BIT") = make_shared_ptr(LogicalType::BIT); m.attr("INTERVAL") = make_shared_ptr(LogicalType::INTERVAL); + m.attr("VARIANT") = make_shared_ptr(LogicalType::VARIANT()); } void DuckDBPyTyping::Initialize(py::module_ &parent) { diff --git a/tests/conftest.py b/tests/conftest.py index 8a16652d..a5d0249f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,7 @@ import os -import sys import warnings from importlib import import_module from pathlib import Path -from typing import Any, Union import pytest @@ -20,32 +18,34 @@ pandas = None pyarrow_dtype = None - # Only install mock after we've failed to import pandas for conftest.py - class MockPandas: - def __getattr__(self, name: str) -> object: - pytest.skip("pandas not available", allow_module_level=True) - sys.modules["pandas"] = MockPandas() - sys.modules["pandas.testing"] = MockPandas() - sys.modules["pandas._testing"] = MockPandas() +# Version-aware helpers for Pandas 2.x vs 3.0 compatibility +def _get_pandas_ge_3(): + if pandas is None: + return False + from packaging.version import Version -# Check if pandas has arrow dtypes enabled -if pandas is not None: - try: - from pandas.compat import pa_version_under7p0 + return Version(pandas.__version__) >= Version("3.0.0") + + +PANDAS_GE_3 = _get_pandas_ge_3() + + +def is_string_dtype(dtype): + """Check if a dtype is a string dtype (works across Pandas 2.x and 3.0). - pyarrow_dtypes_enabled = not pa_version_under7p0 - except (ImportError, AttributeError): - pyarrow_dtypes_enabled = False -else: - pyarrow_dtypes_enabled = False + Uses pd.api.types.is_string_dtype() which handles: + - Pandas 2.x: object dtype for strings + - Pandas 3.0+: str (StringDtype) for strings + """ + return pandas.api.types.is_string_dtype(dtype) def import_pandas(): if pandas: return pandas else: - pytest.skip("Couldn't import pandas", allow_module_level=True) + pytest.skip("Couldn't import pandas") # https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option @@ -126,81 +126,9 @@ def pandas_supports_arrow_backend(): return pandas_2_or_higher() -def numpy_pandas_df(*args, **kwargs): - return import_pandas().DataFrame(*args, **kwargs) - - -def arrow_pandas_df(*args, **kwargs): - df = numpy_pandas_df(*args, **kwargs) - return df.convert_dtypes(dtype_backend="pyarrow") - - -class NumpyPandas: - def __init__(self) -> None: - self.backend = "numpy_nullable" - self.DataFrame = numpy_pandas_df - self.pandas = import_pandas() - - def __getattr__(self, name: str) -> Any: # noqa: ANN401 - return getattr(self.pandas, name) - - -def convert_arrow_to_numpy_backend(df): - names = df.columns - df_content = {} - for name in names: - df_content[name] = df[name].array.__arrow_array__() - # This should convert the pyarrow chunked arrays into numpy arrays - return import_pandas().DataFrame(df_content) - - -def convert_to_numpy(df): - if ( - pyarrow_dtypes_enabled - and pyarrow_dtype is not None - and any(True for x in df.dtypes if isinstance(x, pyarrow_dtype)) - ): - return convert_arrow_to_numpy_backend(df) - return df - - -def convert_and_equal(df1, df2, **kwargs): - df1 = convert_to_numpy(df1) - df2 = convert_to_numpy(df2) - import_pandas().testing.assert_frame_equal(df1, df2, **kwargs) - - -class ArrowMockTesting: - def __init__(self) -> None: - self.testing = import_pandas().testing - self.assert_frame_equal = convert_and_equal - - def __getattr__(self, name: str) -> Any: # noqa: ANN401 - return getattr(self.testing, name) - - -# This converts dataframes constructed with 'DataFrame(...)' to pyarrow backed dataframes -# Assert equal does the opposite, turning all pyarrow backed dataframes into numpy backed ones -# this is done because we don't produce pyarrow backed dataframes yet -class ArrowPandas: - def __init__(self) -> None: - self.pandas = import_pandas() - if pandas_2_or_higher() and pyarrow_dtypes_enabled: - self.backend = "pyarrow" - self.DataFrame = arrow_pandas_df - else: - # For backwards compatible reasons, just mock regular pandas - self.backend = "numpy_nullable" - self.DataFrame = self.pandas.DataFrame - self.testing = ArrowMockTesting() - - def __getattr__(self, name: str) -> Any: # noqa: ANN401 - return getattr(self.pandas, name) - - @pytest.fixture def require(): - def _require(extension_name, db_name="") -> Union[duckdb.DuckDBPyConnection, None]: + def _require(extension_name, db_name="") -> duckdb.DuckDBPyConnection | None: # Paths to search for extensions build = Path(__file__).parent.parent / "build" diff --git a/tests/coverage/test_pandas_categorical_coverage.py b/tests/coverage/test_pandas_categorical_coverage.py index 7b0645e0..6155138a 100644 --- a/tests/coverage/test_pandas_categorical_coverage.py +++ b/tests/coverage/test_pandas_categorical_coverage.py @@ -1,5 +1,4 @@ -import pytest -from conftest import NumpyPandas +import pandas as pd import duckdb @@ -9,23 +8,23 @@ def check_result_list(res): assert res_item[0] == res_item[1] -def check_create_table(category, pandas): +def check_create_table(category): conn = duckdb.connect() conn.execute("PRAGMA enable_verification") - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { - "x": pandas.Categorical(category, ordered=True), - "y": pandas.Categorical(category, ordered=True), + "x": pd.Categorical(category, ordered=True), + "y": pd.Categorical(category, ordered=True), "z": category, } ) category.append("bla") - df_in_diff = pandas.DataFrame( # noqa: F841 + df_in_diff = pd.DataFrame( # noqa: F841 { - "k": pandas.Categorical(category, ordered=True), + "k": pd.Categorical(category, ordered=True), } ) @@ -68,14 +67,11 @@ def check_create_table(category, pandas): conn.execute("DROP TABLE t1") -# TODO: extend tests with ArrowPandas # noqa: TD002, TD003 class TestCategory: - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_category_string_uint16(self, duckdb_cursor, pandas): + def test_category_string_uint16(self, duckdb_cursor): category = [str(i) for i in range(300)] - check_create_table(category, pandas) + check_create_table(category) - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_category_string_uint32(self, duckdb_cursor, pandas): + def test_category_string_uint32(self, duckdb_cursor): category = [str(i) for i in range(70000)] - check_create_table(category, pandas) + check_create_table(category) diff --git a/tests/extensions/test_httpfs.py b/tests/extensions/test_httpfs.py index 26ce917c..b8335814 100644 --- a/tests/extensions/test_httpfs.py +++ b/tests/extensions/test_httpfs.py @@ -1,8 +1,8 @@ import datetime import os +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb @@ -34,8 +34,7 @@ def test_s3fs(self, require): res = rel.fetchone() assert res == (1, 0, datetime.date(1965, 2, 28), 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 6, 0, 0, 0, 0) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_httpfs(self, require, pandas): + def test_httpfs(self, require): connection = require("httpfs") try: connection.execute(""" @@ -51,14 +50,14 @@ def test_httpfs(self, require, pandas): raise result_df = connection.fetchdf() - exp_result = pandas.DataFrame( + exp_result = pd.DataFrame( { - "id": pandas.Series([1, 2, 3], dtype="int32"), + "id": pd.Series([1, 2, 3], dtype="int32"), "first_name": ["Amanda", "Albert", "Evelyn"], "last_name": ["Jordan", "Freeman", "Morgan"], } ) - pandas.testing.assert_frame_equal(result_df, exp_result) + pd.testing.assert_frame_equal(result_df, exp_result, check_dtype=False) def test_http_exception(self, require): connection = require("httpfs") diff --git a/tests/fast/adbc/test_statement_bind.py b/tests/fast/adbc/test_statement_bind.py index e8df14c7..b6cff16c 100644 --- a/tests/fast/adbc/test_statement_bind.py +++ b/tests/fast/adbc/test_statement_bind.py @@ -118,7 +118,7 @@ def test_bind_composite_type(self): # Create the StructArray struct_array = pa.StructArray.from_arrays(arrays=data_dict.values(), names=data_dict.keys()) - schema = pa.schema([(name, array.type) for name, array in zip(["a"], [struct_array])]) + schema = pa.schema([(name, array.type) for name, array in zip(["a"], [struct_array], strict=False)]) # Create the RecordBatch record_batch = pa.RecordBatch.from_arrays([struct_array], schema=schema) diff --git a/tests/fast/api/test_3654.py b/tests/fast/api/test_3654.py index a6b01dd5..11f37946 100644 --- a/tests/fast/api/test_3654.py +++ b/tests/fast/api/test_3654.py @@ -1,4 +1,4 @@ -import pytest +import pandas as pd import duckdb @@ -8,13 +8,11 @@ can_run = True except Exception: can_run = False -from conftest import ArrowPandas, NumpyPandas class Test3654: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_3654_pandas(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame( + def test_3654_pandas(self, duckdb_cursor): + df1 = pd.DataFrame( { "id": [1, 1, 2], } @@ -25,12 +23,11 @@ def test_3654_pandas(self, duckdb_cursor, pandas): print(rel.execute().fetchall()) assert rel.execute().fetchall() == [(1,), (1,), (2,)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_3654_arrow(self, duckdb_cursor, pandas): + def test_3654_arrow(self, duckdb_cursor): if not can_run: return - df1 = pandas.DataFrame( + df1 = pd.DataFrame( { "id": [1, 1, 2], } diff --git a/tests/fast/api/test_config.py b/tests/fast/api/test_config.py index aaec24c4..7d1370eb 100644 --- a/tests/fast/api/test_config.py +++ b/tests/fast/api/test_config.py @@ -2,37 +2,32 @@ import os import re -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb class TestDBConfig: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_default_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({"a": [1, 2, 3]}) + def test_default_order(self, duckdb_cursor): + df = pd.DataFrame({"a": [1, 2, 3]}) con = duckdb.connect(":memory:", config={"default_order": "desc"}) result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_null_order(self, duckdb_cursor, pandas): - df = pandas.DataFrame({"a": [1, 2, 3, None]}) + def test_null_order(self, duckdb_cursor): + df = pd.DataFrame({"a": [1, 2, 3, None]}) con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last"}) result = con.execute("select * from df order by a").fetchall() assert result == [(1,), (2,), (3,), (None,)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_multiple_options(self, duckdb_cursor, pandas): - df = pandas.DataFrame({"a": [1, 2, 3, None]}) + def test_multiple_options(self, duckdb_cursor): + df = pd.DataFrame({"a": [1, 2, 3, None]}) con = duckdb.connect(":memory:", config={"default_null_order": "nulls_last", "default_order": "desc"}) result = con.execute("select * from df order by a").fetchall() assert result == [(3,), (2,), (1,), (None,)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_external_access(self, duckdb_cursor, pandas): - df = pandas.DataFrame({"a": [1, 2, 3]}) + def test_external_access(self, duckdb_cursor): + df = pd.DataFrame({"a": [1, 2, 3]}) # this works (replacement scan) con_regular = duckdb.connect(":memory:", config={}) con_regular.execute("select * from df") diff --git a/tests/fast/api/test_cursor.py b/tests/fast/api/test_cursor.py index f0d7d332..044e736d 100644 --- a/tests/fast/api/test_cursor.py +++ b/tests/fast/api/test_cursor.py @@ -114,3 +114,23 @@ def test_cursor_used_after_close(self): cursor.close() with pytest.raises(duckdb.ConnectionException): cursor.execute("select [1,2,3,4]") + + def test_cursor_relapi_chaining(self): + """Cursor should stay alive while a relation derived from it exists (GH #315).""" + con = duckdb.connect(":memory:") + # Exact repro from the issue + res = con.cursor().sql("SELECT 1 AS foo").fetchall() + assert res == [(1,)] + + def test_cursor_relapi_chaining_filter(self): + """Derived relations should also keep the cursor alive.""" + con = duckdb.connect(":memory:") + res = con.cursor().sql("SELECT 1 AS foo").filter("foo = 1").fetchall() + assert res == [(1,)] + + def test_cursor_relapi_chaining_table(self): + """Other connection methods returning relations should keep cursor alive.""" + con = duckdb.connect(":memory:") + con.execute("CREATE TABLE tbl AS SELECT 42 AS i") + res = con.cursor().table("tbl").fetchall() + assert res == [(42,)] diff --git a/tests/fast/api/test_dbapi00.py b/tests/fast/api/test_dbapi00.py index 425cb7e1..4a942128 100644 --- a/tests/fast/api/test_dbapi00.py +++ b/tests/fast/api/test_dbapi00.py @@ -1,8 +1,8 @@ # simple DB API testcase import numpy +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas def assert_result_equal(result): @@ -83,30 +83,29 @@ def test_numpy_selection(self, duckdb_cursor, integers, timestamps): arr.mask = [False, False, True] numpy.testing.assert_array_equal(result["t"], arr, "Incorrect result returned") - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_pandas_selection(self, duckdb_cursor, pandas, integers, timestamps): + def test_pandas_selection(self, duckdb_cursor, integers, timestamps): import datetime from packaging.version import Version # I don't know when this exactly changed, but 2.0.3 does not support this, recent versions do - if Version(pandas.__version__) <= Version("2.0.3"): + if Version(pd.__version__) <= Version("2.0.3"): pytest.skip("The resulting dtype is 'object' when given a Series with dtype Int32DType") duckdb_cursor.execute("SELECT * FROM integers") result = duckdb_cursor.fetchdf() array = numpy.ma.masked_array(numpy.arange(11)) array.mask = [False] * 10 + [True] - arr = {"i": pandas.Series(array.data, dtype=pandas.Int32Dtype)} - arr["i"][array.mask] = pandas.NA - arr = pandas.DataFrame(arr) - pandas.testing.assert_frame_equal(result, arr) + arr = {"i": pd.Series(array.data, dtype=pd.Int32Dtype)} + arr["i"][array.mask] = pd.NA + arr = pd.DataFrame(arr) + pd.testing.assert_frame_equal(result, arr) duckdb_cursor.execute("SELECT * FROM timestamps") result = duckdb_cursor.fetchdf() - df = pandas.DataFrame( + df = pd.DataFrame( { - "t": pandas.Series( + "t": pd.Series( data=[ datetime.datetime(year=1992, month=10, day=3, hour=18, minute=34, second=45), datetime.datetime(year=2010, month=1, day=1, hour=0, minute=0, second=1), @@ -116,7 +115,7 @@ def test_pandas_selection(self, duckdb_cursor, pandas, integers, timestamps): ) } ) - pandas.testing.assert_frame_equal(result, df) + pd.testing.assert_frame_equal(result, df) # def test_numpy_creation(self, duckdb_cursor): # # numpyarray = {'i': numpy.arange(10), 'v': numpy.random.randint(100, size=(1, 10))} # segfaults diff --git a/tests/fast/api/test_dbapi08.py b/tests/fast/api/test_dbapi08.py index def4e925..79b2ce0b 100644 --- a/tests/fast/api/test_dbapi08.py +++ b/tests/fast/api/test_dbapi08.py @@ -1,21 +1,19 @@ # test fetchdf with various types -import pytest -from conftest import NumpyPandas +import pandas as pd import duckdb class TestType: - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_fetchdf(self, pandas): + def test_fetchdf(self): con = duckdb.connect() con.execute("CREATE TABLE items(item VARCHAR)") con.execute("INSERT INTO items VALUES ('jeans'), (''), (NULL)") res = con.execute("SELECT item FROM items").fetchdf() - assert isinstance(res, pandas.core.frame.DataFrame) + assert isinstance(res, pd.core.frame.DataFrame) - df = pandas.DataFrame({"item": ["jeans", "", None]}) + df = pd.DataFrame({"item": ["jeans", "", None]}) print(res) print(df) - pandas.testing.assert_frame_equal(res, df) + pd.testing.assert_frame_equal(res, df, check_dtype=False) diff --git a/tests/fast/api/test_dbapi_fetch.py b/tests/fast/api/test_dbapi_fetch.py index 97ff6fe6..c6d3ccaa 100644 --- a/tests/fast/api/test_dbapi_fetch.py +++ b/tests/fast/api/test_dbapi_fetch.py @@ -42,11 +42,11 @@ def test_multiple_fetch_arrow(self, duckdb_cursor): pytest.importorskip("pyarrow") con = duckdb.connect() c = con.execute("SELECT 42::BIGINT AS a") - table = c.fetch_arrow_table() + table = c.to_arrow_table() df = table.to_pandas() pd.testing.assert_frame_equal(df, pd.DataFrame.from_dict({"a": [42]})) - assert c.fetch_arrow_table() is None - assert c.fetch_arrow_table() is None + assert c.to_arrow_table() is None + assert c.to_arrow_table() is None def test_multiple_close(self, duckdb_cursor): con = duckdb.connect() diff --git a/tests/fast/api/test_duckdb_connection.py b/tests/fast/api/test_duckdb_connection.py index 246b9d92..9bca8288 100644 --- a/tests/fast/api/test_duckdb_connection.py +++ b/tests/fast/api/test_duckdb_connection.py @@ -1,7 +1,7 @@ import re +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb @@ -25,10 +25,9 @@ def tmp_database(tmp_path_factory): # This file contains tests for DuckDBPyConnection methods, # wrapped by the 'duckdb' module, to execute with the 'default_connection' class TestDuckDBConnection: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_append(self, pandas): + def test_append(self): duckdb.execute("Create table integers (i integer)") - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -55,7 +54,7 @@ def test_default_connection_from_connect(self): def test_arrow(self): pytest.importorskip("pyarrow") duckdb.execute("select [1,2,3]") - duckdb.fetch_arrow_table() + duckdb.to_arrow_table() def test_begin_commit(self): duckdb.begin() @@ -177,8 +176,8 @@ def test_pystatement(self): assert duckdb.table("tbl").fetchall() == [(21,), (22,), (23,)] duckdb.execute("drop table tbl") - def test_fetch_arrow_table(self): - # Needed for 'fetch_arrow_table' + def test_arrow_table(self): + # Needed for 'arrow_table' pytest.importorskip("pyarrow") duckdb.execute("Create Table test (a integer)") @@ -196,7 +195,7 @@ def test_fetch_arrow_table(self): result_df = duckdb.execute(sql).df() - arrow_table = duckdb.execute(sql).fetch_arrow_table() + arrow_table = duckdb.execute(sql).to_arrow_table() arrow_df = arrow_table.to_pandas() assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() @@ -221,12 +220,12 @@ def test_fetch_df_chunk(self): duckdb.execute("DROP TABLE t") def test_fetch_record_batch(self): - # Needed for 'fetch_arrow_table' + # Needed for 'arrow_table' pytest.importorskip("pyarrow") duckdb.execute("CREATE table t as select range a from range(3000);") duckdb.execute("SELECT a FROM t") - record_batch_reader = duckdb.fetch_record_batch(1024) + record_batch_reader = duckdb.to_arrow_reader(1024) chunk = record_batch_reader.read_all() assert len(chunk) == 3000 @@ -300,7 +299,7 @@ def test_unregister_problematic_behavior(self, duckdb_cursor): assert duckdb_cursor.execute("select * from vw").fetchone() == (0,) # Create a registered object called 'vw' - arrow_result = duckdb_cursor.execute("select 42").fetch_arrow_table() + arrow_result = duckdb_cursor.execute("select 42").to_arrow_table() with pytest.raises(duckdb.CatalogException, match='View with name "vw" already exists'): duckdb_cursor.register("vw", arrow_result) @@ -345,13 +344,12 @@ def test_unregister_with_scary_name(self, duckdb_cursor): with pytest.raises(duckdb.CatalogException): duckdb_cursor.sql(f'select * from "{escaped_scary_name}"') - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_relation_out_of_scope(self, pandas): + def test_relation_out_of_scope(self): def temporary_scope(): # Create a connection, we will return this con = duckdb.connect() # Create a dataframe - df = pandas.DataFrame({"a": [1, 2, 3]}) + df = pd.DataFrame({"a": [1, 2, 3]}) # The dataframe has to be registered as well # making sure it does not go out of scope con.register("df", df) @@ -389,10 +387,11 @@ def test_interrupt(self): assert duckdb.interrupt is not None def test_wrap_shadowing(self): - pd = NumpyPandas() + import pandas as pd_local + import duckdb - df = pd.DataFrame({"a": [1, 2, 3]}) # noqa: F841 + df = pd_local.DataFrame({"a": [1, 2, 3]}) # noqa: F841 res = duckdb.sql("from df").fetchall() assert res == [(1,), (2,), (3,)] diff --git a/tests/fast/api/test_duckdb_query.py b/tests/fast/api/test_duckdb_query.py index 04531e49..8be3287c 100644 --- a/tests/fast/api/test_duckdb_query.py +++ b/tests/fast/api/test_duckdb_query.py @@ -1,5 +1,5 @@ +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb from duckdb import Value @@ -21,9 +21,8 @@ def test_duckdb_query(self, duckdb_cursor): res = duckdb_cursor.sql("select 42; select 84;").fetchall() assert res == [(84,)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_duckdb_from_query_multiple_statements(self, pandas): - tst_df = pandas.DataFrame({"a": [1, 23, 3, 5]}) # noqa: F841 + def test_duckdb_from_query_multiple_statements(self): + tst_df = pd.DataFrame({"a": [1, 23, 3, 5]}) # noqa: F841 res = duckdb.sql( """ diff --git a/tests/fast/api/test_fsspec.py b/tests/fast/api/test_fsspec.py index 154f38bd..f68415cb 100644 --- a/tests/fast/api/test_fsspec.py +++ b/tests/fast/api/test_fsspec.py @@ -53,3 +53,53 @@ def __init__(self) -> None: result = duckdb_cursor.read_parquet(file_globs=["deadlock://a", "deadlock://b"], union_by_name=True) assert len(result.fetchall()) == 100_000 + + def test_fsspec_seek_read_atomicity(self, duckdb_cursor, tmp_path): + """Regression test: concurrent positional reads must be atomic (seek+read under one GIL hold). + + Without the fix, separate seek and read GIL acquisitions allow another thread to + seek the same handle between them, corrupting data. We stress this by reading 4 files + with distinct data in parallel (union_by_name) and verifying no cross-contamination. + """ + files = {} + for i, name in enumerate(["a", "b", "c", "d"]): + file_path = tmp_path / f"{name}.parquet" + duckdb_cursor.sql(f"COPY (SELECT {i} AS file_id FROM range(10000)) TO '{file_path!s}' (FORMAT parquet)") + files[name] = file_path.read_bytes() + + class AtomicityTestFS(fsspec.AbstractFileSystem): + protocol = "atomtest" + + @property + def fsid(self): + return "atomtest" + + def ls(self, path, detail=True, **kwargs): + vals = [k for k in self._data if k.startswith(path)] + if detail: + return [ + {"name": p, "size": len(self._data[p]), "type": "file", "created": 0, "islink": False} + for p in vals + ] + return vals + + def modified(self, path): + return datetime.datetime.now() + + def _open(self, path, **kwargs): + return io.BytesIO(self._data[path]) + + def __init__(self) -> None: + super().__init__() + self._data = files + + fsspec.register_implementation("atomtest", AtomicityTestFS, clobber=True) + duckdb_cursor.register_filesystem(fsspec.filesystem("atomtest")) + + globs = ["atomtest://a", "atomtest://b", "atomtest://c", "atomtest://d"] + for _ in range(10): + result = duckdb_cursor.sql( + f"SELECT file_id, count(*) AS cnt FROM read_parquet({globs}, union_by_name=true) " + "GROUP BY ALL ORDER BY file_id" + ).fetchall() + assert result == [(0, 10000), (1, 10000), (2, 10000), (3, 10000)] diff --git a/tests/fast/api/test_native_tz.py b/tests/fast/api/test_native_tz.py index 66b06565..18f7b7e7 100644 --- a/tests/fast/api/test_native_tz.py +++ b/tests/fast/api/test_native_tz.py @@ -1,4 +1,5 @@ import datetime +import zoneinfo from pathlib import Path import pytest @@ -12,6 +13,17 @@ filename = str(Path(__file__).parent / ".." / "data" / "tz.parquet") +def get_tz_string(obj): + if isinstance(obj, zoneinfo.ZoneInfo): + # Pandas 3.0.0 creates ZoneInfo objects + return obj.key + if hasattr(obj, "zone"): + # Before 3.0.0 Pandas created tzdata objects + return obj.zone + msg = f"Can't get tz string from {obj}" + raise ValueError(msg) + + class TestNativeTimeZone: def test_native_python_timestamp_timezone(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") @@ -46,7 +58,7 @@ def test_native_python_time_timezone(self, duckdb_cursor): def test_pandas_timestamp_timezone(self, duckdb_cursor): res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").df() - assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert get_tz_string(res.dtypes["tz"].tz) == "America/Los_Angeles" assert res["tz"][0].hour == 14 assert res["tz"][0].minute == 52 @@ -65,29 +77,27 @@ def test_pandas_timestamp_time(self, duckdb_cursor): Version(pa.__version__) < Version("15.0.0"), reason="pyarrow 14.0.2 'to_pandas' causes a DeprecationWarning" ) def test_arrow_timestamp_timezone(self, duckdb_cursor): - res = duckdb_cursor.execute("SET timezone='America/Los_Angeles';") - table = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table() + duckdb_cursor.execute("SET timezone='America/Los_Angeles';") + table = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").to_arrow_table() res = table.to_pandas() - assert res.dtypes["tz"].tz.zone == "America/Los_Angeles" + assert get_tz_string(res.dtypes["tz"].tz) == "America/Los_Angeles" assert res["tz"][0].hour == 14 assert res["tz"][0].minute == 52 duckdb_cursor.execute("SET timezone='UTC';") - res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").fetch_arrow_table().to_pandas() - assert res.dtypes["tz"].tz.zone == "UTC" + res = duckdb_cursor.execute(f"select TimeRecStart as tz from '{filename}'").to_arrow_table().to_pandas() + assert get_tz_string(res.dtypes["tz"].tz) == "UTC" assert res["tz"][0].hour == 21 assert res["tz"][0].minute == 52 def test_arrow_timestamp_time(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='America/Los_Angeles';") res1 = ( - duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'") - .fetch_arrow_table() - .to_pandas() + duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").to_arrow_table().to_pandas() ) res2 = ( duckdb_cursor.execute(f"select TimeRecStart::TIMETZ::TIME as tz from '{filename}'") - .fetch_arrow_table() + .to_arrow_table() .to_pandas() ) assert res1["tz"][0].hour == 14 @@ -97,13 +107,11 @@ def test_arrow_timestamp_time(self, duckdb_cursor): duckdb_cursor.execute("SET timezone='UTC';") res1 = ( - duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'") - .fetch_arrow_table() - .to_pandas() + duckdb_cursor.execute(f"select TimeRecStart::TIMETZ as tz from '{filename}'").to_arrow_table().to_pandas() ) res2 = ( duckdb_cursor.execute(f"select TimeRecStart::TIMETZ::TIME as tz from '{filename}'") - .fetch_arrow_table() + .to_arrow_table() .to_pandas() ) assert res1["tz"][0].hour == 21 diff --git a/tests/fast/api/test_sql_params_performance.py b/tests/fast/api/test_sql_params_performance.py new file mode 100644 index 00000000..b2654e8c --- /dev/null +++ b/tests/fast/api/test_sql_params_performance.py @@ -0,0 +1,40 @@ +import time + + +class TestSqlEmptyParams: + """Empty params should use lazy QueryRelation path (same as params=None).""" + + def test_empty_list_returns_same_result(self, duckdb_cursor): + """sql(params=[]) returns same data as sql(params=None).""" + duckdb_cursor.execute("CREATE TABLE t AS SELECT i FROM range(10) t(i)") + expected = duckdb_cursor.sql("SELECT * FROM t").fetchall() + result = duckdb_cursor.sql("SELECT * FROM t", params=[]).fetchall() + assert result == expected + + def test_empty_dict_returns_same_result(self, duckdb_cursor): + """sql(params={}) returns same data as sql(params=None).""" + duckdb_cursor.execute("CREATE TABLE t AS SELECT i FROM range(10) t(i)") + expected = duckdb_cursor.sql("SELECT * FROM t").fetchall() + result = duckdb_cursor.sql("SELECT * FROM t", params={}).fetchall() + assert result == expected + + def test_empty_tuple_returns_same_result(self, duckdb_cursor): + """sql(params=()) returns same data as sql(params=None).""" + duckdb_cursor.execute("CREATE TABLE t AS SELECT i FROM range(10) t(i)") + expected = duckdb_cursor.sql("SELECT * FROM t").fetchall() + result = duckdb_cursor.sql("SELECT * FROM t", params=()).fetchall() + assert result == expected + + def test_empty_params_is_chainable(self, duckdb_cursor): + """Empty params produces a real relation that supports chaining.""" + duckdb_cursor.execute("CREATE TABLE t AS SELECT i FROM range(10) t(i)") + result = duckdb_cursor.sql("SELECT * FROM t", params=[]).filter("i < 3").order("i").fetchall() + assert result == [(0,), (1,), (2,)] + + def test_empty_params_explain_is_fast(self, duckdb_cursor): + """Empty params explain should not trigger expensive ToString.""" + duckdb_cursor.execute("CREATE TABLE t AS SELECT i FROM range(100000) t(i)") + t0 = time.perf_counter() + duckdb_cursor.sql("SELECT * FROM t", params=[]).explain() + elapsed = time.perf_counter() - t0 + assert elapsed < 5.0, f"explain() took {elapsed:.2f}s, expected < 5s" diff --git a/tests/fast/api/test_streaming_result.py b/tests/fast/api/test_streaming_result.py index 4003f20f..9cba78b1 100644 --- a/tests/fast/api/test_streaming_result.py +++ b/tests/fast/api/test_streaming_result.py @@ -41,7 +41,7 @@ def test_record_batch_reader(self, duckdb_cursor): pytest.importorskip("pyarrow.dataset") # record batch reader res = duckdb_cursor.sql("SELECT * FROM range(100000) t(i)") - reader = res.fetch_arrow_reader(batch_size=16_384) + reader = res.to_arrow_reader(batch_size=16_384) result = [] for batch in reader: result += batch.to_pydict()["i"] @@ -52,7 +52,7 @@ def test_record_batch_reader(self, duckdb_cursor): "SELECT CASE WHEN i < 10000 THEN i ELSE concat('hello', i::VARCHAR)::INT END FROM range(100000) t(i)" ) with pytest.raises(duckdb.ConversionException, match="Could not convert string 'hello10000' to INT32"): - reader = res.fetch_arrow_reader(batch_size=16_384) + reader = res.to_arrow_reader(batch_size=16_384) def test_9801(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test(id INTEGER , name VARCHAR NOT NULL);") diff --git a/tests/fast/api/test_to_csv.py b/tests/fast/api/test_to_csv.py index 97f13d8b..1354888a 100644 --- a/tests/fast/api/test_to_csv.py +++ b/tests/fast/api/test_to_csv.py @@ -3,17 +3,17 @@ import os import tempfile +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData +from conftest import getTimeSeriesData import duckdb class TestToCSV: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_basic_to_csv(self, pandas): + def test_basic_to_csv(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) + df = pd.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -21,10 +21,9 @@ def test_basic_to_csv(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_sep(self, pandas): + def test_to_csv_sep(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) + df = pd.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, sep=",") @@ -32,10 +31,9 @@ def test_to_csv_sep(self, pandas): csv_rel = duckdb.read_csv(temp_file_name, sep=",") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_na_rep(self, pandas): + def test_to_csv_na_rep(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) + df = pd.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, na_rep="test") @@ -43,10 +41,9 @@ def test_to_csv_na_rep(self, pandas): csv_rel = duckdb.read_csv(temp_file_name, na_values="test") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_header(self, pandas): + def test_to_csv_header(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) + df = pd.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name) @@ -54,10 +51,9 @@ def test_to_csv_header(self, pandas): csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_quotechar(self, pandas): + def test_to_csv_quotechar(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) + df = pd.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quotechar="'", sep=",") @@ -65,10 +61,9 @@ def test_to_csv_quotechar(self, pandas): csv_rel = duckdb.read_csv(temp_file_name, sep=",", quotechar="'") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_escapechar(self, pandas): + def test_to_csv_escapechar(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame( + df = pd.DataFrame( { "c_bool": [True, False], "c_float": [1.0, 3.2], @@ -81,12 +76,11 @@ def test_to_csv_escapechar(self, pandas): csv_rel = duckdb.read_csv(temp_file_name, quotechar='"', escapechar="!") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_date_format(self, pandas): + def test_to_csv_date_format(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame(getTimeSeriesData()) + df = pd.DataFrame(getTimeSeriesData()) dt_index = df.index - df = pandas.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) + df = pd.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, date_format="%Y%m%d") @@ -94,11 +88,10 @@ def test_to_csv_date_format(self, pandas): assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_timestamp_format(self, pandas): + def test_to_csv_timestamp_format(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) + df = pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, timestamp_format="%m/%d/%Y") @@ -106,68 +99,61 @@ def test_to_csv_timestamp_format(self, pandas): assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_quoting_off(self, pandas): + def test_to_csv_quoting_off(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=None) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_quoting_on(self, pandas): + def test_to_csv_quoting_on(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting="force") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_quoting_quote_all(self, pandas): + def test_to_csv_quoting_quote_all(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, quoting=csv.QUOTE_ALL) csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_encoding_incorrect(self, pandas): + def test_to_csv_encoding_incorrect(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) with pytest.raises( duckdb.InvalidInputException, match="Invalid Input Error: The only supported encoding option is 'UTF8" ): rel.to_csv(temp_file_name, encoding="nope") - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_encoding_correct(self, pandas): + def test_to_csv_encoding_correct(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, encoding="UTF-8") csv_rel = duckdb.read_csv(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_compression_gzip(self, pandas): + def test_compression_gzip(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) + df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) rel.to_csv(temp_file_name, compression="gzip") csv_rel = duckdb.read_csv(temp_file_name, compression="gzip") assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_partition(self, pandas): + def test_to_csv_partition(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame( + df = pd.DataFrame( { "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], @@ -190,10 +176,9 @@ def test_to_csv_partition(self, pandas): assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_partition_with_columns_written(self, pandas): + def test_to_csv_partition_with_columns_written(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame( + df = pd.DataFrame( { "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], @@ -210,10 +195,9 @@ def test_to_csv_partition_with_columns_written(self, pandas): ) assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_overwrite(self, pandas): + def test_to_csv_overwrite(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame( + df = pd.DataFrame( { "c_category_1": ["a", "a", "b", "b"], "c_category_2": ["c", "c", "d", "d"], @@ -238,10 +222,9 @@ def test_to_csv_overwrite(self, pandas): ] assert csv_rel.execute().fetchall() == expected - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_overwrite_with_columns_written(self, pandas): + def test_to_csv_overwrite_with_columns_written(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame( + df = pd.DataFrame( { "c_category_1": ["a", "a", "b", "b"], "c_category_2": ["c", "c", "d", "d"], @@ -264,10 +247,9 @@ def test_to_csv_overwrite_with_columns_written(self, pandas): res = duckdb.sql("FROM rel order by all") assert res.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_overwrite_not_enabled(self, pandas): + def test_to_csv_overwrite_not_enabled(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame( + df = pd.DataFrame( { "c_category_1": ["a", "a", "b", "b"], "c_category_2": ["c", "c", "d", "d"], @@ -282,12 +264,11 @@ def test_to_csv_overwrite_not_enabled(self, pandas): with pytest.raises(duckdb.IOException, match="OVERWRITE"): rel.to_csv(temp_file_name, header=True, partition_by=["c_category_1"]) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_per_thread_output(self, pandas): + def test_to_csv_per_thread_output(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] print("num_threads:", num_threads) - df = pandas.DataFrame( + df = pd.DataFrame( { "c_category": ["a", "a", "b", "b"], "c_bool": [True, False, True, True], @@ -301,10 +282,9 @@ def test_to_csv_per_thread_output(self, pandas): csv_rel = duckdb.read_csv(f"{temp_file_name}/*.csv", header=True) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_use_tmp_file(self, pandas): + def test_to_csv_use_tmp_file(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 - df = pandas.DataFrame( + df = pd.DataFrame( { "c_category_1": ["a", "a", "b", "b"], "c_category_2": ["c", "c", "d", "d"], diff --git a/tests/fast/api/test_to_parquet.py b/tests/fast/api/test_to_parquet.py index 370ab8e4..5c70bf3f 100644 --- a/tests/fast/api/test_to_parquet.py +++ b/tests/fast/api/test_to_parquet.py @@ -3,15 +3,14 @@ import re import tempfile +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb class TestToParquet: - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_basic_to_parquet(self, pd): + def test_basic_to_parquet(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) rel = duckdb.from_df(df) @@ -21,8 +20,7 @@ def test_basic_to_parquet(self, pd): csv_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_compression_gzip(self, pd): + def test_compression_gzip(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) @@ -50,9 +48,8 @@ def test_field_ids(self): """ ).execute().fetchall() == [("duckdb_schema", None), ("i", 42), ("my_struct", 43), ("j", 44)] - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("row_group_size_bytes", [122880 * 1024, "2MB"]) - def test_row_group_size_bytes(self, pd, row_group_size_bytes): + def test_row_group_size_bytes(self, row_group_size_bytes): con = duckdb.connect() con.execute("SET preserve_insertion_order=false;") @@ -63,8 +60,7 @@ def test_row_group_size_bytes(self, pd, row_group_size_bytes): parquet_rel = con.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_row_group_size(self, pd): + def test_row_group_size(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame({"a": ["string1", "string2", "string3"]}) rel = duckdb.from_df(df) @@ -72,9 +68,8 @@ def test_row_group_size(self, pd): parquet_rel = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == parquet_rel.execute().fetchall() - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("write_columns", [None, True, False]) - def test_partition(self, pd, write_columns): + def test_partition(self, write_columns): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { @@ -89,9 +84,8 @@ def test_partition(self, pd, write_columns): expected = [("rei", 321.0, "a"), ("shinji", 123.0, "a"), ("asuka", 23.0, "b"), ("kaworu", 340.0, "c")] assert result.execute().fetchall() == expected - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("write_columns", [None, True, False]) - def test_overwrite(self, pd, write_columns): + def test_overwrite(self, write_columns): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { @@ -108,8 +102,7 @@ def test_overwrite(self, pd, write_columns): assert result.execute().fetchall() == expected - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_use_tmp_file(self, pd): + def test_use_tmp_file(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { @@ -124,8 +117,7 @@ def test_use_tmp_file(self, pd): result = duckdb.read_parquet(temp_file_name) assert rel.execute().fetchall() == result.execute().fetchall() - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_per_thread_output(self, pd): + def test_per_thread_output(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 num_threads = duckdb.sql("select current_setting('threads')").fetchone()[0] print("threads:", num_threads) @@ -141,8 +133,7 @@ def test_per_thread_output(self, pd): result = duckdb.read_parquet(f"{temp_file_name}/*.parquet") assert rel.execute().fetchall() == result.execute().fetchall() - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_append(self, pd): + def test_append(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { @@ -173,8 +164,7 @@ def test_append(self, pd): ] assert result.execute().fetchall() == expected - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_filename_pattern_with_index(self, pd): + def test_filename_pattern_with_index(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { @@ -199,8 +189,7 @@ def test_filename_pattern_with_index(self, pd): expected = [("rei", 321.0, "a"), ("shinji", 123.0, "a"), ("asuka", 23.0, "b"), ("kaworu", 340.0, "c")] assert result.execute().fetchall() == expected - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_filename_pattern_with_uuid(self, pd): + def test_filename_pattern_with_uuid(self): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { @@ -242,9 +231,8 @@ def test_file_size_bytes_basic(self, file_size_bytes): result = duckdb.read_parquet(f"{temp_file_name}/*.parquet") assert len(result.execute().fetchall()) == 10000 - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("file_size_bytes", ["256MB", "1G"]) - def test_file_size_bytes_human_readable(self, pd, file_size_bytes): + def test_file_size_bytes_human_readable(self, file_size_bytes): temp_file_name = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names())) # noqa: PTH118 df = pd.DataFrame( { diff --git a/tests/fast/arrow/test_2426.py b/tests/fast/arrow/test_2426.py index 6f76613f..f631ec1a 100644 --- a/tests/fast/arrow/test_2426.py +++ b/tests/fast/arrow/test_2426.py @@ -31,7 +31,7 @@ def test_2426(self, duckdb_cursor): result_df = con.execute(sql).df() - arrow_table = con.execute(sql).fetch_arrow_table() + arrow_table = con.execute(sql).to_arrow_table() arrow_df = arrow_table.to_pandas() assert result_df["repetitions"].sum() == arrow_df["repetitions"].sum() diff --git a/tests/fast/arrow/test_6584.py b/tests/fast/arrow/test_6584.py index feadc6d7..ed7d9a47 100644 --- a/tests/fast/arrow/test_6584.py +++ b/tests/fast/arrow/test_6584.py @@ -9,7 +9,7 @@ def f(cur, i, data): cur.execute(f"create table t_{i} as select * from data") - return cur.execute(f"select * from t_{i}").fetch_arrow_table() + return cur.execute(f"select * from t_{i}").to_arrow_table() def test_6584(): diff --git a/tests/fast/arrow/test_6796.py b/tests/fast/arrow/test_6796.py index bf557038..13286de2 100644 --- a/tests/fast/arrow/test_6796.py +++ b/tests/fast/arrow/test_6796.py @@ -1,15 +1,14 @@ +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb pyarrow = pytest.importorskip("pyarrow") -@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) -def test_6796(pandas): +def test_6796(): conn = duckdb.connect() - input_df = pandas.DataFrame({"foo": ["bar"]}) + input_df = pd.DataFrame({"foo": ["bar"]}) conn.register("input_df", input_df) query = """ @@ -20,7 +19,7 @@ def test_6796(pandas): # fetching directly into Pandas works res_df = conn.execute(query).fetch_df() - res_arrow = conn.execute(query).fetch_arrow_table() # noqa: F841 + res_arrow = conn.execute(query).to_arrow_table() # noqa: F841 df_arrow_table = pyarrow.Table.from_pandas(res_df) # noqa: F841 diff --git a/tests/fast/arrow/test_9443.py b/tests/fast/arrow/test_9443.py index fe5a2ce1..66c8c0be 100644 --- a/tests/fast/arrow/test_9443.py +++ b/tests/fast/arrow/test_9443.py @@ -23,4 +23,4 @@ def test_9443(self, tmp_path, duckdb_cursor): sql = f'SELECT * FROM "{temp_file}"' duckdb_cursor.execute(sql) - duckdb_cursor.fetch_record_batch() + duckdb_cursor.to_arrow_reader() diff --git a/tests/fast/arrow/test_arrow_binary_view.py b/tests/fast/arrow/test_arrow_binary_view.py index 4e161ac3..10580d8e 100644 --- a/tests/fast/arrow/test_arrow_binary_view.py +++ b/tests/fast/arrow/test_arrow_binary_view.py @@ -11,10 +11,10 @@ def test_arrow_binary_view(self, duckdb_cursor): tab = pa.table({"x": pa.array([b"abc", b"thisisaverybigbinaryyaymorethanfifteen", None], pa.binary_view())}) assert con.execute("FROM tab").fetchall() == [(b"abc",), (b"thisisaverybigbinaryyaymorethanfifteen",), (None,)] # By default we won't export a view - assert not con.execute("FROM tab").fetch_arrow_table().equals(tab) + assert not con.execute("FROM tab").to_arrow_table().equals(tab) # We do the binary view from 1.4 onwards con.execute("SET arrow_output_version = 1.4") - assert con.execute("FROM tab").fetch_arrow_table().equals(tab) + assert con.execute("FROM tab").to_arrow_table().equals(tab) assert con.execute("FROM tab where x = 'thisisaverybigbinaryyaymorethanfifteen'").fetchall() == [ (b"thisisaverybigbinaryyaymorethanfifteen",) diff --git a/tests/fast/arrow/test_arrow_decimal256.py b/tests/fast/arrow/test_arrow_decimal256.py index d687ec8a..b7fd1a03 100644 --- a/tests/fast/arrow/test_arrow_decimal256.py +++ b/tests/fast/arrow/test_arrow_decimal256.py @@ -17,4 +17,4 @@ def test_decimal_256_throws(self, duckdb_cursor): with pytest.raises( duckdb.NotImplementedException, match="Unsupported Internal Arrow Type for Decimal d:12,4,256" ): - conn.execute("select * from pa_decimal256;").fetch_arrow_table().to_pylist() + conn.execute("select * from pa_decimal256;").to_arrow_table().to_pylist() diff --git a/tests/fast/arrow/test_arrow_decimal_32_64.py b/tests/fast/arrow/test_arrow_decimal_32_64.py index 301d890f..3ae4dc98 100644 --- a/tests/fast/arrow/test_arrow_decimal_32_64.py +++ b/tests/fast/arrow/test_arrow_decimal_32_64.py @@ -33,7 +33,7 @@ def test_decimal_32(self, duckdb_cursor): ] # Test write - arrow_table = duckdb_cursor.execute("FROM decimal_32").fetch_arrow_table() + arrow_table = duckdb_cursor.execute("FROM decimal_32").to_arrow_table() assert arrow_table.equals(decimal_32) @@ -64,5 +64,5 @@ def test_decimal_64(self, duckdb_cursor): ).fetchall() == [(2,)] # Test write - arrow_table = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table() + arrow_table = duckdb_cursor.execute("FROM decimal_64").to_arrow_table() assert arrow_table.equals(decimal_64) diff --git a/tests/fast/arrow/test_arrow_deprecation.py b/tests/fast/arrow/test_arrow_deprecation.py new file mode 100644 index 00000000..91769288 --- /dev/null +++ b/tests/fast/arrow/test_arrow_deprecation.py @@ -0,0 +1,139 @@ +import warnings + +import pytest + +import duckdb + +pytest.importorskip("pyarrow") + + +class TestArrowDeprecation: + @pytest.fixture(autouse=True) + def setup(self, duckdb_cursor): + self.con = duckdb_cursor + self.con.execute("CREATE TABLE t AS SELECT 1 AS a") + + def test_relation_fetch_arrow_table_deprecated(self): + rel = self.con.table("t") + with pytest.warns( + DeprecationWarning, match="fetch_arrow_table\\(\\) is deprecated, use to_arrow_table\\(\\) instead" + ): + rel.fetch_arrow_table() + + def test_relation_fetch_record_batch_deprecated(self): + rel = self.con.table("t") + with pytest.warns( + DeprecationWarning, match="fetch_record_batch\\(\\) is deprecated, use to_arrow_reader\\(\\) instead" + ): + rel.fetch_record_batch() + + def test_relation_fetch_arrow_reader_deprecated(self): + rel = self.con.table("t") + with pytest.warns( + DeprecationWarning, match="fetch_arrow_reader\\(\\) is deprecated, use to_arrow_reader\\(\\) instead" + ): + rel.fetch_arrow_reader() + + def test_connection_fetch_arrow_table_deprecated(self): + self.con.execute("SELECT 1") + with pytest.warns( + DeprecationWarning, match="fetch_arrow_table\\(\\) is deprecated, use to_arrow_table\\(\\) instead" + ): + self.con.fetch_arrow_table() + + def test_connection_fetch_record_batch_deprecated(self): + self.con.execute("SELECT 1") + with pytest.warns( + DeprecationWarning, match="fetch_record_batch\\(\\) is deprecated, use to_arrow_reader\\(\\) instead" + ): + self.con.fetch_record_batch() + + def test_module_fetch_arrow_table_deprecated(self): + duckdb.execute("SELECT 1") + with pytest.warns( + DeprecationWarning, match="fetch_arrow_table\\(\\) is deprecated, use to_arrow_table\\(\\) instead" + ): + duckdb.fetch_arrow_table() + + def test_module_fetch_record_batch_deprecated(self): + duckdb.execute("SELECT 1") + with pytest.warns( + DeprecationWarning, match="fetch_record_batch\\(\\) is deprecated, use to_arrow_reader\\(\\) instead" + ): + duckdb.fetch_record_batch() + + def test_relation_to_arrow_table_works(self): + rel = self.con.table("t") + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = rel.to_arrow_table() + assert result.num_rows == 1 + + def test_relation_to_arrow_reader_works(self): + rel = self.con.table("t") + with warnings.catch_warnings(): + warnings.simplefilter("error") + reader = rel.to_arrow_reader() + assert reader.read_all().num_rows == 1 + + def test_relation_arrow_no_warning(self): + """relation.arrow() should NOT emit a deprecation warning (soft deprecated).""" + rel = self.con.table("t") + with warnings.catch_warnings(): + warnings.simplefilter("error") + reader = rel.arrow() + assert reader.read_all().num_rows == 1 + + def test_connection_to_arrow_table_works(self): + self.con.execute("SELECT 1") + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = self.con.to_arrow_table() + assert result.num_rows == 1 + + def test_connection_to_arrow_reader_works(self): + self.con.execute("SELECT 1") + with warnings.catch_warnings(): + warnings.simplefilter("error") + reader = self.con.to_arrow_reader() + assert reader.read_all().num_rows == 1 + + def test_connection_arrow_no_warning(self): + """connection.arrow() should NOT emit a deprecation warning (soft deprecated).""" + self.con.execute("SELECT 1") + with warnings.catch_warnings(): + warnings.simplefilter("error") + reader = self.con.arrow() + assert reader.read_all().num_rows == 1 + + def test_module_to_arrow_table_works(self): + duckdb.execute("SELECT 1") + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = duckdb.to_arrow_table() + assert result.num_rows == 1 + + def test_module_to_arrow_reader_works(self): + duckdb.execute("SELECT 1") + with warnings.catch_warnings(): + warnings.simplefilter("error") + reader = duckdb.to_arrow_reader() + assert reader.read_all().num_rows == 1 + + def test_module_arrow_no_warning(self): + """duckdb.arrow(rows_per_batch) should NOT emit a deprecation warning (soft deprecated).""" + duckdb.execute("SELECT 1") + with warnings.catch_warnings(): + warnings.simplefilter("error") + result = duckdb.arrow() + assert result.read_all().num_rows == 1 + + def test_from_arrow_not_deprecated(self): + """duckdb.arrow(arrow_object) should NOT emit a deprecation warning.""" + import pyarrow as pa + + table = pa.table({"a": [1, 2, 3]}) + with warnings.catch_warnings(): + warnings.simplefilter("error") + rel = duckdb.arrow(table) + assert rel.fetchall() == [(1,), (2,), (3,)] diff --git a/tests/fast/arrow/test_arrow_extensions.py b/tests/fast/arrow/test_arrow_extensions.py index f79c32c4..3a73d266 100644 --- a/tests/fast/arrow/test_arrow_extensions.py +++ b/tests/fast/arrow/test_arrow_extensions.py @@ -21,7 +21,7 @@ def test_uuid(self): arrow_table = pa.Table.from_arrays([storage_array], names=["uuid_col"]) - duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").to_arrow_table() assert duck_arrow.equals(arrow_table) @@ -29,7 +29,7 @@ def test_uuid_from_duck(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") - arrow_table = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() + arrow_table = duckdb_cursor.execute("select uuid from test_all_types()").to_arrow_table() assert arrow_table.to_pylist() == [ {"uuid": UUID("00000000-0000-0000-0000-000000000000")}, @@ -45,7 +45,7 @@ def test_uuid_from_duck(self): arrow_table = duckdb_cursor.execute( "select '00000000-0000-0000-0000-000000000100'::UUID as uuid" - ).fetch_arrow_table() + ).to_arrow_table() assert arrow_table.to_pylist() == [{"uuid": UUID("00000000-0000-0000-0000-000000000100")}] assert duckdb_cursor.execute("FROM arrow_table").fetchall() == [(UUID("00000000-0000-0000-0000-000000000100"),)] @@ -61,7 +61,7 @@ def test_json(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([storage_array], names=["json_col"]) duckdb_cursor.execute("SET arrow_lossless_conversion = true") - duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").to_arrow_table() assert duck_arrow.equals(arrow_table) @@ -69,7 +69,7 @@ def test_uuid_no_def(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") - res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() + res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").to_arrow_table() res_duck = duckdb_cursor.execute("from res_arrow").fetchall() assert res_duck == [ (UUID("00000000-0000-0000-0000-000000000000"),), @@ -79,7 +79,7 @@ def test_uuid_no_def(self): def test_uuid_no_def_lossless(self): duckdb_cursor = duckdb.connect() - res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_arrow_table() + res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").to_arrow_table() assert res_arrow.to_pylist() == [ {"uuid": "00000000-0000-0000-0000-000000000000"}, {"uuid": "ffffffff-ffff-ffff-ffff-ffffffffffff"}, @@ -97,7 +97,7 @@ def test_uuid_no_def_stream(self): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("SET arrow_lossless_conversion = true") - res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").fetch_record_batch() + res_arrow = duckdb_cursor.execute("select uuid from test_all_types()").to_arrow_reader() res_duck = duckdb.execute("from res_arrow").fetchall() assert res_duck == [ (UUID("00000000-0000-0000-0000-000000000000"),), @@ -136,7 +136,7 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized) -> object: arrow_table = pa.Table.from_arrays([storage_array, age_array], names=["pedro_pedro_pedro", "age"]) - duck_arrow = duckdb_cursor.execute("FROM arrow_table").fetch_arrow_table() + duck_arrow = duckdb_cursor.execute("FROM arrow_table").to_arrow_table() assert duckdb_cursor.execute("FROM duck_arrow").fetchall() == [(b"pedro", 29)] def test_hugeint(self): @@ -153,11 +153,11 @@ def test_hugeint(self): assert con.execute("FROM arrow_table").fetchall() == [(-1,)] - assert con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) + assert con.execute("FROM arrow_table").to_arrow_table().equals(arrow_table) con.execute("SET arrow_lossless_conversion = false") - assert not con.execute("FROM arrow_table").fetch_arrow_table().equals(arrow_table) + assert not con.execute("FROM arrow_table").to_arrow_table().equals(arrow_table) def test_uhugeint(self, duckdb_cursor): storage_array = pa.array([b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"], pa.binary(16)) @@ -171,11 +171,11 @@ def test_uhugeint(self, duckdb_cursor): def test_bit(self): con = duckdb.connect() - res_blob = con.execute("SELECT '0101011'::BIT str FROM range(5) tbl(i)").fetch_arrow_table() + res_blob = con.execute("SELECT '0101011'::BIT str FROM range(5) tbl(i)").to_arrow_table() con.execute("SET arrow_lossless_conversion = true") - res_bit = con.execute("SELECT '0101011'::BIT str FROM range(5) tbl(i)").fetch_arrow_table() + res_bit = con.execute("SELECT '0101011'::BIT str FROM range(5) tbl(i)").to_arrow_table() assert con.execute("FROM res_blob").fetchall() == [ (b"\x01\xab",), @@ -195,11 +195,11 @@ def test_bit(self): def test_timetz(self): con = duckdb.connect() - res_time = con.execute("SELECT '02:30:00+04'::TIMETZ str FROM range(1) tbl(i)").fetch_arrow_table() + res_time = con.execute("SELECT '02:30:00+04'::TIMETZ str FROM range(1) tbl(i)").to_arrow_table() con.execute("SET arrow_lossless_conversion = true") - res_tz = con.execute("SELECT '02:30:00+04'::TIMETZ str FROM range(1) tbl(i)").fetch_arrow_table() + res_tz = con.execute("SELECT '02:30:00+04'::TIMETZ str FROM range(1) tbl(i)").to_arrow_table() assert con.execute("FROM res_time").fetchall() == [(datetime.time(2, 30),)] assert con.execute("FROM res_tz").fetchall() == [ @@ -210,7 +210,7 @@ def test_bignum(self): con = duckdb.connect() res_bignum = con.execute( "SELECT '179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540458953514382464234321326889464182768467546703537516986049910576551282076245490090389328944075868508455133942304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368'::bignum a FROM range(1) tbl(i)" # noqa: E501 - ).fetch_arrow_table() + ).to_arrow_table() assert res_bignum.column("a").type.type_name == "bignum" assert res_bignum.column("a").type.vendor_name == "DuckDB" @@ -226,7 +226,7 @@ def test_nested_types_with_extensions(self): arrow_table = duckdb_cursor.execute( "select map {uuid(): 1::uhugeint, uuid(): 2::uhugeint} as li" - ).fetch_arrow_table() + ).to_arrow_table() assert arrow_table.schema[0].type.key_type.extension_name == "arrow.uuid" assert arrow_table.schema[0].type.item_type.extension_name == "arrow.opaque" @@ -267,7 +267,7 @@ def test_boolean(self): bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), storage_array) arrow_table = pa.Table.from_arrays([bool8_array], names=["bool8"]) assert con.execute("FROM arrow_table").fetchall() == [(True,), (False,), (True,), (True,), (None,)] - result_table = con.execute("FROM arrow_table").fetch_arrow_table() + result_table = con.execute("FROM arrow_table").to_arrow_table() res_storage_array = pa.array([1, 0, 1, 1, None], pa.int8()) res_bool8_array = pa.ExtensionArray.from_storage(pa.bool8(), res_storage_array) @@ -290,7 +290,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): schema=schema, ) - tbl = duckdb_cursor.sql("""SELECT geometry as wkt FROM geo_table;""").fetch_arrow_table() + tbl = duckdb_cursor.sql("""SELECT geometry as wkt FROM geo_table;""").to_arrow_table() assert pa.types.is_binary(tbl.schema[0].type) field = pa.field( @@ -307,7 +307,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): schema=schema, ) with pytest.raises(duckdb.SerializationException, match="Failed to parse JSON string"): - tbl = duckdb_cursor.sql("""SELECT geometry as wkt FROM geo_table;""").fetch_arrow_table() + tbl = duckdb_cursor.sql("""SELECT geometry as wkt FROM geo_table;""").to_arrow_table() field = pa.field( "geometry", @@ -322,7 +322,7 @@ def test_accept_malformed_complex_json(self, duckdb_cursor): [pa.array([], pa.binary())], schema=schema, ) - tbl = duckdb_cursor.sql("""SELECT geometry as wkt FROM geo_table;""").fetch_arrow_table() + tbl = duckdb_cursor.sql("""SELECT geometry as wkt FROM geo_table;""").to_arrow_table() assert pa.types.is_binary(tbl.schema[0].type) field = pa.field( diff --git a/tests/fast/arrow/test_arrow_fetch.py b/tests/fast/arrow/test_arrow_fetch.py index ba5d13a4..cf1ff46b 100644 --- a/tests/fast/arrow/test_arrow_fetch.py +++ b/tests/fast/arrow/test_arrow_fetch.py @@ -14,7 +14,7 @@ def check_equal(duckdb_conn): true_result = duckdb_conn.execute("SELECT * from test").fetchall() duck_tbl = duckdb_conn.table("test") - duck_from_arrow = duckdb_conn.from_arrow(duck_tbl.fetch_arrow_table()) + duck_from_arrow = duckdb_conn.from_arrow(duck_tbl.to_arrow_table()) duck_from_arrow.create("testarrow") arrow_result = duckdb_conn.execute("SELECT * from testarrow").fetchall() assert arrow_result == true_result @@ -86,7 +86,7 @@ def test_to_arrow_chunk_size(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") relation = duckdb_cursor.table("t") - arrow_tbl = relation.fetch_arrow_table() + arrow_tbl = relation.to_arrow_table() assert arrow_tbl["a"].num_chunks == 1 - arrow_tbl = relation.fetch_arrow_table(2048) + arrow_tbl = relation.to_arrow_table(2048) assert arrow_tbl["a"].num_chunks == 2 diff --git a/tests/fast/arrow/test_arrow_fetch_recordbatch.py b/tests/fast/arrow/test_arrow_fetch_recordbatch.py index a5804c87..d060659f 100644 --- a/tests/fast/arrow/test_arrow_fetch_recordbatch.py +++ b/tests/fast/arrow/test_arrow_fetch_recordbatch.py @@ -12,7 +12,7 @@ def test_record_batch_next_batch_numeric(self, duckdb_cursor): duckdb_cursor_check = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 @@ -24,7 +24,7 @@ def test_record_batch_next_batch_numeric(self, duckdb_cursor): chunk = record_batch_reader.read_next_batch() # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() @@ -38,7 +38,7 @@ def test_record_batch_next_batch_bool(self, duckdb_cursor): "CREATE table t as SELECT CASE WHEN i % 2 = 0 THEN true ELSE false END AS a from range(3000) as tbl(i);" ) query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 @@ -51,7 +51,7 @@ def test_record_batch_next_batch_bool(self, duckdb_cursor): # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() @@ -63,7 +63,7 @@ def test_record_batch_next_batch_varchar(self, duckdb_cursor): duckdb_cursor_check = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range::varchar a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 @@ -76,7 +76,7 @@ def test_record_batch_next_batch_varchar(self, duckdb_cursor): # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() @@ -90,7 +90,7 @@ def test_record_batch_next_batch_struct(self, duckdb_cursor): "CREATE table t as select {'x': i, 'y': i::varchar, 'z': i+1} as a from range(3000) as tbl(i);" ) query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 @@ -103,7 +103,7 @@ def test_record_batch_next_batch_struct(self, duckdb_cursor): # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() @@ -115,7 +115,7 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): duckdb_cursor_check = duckdb.connect() duckdb_cursor.execute("CREATE table t as select [i,i+1] as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 @@ -128,7 +128,7 @@ def test_record_batch_next_batch_list(self, duckdb_cursor): # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() @@ -141,7 +141,7 @@ def test_record_batch_next_batch_map(self, duckdb_cursor): duckdb_cursor_check = duckdb.connect() duckdb_cursor.execute("CREATE table t as select map([i], [i+1]) as a from range(3000) as tbl(i);") query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 @@ -154,7 +154,7 @@ def test_record_batch_next_batch_map(self, duckdb_cursor): # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() @@ -169,7 +169,7 @@ def test_record_batch_next_batch_with_null(self, duckdb_cursor): "CREATE table t as SELECT CASE WHEN i % 2 = 0 THEN i ELSE NULL END AS a from range(3000) as tbl(i);" ) query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) assert record_batch_reader.schema.names == ["a"] chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1024 @@ -182,7 +182,7 @@ def test_record_batch_next_batch_with_null(self, duckdb_cursor): # Check if we are producing the correct thing query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) res = duckdb_cursor_check.execute("select * from record_batch_reader").fetchall() correct = duckdb_cursor.execute("select * from t").fetchall() @@ -193,7 +193,7 @@ def test_record_batch_read_default(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch() + record_batch_reader = query.to_arrow_reader() chunk = record_batch_reader.read_next_batch() assert len(chunk) == 3000 @@ -201,7 +201,7 @@ def test_record_batch_next_batch_multiple_vectors_per_chunk(self, duckdb_cursor) duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(5000);") query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(2048) + record_batch_reader = query.to_arrow_reader(2048) chunk = record_batch_reader.read_next_batch() assert len(chunk) == 2048 chunk = record_batch_reader.read_next_batch() @@ -212,12 +212,12 @@ def test_record_batch_next_batch_multiple_vectors_per_chunk(self, duckdb_cursor) chunk = record_batch_reader.read_next_batch() query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1) + record_batch_reader = query.to_arrow_reader(1) chunk = record_batch_reader.read_next_batch() assert len(chunk) == 1 query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(2000) + record_batch_reader = query.to_arrow_reader(2000) chunk = record_batch_reader.read_next_batch() assert len(chunk) == 2000 @@ -226,15 +226,15 @@ def test_record_batch_next_batch_multiple_vectors_per_chunk_error(self, duckdb_c duckdb_cursor.execute("CREATE table t as select range a from range(5000);") query = duckdb_cursor.execute("SELECT a FROM t") with pytest.raises(RuntimeError, match="Approximate Batch Size of Record Batch MUST be higher than 0"): - query.fetch_record_batch(0) + query.to_arrow_reader(0) with pytest.raises(TypeError, match="incompatible function arguments"): - query.fetch_record_batch(-1) + query.to_arrow_reader(-1) def test_record_batch_reader_from_relation(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(3000);") relation = duckdb_cursor.table("t") - record_batch_reader = relation.fetch_record_batch() + record_batch_reader = relation.to_arrow_reader() chunk = record_batch_reader.read_next_batch() assert len(chunk) == 3000 @@ -242,7 +242,7 @@ def test_record_coverage(self, duckdb_cursor): duckdb_cursor = duckdb.connect() duckdb_cursor.execute("CREATE table t as select range a from range(2048);") query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(1024) + record_batch_reader = query.to_arrow_reader(1024) chunk = record_batch_reader.read_all() assert len(chunk) == 2048 @@ -268,7 +268,7 @@ def test_many_list_batches(self): # Because this produces multiple chunks, this caused a segfault before # because we changed some data in the first batch fetch - batch_iter = conn.execute(query).fetch_record_batch(chunk_size) + batch_iter = conn.execute(query).to_arrow_reader(chunk_size) for batch in batch_iter: del batch @@ -278,7 +278,7 @@ def test_many_chunk_sizes(self): query = duckdb_cursor.execute(f"CREATE table t as select range a from range({object_size});") for i in [1, 2, 4, 8, 16, 32, 33, 77, 999, 999999]: query = duckdb_cursor.execute("SELECT a FROM t") - record_batch_reader = query.fetch_record_batch(i) + record_batch_reader = query.to_arrow_reader(i) num_loops = int(object_size / i) for _j in range(num_loops): assert record_batch_reader.schema.names == ["a"] diff --git a/tests/fast/arrow/test_arrow_pycapsule.py b/tests/fast/arrow/test_arrow_pycapsule.py index 47f1542b..6d0319f4 100644 --- a/tests/fast/arrow/test_arrow_pycapsule.py +++ b/tests/fast/arrow/test_arrow_pycapsule.py @@ -2,6 +2,7 @@ import duckdb +pa = pytest.importorskip("pyarrow") pl = pytest.importorskip("polars") @@ -11,6 +12,73 @@ def polars_supports_capsule(): return Version(pl.__version__) >= Version("1.4.1") +class TestArrowPyCapsuleExport: + """Tests for the PyCapsule export path (rel.__arrow_c_stream__). + + Validates that the fast path (PhysicalArrowCollector + ArrowQueryResultStreamWrapper) + produces correct data, matching to_arrow_table() across types and edge cases. + """ + + def test_capsule_matches_to_arrow_table(self): + """Fast path produces identical data to to_arrow_table for various types.""" + conn = duckdb.connect() + sql = """ + SELECT + i AS int_col, + i::DOUBLE AS double_col, + 'row_' || i::VARCHAR AS str_col, + i % 2 = 0 AS bool_col, + CASE WHEN i % 3 = 0 THEN NULL ELSE i END AS nullable_col + FROM range(1000) t(i) + """ + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.equals(expected) + + def test_capsule_matches_to_arrow_table_nested_types(self): + """Fast path handles nested types (struct, list, map).""" + conn = duckdb.connect() + sql = """ + SELECT + {'x': i, 'y': i::VARCHAR} AS struct_col, + [i, i+1, i+2] AS list_col, + MAP {i::VARCHAR: i*10} AS map_col, + FROM range(100) t(i) + """ + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.equals(expected) + + def test_capsule_multi_batch(self): + """Data exceeding the 1M batch size produces multiple batches, all yielded correctly.""" + conn = duckdb.connect() + sql = "SELECT i, i::DOUBLE AS d FROM range(1500000) t(i)" + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.num_rows == 1500000 + assert actual.equals(expected) + + def test_capsule_empty_result(self): + """Empty result set produces a valid empty table with correct schema.""" + conn = duckdb.connect() + sql = "SELECT i AS a, i::VARCHAR AS b FROM range(10) t(i) WHERE i < 0" + expected = conn.sql(sql).to_arrow_table() + actual = pa.table(conn.sql(sql)) + assert actual.num_rows == 0 + assert actual.schema.equals(expected.schema) + + def test_capsule_slow_path_after_execute(self): + """Pre-executed relation takes the slow path (MaterializedQueryResult) and still works.""" + conn = duckdb.connect() + sql = "SELECT i, i::DOUBLE AS d FROM range(500) t(i)" + expected = conn.sql(sql).to_arrow_table() + + rel = conn.sql(sql) + rel.execute() # forces MaterializedCollector, not PhysicalArrowCollector + actual = pa.table(rel) + assert actual.equals(expected) + + @pytest.mark.skipif( not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface" ) @@ -29,21 +97,24 @@ def __arrow_c_stream__(self, requested_schema=None) -> object: obj = MyObject(df) # Call the __arrow_c_stream__ from within DuckDB + # MyObject has no __arrow_c_schema__, so GetSchema() falls back to __arrow_c_stream__ (1 call), + # then Produce() calls __arrow_c_stream__ again (1 call) = 2 calls minimum per scan. res = duckdb_cursor.sql("select * from obj") assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)] - assert obj.count == 1 + count_after_first = obj.count + assert count_after_first >= 2 # Call the __arrow_c_stream__ method and pass in the capsule instead capsule = obj.__arrow_c_stream__() res = duckdb_cursor.sql("select * from capsule") assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)] - assert obj.count == 2 + assert obj.count == count_after_first + 1 # Ensure __arrow_c_stream__ accepts a requested_schema argument as noop capsule = obj.__arrow_c_stream__(requested_schema="foo") # noqa: F841 res = duckdb_cursor.sql("select * from capsule") assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)] - assert obj.count == 3 + assert obj.count == count_after_first + 2 def test_capsule_roundtrip(self, duckdb_cursor): def create_capsule(): @@ -68,6 +139,16 @@ def test_automatic_reexecution(self, duckdb_cursor): assert len(res1) == 100 assert res1 == res2 + def test_pycapsule_rescan_error_type(self, duckdb_cursor): + """Issue #105: re-executing a relation backed by a consumed PyCapsule.""" + pa = pytest.importorskip("pyarrow") + tbl = pa.table({"a": [1]}) + capsule = tbl.__arrow_c_stream__() # noqa: F841 + rel = duckdb_cursor.sql("SELECT * FROM capsule") + rel.fetchall() # consumes the capsule + with pytest.raises(duckdb.InvalidInputException): + rel.fetchall() # re-execution should be InvalidInputException, not InternalException + def test_consumer_interface_roundtrip(self, duckdb_cursor): def create_table(): class MyTable: diff --git a/tests/fast/arrow/test_arrow_run_end_encoding.py b/tests/fast/arrow/test_arrow_run_end_encoding.py index e04f9ea0..6451f348 100644 --- a/tests/fast/arrow/test_arrow_run_end_encoding.py +++ b/tests/fast/arrow/test_arrow_run_end_encoding.py @@ -56,7 +56,7 @@ def test_arrow_run_end_encoding_numerics(self, duckdb_cursor, query, run_length, size = 127 query = query.format(run_length, value_type, size) rel = duckdb_cursor.sql(query) - array = rel.fetch_arrow_table()["ree"] + array = rel.to_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) @@ -128,7 +128,7 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) expected = duckdb_cursor.query(f"select {projection} from ree_tbl where {filter}").fetchall() # Create an Arrow Table from the table - arrow_conversion = rel.fetch_arrow_table() + arrow_conversion = rel.to_arrow_table() arrays = { "ree": arrow_conversion["ree"], "a": arrow_conversion["a"], @@ -157,7 +157,7 @@ def test_arrow_run_end_encoding(self, duckdb_cursor, dbtype, val1, val2, filter) def test_arrow_ree_empty_table(self, duckdb_cursor): duckdb_cursor.query("create table tbl (ree integer)") rel = duckdb_cursor.table("tbl") - array = rel.fetch_arrow_table()["ree"] + array = rel.to_arrow_table()["ree"] expected = rel.fetchall() encoded_array = pc.run_end_encode(array) @@ -194,7 +194,7 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): ) # Fetch the result as an Arrow Table - result = duckdb_cursor.table("tbl").fetch_arrow_table() + result = duckdb_cursor.table("tbl").to_arrow_table() # Turn 'ree' into a run-end-encoded array and reconstruct a table from it arrays = { @@ -225,7 +225,7 @@ def test_arrow_ree_projections(self, duckdb_cursor, projection): f""" select {projection} from arrow_tbl """ - ).fetch_arrow_table() + ).to_arrow_table() # Verify correctness by fetching from the original table and the constructed result expected = duckdb_cursor.query(f"select {projection} from tbl").fetchall() @@ -249,7 +249,7 @@ def test_arrow_ree_list(self, duckdb_cursor, create_list): """ select * from tbl """ - ).fetch_arrow_table() + ).to_arrow_table() columns = unstructured.columns # Run-encode the first column ('ree') @@ -273,7 +273,7 @@ def test_arrow_ree_list(self, duckdb_cursor, create_list): structured = pa.chunked_array(structured_chunks) arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) - result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() + result = duckdb_cursor.query("select * from arrow_tbl").to_arrow_table() assert arrow_tbl.to_pylist() == result.to_pylist() def test_arrow_ree_struct(self, duckdb_cursor): @@ -294,7 +294,7 @@ def test_arrow_ree_struct(self, duckdb_cursor): """ select * from tbl """ - ).fetch_arrow_table() + ).to_arrow_table() columns = unstructured.columns # Run-encode the first column ('ree') @@ -303,13 +303,13 @@ def test_arrow_ree_struct(self, duckdb_cursor): # Create a (chunked) StructArray from the chunked arrays (columns) of the ArrowTable names = unstructured.column_names iterables = [x.iterchunks() for x in columns] - zipped = zip(*iterables) + zipped = zip(*iterables, strict=False) structured_chunks = [pa.StructArray.from_arrays(list(x), names=names) for x in zipped] structured = pa.chunked_array(structured_chunks) arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) # noqa: F841 - result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # noqa: F841 + result = duckdb_cursor.query("select * from arrow_tbl").to_arrow_table() # noqa: F841 expected = duckdb_cursor.query("select {'ree': ree, 'a': a, 'b': b, 'c': c} as s from tbl").fetchall() actual = duckdb_cursor.query("select * from result").fetchall() @@ -336,7 +336,7 @@ def test_arrow_ree_union(self, duckdb_cursor): """ select * from tbl """ - ).fetch_arrow_table() + ).to_arrow_table() columns = unstructured.columns # Run-encode the first column ('ree') @@ -345,7 +345,7 @@ def test_arrow_ree_union(self, duckdb_cursor): # Create a (chunked) UnionArray from the chunked arrays (columns) of the ArrowTable names = unstructured.column_names iterables = [x.iterchunks() for x in columns] - zipped = zip(*iterables) + zipped = zip(*iterables, strict=False) structured_chunks = [] for chunk in zipped: @@ -358,7 +358,7 @@ def test_arrow_ree_union(self, duckdb_cursor): structured = pa.chunked_array(structured_chunks) arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) # noqa: F841 - result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() # noqa: F841 + result = duckdb_cursor.query("select * from arrow_tbl").to_arrow_table() # noqa: F841 # Recreate the same result set expected = [] @@ -392,7 +392,7 @@ def test_arrow_ree_map(self, duckdb_cursor): """ select * from tbl """ - ).fetch_arrow_table() + ).to_arrow_table() columns = unstructured.columns # Run-encode the first column ('ree') @@ -400,7 +400,7 @@ def test_arrow_ree_map(self, duckdb_cursor): # Create a (chunked) MapArray from the chunked arrays (columns) of the ArrowTable iterables = [x.iterchunks() for x in columns] - zipped = zip(*iterables) + zipped = zip(*iterables, strict=False) structured_chunks = [] for chunk in zipped: @@ -418,7 +418,7 @@ def test_arrow_ree_map(self, duckdb_cursor): structured = pa.chunked_array(structured_chunks) arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) - result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() + result = duckdb_cursor.query("select * from arrow_tbl").to_arrow_table() # Verify that the resulting scan is the same as the input assert result.to_pylist() == arrow_tbl.to_pylist() @@ -440,7 +440,7 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): """ select * from tbl """ - ).fetch_arrow_table() + ).to_arrow_table() columns = unstructured.columns # Run-encode the first column ('ree') @@ -458,7 +458,7 @@ def test_arrow_ree_dictionary(self, duckdb_cursor): structured = pa.chunked_array(structured_chunks) arrow_tbl = pa.Table.from_arrays([structured], names=["ree"]) - result = duckdb_cursor.query("select * from arrow_tbl").fetch_arrow_table() + result = duckdb_cursor.query("select * from arrow_tbl").to_arrow_table() # Verify that the resulting scan is the same as the input assert result.to_pylist() == arrow_tbl.to_pylist() diff --git a/tests/fast/arrow/test_arrow_stream_scan.py b/tests/fast/arrow/test_arrow_stream_scan.py new file mode 100644 index 00000000..2370ec13 --- /dev/null +++ b/tests/fast/arrow/test_arrow_stream_scan.py @@ -0,0 +1,414 @@ +import contextlib +import subprocess +import sys + +import pytest + +import duckdb + +pa = pytest.importorskip("pyarrow") +ds = pytest.importorskip("pyarrow.dataset") + + +class ArrowStream: + """Minimal PyCapsuleInterface wrapper around a PyArrow table. + + This represents any third-party library (not Polars, not PyArrow) that + implements the Arrow PyCapsule interface. DuckDB's replacement scan + handles Polars and PyArrow types explicitly before falling through to + PyCapsuleInterface detection via GetArrowType(), so we need a wrapper + like this to exercise that code path. + """ + + def __init__(self, tbl) -> None: + self.tbl = tbl + self.stream_count = 0 + + def __arrow_c_stream__(self, requested_schema=None): # noqa: ANN204 + self.stream_count += 1 + return self.tbl.__arrow_c_stream__(requested_schema=requested_schema) + + +class ArrowStreamWithSchema(ArrowStream): + """PyCapsuleInterface wrapper that also exposes __arrow_c_schema__.""" + + def __arrow_c_schema__(self): # noqa: ANN204 + return self.tbl.schema.__arrow_c_schema__() + + +class ArrowStreamWithDotSchema(ArrowStream): + """PyCapsuleInterface wrapper that exposes .schema (pyarrow schema with _export_to_c).""" + + def __init__(self, tbl) -> None: + super().__init__(tbl) + self.schema = tbl.schema + + +class SingleUseArrowStream: + """PyCapsuleInterface that can only produce one stream, but exposes .schema.""" + + def __init__(self, tbl) -> None: + self.tbl = tbl + self.schema = tbl.schema + self.stream_count = 0 + + def __arrow_c_stream__(self, requested_schema=None): # noqa: ANN204 + self.stream_count += 1 + if self.stream_count > 1: + msg = "Stream already consumed" + raise RuntimeError(msg) + return self.tbl.__arrow_c_stream__(requested_schema=requested_schema) + + +class TestPyCapsuleInterfaceMultiScan: + """Issue #70: queries requiring multiple scans of an arrow stream. + + PyCapsuleInterface objects support multi-scan because each call to + __arrow_c_stream__() produces a fresh stream. + """ + + def test_union_all(self, duckdb_cursor): + """UNION ALL scans the same PyCapsuleInterface twice in one query.""" + obj = ArrowStream(pa.table({"id": [1, 2, 3, 4, 5]})) # noqa: F841 + result = duckdb_cursor.sql("SELECT id FROM obj UNION ALL SELECT id + 1 FROM obj").fetchall() + ids = sorted(r[0] for r in result) + assert ids == sorted([1, 2, 3, 4, 5, 2, 3, 4, 5, 6]) + + def test_rescan_across_queries(self, duckdb_cursor): + """PyCapsuleInterface scanned in two consecutive queries.""" + obj = ArrowStream(pa.table({"id": [1, 2, 3]})) # noqa: F841 + r1 = duckdb_cursor.sql("SELECT * FROM obj").fetchall() + r2 = duckdb_cursor.sql("SELECT * FROM obj").fetchall() + assert r1 == r2 == [(1,), (2,), (3,)] + + def test_register(self, duckdb_cursor): + """PyCapsuleInterface registered via register() supports multi-scan.""" + obj = ArrowStream(pa.table({"id": [1, 2, 3]})) + duckdb_cursor.register("my_stream", obj) + result = duckdb_cursor.sql("SELECT id FROM my_stream UNION ALL SELECT id FROM my_stream").fetchall() + assert len(result) == 6 + + def test_from_arrow(self, duckdb_cursor): + """PyCapsuleInterface passed to from_arrow() supports multi-scan.""" + obj = ArrowStream(pa.table({"id": [1, 2, 3]})) + rel = duckdb_cursor.from_arrow(obj) + r1 = rel.fetchall() + r2 = rel.fetchall() + assert r1 == r2 == [(1,), (2,), (3,)] + + def test_self_join(self, duckdb_cursor): + """Self-join on PyCapsuleInterface requires two scans.""" + obj = ArrowStream(pa.table({"id": [1, 2, 3], "val": [10, 20, 30]})) # noqa: F841 + result = duckdb_cursor.sql("SELECT a.id, b.val FROM obj a JOIN obj b ON a.id = b.id").fetchall() + assert sorted(result) == [(1, 10), (2, 20), (3, 30)] + + +class TestPyCapsuleInterfacePushdown: + """PyCapsuleInterface objects get projection and filter pushdown via arrow_scan.""" + + def test_projection_pushdown(self, duckdb_cursor): + """Selecting a subset of columns only reads those columns.""" + obj = ArrowStream(pa.table({"a": [1, 2, 3], "b": [10, 20, 30], "c": ["x", "y", "z"]})) # noqa: F841 + result = duckdb_cursor.sql("SELECT a FROM obj").fetchall() + assert result == [(1,), (2,), (3,)] + + def test_filter_pushdown(self, duckdb_cursor): + """Filters are pushed down to the arrow scanner.""" + obj = ArrowStream(pa.table({"a": [1, 2, 3, 4, 5], "b": [10, 20, 30, 40, 50]})) # noqa: F841 + result = duckdb_cursor.sql("SELECT a, b FROM obj WHERE a > 3").fetchall() + assert sorted(result) == [(4, 40), (5, 50)] + + def test_combined_pushdown(self, duckdb_cursor): + """Projection + filter pushdown combined.""" + obj = ArrowStream(pa.table({"a": [1, 2, 3, 4, 5], "b": [10, 20, 30, 40, 50]})) # noqa: F841 + result = duckdb_cursor.sql("SELECT b FROM obj WHERE a <= 2").fetchall() + assert sorted(result) == [(10,), (20,)] + + +class TestPyCapsuleInterfaceSchemaOptimization: + """GetSchema() uses __arrow_c_schema__ when available to avoid allocating a stream.""" + + def test_arrow_c_schema_avoids_stream_call(self, duckdb_cursor): + """When __arrow_c_schema__ is available, GetSchema() does not call __arrow_c_stream__.""" + obj = ArrowStreamWithSchema(pa.table({"a": [1, 2, 3]})) + duckdb_cursor.sql("SELECT * FROM obj").fetchall() + # With __arrow_c_schema__: only Produce() calls __arrow_c_stream__ (1 call). + # Without it: GetSchema() fallback + Produce() = 2 calls. + assert obj.stream_count == 1 + + def test_without_arrow_c_schema_uses_stream_fallback(self, duckdb_cursor): + """Without __arrow_c_schema__, GetSchema() falls back to __arrow_c_stream__.""" + obj = ArrowStream(pa.table({"a": [1, 2, 3]})) + duckdb_cursor.sql("SELECT * FROM obj").fetchall() + # GetSchema() fallback (1) + Produce() (1) = 2 calls minimum + assert obj.stream_count >= 2 + + def test_dot_schema_avoids_stream_call(self, duckdb_cursor): + """When .schema with _export_to_c is available, GetSchema() uses it instead of __arrow_c_stream__.""" + obj = ArrowStreamWithDotSchema(pa.table({"a": [1, 2, 3]})) + result = duckdb_cursor.sql("SELECT * FROM obj").fetchall() + assert result == [(1,), (2,), (3,)] + # With .schema: only Produce() calls __arrow_c_stream__ (1 call). + assert obj.stream_count == 1 + + def test_schema_via_dotschema_preserves_stream(self, duckdb_cursor): + """A SingleUseArrowStream can be scanned because GetSchema uses .schema.""" + obj = SingleUseArrowStream(pa.table({"a": [1, 2, 3], "b": [10, 20, 30]})) + result = duckdb_cursor.sql("SELECT a, b FROM obj").fetchall() + assert sorted(result) == [(1, 10), (2, 20), (3, 30)] + # Only 1 call to __arrow_c_stream__ (from Produce), schema came from .schema + assert obj.stream_count == 1 + + def test_schema_fallback_order(self, duckdb_cursor): + """Schema extraction priority: __arrow_c_schema__ > .schema._export_to_c > __arrow_c_stream__.""" + # Object with __arrow_c_schema__ — should use that, not .schema or stream + obj_with_capsule_schema = ArrowStreamWithSchema(pa.table({"x": [1]})) + duckdb_cursor.sql("SELECT * FROM obj_with_capsule_schema").fetchall() + assert obj_with_capsule_schema.stream_count == 1 # only Produce + + # Object with .schema — should use that, not stream + obj_with_dot_schema = ArrowStreamWithDotSchema(pa.table({"x": [1]})) + duckdb_cursor.sql("SELECT * FROM obj_with_dot_schema").fetchall() + assert obj_with_dot_schema.stream_count == 1 # only Produce + + # Object with neither — falls back to stream + obj_bare = ArrowStream(pa.table({"x": [1]})) + duckdb_cursor.sql("SELECT * FROM obj_bare").fetchall() + assert obj_bare.stream_count >= 2 # GetSchema + Produce + + +class TestPyArrowTableUnifiedPath: + """PyArrow Table now enters via __arrow_c_stream__ (PyCapsuleInterface path). + + This verifies that Table gets multi-scan, pushdown, and correct results + through the unified path instead of the old dedicated Table branch. + """ + + def test_pyarrow_table_scan(self, duckdb_cursor): + """Basic scan of a PyArrow Table through the unified path.""" + tbl = pa.table({"a": [1, 2, 3], "b": [10, 20, 30]}) # noqa: F841 + result = duckdb_cursor.sql("SELECT * FROM tbl").fetchall() + assert sorted(result) == [(1, 10), (2, 20), (3, 30)] + + def test_pyarrow_table_projection(self, duckdb_cursor): + """Projection pushdown on a PyArrow Table.""" + tbl = pa.table({"a": [1, 2, 3], "b": [10, 20, 30], "c": ["x", "y", "z"]}) # noqa: F841 + result = duckdb_cursor.sql("SELECT a FROM tbl").fetchall() + assert result == [(1,), (2,), (3,)] + + def test_pyarrow_table_filter(self, duckdb_cursor): + """Filter pushdown on a PyArrow Table.""" + tbl = pa.table({"a": [1, 2, 3, 4, 5], "b": [10, 20, 30, 40, 50]}) # noqa: F841 + result = duckdb_cursor.sql("SELECT a, b FROM tbl WHERE a > 3").fetchall() + assert sorted(result) == [(4, 40), (5, 50)] + + def test_pyarrow_table_combined_pushdown(self, duckdb_cursor): + """Projection + filter pushdown on a PyArrow Table.""" + tbl = pa.table({"a": [1, 2, 3, 4, 5], "b": [10, 20, 30, 40, 50]}) # noqa: F841 + result = duckdb_cursor.sql("SELECT b FROM tbl WHERE a <= 2").fetchall() + assert sorted(result) == [(10,), (20,)] + + def test_pyarrow_table_union_all(self, duckdb_cursor): + """Table scanned twice in one query via UNION ALL.""" + tbl = pa.table({"id": [1, 2, 3]}) # noqa: F841 + result = duckdb_cursor.sql("SELECT id FROM tbl UNION ALL SELECT id FROM tbl").fetchall() + assert sorted(r[0] for r in result) == [1, 1, 2, 2, 3, 3] + + def test_pyarrow_table_rescan(self, duckdb_cursor): + """Table can be scanned across multiple queries.""" + tbl = pa.table({"id": [1, 2, 3]}) # noqa: F841 + r1 = duckdb_cursor.sql("SELECT * FROM tbl").fetchall() + r2 = duckdb_cursor.sql("SELECT * FROM tbl").fetchall() + assert r1 == r2 == [(1,), (2,), (3,)] + + +class TestRecordBatchReaderSingleUse: + """RecordBatchReaders are inherently single-use streams. + + After the first scan consumes the reader, subsequent scans return empty results. + This is correct behavior — RecordBatchReaders represent forward-only streams + (e.g., reading from a socket or file). + """ + + def test_second_scan_empty(self, duckdb_cursor): + """Second scan of a RecordBatchReader returns empty results.""" + reader = pa.RecordBatchReader.from_batches( # noqa: F841 + pa.schema([("id", pa.int64())]), + [pa.record_batch([pa.array([1, 2, 3])], names=["id"])], + ) + r1 = duckdb_cursor.sql("SELECT * FROM reader").fetchall() + assert r1 == [(1,), (2,), (3,)] + r2 = duckdb_cursor.sql("SELECT * FROM reader").fetchall() + assert r2 == [] + + def test_register_second_scan_empty(self, duckdb_cursor): + """Registered RecordBatchReader is also single-use.""" + reader = pa.RecordBatchReader.from_batches( + pa.schema([("id", pa.int64())]), + [pa.record_batch([pa.array([1, 2, 3])], names=["id"])], + ) + duckdb_cursor.register("my_reader", reader) + r1 = duckdb_cursor.sql("SELECT * FROM my_reader").fetchall() + assert r1 == [(1,), (2,), (3,)] + r2 = duckdb_cursor.sql("SELECT * FROM my_reader").fetchall() + assert r2 == [] + + def test_has_pushdown(self, duckdb_cursor): + """RecordBatchReader gets projection/filter pushdown (not materialized).""" + reader = pa.RecordBatchReader.from_batches( # noqa: F841 + pa.schema([("a", pa.int64()), ("b", pa.int64())]), + [pa.record_batch([pa.array([1, 2, 3]), pa.array([10, 20, 30])], names=["a", "b"])], + ) + result = duckdb_cursor.sql("SELECT b FROM reader WHERE a > 1").fetchall() + assert sorted(result) == [(20,), (30,)] + + +class TestPyCapsuleConsumed: + """Issue #105: scanning a bare PyCapsule twice. + + Bare PyCapsules are single-use (the capsule IS the stream, not a stream factory). + The fix ensures a clear InvalidInputException instead of InternalException. + """ + + def test_error_type(self, duckdb_cursor): + """Consumed PyCapsule raises InvalidInputException, not InternalException.""" + tbl = pa.table({"a": [1]}) + capsule = tbl.__arrow_c_stream__() # noqa: F841 + duckdb_cursor.sql("SELECT * FROM capsule").fetchall() + # Error thrown by GetArrowType() in pyconnection.cpp when it detects the released stream. + with pytest.raises(duckdb.InvalidInputException, match="The ArrowArrayStream was already released"): + duckdb_cursor.sql("SELECT * FROM capsule") + + def test_pycapsule_interface_not_affected(self, duckdb_cursor): + """Scanning through the PyCapsuleInterface object (not the capsule) works repeatedly.""" + obj = ArrowStream(pa.table({"a": [1, 2, 3]})) # noqa: F841 + + # First scan + r1 = duckdb_cursor.sql("SELECT * FROM obj").fetchall() + assert r1 == [(1,), (2,), (3,)] + + # Second scan — works because __arrow_c_stream__() is called lazily each time + r2 = duckdb_cursor.sql("SELECT * FROM obj").fetchall() + assert r2 == [(1,), (2,), (3,)] + + +class TestSameConnectionRecordBatchReader: + """Issue #85: DuckDB-originated RecordBatchReader on the same connection. + + When conn.sql(...).to_arrow_reader() returns a RecordBatchReader backed by + the same connection, scanning it on that connection may deadlock or return + empty results due to lock contention. Run in subprocess to avoid hanging + the test suite. The workaround is to use a different connection for the scan. + """ + + def test_same_connection_no_data(self): + """Same-connection RecordBatchReader scan fails to return data. + + Run in subprocess to prevent hanging the test suite if it deadlocks. + """ + code = """\ +import duckdb +conn = duckdb.connect("") +reader = conn.sql("FROM range(5) T(a)").to_arrow_reader() +result = conn.sql("FROM reader").fetchall() +assert result != [(i,) for i in range(5)], "Expected no data due to lock contention" +""" + with contextlib.suppress(subprocess.TimeoutExpired): + subprocess.run( + [sys.executable, "-c", code], + timeout=5, + capture_output=True, + ) + + def test_different_connection_works(self, duckdb_cursor): + """RecordBatchReader from connection A scanned on connection B works fine.""" + conn_a = duckdb.connect() + conn_b = duckdb.connect() + reader = conn_a.sql("FROM range(5) T(a)").to_arrow_reader() # noqa: F841 + result = conn_b.sql("FROM reader").fetchall() + assert result == [(i,) for i in range(5)] + + def test_arrow_method_different_connection(self, duckdb_cursor): + """The .arrow() method (which returns RecordBatchReader) works cross-connection.""" + conn_a = duckdb.connect() + conn_b = duckdb.connect() + arrow_reader = conn_a.sql("FROM range(5) T(a)").arrow() # noqa: F841 + result = conn_b.sql("FROM arrow_reader").fetchall() + assert result == [(i,) for i in range(5)] + + +class TestPyCapsuleInterfaceNoPyarrowDataset: + """Tier B fallback: PyCapsuleInterface objects are scannable without pyarrow.dataset. + + When pyarrow.dataset is not available, PyCapsuleInterface uses arrow_scan_dumb + (no pushdown). DuckDB handles projection/filter post-scan. + Run in subprocess to avoid polluting the test process's import state. + """ + + def _run_in_subprocess(self, code): + result = subprocess.run( + [sys.executable, "-c", code], + timeout=30, + capture_output=True, + text=True, + ) + if result.returncode != 0: + msg = f"Subprocess failed (rc={result.returncode}):\nstdout: {result.stdout}\nstderr: {result.stderr}" + raise AssertionError(msg) + + def test_pycapsule_interface_no_pyarrow_dataset(self): + """PyCapsuleInterface objects can be scanned without pyarrow.dataset.""" + self._run_in_subprocess("""\ +import pyarrow as pa +import duckdb + +class MyStream: + def __init__(self, tbl): + self.tbl = tbl + def __arrow_c_stream__(self, requested_schema=None): + return self.tbl.__arrow_c_stream__(requested_schema=requested_schema) + def __arrow_c_schema__(self): + return self.tbl.schema.__arrow_c_schema__() + +obj = MyStream(pa.table({"a": [1, 2, 3], "b": [10, 20, 30]})) +result = duckdb.sql("SELECT * FROM obj").fetchall() +assert sorted(result) == [(1, 10), (2, 20), (3, 30)], f"Unexpected: {result}" +""") + + def test_pycapsule_interface_no_pyarrow_dataset_projection(self): + """DuckDB applies projection post-scan when pyarrow.dataset unavailable.""" + self._run_in_subprocess("""\ +import pyarrow as pa +import duckdb + +class MyStream: + def __init__(self, tbl): + self.tbl = tbl + def __arrow_c_stream__(self, requested_schema=None): + return self.tbl.__arrow_c_stream__(requested_schema=requested_schema) + def __arrow_c_schema__(self): + return self.tbl.schema.__arrow_c_schema__() + +obj = MyStream(pa.table({"a": [1, 2, 3], "b": [10, 20, 30], "c": ["x", "y", "z"]})) +result = duckdb.sql("SELECT a FROM obj").fetchall() +assert result == [(1,), (2,), (3,)], f"Unexpected: {result}" +""") + + def test_pycapsule_interface_no_pyarrow_dataset_filter(self): + """DuckDB applies filter post-scan when pyarrow.dataset unavailable.""" + self._run_in_subprocess("""\ +import pyarrow as pa +import duckdb + +class MyStream: + def __init__(self, tbl): + self.tbl = tbl + def __arrow_c_stream__(self, requested_schema=None): + return self.tbl.__arrow_c_stream__(requested_schema=requested_schema) + def __arrow_c_schema__(self): + return self.tbl.schema.__arrow_c_schema__() + +obj = MyStream(pa.table({"a": [1, 2, 3, 4, 5], "b": [10, 20, 30, 40, 50]})) +result = duckdb.sql("SELECT a, b FROM obj WHERE a > 3").fetchall() +assert sorted(result) == [(4, 40), (5, 50)], f"Unexpected: {result}" +""") diff --git a/tests/fast/arrow/test_arrow_string_view.py b/tests/fast/arrow/test_arrow_string_view.py index 9ed9bece..abba336e 100644 --- a/tests/fast/arrow/test_arrow_string_view.py +++ b/tests/fast/arrow/test_arrow_string_view.py @@ -13,7 +13,7 @@ def RoundTripStringView(query, array): con = duckdb.connect() con.execute("SET produce_arrow_string_view=True") - arrow_tbl = con.execute(query).fetch_arrow_table() + arrow_tbl = con.execute(query).to_arrow_table() # Assert that we spit the same as the defined array arrow_tbl[0].validate(full=True) assert arrow_tbl[0].combine_chunks().tolist() == array.tolist() @@ -27,14 +27,14 @@ def RoundTripStringView(query, array): # Create a table using the schema and the array gt_table = pa.Table.from_arrays([array], schema=schema) # noqa: F841 - arrow_table = con.execute("select * from gt_table").fetch_arrow_table() # noqa: F841 + arrow_table = con.execute("select * from gt_table").to_arrow_table() # noqa: F841 assert arrow_tbl[0].combine_chunks().tolist() == array.tolist() def RoundTripDuckDBInternal(query): con = duckdb.connect() con.execute("SET produce_arrow_string_view=True") - arrow_tbl = con.execute(query).fetch_arrow_table() + arrow_tbl = con.execute(query).to_arrow_table() arrow_tbl.validate(full=True) res = con.execute(query).fetchall() from_arrow_res = con.execute("FROM arrow_tbl order by str").fetchall() diff --git a/tests/fast/arrow/test_arrow_types.py b/tests/fast/arrow/test_arrow_types.py index be03009c..5f884f6a 100644 --- a/tests/fast/arrow/test_arrow_types.py +++ b/tests/fast/arrow/test_arrow_types.py @@ -12,7 +12,7 @@ def test_null_type(self, duckdb_cursor): inputs = [pa.array([None, None, None], type=pa.null())] arrow_table = pa.Table.from_arrays(inputs, schema=schema) duckdb_cursor.register("testarrow", arrow_table) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() # We turn it to an array of int32 nulls schema = pa.schema([("data", pa.int32())]) inputs = [pa.array([None, None, None], type=pa.null())] diff --git a/tests/fast/arrow/test_arrow_union.py b/tests/fast/arrow/test_arrow_union.py index 784a5433..1c89c388 100644 --- a/tests/fast/arrow/test_arrow_union.py +++ b/tests/fast/arrow/test_arrow_union.py @@ -36,7 +36,7 @@ def test_unions_with_struct(duckdb_cursor): ) rel = duckdb_cursor.table("tbl") - arrow = rel.fetch_arrow_table() # noqa: F841 + arrow = rel.to_arrow_table() # noqa: F841 duckdb_cursor.execute("create table other as select * from arrow") rel2 = duckdb_cursor.table("other") @@ -45,4 +45,4 @@ def test_unions_with_struct(duckdb_cursor): def run(conn, query): - return conn.sql(query).fetch_arrow_table().columns[0][0] + return conn.sql(query).to_arrow_table().columns[0][0] diff --git a/tests/fast/arrow/test_arrow_version_format.py b/tests/fast/arrow/test_arrow_version_format.py index d2864b15..7e5e37c3 100644 --- a/tests/fast/arrow/test_arrow_version_format.py +++ b/tests/fast/arrow/test_arrow_version_format.py @@ -20,7 +20,7 @@ def test_decimal_v1_5(self, duckdb_cursor): ], pa.schema([("data", pa.decimal32(5, 2))]), ) - col_type = duckdb_cursor.execute("FROM decimal_32").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("FROM decimal_32").to_arrow_table().schema.field("data").type assert col_type.bit_width == 32 assert pa.types.is_decimal(col_type) @@ -33,12 +33,12 @@ def test_decimal_v1_5(self, duckdb_cursor): ], pa.schema([("data", pa.decimal64(16, 3))]), ) - col_type = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("FROM decimal_64").to_arrow_table().schema.field("data").type assert col_type.bit_width == 64 assert pa.types.is_decimal(col_type) for version in ["1.0", "1.1", "1.2", "1.3", "1.4"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") - result = duckdb_cursor.execute("FROM decimal_32").fetch_arrow_table() + result = duckdb_cursor.execute("FROM decimal_32").to_arrow_table() col_type = result.schema.field("data").type assert col_type.bit_width == 128 assert pa.types.is_decimal(col_type) @@ -46,7 +46,7 @@ def test_decimal_v1_5(self, duckdb_cursor): "data": [Decimal("100.20"), Decimal("110.21"), Decimal("31.20"), Decimal("500.20")] } - result = duckdb_cursor.execute("FROM decimal_64").fetch_arrow_table() + result = duckdb_cursor.execute("FROM decimal_64").to_arrow_table() col_type = result.schema.field("data").type assert col_type.bit_width == 128 assert pa.types.is_decimal(col_type) @@ -64,31 +64,31 @@ def test_view_v1_4(self, duckdb_cursor): duckdb_cursor.execute("SET arrow_output_version = 1.5") duckdb_cursor.execute("SET produce_arrow_string_view=True") duckdb_cursor.execute("SET arrow_output_list_view=True") - col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT 'string' as data ").to_arrow_table().schema.field("data").type assert pa.types.is_string_view(col_type) - col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT ['string'] as data ").to_arrow_table().schema.field("data").type assert pa.types.is_list_view(col_type) for version in ["1.0", "1.1", "1.2", "1.3"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") - col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT 'string' as data ").to_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) - col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT ['string'] as data ").to_arrow_table().schema.field("data").type assert not pa.types.is_list_view(col_type) for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") - col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT 'string' as data ").to_arrow_table().schema.field("data").type assert pa.types.is_string_view(col_type) - col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT ['string'] as data ").to_arrow_table().schema.field("data").type assert pa.types.is_list_view(col_type) duckdb_cursor.execute("SET produce_arrow_string_view=False") duckdb_cursor.execute("SET arrow_output_list_view=False") for version in ["1.4", "1.5"]: duckdb_cursor.execute(f"SET arrow_output_version = {version}") - col_type = duckdb_cursor.execute("SELECT 'string' as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT 'string' as data ").to_arrow_table().schema.field("data").type assert not pa.types.is_string_view(col_type) - col_type = duckdb_cursor.execute("SELECT ['string'] as data ").fetch_arrow_table().schema.field("data").type + col_type = duckdb_cursor.execute("SELECT ['string'] as data ").to_arrow_table().schema.field("data").type assert not pa.types.is_list_view(col_type) diff --git a/tests/fast/arrow/test_buffer_size_option.py b/tests/fast/arrow/test_buffer_size_option.py index c63adfcc..6f5c41af 100644 --- a/tests/fast/arrow/test_buffer_size_option.py +++ b/tests/fast/arrow/test_buffer_size_option.py @@ -11,23 +11,23 @@ def test_arrow_buffer_size(self): con = duckdb.connect() # All small string - res = con.query("select 'bla'").fetch_arrow_table() + res = con.query("select 'bla'").to_arrow_table() assert res[0][0].type == pa.string() - res = con.query("select 'bla'").fetch_record_batch() + res = con.query("select 'bla'").to_arrow_reader() assert res.schema[0].type == pa.string() # All Large String con.execute("SET arrow_large_buffer_size=True") - res = con.query("select 'bla'").fetch_arrow_table() + res = con.query("select 'bla'").to_arrow_table() assert res[0][0].type == pa.large_string() - res = con.query("select 'bla'").fetch_record_batch() + res = con.query("select 'bla'").to_arrow_reader() assert res.schema[0].type == pa.large_string() # All small string again con.execute("SET arrow_large_buffer_size=False") - res = con.query("select 'bla'").fetch_arrow_table() + res = con.query("select 'bla'").to_arrow_table() assert res[0][0].type == pa.string() - res = con.query("select 'bla'").fetch_record_batch() + res = con.query("select 'bla'").to_arrow_reader() assert res.schema[0].type == pa.string() def test_arrow_buffer_size_udf(self): @@ -37,12 +37,12 @@ def just_return(x): con = duckdb.connect() con.create_function("just_return", just_return, [VARCHAR], VARCHAR, type="arrow") - res = con.query("select just_return('bla')").fetch_arrow_table() + res = con.query("select just_return('bla')").to_arrow_table() assert res[0][0].type == pa.string() # All Large String con.execute("SET arrow_large_buffer_size=True") - res = con.query("select just_return('bla')").fetch_arrow_table() + res = con.query("select just_return('bla')").to_arrow_table() assert res[0][0].type == pa.large_string() diff --git a/tests/fast/arrow/test_dataset.py b/tests/fast/arrow/test_dataset.py index 36e29110..595c384a 100644 --- a/tests/fast/arrow/test_dataset.py +++ b/tests/fast/arrow/test_dataset.py @@ -77,7 +77,7 @@ def test_parallel_dataset_roundtrip(self, duckdb_cursor): duckdb_conn.register("dataset", userdata_parquet_dataset) query = duckdb_conn.execute("SELECT * FROM dataset order by id") - record_batch_reader = query.fetch_record_batch(2048) + record_batch_reader = query.to_arrow_reader(2048) arrow_table = record_batch_reader.read_all() # noqa: F841 # reorder since order of rows isn't deterministic @@ -94,7 +94,7 @@ def test_ducktyping(self, duckdb_cursor): duckdb_conn = duckdb.connect() dataset = CustomDataset() # noqa: F841 query = duckdb_conn.execute("SELECT b FROM dataset WHERE a < 5") - record_batch_reader = query.fetch_record_batch(2048) + record_batch_reader = query.to_arrow_reader(2048) arrow_table = record_batch_reader.read_all() assert arrow_table.equals(CustomDataset.DATA[:5].select(["b"])) diff --git a/tests/fast/arrow/test_date.py b/tests/fast/arrow/test_date.py index 20cf9f0f..951b85ef 100644 --- a/tests/fast/arrow/test_date.py +++ b/tests/fast/arrow/test_date.py @@ -15,7 +15,7 @@ def test_date_types(self, duckdb_cursor): data = (pa.array([1000 * 60 * 60 * 24], type=pa.date64()), pa.array([1], type=pa.date32())) arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == arrow_table["b"] assert rel["b"] == arrow_table["b"] @@ -24,7 +24,7 @@ def test_date_null(self, duckdb_cursor): return data = (pa.array([None], type=pa.date64()), pa.array([None], type=pa.date32())) arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == arrow_table["b"] assert rel["b"] == arrow_table["b"] @@ -38,6 +38,6 @@ def test_max_date(self, duckdb_cursor): pa.array([2147483647], type=pa.date32()), ) arrow_table = pa.Table.from_arrays([data[0], data[1]], ["a", "b"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == result["a"] assert rel["b"] == result["b"] diff --git a/tests/fast/arrow/test_dictionary_arrow.py b/tests/fast/arrow/test_dictionary_arrow.py index 32c348a3..7a6a5be7 100644 --- a/tests/fast/arrow/test_dictionary_arrow.py +++ b/tests/fast/arrow/test_dictionary_arrow.py @@ -127,7 +127,7 @@ def test_dictionary_roundtrip(self, query, element, duckdb_cursor, count): query = query.format(element, count) original_rel = duckdb_cursor.sql(query) expected = original_rel.fetchall() - arrow_res = original_rel.fetch_arrow_table() # noqa: F841 + arrow_res = original_rel.to_arrow_table() # noqa: F841 roundtrip_rel = duckdb_cursor.sql("select * from arrow_res") actual = roundtrip_rel.fetchall() diff --git a/tests/fast/arrow/test_filter_pushdown.py b/tests/fast/arrow/test_filter_pushdown.py index 225f48c0..d2eea92e 100644 --- a/tests/fast/arrow/test_filter_pushdown.py +++ b/tests/fast/arrow/test_filter_pushdown.py @@ -2,32 +2,34 @@ import sys import pytest -from conftest import pandas_supports_arrow_backend +from conftest import PANDAS_GE_3 from packaging.version import Version import duckdb pa = pytest.importorskip("pyarrow") -pd = pytest.importorskip("pyarrow.dataset") +pa_ds = pytest.importorskip("pyarrow.dataset") pa_lib = pytest.importorskip("pyarrow.lib") -pq = pytest.importorskip("pyarrow.parquet") +pa_parquet = pytest.importorskip("pyarrow.parquet") +pd = pytest.importorskip("pandas") np = pytest.importorskip("numpy") re = pytest.importorskip("re") def create_pyarrow_pandas(rel): - if not pandas_supports_arrow_backend(): - pytest.skip(reason="Pandas version doesn't support 'pyarrow' backend") - return rel.df().convert_dtypes(dtype_backend="pyarrow") + if PANDAS_GE_3: + return rel.df() + else: + return rel.df().convert_dtypes(dtype_backend="pyarrow") def create_pyarrow_table(rel): - return rel.fetch_arrow_table() + return rel.to_arrow_table() def create_pyarrow_dataset(rel): table = create_pyarrow_table(rel) - return pd.dataset(table) + return pa_ds.dataset(table) def test_decimal_filter_pushdown(duckdb_cursor): @@ -550,8 +552,8 @@ def test_9371(self, duckdb_cursor, tmp_path): df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set) df.to_parquet(str(file_path)) - my_arrow_dataset = pd.dataset(str(file_path)) - res = duckdb_cursor.execute("SELECT * FROM my_arrow_dataset WHERE ts = ?", parameters=[dt]).fetch_arrow_table() + my_arrow_dataset = pa_ds.dataset(str(file_path)) + res = duckdb_cursor.execute("SELECT * FROM my_arrow_dataset WHERE ts = ?", parameters=[dt]).to_arrow_table() output = duckdb_cursor.sql("select * from res").fetchall() expected = [(1, dt), (2, dt), (3, dt)] assert output == expected @@ -702,40 +704,12 @@ def test_filter_pushdown_2145(self, duckdb_cursor, tmp_path, create_table): duckdb_cursor.execute(f"copy (select * from df2) to '{data2.as_posix()}'") glob_pattern = tmp_path / "data*.parquet" - table = duckdb_cursor.read_parquet(glob_pattern.as_posix()).fetch_arrow_table() + table = duckdb_cursor.read_parquet(glob_pattern.as_posix()).to_arrow_table() output_df = duckdb.arrow(table).filter("date > '2019-01-01'").df() expected_df = duckdb.from_parquet(glob_pattern.as_posix()).filter("date > '2019-01-01'").df() pandas.testing.assert_frame_equal(expected_df, output_df) - # https://github.com/duckdb/duckdb/pull/4817/files#r1339973721 - @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) - def test_filter_column_removal(self, duckdb_cursor, create_table): - duckdb_cursor.execute( - """ - CREATE TABLE test AS SELECT - range a, - 100 - range b - FROM range(100) - """ - ) - duck_test_table = duckdb_cursor.table("test") - arrow_table = create_table(duck_test_table) - - # PR 4817 - remove filter columns that are unused in the remainder of the query plan from the table function - query_res = duckdb_cursor.execute( - """ - EXPLAIN SELECT count(*) FROM arrow_table WHERE - a > 25 AND b > 25 - """ - ).fetchall() - - # scanned columns that come out of the scan are displayed like this, so we shouldn't see them - match = re.search("│ +a +│", query_res[0][1]) - assert not match - match = re.search("│ +b +│", query_res[0][1]) - assert not match - @pytest.mark.skipif(sys.version_info < (3, 9), reason="Requires python 3.9") @pytest.mark.parametrize("create_table", [create_pyarrow_pandas, create_pyarrow_table]) def test_struct_filter_pushdown(self, duckdb_cursor, create_table): @@ -895,7 +869,7 @@ def test_filter_pushdown_not_supported(self): con.execute( "CREATE TABLE T as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d FROM range(5) tbl(i)" ) - arrow_tbl = con.execute("FROM T").fetch_arrow_table() + arrow_tbl = con.execute("FROM T").to_arrow_table() # No projection just unsupported filter assert con.execute("from arrow_tbl where c == 3").fetchall() == [(3, "3", 3, 3)] @@ -919,7 +893,7 @@ def test_filter_pushdown_not_supported(self): "CREATE TABLE T_2 as SELECT i::integer a, i::varchar b, i::uhugeint c, i::integer d , i::uhugeint e, i::smallint f, i::uhugeint g FROM range(50) tbl(i)" # noqa: E501 ) - arrow_tbl = con.execute("FROM T_2").fetch_arrow_table() + arrow_tbl = con.execute("FROM T_2").to_arrow_table() assert con.execute( "select a, b from arrow_tbl where a > 2 and c < 40 and b == '28' and g > 15 and e < 30" @@ -931,8 +905,8 @@ def test_join_filter_pushdown(self, duckdb_cursor): duckdb_conn.execute("CREATE TABLE build as select (random()*9999)::INT b from range(20);") duck_probe = duckdb_conn.table("probe") duck_build = duckdb_conn.table("build") - duck_probe_arrow = duck_probe.fetch_arrow_table() - duck_build_arrow = duck_build.fetch_arrow_table() + duck_probe_arrow = duck_probe.to_arrow_table() + duck_build_arrow = duck_build.to_arrow_table() duckdb_conn.register("duck_probe_arrow", duck_probe_arrow) duckdb_conn.register("duck_build_arrow", duck_build_arrow) assert duckdb_conn.execute("SELECT count(*) from duck_probe_arrow, duck_build_arrow where a=b").fetchall() == [ @@ -943,10 +917,37 @@ def test_in_filter_pushdown(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE probe as select range a from range(1000);") duck_probe = duckdb_conn.table("probe") - duck_probe_arrow = duck_probe.fetch_arrow_table() + duck_probe_arrow = duck_probe.to_arrow_table() duckdb_conn.register("duck_probe_arrow", duck_probe_arrow) assert duckdb_conn.execute("SELECT * from duck_probe_arrow where a = any([1,999])").fetchall() == [(1,), (999,)] + @pytest.mark.timeout(10) + def test_in_filter_pushdown_large_list(self, duckdb_cursor): + """Large IN lists must not hang. Regression test for https://github.com/duckdb/duckdb-python/issues/52.""" + arrow_table = pa.table({"a": pa.array(range(5000))}) + in_list = ", ".join(str(i) for i in range(0, 5000, 2)) + result = duckdb.sql(f"SELECT count(*) FROM arrow_table WHERE a IN ({in_list})").fetchone() + assert result == (2500,) + + def test_in_filter_pushdown_with_nulls(self, duckdb_cursor): + arrow_table = pa.table({"a": pa.array([1, 2, None, 4, None, 6])}) + # IN list without NULL: null rows should not match + result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4) ORDER BY a").fetchall() + assert result == [(1,), (4,)] + # IN list with NULL: null rows still should not match (SQL semantics) + result = duckdb.sql("SELECT a FROM arrow_table WHERE a IN (1, 4, NULL) ORDER BY a").fetchall() + assert result == [(1,), (4,)] + + def test_in_filter_pushdown_varchar(self, duckdb_cursor): + arrow_table = pa.table({"s": pa.array(["alice", "bob", "charlie", "dave", None])}) + result = duckdb.sql("SELECT s FROM arrow_table WHERE s IN ('bob', 'dave') ORDER BY s").fetchall() + assert result == [("bob",), ("dave",)] + + def test_in_filter_pushdown_float(self, duckdb_cursor): + arrow_table = pa.table({"f": pa.array([1.0, 2.5, 3.75, 4.0, None], type=pa.float64())}) + result = duckdb.sql("SELECT f FROM arrow_table WHERE f IN (2.5, 4.0) ORDER BY f").fetchall() + assert result == [(2.5,), (4.0,)] + def test_pushdown_of_optional_filter(self, duckdb_cursor): cardinality_table = pa.Table.from_pydict( { @@ -1006,7 +1007,7 @@ def assert_equal_results(con, arrow_table, query) -> None: arrow_res = con.sql(query.format(table="arrow_table")).fetchall() assert len(duckdb_res) == len(arrow_res) - arrow_table = duckdb_cursor.table("test").fetch_arrow_table() + arrow_table = duckdb_cursor.table("test").to_arrow_table() assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a > 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a >= 'NaN'::FLOAT") assert_equal_results(duckdb_cursor, arrow_table, "select * from {table} where a < 'NaN'::FLOAT") @@ -1023,14 +1024,14 @@ def test_dynamic_filter(self, duckdb_cursor): def test_binary_view_filter(self, duckdb_cursor): """Filters on a view column work (without pushdown because pyarrow does not support view filters yet).""" table = pa.table({"col": pa.array([b"abc", b"efg"], type=pa.binary_view())}) - dset = pd.dataset(table) + dset = pa_ds.dataset(table) res = duckdb_cursor.sql("select * from dset where col = 'abc'::binary") assert len(res) == 1 def test_string_view_filter(self, duckdb_cursor): """Filters on a view column work (without pushdown because pyarrow does not support view filters yet).""" table = pa.table({"col": pa.array(["abc", "efg"], type=pa.string_view())}) - dset = pd.dataset(table) + dset = pa_ds.dataset(table) res = duckdb_cursor.sql("select * from dset where col = 'abc'") assert len(res) == 1 @@ -1038,10 +1039,10 @@ def test_string_view_filter(self, duckdb_cursor): def test_canary_for_pyarrow_string_view_filter_support(self, duckdb_cursor): """This canary will xpass when pyarrow implements string view filter support.""" # predicate: field == "string value" - filter_expr = pd.field("col") == pd.scalar("val1") + filter_expr = pa_ds.field("col") == pa_ds.scalar("val1") # dataset with a string view column table = pa.table({"col": pa.array(["val1", "val2"], type=pa.string_view())}) - dset = pd.dataset(table) + dset = pa_ds.dataset(table) # creating the scanner fails dset.scanner(columns=["col"], filter=filter_expr) @@ -1049,10 +1050,10 @@ def test_canary_for_pyarrow_string_view_filter_support(self, duckdb_cursor): def test_canary_for_pyarrow_binary_view_filter_support(self, duckdb_cursor): """This canary will xpass when pyarrow implements binary view filter support.""" # predicate: field == const - const = pd.scalar(pa.scalar(b"bin1", pa.binary_view())) - filter_expr = pd.field("col") == const + const = pa_ds.scalar(pa.scalar(b"bin1", pa.binary_view())) + filter_expr = pa_ds.field("col") == const # dataset with a string view column table = pa.table({"col": pa.array([b"bin1", b"bin2"], type=pa.binary_view())}) - dset = pd.dataset(table) + dset = pa_ds.dataset(table) # creating the scanner fails dset.scanner(columns=["col"], filter=filter_expr) diff --git a/tests/fast/arrow/test_integration.py b/tests/fast/arrow/test_integration.py index 1ec3a603..858d9f9c 100644 --- a/tests/fast/arrow/test_integration.py +++ b/tests/fast/arrow/test_integration.py @@ -19,10 +19,10 @@ def test_parquet_roundtrip(self, duckdb_cursor): userdata_parquet_table = pq.read_table(parquet_filename) userdata_parquet_table.validate(full=True) - rel_from_arrow = duckdb.arrow(userdata_parquet_table).project(cols).fetch_arrow_table() + rel_from_arrow = duckdb.arrow(userdata_parquet_table).project(cols).to_arrow_table() rel_from_arrow.validate(full=True) - rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).fetch_arrow_table() + rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).to_arrow_table() rel_from_duckdb.validate(full=True) # batched version, lets use various values for batch size @@ -30,7 +30,7 @@ def test_parquet_roundtrip(self, duckdb_cursor): userdata_parquet_table2 = pa.Table.from_batches(userdata_parquet_table.to_batches(i)) assert userdata_parquet_table.equals(userdata_parquet_table2, check_metadata=True) - rel_from_arrow2 = duckdb.arrow(userdata_parquet_table2).project(cols).fetch_arrow_table() + rel_from_arrow2 = duckdb.arrow(userdata_parquet_table2).project(cols).to_arrow_table() rel_from_arrow2.validate(full=True) assert rel_from_arrow.equals(rel_from_arrow2, check_metadata=True) @@ -42,10 +42,10 @@ def test_unsigned_roundtrip(self, duckdb_cursor): unsigned_parquet_table = pq.read_table(parquet_filename) unsigned_parquet_table.validate(full=True) - rel_from_arrow = duckdb.arrow(unsigned_parquet_table).project(cols).fetch_arrow_table() + rel_from_arrow = duckdb.arrow(unsigned_parquet_table).project(cols).to_arrow_table() rel_from_arrow.validate(full=True) - rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).fetch_arrow_table() + rel_from_duckdb = duckdb_cursor.from_parquet(parquet_filename).project(cols).to_arrow_table() rel_from_duckdb.validate(full=True) assert rel_from_arrow.equals(rel_from_duckdb, check_metadata=True) @@ -53,7 +53,7 @@ def test_unsigned_roundtrip(self, duckdb_cursor): duckdb_cursor.execute( "select NULL c_null, (c % 4 = 0)::bool c_bool, (c%128)::tinyint c_tinyint, c::smallint*1000::INT c_smallint, c::integer*100000 c_integer, c::bigint*1000000000000 c_bigint, c::float c_float, c::double c_double, 'c_' || c::string c_string from (select case when range % 2 == 0 then range else null end as c from range(-10000, 10000)) sq" # noqa: E501 ) - arrow_result = duckdb_cursor.fetch_arrow_table() + arrow_result = duckdb_cursor.to_arrow_table() arrow_result.validate(full=True) arrow_result.combine_chunks() arrow_result.validate(full=True) @@ -72,7 +72,7 @@ def test_decimals_roundtrip(self, duckdb_cursor): duck_tbl = duckdb_cursor.table("test") - duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.fetch_arrow_table()) + duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.to_arrow_table()) duck_from_arrow.create("testarrow") @@ -114,7 +114,7 @@ def test_intervals_roundtrip(self, duckdb_cursor): data = pa.array(arr, pa.month_day_nano_interval()) arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervaltbl") - duck_arrow_tbl = duckdb_cursor.table("intervaltbl").fetch_arrow_table()["a"] + duck_arrow_tbl = duckdb_cursor.table("intervaltbl").to_arrow_table()["a"] assert duck_arrow_tbl[0].value == expected_value @@ -122,7 +122,7 @@ def test_intervals_roundtrip(self, duckdb_cursor): duckdb_cursor.execute("CREATE TABLE test (a INTERVAL)") duckdb_cursor.execute("INSERT INTO test VALUES (INTERVAL 1 YEAR + INTERVAL 1 DAY + INTERVAL 1 SECOND)") expected_value = pa.MonthDayNano([12, 1, 1000000000]) - duck_tbl_arrow = duckdb_cursor.table("test").fetch_arrow_table()["a"] + duck_tbl_arrow = duckdb_cursor.table("test").to_arrow_table()["a"] assert duck_tbl_arrow[0].value.months == expected_value.months assert duck_tbl_arrow[0].value.days == expected_value.days assert duck_tbl_arrow[0].value.nanoseconds == expected_value.nanoseconds @@ -144,7 +144,7 @@ def test_null_intervals_roundtrip(self, duckdb_cursor): data = pa.array(arr, pa.month_day_nano_interval()) arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalnulltbl") - duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").fetch_arrow_table()["a"] + duckdb_tbl_arrow = duckdb_cursor.table("intervalnulltbl").to_arrow_table()["a"] assert duckdb_tbl_arrow[0].value is None assert duckdb_tbl_arrow[1].value == expected_value @@ -158,7 +158,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): dict_array = pa.DictionaryArray.from_arrays(indices, dictionary) arrow_table = pa.Table.from_arrays([dict_array], ["a"]) duckdb_cursor.from_arrow(arrow_table).create("dictionarytbl") - duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").fetch_arrow_table()["a"] + duckdb_tbl_arrow = duckdb_cursor.table("dictionarytbl").to_arrow_table()["a"] assert duckdb_tbl_arrow[0].value == first_value assert duckdb_tbl_arrow[1].value == second_value @@ -172,7 +172,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): # List query = duckdb_cursor.sql( "SELECT a from (select list_value(INTERVAL 3 MONTHS, INTERVAL 5 DAYS, INTERVAL 10 SECONDS, NULL) as a) as t" - ).fetch_arrow_table()["a"] + ).to_arrow_table()["a"] assert query[0][0].value == pa.MonthDayNano([3, 0, 0]) assert query[0][1].value == pa.MonthDayNano([0, 5, 0]) assert query[0][2].value == pa.MonthDayNano([0, 0, 10000000000]) @@ -181,7 +181,7 @@ def test_nested_interval_roundtrip(self, duckdb_cursor): # Struct query = "SELECT a from (SELECT STRUCT_PACK(a := INTERVAL 1 MONTHS, b := INTERVAL 10 DAYS, c:= INTERVAL 20 SECONDS) as a) as t" # noqa: E501 true_answer = duckdb_cursor.sql(query).fetchall() - from_arrow = duckdb_cursor.from_arrow(duckdb_cursor.sql(query).fetch_arrow_table()).fetchall() + from_arrow = duckdb_cursor.from_arrow(duckdb_cursor.sql(query).to_arrow_table()).fetchall() assert true_answer[0][0]["a"] == from_arrow[0][0]["a"] assert true_answer[0][0]["b"] == from_arrow[0][0]["b"] assert true_answer[0][0]["c"] == from_arrow[0][0]["c"] @@ -193,7 +193,7 @@ def test_min_max_interval_roundtrip(self, duckdb_cursor): arrow_tbl = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_tbl).create("intervalminmaxtbl") - duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").fetch_arrow_table()["a"] + duck_arrow_tbl = duckdb_cursor.table("intervalminmaxtbl").to_arrow_table()["a"] assert duck_arrow_tbl[0].value == pa.MonthDayNano([0, 0, 0]) assert duck_arrow_tbl[1].value == pa.MonthDayNano([2147483647, 2147483647, 9223372036854775000]) @@ -211,7 +211,7 @@ def test_duplicate_column_names(self, duckdb_cursor): df_b table2 ON table1.join_key = table2.join_key """ - ).fetch_arrow_table() + ).to_arrow_table() assert res.schema.names == ["join_key", "col_a", "join_key", "col_a"] def test_strings_roundtrip(self, duckdb_cursor): @@ -227,7 +227,7 @@ def test_strings_roundtrip(self, duckdb_cursor): duck_tbl = duckdb_cursor.table("test") - duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.fetch_arrow_table()) + duck_from_arrow = duckdb_cursor.from_arrow(duck_tbl.to_arrow_table()) duck_from_arrow.create("testarrow") diff --git a/tests/fast/arrow/test_interval.py b/tests/fast/arrow/test_interval.py index 5426f39d..80d22e66 100644 --- a/tests/fast/arrow/test_interval.py +++ b/tests/fast/arrow/test_interval.py @@ -24,7 +24,7 @@ def test_duration_types(self, duckdb_cursor): pa.array([1], pa.duration("s")), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == expected_arrow["a"] assert rel["b"] == expected_arrow["a"] assert rel["c"] == expected_arrow["a"] @@ -41,7 +41,7 @@ def test_duration_null(self, duckdb_cursor): pa.array([None], pa.duration("s")), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == expected_arrow["a"] assert rel["b"] == expected_arrow["a"] assert rel["c"] == expected_arrow["a"] @@ -56,4 +56,4 @@ def test_duration_overflow(self, duckdb_cursor): arrow_table = pa.Table.from_arrays([data], ["a"]) with pytest.raises(duckdb.ConversionException, match="Could not convert Interval to Microsecond"): - duckdb.from_arrow(arrow_table).fetch_arrow_table() + duckdb.from_arrow(arrow_table).to_arrow_table() diff --git a/tests/fast/arrow/test_large_offsets.py b/tests/fast/arrow/test_large_offsets.py index 45b078b8..f74b9955 100644 --- a/tests/fast/arrow/test_large_offsets.py +++ b/tests/fast/arrow/test_large_offsets.py @@ -18,11 +18,11 @@ def test_large_lists(self, duckdb_cursor): match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but " "the offset of 2147481000 exceeds this", ): - duckdb_cursor.sql("SELECT col FROM tbl").fetch_arrow_table() + duckdb_cursor.sql("SELECT col FROM tbl").to_arrow_table() tbl2 = pa.Table.from_pydict({"col": ary.cast(pa.large_list(pa.uint8()))}) # noqa: F841 duckdb_cursor.sql("set arrow_large_buffer_size = true") - res2 = duckdb_cursor.sql("SELECT col FROM tbl2").fetch_arrow_table() + res2 = duckdb_cursor.sql("SELECT col FROM tbl2").to_arrow_table() res2.validate() @pytest.mark.skip(reason="CI does not have enough memory to validate this") @@ -35,8 +35,8 @@ def test_large_maps(self, duckdb_cursor): match="Arrow Appender: The maximum combined list offset for regular list buffers is 2147483647 but the " "offset of 2147481000 exceeds this", ): - duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() + duckdb_cursor.sql("select map(col, col) from tbl").to_arrow_table() duckdb_cursor.sql("set arrow_large_buffer_size = true") - arrow_map_large = duckdb_cursor.sql("select map(col, col) from tbl").fetch_arrow_table() + arrow_map_large = duckdb_cursor.sql("select map(col, col) from tbl").to_arrow_table() arrow_map_large.validate() diff --git a/tests/fast/arrow/test_nested_arrow.py b/tests/fast/arrow/test_nested_arrow.py index 10fbfae0..0530ee15 100644 --- a/tests/fast/arrow/test_nested_arrow.py +++ b/tests/fast/arrow/test_nested_arrow.py @@ -10,13 +10,13 @@ def compare_results(duckdb_cursor, query): true_answer = duckdb_cursor.query(query).fetchall() - produced_arrow = duckdb_cursor.query(query).fetch_arrow_table() + produced_arrow = duckdb_cursor.query(query).to_arrow_table() from_arrow = duckdb_cursor.from_arrow(produced_arrow).fetchall() assert true_answer == from_arrow def arrow_to_pandas(duckdb_cursor, query): - return duckdb_cursor.query(query).fetch_arrow_table().to_pandas()["a"].values.tolist() + return duckdb_cursor.query(query).to_arrow_table().to_pandas()["a"].values.tolist() def get_use_list_view_options(): @@ -31,23 +31,19 @@ class TestArrowNested: def test_lists_basic(self, duckdb_cursor): # Test Constant List query = ( - duckdb_cursor.query("SELECT a from (select list_value(3,5,10) as a) as t") - .fetch_arrow_table()["a"] - .to_numpy() + duckdb_cursor.query("SELECT a from (select list_value(3,5,10) as a) as t").to_arrow_table()["a"].to_numpy() ) assert query[0][0] == 3 assert query[0][1] == 5 assert query[0][2] == 10 # Empty List - query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").fetch_arrow_table()["a"].to_numpy() + query = duckdb_cursor.query("SELECT a from (select list_value() as a) as t").to_arrow_table()["a"].to_numpy() assert len(query[0]) == 0 # Test Constant List With Null query = ( - duckdb_cursor.query("SELECT a from (select list_value(3,NULL) as a) as t") - .fetch_arrow_table()["a"] - .to_numpy() + duckdb_cursor.query("SELECT a from (select list_value(3,NULL) as a) as t").to_arrow_table()["a"].to_numpy() ) assert query[0][0] == 3 assert np.isnan(query[0][1]) diff --git a/tests/fast/arrow/test_polars.py b/tests/fast/arrow/test_polars.py index 0eba5eeb..f3d7e072 100644 --- a/tests/fast/arrow/test_polars.py +++ b/tests/fast/arrow/test_polars.py @@ -12,6 +12,7 @@ from duckdb.polars_io import _pl_tree_to_sql, _predicate_to_expression # noqa: E402 +pl_pre_1_35_0 = parse_version(pl.__version__) < parse_version("1.35.0") pl_pre_1_36_0 = parse_version(pl.__version__) < parse_version("1.36.0") @@ -437,7 +438,7 @@ def test_polars_lazy_pushdown_timestamp(self, duckdb_cursor): lazy_df.filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)).select(pl.len()).collect().item() == 2 ) - @pytest.mark.skipif(pl_pre_1_36_0, reason="Polars < 1.36.0 expressions on dates produce casts in predicates") + @pytest.mark.skipif(pl_pre_1_35_0, reason="Polars < 1.36.0 expressions on dates produce casts in predicates") def test_polars_predicate_to_expression_post_1_36_0(self): ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1) ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1) @@ -454,7 +455,7 @@ def test_polars_predicate_to_expression_post_1_36_0(self): valid_filter((pl.col("a") == ts_2020) & (pl.col("b") == ts_2010) & (pl.col("c") == ts_2020)) valid_filter((pl.col("a") == ts_2020) | (pl.col("b") == ts_2008)) - @pytest.mark.skipif(not pl_pre_1_36_0, reason="Polars >= 1.36.0 expressions on dates don't produce casts") + @pytest.mark.skipif(not pl_pre_1_35_0, reason="Polars >= 1.36.0 expressions on dates don't produce casts") def test_polars_predicate_to_expression_pre_1_36_0(self): ts_2008 = datetime.datetime(2008, 1, 1, 0, 0, 1) ts_2010 = datetime.datetime(2010, 1, 1, 10, 0, 1) @@ -681,6 +682,30 @@ def test_invalid_expr_json(self): with pytest.raises(AssertionError, match="The col name of a Column should be a str but got"): _pl_tree_to_sql(json.loads(bad_type_expr)) + @pytest.mark.parametrize( + ("dtype", "test_value"), + [ + (pl.Int8, 1), + (pl.Int16, 1), + (pl.Int32, 1), + (pl.Int64, 1), + (pl.Int128, 1), + (pl.UInt8, 1), + (pl.UInt16, 1), + (pl.UInt32, 1), + (pl.UInt64, 1), + (pl.UInt128, 1), + (pl.Float32, 1.0), + (pl.Float64, 1.0), + (pl.Boolean, True), + ], + ) + def test_scalar_type_pushdown(self, dtype, test_value): + """Verify that literals of each scalar type can be pushed down.""" + expr = pl.col("a") == pl.lit(test_value, dtype=dtype) + sql_expression = _predicate_to_expression(expr) + assert sql_expression is not None, f"Pushdown failed for {dtype}" + def test_decimal_scale(self): scalar_decimal_no_scale = """ { "Scalar": { @@ -702,3 +727,59 @@ def test_decimal_scale(self): } } """ assert _pl_tree_to_sql(json.loads(scalar_decimal_scale)) == "1" + + def test_cast_node_unwraps_inner_expression(self): + """Cast nodes should be unwrapped to process the inner expression.""" + # A Cast wrapping a Column reference + cast_column = json.loads( + '{"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}}' + ) + assert _pl_tree_to_sql(cast_column) == '"a"' + + # A Cast wrapping a full binary expression + cast_expr = json.loads(""" + { + "BinaryExpr": { + "left": {"Cast": {"expr": {"Column": "a"}, "dtype": {"Decimal": [20, 0]}, "options": "NonStrict"}}, + "op": "Eq", + "right": {"Literal": {"Int": 1}} + } + } + """) + assert _pl_tree_to_sql(cast_expr) == '("a" = 1)' + + def test_cast_node_predicate_pushdown(self): + """Predicates with Cast nodes should be successfully pushed down.""" + # A decimal with non-38 precision produces a Cast node in Polars + expr = pl.col("a") == pl.lit(1, dtype=pl.Decimal(precision=20, scale=0)) + valid_filter(expr) + + def test_polars_lazy_pushdown_decimal_with_cast(self): + """End-to-end test: decimal columns with non-38 precision should push down filters.""" + con = duckdb.connect() + con.execute("CREATE TABLE test_cast (a DECIMAL(20,0))") + con.execute("INSERT INTO test_cast VALUES (1), (10), (100), (NULL)") + rel = con.sql("FROM test_cast") + lazy_df = rel.pl(lazy=True) + + assert lazy_df.filter(pl.col("a") == 1).collect().to_dicts() == [{"a": 1}] + assert lazy_df.filter(pl.col("a") > 1).collect().to_dicts() == [{"a": 10}, {"a": 100}] + + def test_explicit_cast_not_pushed_down(self): + """Explicit user .cast() (Strict) should not be pushed down - falls back to Polars.""" + # pl.col("a").cast(pl.Int64) produces a Strict Cast node + expr = pl.col("a").cast(pl.Int64) > 5 + invalid_filter(expr) + + def test_polars_lazy_cursor_lifetime(self): + """Cursor should stay alive while a lazy polars frame derived from it exists (GH #161).""" + con = duckdb.connect(":memory:") + + def get_lazy_frame(con): + cur = con.cursor() + return cur.sql("SELECT 1 AS foo, 2 AS bar").pl(lazy=True) + + lf = get_lazy_frame(con) + # Cursor went out of scope, but the lazy frame should keep it alive + result = lf.collect() + assert result.to_dicts() == [{"foo": 1, "bar": 2}] diff --git a/tests/fast/arrow/test_polars_filter_pushdown.py b/tests/fast/arrow/test_polars_filter_pushdown.py new file mode 100644 index 00000000..320dffae --- /dev/null +++ b/tests/fast/arrow/test_polars_filter_pushdown.py @@ -0,0 +1,192 @@ +# ruff: noqa: F841 +import math + +import pytest + +import duckdb + +pl = pytest.importorskip("polars") +pytest.importorskip("pyarrow") + + +class TestPolarsLazyFrameFilterPushdown: + """Tests for filter pushdown on LazyFrames. + + All tests use pl.LazyFrame (the target of this change). DuckDB pushes filters and projections into the Polars lazy + plan before collection, so only surviving rows are ever materialized. + """ + + ##### CONSTANT_COMPARISON: all six comparison operators + + def test_comparison_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a = 3").fetchall() == [(3,)] + + def test_comparison_not_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a != 3").fetchall() == [(1,), (2,), (4,), (5,)] + + def test_comparison_less_than(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a < 3").fetchall() == [(1,), (2,)] + + def test_comparison_less_than_or_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a <= 3").fetchall() == [(1,), (2,), (3,)] + + def test_comparison_greater_than(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a > 3").fetchall() == [(4,), (5,)] + + def test_comparison_greater_than_or_equal(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + assert duckdb.sql("SELECT * FROM lf WHERE a >= 3").fetchall() == [(3,), (4,), (5,)] + + def test_string_comparison(self): + lf = pl.LazyFrame({"name": ["alice", "bob", "charlie"], "val": [1, 2, 3]}) + assert duckdb.sql("SELECT * FROM lf WHERE name = 'bob'").fetchall() == [("bob", 2)] + + ##### NaN comparisons (CONSTANT_COMPARISON with is_nan path) + + def test_nan_equal(self): + """NaN = NaN is true in DuckDB; pushes is_nan().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 'NaN'::DOUBLE").fetchall() + assert len(result) == 1 + assert math.isnan(result[0][0]) + + def test_nan_greater_than_or_equal(self): + """NaN >= NaN is true; pushes is_nan().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a >= 'NaN'::DOUBLE").fetchall() + assert len(result) == 1 + assert math.isnan(result[0][0]) + + def test_nan_less_than(self): + """X < NaN is true for non-NaN values; pushes is_nan().__invert__().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a < 'NaN'::DOUBLE").fetchall() + assert sorted(result) == [(1.0,), (3.0,)] + + def test_nan_not_equal(self): + """X != NaN is true for non-NaN values; pushes is_nan().__invert__().""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a != 'NaN'::DOUBLE").fetchall() + assert sorted(result) == [(1.0,), (3.0,)] + + def test_nan_greater_than(self): + """X > NaN is always false; pushes lit(false).""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a > 'NaN'::DOUBLE").fetchall() + assert result == [] + + def test_nan_less_than_or_equal(self): + """X <= NaN is always true; pushes lit(true).""" + lf = pl.LazyFrame({"a": [1.0, float("nan"), 3.0]}) + result = duckdb.sql("SELECT * FROM lf WHERE a <= 'NaN'::DOUBLE").fetchall() + assert len(result) == 3 + + ##### IS_NULL / IS_NOT_NULL (triggered via DISTINCT FROM NULL inside OR) + + def test_is_null_filter(self): + """IS NOT DISTINCT FROM NULL inside an OR pushes IS_NULL as a child of CONJUNCTION_OR.""" + lf = pl.LazyFrame({"a": [1, None, 3, None, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a IS NOT DISTINCT FROM NULL").fetchall() + values = [row[0] for row in result] + assert values.count(None) == 2 + assert 1 in values + assert len(values) == 3 + + def test_is_not_null_filter(self): + """IS DISTINCT FROM NULL inside an OR pushes IS_NOT_NULL as a child of CONJUNCTION_OR.""" + lf = pl.LazyFrame({"a": [1, None, 3, None, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a IS DISTINCT FROM NULL").fetchall() + assert sorted(result) == [(1,), (3,), (5,)] + + # ── CONJUNCTION_AND ── + + def test_conjunction_and_range(self): + """BETWEEN on a single column pushes a CONJUNCTION_AND with GTE + LTE children.""" + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a BETWEEN 2 AND 4").fetchall() + assert result == [(2,), (3,), (4,)] + + def test_conjunction_and_multi_column(self): + """Filters on two different columns combine via AND in TransformFilter.""" + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5], "b": ["x", "y", "x", "y", "x"]}) + result = duckdb.sql("SELECT * FROM lf WHERE a > 2 AND b = 'x'").fetchall() + assert result == [(3, "x"), (5, "x")] + + ##### CONJUNCTION_OR + + def test_conjunction_or(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a = 5").fetchall() + assert sorted(result) == [(1,), (5,)] + + ##### IN_FILTER + + def test_in_filter(self): + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a IN (2, 4)").fetchall() + assert sorted(result) == [(2,), (4,)] + + ##### STRUCT_EXTRACT + + def test_struct_extract(self): + lf = pl.LazyFrame({"s": [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}]}) + result = duckdb.sql("SELECT * FROM lf WHERE s.x > 1").fetchall() + assert len(result) == 2 + assert all(row[0]["x"] > 1 for row in result) + + ##### OPTIONAL_FILTER + + def test_optional_filter(self): + """OR filters are wrapped in OPTIONAL_FILTER by DuckDB's optimizer.""" + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + result = duckdb.sql("SELECT * FROM lf WHERE a = 1 OR a = 3").fetchall() + assert sorted(result) == [(1,), (3,)] + + ##### Produce path, no filters + + def test_unfiltered_scan(self): + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = duckdb.sql("SELECT * FROM lf").fetchall() + assert result == [(1, 4), (2, 5), (3, 6)] + + ##### Produce path, column projection + + def test_column_projection(self): + lf = pl.LazyFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + result = duckdb.sql("SELECT a, c FROM lf").fetchall() + assert result == [(1, 7), (2, 8), (3, 9)] + + ##### Produce path, cached DataFrame reuse + + def test_cached_dataframe_reuse(self): + """Repeated unfiltered scans on a registered LazyFrame reuse the cached DataFrame.""" + con = duckdb.connect() + lf = pl.LazyFrame({"a": [1, 2, 3]}) + con.register("my_lf", lf) + r1 = con.sql("SELECT * FROM my_lf").fetchall() + r2 = con.sql("SELECT * FROM my_lf").fetchall() + assert r1 == r2 == [(1,), (2,), (3,)] + + ##### Produce path, filter + collect (no cache) + + def test_filtered_scan_not_cached(self): + """Filtered scans collect a new DataFrame each time (not cached).""" + con = duckdb.connect() + lf = pl.LazyFrame({"a": [1, 2, 3, 4, 5]}) + con.register("my_lf", lf) + r1 = con.sql("SELECT * FROM my_lf WHERE a > 3").fetchall() + r2 = con.sql("SELECT * FROM my_lf WHERE a < 3").fetchall() + assert sorted(r1) == [(4,), (5,)] + assert sorted(r2) == [(1,), (2,)] + + ##### Empty result + + def test_empty_result(self): + lf = pl.LazyFrame({"a": [1, 2, 3]}) + result = duckdb.sql("SELECT * FROM lf WHERE a > 100").fetchall() + assert result == [] diff --git a/tests/fast/arrow/test_projection_pushdown.py b/tests/fast/arrow/test_projection_pushdown.py index fbd258e0..7a792b9a 100644 --- a/tests/fast/arrow/test_projection_pushdown.py +++ b/tests/fast/arrow/test_projection_pushdown.py @@ -21,7 +21,7 @@ def test_projection_pushdown_no_filter(self, duckdb_cursor): """ ) duck_tbl = duckdb_cursor.table("test") - arrow_table = duck_tbl.fetch_arrow_table() + arrow_table = duck_tbl.to_arrow_table() assert duckdb_cursor.execute("SELECT sum(c) FROM arrow_table").fetchall() == [(333,)] # RecordBatch does not use projection pushdown, test that this also still works diff --git a/tests/fast/arrow/test_time.py b/tests/fast/arrow/test_time.py index b9bc5a21..1f359bd4 100644 --- a/tests/fast/arrow/test_time.py +++ b/tests/fast/arrow/test_time.py @@ -20,11 +20,11 @@ def test_time_types(self, duckdb_cursor): pa.array([1000000000], pa.time64("ns")), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == arrow_table["c"] assert rel["b"] == arrow_table["c"] assert rel["c"] == arrow_table["c"] - assert rel["d"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_time_null(self, duckdb_cursor): if not can_run: @@ -36,11 +36,11 @@ def test_time_null(self, duckdb_cursor): pa.array([None], pa.time64("ns")), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == arrow_table["c"] assert rel["b"] == arrow_table["c"] assert rel["c"] == arrow_table["c"] - assert rel["d"] == arrow_table["c"] + assert rel["d"] == arrow_table["d"] def test_max_times(self, duckdb_cursor): if not can_run: @@ -50,7 +50,7 @@ def test_max_times(self, duckdb_cursor): # Max Sec data = pa.array([2147483647], type=pa.time32("s")) arrow_table = pa.Table.from_arrays([data], ["a"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == result["a"] # Max MSec @@ -58,15 +58,15 @@ def test_max_times(self, duckdb_cursor): result = pa.Table.from_arrays([data], ["a"]) data = pa.array([2147483647], type=pa.time32("ms")) arrow_table = pa.Table.from_arrays([data], ["a"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == result["a"] # Max NSec - data = pa.array([9223372036854774], type=pa.time64("us")) + data = pa.array([9223372036854774000], type=pa.time64("ns")) result = pa.Table.from_arrays([data], ["a"]) data = pa.array([9223372036854774000], type=pa.time64("ns")) arrow_table = pa.Table.from_arrays([data], ["a"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() print(rel["a"]) print(result["a"]) diff --git a/tests/fast/arrow/test_timestamp_timezone.py b/tests/fast/arrow/test_timestamp_timezone.py index 27ddf3ac..cec7c20d 100644 --- a/tests/fast/arrow/test_timestamp_timezone.py +++ b/tests/fast/arrow/test_timestamp_timezone.py @@ -45,7 +45,7 @@ def test_timestamp_tz_to_arrow(self, duckdb_cursor): for timezone in timezones: con.execute("SET TimeZone = '" + timezone + "'") arrow_table = generate_table(current_time, precision, timezone) - res = con.from_arrow(arrow_table).fetch_arrow_table() + res = con.from_arrow(arrow_table).to_arrow_table() assert res[0].type == pa.timestamp("us", tz=timezone) assert res == generate_table(current_time, "us", timezone) @@ -54,7 +54,7 @@ def test_timestamp_tz_with_null(self, duckdb_cursor): con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") rel = con.table("t") - arrow_tbl = rel.fetch_arrow_table() + arrow_tbl = rel.to_arrow_table() con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() @@ -64,7 +64,7 @@ def test_timestamp_stream(self, duckdb_cursor): con.execute("create table t (i timestamptz)") con.execute("insert into t values (NULL),('2021-11-15 02:30:00'::timestamptz)") rel = con.table("t") - arrow_tbl = rel.fetch_record_batch().read_all() + arrow_tbl = rel.to_arrow_reader().read_all() con.register("t2", arrow_tbl) assert con.execute("select * from t").fetchall() == con.execute("select * from t2").fetchall() diff --git a/tests/fast/arrow/test_timestamps.py b/tests/fast/arrow/test_timestamps.py index b00b7982..0733a6d0 100644 --- a/tests/fast/arrow/test_timestamps.py +++ b/tests/fast/arrow/test_timestamps.py @@ -21,7 +21,7 @@ def test_timestamp_types(self, duckdb_cursor): pa.array([datetime.datetime.now()], pa.timestamp("s")), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == arrow_table["a"] assert rel["b"] == arrow_table["b"] assert rel["c"] == arrow_table["c"] @@ -37,7 +37,7 @@ def test_timestamp_nulls(self, duckdb_cursor): pa.array([None], pa.timestamp("s")), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2], data[3]], ["a", "b", "c", "d"]) - rel = duckdb.from_arrow(arrow_table).fetch_arrow_table() + rel = duckdb.from_arrow(arrow_table).to_arrow_table() assert rel["a"] == arrow_table["a"] assert rel["b"] == arrow_table["b"] assert rel["c"] == arrow_table["c"] @@ -52,7 +52,7 @@ def test_timestamp_overflow(self, duckdb_cursor): pa.array([9223372036854775807], pa.timestamp("us")), ) arrow_table = pa.Table.from_arrays([data[0], data[1], data[2]], ["a", "b", "c"]) - arrow_from_duck = duckdb.from_arrow(arrow_table).fetch_arrow_table() + arrow_from_duck = duckdb.from_arrow(arrow_table).to_arrow_table() assert arrow_from_duck["a"] == arrow_table["a"] assert arrow_from_duck["b"] == arrow_table["b"] assert arrow_from_duck["c"] == arrow_table["c"] diff --git a/tests/fast/arrow/test_tpch.py b/tests/fast/arrow/test_tpch.py index cb6024cf..1c52c262 100644 --- a/tests/fast/arrow/test_tpch.py +++ b/tests/fast/arrow/test_tpch.py @@ -47,7 +47,7 @@ def test_tpch_arrow(self, duckdb_cursor): for tpch_table in tpch_tables: duck_tbl = duckdb_conn.table(tpch_table) - arrow_tables.append(duck_tbl.fetch_arrow_table()) + arrow_tables.append(duck_tbl.to_arrow_table()) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) @@ -77,7 +77,7 @@ def test_tpch_arrow_01(self, duckdb_cursor): for tpch_table in tpch_tables: duck_tbl = duckdb_conn.table(tpch_table) - arrow_tables.append(duck_tbl.fetch_arrow_table()) + arrow_tables.append(duck_tbl.to_arrow_table()) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) @@ -105,7 +105,7 @@ def test_tpch_arrow_batch(self, duckdb_cursor): for tpch_table in tpch_tables: duck_tbl = duckdb_conn.table(tpch_table) - arrow_tables.append(pyarrow.Table.from_batches(duck_tbl.fetch_arrow_table().to_batches(10))) + arrow_tables.append(pyarrow.Table.from_batches(duck_tbl.to_arrow_table().to_batches(10))) duck_arrow_table = duckdb_conn.from_arrow(arrow_tables[-1]) duckdb_conn.execute("DROP TABLE " + tpch_table) duck_arrow_table.create(tpch_table) diff --git a/tests/fast/arrow/test_unregister.py b/tests/fast/arrow/test_unregister.py index 0aceaea1..c2ca8e37 100644 --- a/tests/fast/arrow/test_unregister.py +++ b/tests/fast/arrow/test_unregister.py @@ -17,10 +17,10 @@ def test_arrow_unregister1(self, duckdb_cursor): connection = duckdb.connect(":memory:") connection.register("arrow_table", arrow_table_obj) - connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() + connection.execute("SELECT * FROM arrow_table;").to_arrow_table() connection.unregister("arrow_table") with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): - connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() + connection.execute("SELECT * FROM arrow_table;").to_arrow_table() with pytest.raises(duckdb.CatalogException, match="View with name arrow_table does not exist"): connection.execute("DROP VIEW arrow_table;") connection.execute("DROP VIEW IF EXISTS arrow_table;") @@ -39,7 +39,7 @@ def test_arrow_unregister2(self, duckdb_cursor): connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): - connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() + connection.execute("SELECT * FROM arrow_table;").to_arrow_table() connection.close() del arrow_table_obj gc.collect() @@ -47,5 +47,5 @@ def test_arrow_unregister2(self, duckdb_cursor): connection = duckdb.connect(db) assert len(connection.execute("PRAGMA show_tables;").fetchall()) == 0 with pytest.raises(duckdb.CatalogException, match="Table with name arrow_table does not exist"): - connection.execute("SELECT * FROM arrow_table;").fetch_arrow_table() + connection.execute("SELECT * FROM arrow_table;").to_arrow_table() connection.close() diff --git a/tests/fast/pandas/test_2304.py b/tests/fast/pandas/test_2304.py index c60b1b4a..e40c2dd1 100644 --- a/tests/fast/pandas/test_2304.py +++ b/tests/fast/pandas/test_2304.py @@ -1,14 +1,12 @@ import numpy as np -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb class TestPandasMergeSameName: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_2304(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame( + def test_2304(self, duckdb_cursor): + df1 = pd.DataFrame( { "id_1": [1, 1, 1, 2, 2], "agedate": np.array(["2010-01-01", "2010-02-01", "2010-03-01", "2020-02-01", "2020-03-01"]).astype( @@ -19,7 +17,7 @@ def test_2304(self, duckdb_cursor, pandas): } ) - df2 = pandas.DataFrame( + df2 = pd.DataFrame( { "id_1": [1, 1, 2], "agedate": np.array(["2010-01-01", "2010-02-01", "2020-03-01"]).astype("datetime64[D]"), @@ -54,9 +52,8 @@ def test_2304(self, duckdb_cursor, pandas): assert result == expected_result - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_pd_names(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame( + def test_pd_names(self, duckdb_cursor): + df1 = pd.DataFrame( { "id": [1, 1, 2], "id_1": [1, 1, 2], @@ -64,9 +61,9 @@ def test_pd_names(self, duckdb_cursor, pandas): } ) - df2 = pandas.DataFrame({"id": [1, 1, 2], "id_1": [1, 1, 2], "id_2": [1, 1, 1]}) + df2 = pd.DataFrame({"id": [1, 1, 2], "id_1": [1, 1, 2], "id_2": [1, 1, 1]}) - exp_result = pandas.DataFrame( + exp_result = pd.DataFrame( { "id": [1, 1, 2, 1, 1], "id_1": [1, 1, 2, 1, 1], @@ -85,11 +82,10 @@ def test_pd_names(self, duckdb_cursor, pandas): ON (df1.id_1=df2.id_1)""" result_df = con.execute(query).fetchdf() - pandas.testing.assert_frame_equal(exp_result, result_df) + pd.testing.assert_frame_equal(exp_result, result_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_repeat_name(self, duckdb_cursor, pandas): - df1 = pandas.DataFrame( + def test_repeat_name(self, duckdb_cursor): + df1 = pd.DataFrame( { "id": [1], "id_1": [1], @@ -97,9 +93,9 @@ def test_repeat_name(self, duckdb_cursor, pandas): } ) - df2 = pandas.DataFrame({"id": [1]}) + df2 = pd.DataFrame({"id": [1]}) - exp_result = pandas.DataFrame( + exp_result = pd.DataFrame( { "id": [1], "id_1": [1], @@ -119,4 +115,4 @@ def test_repeat_name(self, duckdb_cursor, pandas): ON (df1.id=df2.id) """ ).fetchdf() - pandas.testing.assert_frame_equal(exp_result, result_df) + pd.testing.assert_frame_equal(exp_result, result_df) diff --git a/tests/fast/pandas/test_append_df.py b/tests/fast/pandas/test_append_df.py index d93cfa2d..be287a8f 100644 --- a/tests/fast/pandas/test_append_df.py +++ b/tests/fast/pandas/test_append_df.py @@ -1,15 +1,14 @@ +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb class TestAppendDF: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_df_to_table_append(self, duckdb_cursor, pandas): + def test_df_to_table_append(self, duckdb_cursor): conn = duckdb.connect() conn.execute("Create table integers (i integer)") - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -17,11 +16,10 @@ def test_df_to_table_append(self, duckdb_cursor, pandas): conn.append("integers", df_in) assert conn.execute("select count(*) from integers").fetchone()[0] == 5 - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_append_by_name(self, pandas): + def test_append_by_name(self): con = duckdb.connect() con.execute("create table tbl (a integer, b bool, c varchar)") - df_in = pandas.DataFrame({"c": ["duck", "db"], "b": [False, True], "a": [4, 2]}) + df_in = pd.DataFrame({"c": ["duck", "db"], "b": [False, True], "a": [4, 2]}) # By default we append by position, causing the following exception: with pytest.raises( duckdb.ConversionException, match="Conversion Error: Could not convert string 'duck' to INT32" @@ -33,29 +31,27 @@ def test_append_by_name(self, pandas): res = con.table("tbl").fetchall() assert res == [(4, False, "duck"), (2, True, "db")] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_append_by_name_quoted(self, pandas): + def test_append_by_name_quoted(self): con = duckdb.connect() con.execute( """ create table tbl ("needs to be quoted" integer, other varchar) """ ) - df_in = pandas.DataFrame({"needs to be quoted": [1, 2, 3]}) + df_in = pd.DataFrame({"needs to be quoted": [1, 2, 3]}) con.append("tbl", df_in, by_name=True) res = con.table("tbl").fetchall() assert res == [(1, None), (2, None), (3, None)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_append_by_name_no_exact_match(self, pandas): + def test_append_by_name_no_exact_match(self): con = duckdb.connect() con.execute("create table tbl (a integer, b bool)") - df_in = pandas.DataFrame({"c": ["a", "b"], "b": [True, False], "a": [42, 1337]}) + df_in = pd.DataFrame({"c": ["a", "b"], "b": [True, False], "a": [42, 1337]}) # Too many columns raises an error, because the columns cant be found in the targeted table with pytest.raises(duckdb.BinderException, match='Table "tbl" does not have a column with name "c"'): con.append("tbl", df_in, by_name=True) - df_in = pandas.DataFrame({"b": [False, False, False]}) + df_in = pd.DataFrame({"b": [False, False, False]}) # Not matching all columns is not a problem, as they will be filled with NULL instead con.append("tbl", df_in, by_name=True) @@ -66,7 +62,7 @@ def test_append_by_name_no_exact_match(self, pandas): # Empty the table con.execute("create or replace table tbl (a integer, b bool)") - df_in = pandas.DataFrame({"a": [1, 2, 3]}) + df_in = pd.DataFrame({"a": [1, 2, 3]}) con.append("tbl", df_in, by_name=True) res = con.table("tbl").fetchall() # Also works for missing columns *after* the supplied ones diff --git a/tests/fast/pandas/test_bug5922.py b/tests/fast/pandas/test_bug5922.py index b75ddf1b..196764e3 100644 --- a/tests/fast/pandas/test_bug5922.py +++ b/tests/fast/pandas/test_bug5922.py @@ -1,13 +1,11 @@ -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb class TestPandasAcceptFloat16: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_pandas_accept_float16(self, duckdb_cursor, pandas): - df = pandas.DataFrame({"col": [1, 2, 3]}) + def test_pandas_accept_float16(self, duckdb_cursor): + df = pd.DataFrame({"col": [1, 2, 3]}) df16 = df.astype({"col": "float16"}) # noqa: F841 con = duckdb.connect() con.execute("CREATE TABLE tbl AS SELECT * FROM df16") diff --git a/tests/fast/pandas/test_copy_on_write.py b/tests/fast/pandas/test_copy_on_write.py index 176c2133..417fae0d 100644 --- a/tests/fast/pandas/test_copy_on_write.py +++ b/tests/fast/pandas/test_copy_on_write.py @@ -1,26 +1,27 @@ import datetime import pytest +from packaging.version import Version import duckdb # https://pandas.pydata.org/docs/dev/user_guide/copy_on_write.html pandas = pytest.importorskip("pandas", "1.5", reason="copy_on_write does not exist in earlier versions") +# Starting from Pandas 3.0.0 copy-on-write can no longer be disabled and this setting is deprecated +pre_3_0 = Version(pandas.__version__) < Version("3.0.0") # Make sure the variable get's properly reset even in case of error @pytest.fixture(autouse=True) def scoped_copy_on_write_setting(): - old_value = pandas.options.mode.copy_on_write - pandas.options.mode.copy_on_write = True - yield - # Reset it at the end of the function - pandas.options.mode.copy_on_write = old_value - return - - -def convert_to_result(col): - return [(x,) for x in col] + if pre_3_0: + old_value = pandas.options.mode.copy_on_write + pandas.options.mode.copy_on_write = True + yield + # Reset it at the end of the function + pandas.options.mode.copy_on_write = old_value + else: + yield class TestCopyOnWrite: @@ -35,7 +36,6 @@ class TestCopyOnWrite: ], ) def test_copy_on_write(self, col): - assert pandas.options.mode.copy_on_write con = duckdb.connect() df_in = pandas.DataFrame( # noqa: F841 { @@ -45,5 +45,5 @@ def test_copy_on_write(self, col): rel = con.sql("select * from df_in") res = rel.fetchall() print(res) - expected = convert_to_result(col) + expected = [(x,) for x in col] assert res == expected diff --git a/tests/fast/pandas/test_create_table_from_pandas.py b/tests/fast/pandas/test_create_table_from_pandas.py index 436fd0c8..b9937de2 100644 --- a/tests/fast/pandas/test_create_table_from_pandas.py +++ b/tests/fast/pandas/test_create_table_from_pandas.py @@ -1,12 +1,11 @@ -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb -def assert_create(internal_data, expected_result, data_type, pandas): +def assert_create(internal_data, expected_result, data_type): conn = duckdb.connect() - df_in = pandas.DataFrame(data=internal_data, dtype=data_type) # noqa: F841 + df_in = pd.DataFrame(data=internal_data, dtype=data_type) # noqa: F841 conn.execute("CREATE TABLE t AS SELECT * FROM df_in") @@ -14,9 +13,9 @@ def assert_create(internal_data, expected_result, data_type, pandas): assert result == expected_result -def assert_create_register(internal_data, expected_result, data_type, pandas): +def assert_create_register(internal_data, expected_result, data_type): conn = duckdb.connect() - df_in = pandas.DataFrame(data=internal_data, dtype=data_type) + df_in = pd.DataFrame(data=internal_data, dtype=data_type) conn.register("dataframe", df_in) conn.execute("CREATE TABLE t AS SELECT * FROM dataframe") @@ -25,15 +24,14 @@ def assert_create_register(internal_data, expected_result, data_type, pandas): class TestCreateTableFromPandas: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_integer_create_table(self, duckdb_cursor, pandas): + def test_integer_create_table(self, duckdb_cursor): # TODO: This should work with other data types e.g., int8... # noqa: TD002, TD003 data_types = ["Int8", "Int16", "Int32", "Int64"] internal_data = [1, 2, 3, 4] expected_result = [(1,), (2,), (3,), (4,)] for data_type in data_types: print(data_type) - assert_create_register(internal_data, expected_result, data_type, pandas) - assert_create(internal_data, expected_result, data_type, pandas) + assert_create_register(internal_data, expected_result, data_type) + assert_create(internal_data, expected_result, data_type) # TODO: Also test other data types # noqa: TD002, TD003 diff --git a/tests/fast/pandas/test_datetime_time.py b/tests/fast/pandas/test_datetime_time.py index 0b2642b0..a2fda09a 100644 --- a/tests/fast/pandas/test_datetime_time.py +++ b/tests/fast/pandas/test_datetime_time.py @@ -1,8 +1,8 @@ from datetime import datetime, time, timezone import numpy as np +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb @@ -10,25 +10,22 @@ class TestDateTimeTime: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_time_high(self, duckdb_cursor, pandas): + def test_time_high(self, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT make_time(23, 1, 34.234345) AS '0'").df() data = [time(hour=23, minute=1, second=34, microsecond=234345)] - df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) + df_in = pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() - pandas.testing.assert_frame_equal(df_out, duckdb_time) + pd.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_time_low(self, duckdb_cursor, pandas): + def test_time_low(self, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT make_time(00, 01, 1.000) AS '0'").df() data = [time(hour=0, minute=1, second=1)] - df_in = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) + df_in = pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) df_out = duckdb.query_df(df_in, "df", "select * from df").df() - pandas.testing.assert_frame_equal(df_out, duckdb_time) + pd.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("input", ["2263-02-28", "9999-01-01"]) - def test_pandas_datetime_big(self, pandas, input): + def test_pandas_datetime_big(self, input): duckdb_con = duckdb.connect() duckdb_con.execute("create table test (date DATE)") @@ -36,8 +33,8 @@ def test_pandas_datetime_big(self, pandas, input): res = duckdb_con.execute("select * from test").df() date_value = np.array([f"{input}"], dtype="datetime64[us]") - df = pandas.DataFrame({"date": date_value}) - pandas.testing.assert_frame_equal(res, df) + df = pd.DataFrame({"date": date_value}) + pd.testing.assert_frame_equal(res, df) def test_timezone_datetime(self): con = duckdb.connect() diff --git a/tests/fast/pandas/test_datetime_timestamp.py b/tests/fast/pandas/test_datetime_timestamp.py index c6d4e3a9..063be160 100644 --- a/tests/fast/pandas/test_datetime_timestamp.py +++ b/tests/fast/pandas/test_datetime_timestamp.py @@ -1,39 +1,35 @@ import datetime +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas from packaging.version import Version -pd = pytest.importorskip("pandas") - class TestDateTimeTimeStamp: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_timestamp_high(self, pandas, duckdb_cursor): + def test_timestamp_high(self, duckdb_cursor): duckdb_time = duckdb_cursor.sql("SELECT '2260-01-01 23:59:00'::TIMESTAMP AS '0'").df() - df_in = pandas.DataFrame( # noqa: F841 + df_in = pd.DataFrame( # noqa: F841 { - 0: pandas.Series( + 0: pd.Series( data=[datetime.datetime(year=2260, month=1, day=1, hour=23, minute=59)], dtype="datetime64[us]", ) } ) df_out = duckdb_cursor.sql("select * from df_in").df() - pandas.testing.assert_frame_equal(df_out, duckdb_time) + pd.testing.assert_frame_equal(df_out, duckdb_time) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_timestamp_low(self, pandas, duckdb_cursor): + def test_timestamp_low(self, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ SELECT '1680-01-01 23:59:00.234243'::TIMESTAMP AS '0' """ ).df() - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { - "0": pandas.Series( + "0": pd.Series( data=[ - pandas.Timestamp( + pd.Timestamp( datetime.datetime(year=1680, month=1, day=1, hour=23, minute=59, microsecond=234243), unit="us", ) @@ -46,13 +42,12 @@ def test_timestamp_low(self, pandas, duckdb_cursor): print("df_in:", df_in["0"].dtype) df_out = duckdb_cursor.sql("select * from df_in").df() print("df_out:", df_out["0"].dtype) - pandas.testing.assert_frame_equal(df_out, duckdb_time) + pd.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): + def test_timestamp_timezone_regular(self, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ SELECT timestamp '2022-01-01 12:00:00' AT TIME ZONE 'Pacific/Easter' as "0" @@ -61,9 +56,9 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): offset = datetime.timedelta(hours=-2) timezone = datetime.timezone(offset) - df_in = pandas.DataFrame( # noqa: F841 + df_in = pd.DataFrame( # noqa: F841 { - 0: pandas.Series( + 0: pd.Series( data=[datetime.datetime(year=2022, month=1, day=1, hour=15, tzinfo=timezone)], dtype="object" ) } @@ -71,13 +66,12 @@ def test_timestamp_timezone_regular(self, pandas, duckdb_cursor): df_out = duckdb_cursor.sql("select * from df_in").df() print(df_out) print(duckdb_time) - pandas.testing.assert_frame_equal(df_out, duckdb_time) + pd.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): + def test_timestamp_timezone_negative_extreme(self, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ SELECT timestamp '2022-01-01 12:00:00' AT TIME ZONE 'Chile/EasterIsland' as "0" @@ -87,21 +81,20 @@ def test_timestamp_timezone_negative_extreme(self, pandas, duckdb_cursor): offset = datetime.timedelta(hours=-19) timezone = datetime.timezone(offset) - df_in = pandas.DataFrame( # noqa: F841 + df_in = pd.DataFrame( # noqa: F841 { - 0: pandas.Series( + 0: pd.Series( data=[datetime.datetime(year=2021, month=12, day=31, hour=22, tzinfo=timezone)], dtype="object" ) } ) df_out = duckdb_cursor.sql("select * from df_in").df() - pandas.testing.assert_frame_equal(df_out, duckdb_time) + pd.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): + def test_timestamp_timezone_positive_extreme(self, duckdb_cursor): duckdb_time = duckdb_cursor.sql( """ SELECT timestamp '2021-12-31 23:00:00' AT TIME ZONE 'Etc/GMT-14' as "0" @@ -111,22 +104,21 @@ def test_timestamp_timezone_positive_extreme(self, pandas, duckdb_cursor): offset = datetime.timedelta(hours=14) timezone = datetime.timezone(offset) - df_in = pandas.DataFrame( # noqa: F841 + df_in = pd.DataFrame( # noqa: F841 { - 0: pandas.Series( + 0: pd.Series( data=[datetime.datetime(year=2021, month=12, day=31, hour=23, tzinfo=timezone)], dtype="object" ) } ) df_out = duckdb_cursor.sql("""select * from df_in""").df() - pandas.testing.assert_frame_equal(df_out, duckdb_time) + pd.testing.assert_frame_equal(df_out, duckdb_time) @pytest.mark.skipif( Version(pd.__version__) < Version("2.0.2"), reason="pandas < 2.0.2 does not properly convert timezones" ) @pytest.mark.parametrize("unit", ["ms", "ns", "s"]) def test_timestamp_timezone_coverage(self, unit, duckdb_cursor): - pd = pytest.importorskip("pandas") ts_df = pd.DataFrame( # noqa: F841 {"ts": pd.Series(data=[pd.Timestamp(datetime.datetime(1990, 12, 21))], dtype=f"datetime64[{unit}]")} ) diff --git a/tests/fast/pandas/test_df_analyze.py b/tests/fast/pandas/test_df_analyze.py index 96cd426d..d9881ffa 100644 --- a/tests/fast/pandas/test_df_analyze.py +++ b/tests/fast/pandas/test_df_analyze.py @@ -1,58 +1,51 @@ -import numpy as np +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas +from conftest import is_string_dtype import duckdb -def create_generic_dataframe(data, pandas): - return pandas.DataFrame({"col0": pandas.Series(data=data, dtype="object")}) +def create_generic_dataframe(data): + return pd.DataFrame({"col0": pd.Series(data=data, dtype="object")}) class TestResolveObjectColumns: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_sample_low_correct(self, duckdb_cursor, pandas): - print(pandas.backend) + def test_sample_low_correct(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=3") data = [1000008, 6, 9, 4, 1, 6] - df = create_generic_dataframe(data, pandas) + df = create_generic_dataframe(data) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() duckdb_df = duckdb_conn.query("select * FROM (VALUES (1000008), (6), (9), (4), (1), (6)) as '0'").df() - pandas.testing.assert_frame_equal(duckdb_df, roundtripped_df, check_dtype=False) + pd.testing.assert_frame_equal(duckdb_df, roundtripped_df, check_dtype=False) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_sample_low_incorrect_detected(self, duckdb_cursor, pandas): + def test_sample_low_incorrect_detected(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=2") # size of list (6) divided by 'pandas_analyze_sample' (2) is the increment used # in this case index 0 (1000008) and index 3 ([4]) are checked, which dont match data = [1000008, 6, 9, [4], 1, 6] - df = create_generic_dataframe(data, pandas) + df = create_generic_dataframe(data) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() # Sample high enough to detect mismatch in types, fallback to VARCHAR - assert roundtripped_df["col0"].dtype == np.dtype("object") + assert is_string_dtype(roundtripped_df["col0"].dtype) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_sample_zero(self, duckdb_cursor, pandas): + def test_sample_zero_infers_varchar(self, duckdb_cursor): + """Test that with analyze disabled, object columns are treated as VARCHAR.""" duckdb_conn = duckdb.connect() # Disable dataframe analyze duckdb_conn.execute("SET pandas_analyze_sample=0") data = [1000008, 6, 9, 3, 1, 6] - df = create_generic_dataframe(data, pandas) + df = create_generic_dataframe(data) roundtripped_df = duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() - # Always converts to VARCHAR - if pandas.backend == "pyarrow": - assert roundtripped_df["col0"].dtype == np.dtype("int64") - else: - assert roundtripped_df["col0"].dtype == np.dtype("object") + # Always converts to VARCHAR when analyze is disabled + assert is_string_dtype(roundtripped_df["col0"].dtype) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_sample_low_incorrect_undetected(self, duckdb_cursor, pandas): + def test_sample_low_incorrect_undetected(self, duckdb_cursor): duckdb_conn = duckdb.connect() duckdb_conn.execute("SET pandas_analyze_sample=1") data = [1000008, 6, 9, [4], [1], 6] - df = create_generic_dataframe(data, pandas) + df = create_generic_dataframe(data) # Sample size is too low to detect the mismatch, exception is raised when trying to convert with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): duckdb.query_df(df, "x", "select * from x", connection=duckdb_conn).df() @@ -65,12 +58,11 @@ def test_reset_analyze_sample_setting(self, duckdb_cursor): res = duckdb_cursor.execute("select current_setting('pandas_analyze_sample')").fetchall() assert res == [(1000,)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_10750(self, duckdb_cursor, pandas): + def test_10750(self, duckdb_cursor): max_row_number = 2000 data = {"id": list(range(max_row_number + 1)), "content": [None for _ in range(max_row_number + 1)]} - pdf = pandas.DataFrame(data=data) + pdf = pd.DataFrame(data=data) duckdb_cursor.register("content", pdf) res = duckdb_cursor.query("select id from content").fetchall() expected = [(i,) for i in range(2001)] diff --git a/tests/fast/pandas/test_df_object_resolution.py b/tests/fast/pandas/test_df_object_resolution.py index 58ae0c94..0c5ab311 100644 --- a/tests/fast/pandas/test_df_object_resolution.py +++ b/tests/fast/pandas/test_df_object_resolution.py @@ -7,16 +7,17 @@ from decimal import Decimal import numpy as np +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas +from conftest import is_string_dtype import duckdb standard_vector_size = duckdb.__standard_vector_size__ -def create_generic_dataframe(data, pandas): - return pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) +def create_generic_dataframe(data): + return pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) def create_repeated_nulls(size): @@ -42,11 +43,11 @@ def __str__(self) -> str: # To avoid DECIMAL being upgraded to DOUBLE (because DOUBLE outranks DECIMAL as a LogicalType) # These floats had their precision preserved as string and are now cast to decimal.Decimal -def ConvertStringToDecimal(data: list, pandas): +def ConvertStringToDecimal(data: list): for i in range(len(data)): if isinstance(data[i], str): data[i] = decimal.Decimal(data[i]) - data = pandas.Series(data=data, dtype="object") + data = pd.Series(data=data, dtype="object") return data @@ -74,9 +75,9 @@ def construct_map(pair): ] -def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, pandas, cursor): +def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, cursor): column_data = creation_method(pair) - df = pandas.DataFrame(data={"col": column_data}) + df = pd.DataFrame(data={"col": column_data}) rel = cursor.query("select col from df") res = rel.fetchall() print("COLUMN_DATA", column_data) @@ -85,29 +86,25 @@ def check_struct_upgrade(expected_type: str, creation_method, pair: ObjectPair, class TestResolveObjectColumns: - # TODO: add support for ArrowPandas # noqa: TD002, TD003 - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_integers(self, pandas, duckdb_cursor): + def test_integers(self, duckdb_cursor): data = [5, 0, 3] - df_in = create_generic_dataframe(data, pandas) + df_in = create_generic_dataframe(data) # These are float64 because pandas would force these to be float64 even if we set them to int8, int16, # int32, int64 respectively - df_expected_res = pandas.DataFrame({"0": pandas.Series(data=data, dtype="int32")}) + df_expected_res = pd.DataFrame({"0": pd.Series(data=data, dtype="int32")}) df_out = duckdb_cursor.sql("SELECT * FROM df_in").df() print(df_out) - pandas.testing.assert_frame_equal(df_expected_res, df_out) + pd.testing.assert_frame_equal(df_expected_res, df_out) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_struct_correct(self, pandas, duckdb_cursor): + def test_struct_correct(self, duckdb_cursor): data = [{"a": 1, "b": 3, "c": 3, "d": 7}] - df = pandas.DataFrame({"0": pandas.Series(data=data)}) + df = pd.DataFrame({"0": pd.Series(data=data)}) duckdb_col = duckdb_cursor.sql("SELECT {a: 1, b: 3, c: 3, d: 7} as '0'").df() converted_col = duckdb_cursor.sql("SELECT * FROM df").df() - pandas.testing.assert_frame_equal(duckdb_col, converted_col) + pd.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_fallback_different_keys(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_map_fallback_different_keys(self, duckdb_cursor): + x = pd.DataFrame( [ [{"a": 1, "b": 3, "c": 3, "d": 7}], [{"a": 1, "b": 3, "c": 3, "d": 7}], @@ -118,7 +115,7 @@ def test_map_fallback_different_keys(self, pandas, duckdb_cursor): ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() - y = pandas.DataFrame( + y = pd.DataFrame( [ [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], @@ -128,11 +125,10 @@ def test_map_fallback_different_keys(self, pandas, duckdb_cursor): ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() - pandas.testing.assert_frame_equal(converted_df, equal_df) + pd.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_fallback_incorrect_amount_of_keys(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_map_fallback_incorrect_amount_of_keys(self, duckdb_cursor): + x = pd.DataFrame( [ [{"a": 1, "b": 3, "c": 3, "d": 7}], [{"a": 1, "b": 3, "c": 3, "d": 7}], @@ -142,7 +138,7 @@ def test_map_fallback_incorrect_amount_of_keys(self, pandas, duckdb_cursor): ] ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() - y = pandas.DataFrame( + y = pd.DataFrame( [ [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], @@ -152,11 +148,10 @@ def test_map_fallback_incorrect_amount_of_keys(self, pandas, duckdb_cursor): ] ) equal_df = duckdb_cursor.sql("SELECT * FROM y").df() - pandas.testing.assert_frame_equal(converted_df, equal_df) + pd.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_struct_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_struct_value_upgrade(self, duckdb_cursor): + x = pd.DataFrame( [ [{"a": 1, "b": 3, "c": 3, "d": "string"}], [{"a": 1, "b": 3, "c": 3, "d": 7}], @@ -165,7 +160,7 @@ def test_struct_value_upgrade(self, pandas, duckdb_cursor): [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) - y = pandas.DataFrame( + y = pd.DataFrame( [ [{"a": 1, "b": 3, "c": 3, "d": "string"}], [{"a": 1, "b": 3, "c": 3, "d": "7"}], @@ -176,11 +171,10 @@ def test_struct_value_upgrade(self, pandas, duckdb_cursor): ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() - pandas.testing.assert_frame_equal(converted_df, equal_df) + pd.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_struct_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_struct_null(self, duckdb_cursor): + x = pd.DataFrame( [ [None], [{"a": 1, "b": 3, "c": 3, "d": 7}], @@ -189,7 +183,7 @@ def test_struct_null(self, pandas, duckdb_cursor): [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) - y = pandas.DataFrame( + y = pd.DataFrame( [ [None], [{"a": 1, "b": 3, "c": 3, "d": 7}], @@ -200,11 +194,10 @@ def test_struct_null(self, pandas, duckdb_cursor): ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() - pandas.testing.assert_frame_equal(converted_df, equal_df) + pd.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_fallback_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_map_fallback_value_upgrade(self, duckdb_cursor): + x = pd.DataFrame( [ [{"a": 1, "b": 3, "c": 3, "d": "test"}], [{"a": 1, "b": 3, "c": 3, "d": 7}], @@ -213,7 +206,7 @@ def test_map_fallback_value_upgrade(self, pandas, duckdb_cursor): [{"a": 1, "b": 3, "c": 3, "d": 7}], ] ) - y = pandas.DataFrame( + y = pd.DataFrame( [ [{"a": "1", "b": "3", "c": "3", "d": "test"}], [{"a": "1", "b": "3", "c": "3", "d": "7"}], @@ -224,11 +217,10 @@ def test_map_fallback_value_upgrade(self, pandas, duckdb_cursor): ) converted_df = duckdb_cursor.sql("SELECT * FROM x").df() equal_df = duckdb_cursor.sql("SELECT * FROM y").df() - pandas.testing.assert_frame_equal(converted_df, equal_df) + pd.testing.assert_frame_equal(converted_df, equal_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_correct(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_map_correct(self, duckdb_cursor): + x = pd.DataFrame( [ [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], @@ -255,23 +247,21 @@ def test_map_correct(self, pandas, duckdb_cursor): duckdb_col = duckdb_cursor.sql("select a from tmp AS '0'").df() print(duckdb_col.columns) print(converted_col.columns) - pandas.testing.assert_frame_equal(converted_col, duckdb_col) + pd.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) @pytest.mark.parametrize("sample_size", [1, 10]) @pytest.mark.parametrize("fill", [1000, 10000]) @pytest.mark.parametrize("get_data", [create_repeated_nulls, create_trailing_non_null]) - def test_analyzing_nulls(self, pandas, duckdb_cursor, fill, sample_size, get_data): + def test_analyzing_nulls(self, duckdb_cursor, fill, sample_size, get_data): data = get_data(fill) - df1 = pandas.DataFrame(data={"col1": data}) + df1 = pd.DataFrame(data={"col1": data}) duckdb_cursor.execute(f"SET GLOBAL pandas_analyze_sample={sample_size}") df = duckdb_cursor.execute("select * from df1").df() - pandas.testing.assert_frame_equal(df1, df) + pd.testing.assert_frame_equal(df1, df, check_dtype=False) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_nested_map(self, pandas, duckdb_cursor): - df = pandas.DataFrame(data={"col1": [{"a": {"b": {"x": "A", "y": "B"}}}, {"c": {"b": {"x": "A"}}}]}) + def test_nested_map(self, duckdb_cursor): + df = pd.DataFrame(data={"col1": [{"a": {"b": {"x": "A", "y": "B"}}}, {"c": {"b": {"x": "A"}}}]}) rel = duckdb_cursor.sql("select * from df") expected_rel = duckdb_cursor.sql( @@ -287,9 +277,8 @@ def test_nested_map(self, pandas, duckdb_cursor): expected_res = str(expected_rel) assert res == expected_res - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_map_value_upgrade(self, duckdb_cursor): + x = pd.DataFrame( [ [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, "test"]}], [{"key": ["a", "b", "c", "d"], "value": [1, 3, 3, 7]}], @@ -321,36 +310,31 @@ def test_map_value_upgrade(self, pandas, duckdb_cursor): duckdb_col = duckdb_cursor.sql("select a from tmp2 AS '0'").df() print(duckdb_col.columns) print(converted_col.columns) - pandas.testing.assert_frame_equal(converted_col, duckdb_col) + pd.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_duplicate(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{"key": ["a", "a", "b"], "value": [4, 0, 4]}]]) + def test_map_duplicate(self, duckdb_cursor): + x = pd.DataFrame([[{"key": ["a", "a", "b"], "value": [4, 0, 4]}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys must be unique"): duckdb_cursor.sql("select * from x").show() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{"key": [None, "a", "b"], "value": [4, 0, 4]}]]) + def test_map_nullkey(self, duckdb_cursor): + x = pd.DataFrame([[{"key": [None, "a", "b"], "value": [4, 0, 4]}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_nullkeylist(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{"key": None, "value": None}]]) + def test_map_nullkeylist(self, duckdb_cursor): + x = pd.DataFrame([[{"key": None, "value": None}]]) converted_col = duckdb_cursor.sql("select * from x").df() duckdb_col = duckdb_cursor.sql("SELECT MAP(NULL, NULL) as '0'").df() - pandas.testing.assert_frame_equal(duckdb_col, converted_col) + pd.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_fallback_nullkey(self, pandas, duckdb_cursor): - x = pandas.DataFrame([[{"a": 4, None: 0, "c": 4}], [{"a": 4, None: 0, "d": 4}]]) + def test_map_fallback_nullkey(self, duckdb_cursor): + x = pd.DataFrame([[{"a": 4, None: 0, "c": 4}], [{"a": 4, None: 0, "d": 4}]]) with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_map_fallback_nullkey_coverage(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_map_fallback_nullkey_coverage(self, duckdb_cursor): + x = pd.DataFrame( [ [{"key": None, "value": None}], [{"key": None, None: 5}], @@ -359,8 +343,7 @@ def test_map_fallback_nullkey_coverage(self, pandas, duckdb_cursor): with pytest.raises(duckdb.InvalidInputException, match="Map keys can not be NULL"): converted_col = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_structs_in_nested_types(self, pandas, duckdb_cursor): + def test_structs_in_nested_types(self, duckdb_cursor): # This test is testing a bug that occurred when type upgrades occurred inside nested types # STRUCT(key1 varchar) + STRUCT(key1 varchar, key2 varchar) turns into MAP # But when inside a nested structure, this upgrade did not happen properly @@ -373,20 +356,19 @@ def test_structs_in_nested_types(self, pandas, duckdb_cursor): } for pair in pairs.values(): - check_struct_upgrade("MAP(VARCHAR, INTEGER)[]", construct_list, pair, pandas, duckdb_cursor) + check_struct_upgrade("MAP(VARCHAR, INTEGER)[]", construct_list, pair, duckdb_cursor) for key, pair in pairs.items(): expected_type = "MAP(VARCHAR, MAP(VARCHAR, INTEGER))" if key == "v4" else "STRUCT(v1 MAP(VARCHAR, INTEGER))" - check_struct_upgrade(expected_type, construct_struct, pair, pandas, duckdb_cursor) + check_struct_upgrade(expected_type, construct_struct, pair, duckdb_cursor) for pair in pairs.values(): - check_struct_upgrade("MAP(VARCHAR, MAP(VARCHAR, INTEGER))", construct_map, pair, pandas, duckdb_cursor) + check_struct_upgrade("MAP(VARCHAR, MAP(VARCHAR, INTEGER))", construct_map, pair, duckdb_cursor) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_structs_of_different_sizes(self, pandas, duckdb_cursor): + def test_structs_of_different_sizes(self, duckdb_cursor): # This list has both a STRUCT(v1) and a STRUCT(v1, v2) member # Those can't be combined - df = pandas.DataFrame( + df = pd.DataFrame( data={ "col": [ [ @@ -416,9 +398,8 @@ def test_structs_of_different_sizes(self, pandas, duckdb_cursor): ): res = duckdb_cursor.execute("select $1", [malformed_struct]) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_struct_key_conversion(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_struct_key_conversion(self, duckdb_cursor): + x = pd.DataFrame( [ [{IntString(5): 1, IntString(-25): 3, IntString(32): 3, IntString(32456): 7}], ] @@ -426,43 +407,38 @@ def test_struct_key_conversion(self, pandas, duckdb_cursor): duckdb_col = duckdb_cursor.sql("select {'5':1, '-25':3, '32':3, '32456':7} as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") - pandas.testing.assert_frame_equal(duckdb_col, converted_col) + pd.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_list_correct(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{"0": [[5], [34], [-245]]}]) + def test_list_correct(self, duckdb_cursor): + x = pd.DataFrame([{"0": [[5], [34], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], [34], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") - pandas.testing.assert_frame_equal(duckdb_col, converted_col) + pd.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_list_contains_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{"0": [[5], None, [-245]]}]) + def test_list_contains_null(self, duckdb_cursor): + x = pd.DataFrame([{"0": [[5], None, [-245]]}]) duckdb_col = duckdb_cursor.sql("select [[5], NULL, [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") - pandas.testing.assert_frame_equal(duckdb_col, converted_col) + pd.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_list_starts_with_null(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{"0": [None, [5], [-245]]}]) + def test_list_starts_with_null(self, duckdb_cursor): + x = pd.DataFrame([{"0": [None, [5], [-245]]}]) duckdb_col = duckdb_cursor.sql("select [NULL, [5], [-245]] as '0'").df() converted_col = duckdb_cursor.sql("select * from x").df() duckdb_cursor.sql("drop view if exists tbl") - pandas.testing.assert_frame_equal(duckdb_col, converted_col) + pd.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_list_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame([{"0": [["5"], [34], [-245]]}]) + def test_list_value_upgrade(self, duckdb_cursor): + x = pd.DataFrame([{"0": [["5"], [34], [-245]]}]) duckdb_rel = duckdb_cursor.sql("select [['5'], ['34'], ['-245']] as '0'") duckdb_col = duckdb_rel.df() converted_col = duckdb_cursor.sql("select * from x").df() - pandas.testing.assert_frame_equal(duckdb_col, converted_col) + pd.testing.assert_frame_equal(duckdb_col, converted_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_list_column_value_upgrade(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_list_column_value_upgrade(self, duckdb_cursor): + x = pd.DataFrame( [ [[1, 25, 300]], [[500, 345, 30]], @@ -496,46 +472,35 @@ def test_list_column_value_upgrade(self, pandas, duckdb_cursor): duckdb_col = duckdb_cursor.sql("select a from tmp3 AS '0'").df() print(duckdb_col.columns) print(converted_col.columns) - pandas.testing.assert_frame_equal(converted_col, duckdb_col) + pd.testing.assert_frame_equal(converted_col, duckdb_col) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_ubigint_object_conversion(self, pandas, duckdb_cursor): + def test_ubigint_object_conversion(self, duckdb_cursor): # UBIGINT + TINYINT would result in HUGEINT, but conversion to HUGEINT is not supported yet from pandas->duckdb # So this instead becomes a DOUBLE data = [18446744073709551615, 0] - x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) + x = pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() - if pandas.backend == "numpy_nullable": - float64 = np.dtype("float64") - assert isinstance(converted_col["0"].dtype, float64.__class__) - else: - uint64 = np.dtype("uint64") - assert isinstance(converted_col["0"].dtype, uint64.__class__) - - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_double_object_conversion(self, pandas, duckdb_cursor): + float64 = np.dtype("float64") + assert isinstance(converted_col["0"].dtype, float64.__class__) + + def test_double_object_conversion(self, duckdb_cursor): data = [18446744073709551616, 0] - x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) + x = pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() double_dtype = np.dtype("float64") assert isinstance(converted_col["0"].dtype, double_dtype.__class__) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) @pytest.mark.xfail( condition=platform.system() == "Emscripten", reason="older numpy raises a warning when running with Pyodide", ) - def test_numpy_object_with_stride(self, pandas, duckdb_cursor): - df = pandas.DataFrame(columns=["idx", "evens", "zeros"]) - - df["idx"] = list(range(10)) - for col in df.columns[1:]: - df[col].values[:] = 0 + def test_numpy_object_with_stride(self, duckdb_cursor): + # Create 2D array in C-order (row-major) + data = np.zeros((10, 3), dtype=np.int64) + data[:, 0] = np.arange(10) + data[:, 1] = np.arange(0, 20, 2) - counter = 0 - for i in range(10): - df.loc[df["idx"] == i, "evens"] += counter - counter += 2 + df = pd.DataFrame(data, columns=["idx", "evens", "zeros"]) res = duckdb_cursor.sql("select * from df").fetchall() assert res == [ @@ -551,27 +516,24 @@ def test_numpy_object_with_stride(self, pandas, duckdb_cursor): (9, 18, 0), ] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numpy_stringliterals(self, pandas, duckdb_cursor): - df = pandas.DataFrame({"x": list(map(np.str_, range(3)))}) + def test_numpy_stringliterals(self, duckdb_cursor): + df = pd.DataFrame({"x": list(map(np.str_, range(3)))}) res = duckdb_cursor.execute("select * from df").fetchall() assert res == [("0",), ("1",), ("2",)] - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_integer_conversion_fail(self, pandas, duckdb_cursor): + def test_integer_conversion_fail(self, duckdb_cursor): data = [2**10000, 0] - x = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) + x = pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) converted_col = duckdb_cursor.sql("select * from x").df() print(converted_col["0"]) - double_dtype = np.dtype("object") - assert isinstance(converted_col["0"].dtype, double_dtype.__class__) + # default: VARCHAR + assert is_string_dtype(converted_col["0"].dtype) # Most of the time numpy.datetime64 is just a wrapper around a datetime.datetime object # But to support arbitrary precision, it can fall back to using an `int` internally - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) # Which we don't support yet - def test_numpy_datetime(self, pandas, duckdb_cursor): + def test_numpy_datetime(self, duckdb_cursor): numpy = pytest.importorskip("numpy") data = [] @@ -579,25 +541,23 @@ def test_numpy_datetime(self, pandas, duckdb_cursor): data += [numpy.datetime64("2022-02-21T06:59:23.324812")] * standard_vector_size data += [numpy.datetime64("1974-06-05T13:12:01.000000")] * standard_vector_size data += [numpy.datetime64("2049-01-13T00:24:31.999999")] * standard_vector_size - x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) + x = pd.DataFrame({"dates": pd.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_numpy_datetime_int_internally(self, pandas, duckdb_cursor): + def test_numpy_datetime_int_internally(self, duckdb_cursor): numpy = pytest.importorskip("numpy") data = [numpy.datetime64("2022-12-10T21:38:24.0000000000001")] - x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) + x = pd.DataFrame({"dates": pd.Series(data=data, dtype="object")}) with pytest.raises( duckdb.ConversionException, match=re.escape("Conversion Error: Unimplemented type for cast (BIGINT -> TIMESTAMP)"), ): rel = duckdb.query_df(x, "x", "create table dates as select dates::TIMESTAMP WITHOUT TIME ZONE from x") - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_fallthrough_object_conversion(self, duckdb_cursor): + x = pd.DataFrame( [ [IntString(4)], [IntString(2)], @@ -605,11 +565,10 @@ def test_fallthrough_object_conversion(self, pandas, duckdb_cursor): ] ) duckdb_col = duckdb_cursor.sql("select * from x").df() - df_expected_res = pandas.DataFrame({"0": pandas.Series(["4", "2", "0"])}) - pandas.testing.assert_frame_equal(duckdb_col, df_expected_res) + df_expected_res = pd.DataFrame({"0": pd.Series(["4", "2", "0"])}) + pd.testing.assert_frame_equal(duckdb_col, df_expected_res, check_dtype=False) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal(self, pandas, duckdb_cursor): + def test_numeric_decimal(self, duckdb_cursor): # DuckDB uses DECIMAL where possible, so all the 'float' types here are actually DECIMAL reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( @@ -625,14 +584,12 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): duckdb_cursor.execute(reference_query) # Because of this we need to wrap these native floats as DECIMAL for this test, to avoid these decimals being # "upgraded" to DOUBLE - x = pandas.DataFrame( + x = pd.DataFrame( { - "0": ConvertStringToDecimal([5, "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), - "1": ConvertStringToDecimal( - [5002340, 13, "-12.0000000005", "7453324234.0", None, "-324234234"], pandas - ), + "0": ConvertStringToDecimal([5, "12.0", "-123.0", "-234234.0", None, "1.234"]), + "1": ConvertStringToDecimal([5002340, 13, "-12.0000000005", "7453324234.0", None, "-324234234"]), "2": ConvertStringToDecimal( - ["-234234234234.0", "324234234.00000005", -128, 345345, "1E5", "1324234359"], pandas + ["-234234234234.0", "324234234.00000005", -128, 345345, "1E5", "1324234359"] ), } ) @@ -641,9 +598,8 @@ def test_numeric_decimal(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_numeric_decimal_coverage(self, duckdb_cursor): + x = pd.DataFrame( {"0": [Decimal("nan"), Decimal("+nan"), Decimal("-nan"), Decimal("inf"), Decimal("+inf"), Decimal("-inf")]} ) conversion = duckdb_cursor.sql("select * from x").fetchall() @@ -659,22 +615,18 @@ def test_numeric_decimal_coverage(self, pandas, duckdb_cursor): assert str(conversion) == "[(nan,), (nan,), (nan,), (inf,), (inf,), (inf,)]" # Test that the column 'offset' is actually used when converting, - - @pytest.mark.parametrize( - "pandas", [NumpyPandas(), ArrowPandas()] - ) # and that the same 2048 (STANDARD_VECTOR_SIZE) values are not being scanned over and over again - def test_multiple_chunks(self, pandas, duckdb_cursor): + # and that the same 2048 (STANDARD_VECTOR_SIZE) values are not being scanned over and over again + def test_multiple_chunks(self, duckdb_cursor): data = [] data += [datetime.date(2022, 9, 13) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 14) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 15) for x in range(standard_vector_size)] data += [datetime.date(2022, 9, 16) for x in range(standard_vector_size)] - x = pandas.DataFrame({"dates": pandas.Series(data=data, dtype="object")}) + x = pd.DataFrame({"dates": pd.Series(data=data, dtype="object")}) res = duckdb_cursor.sql("select distinct * from x").df() assert len(res["dates"].__array__()) == 4 - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): + def test_multiple_chunks_aggregate(self, duckdb_cursor): duckdb_cursor.execute("SET GLOBAL pandas_analyze_sample=4096") duckdb_cursor.execute( "create table dates as select '2022-09-14'::DATE + INTERVAL (i::INTEGER) DAY as i from range(4096) tbl(i);" @@ -684,7 +636,7 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): date_df = res.copy() # Convert the dataframe to datetime - date_df["i"] = pandas.to_datetime(res["i"]).dt.date + date_df["i"] = pd.to_datetime(res["i"]).dt.date assert str(date_df["i"].dtype) == "object" expected_res = [ @@ -722,7 +674,7 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): ] # Convert the dataframe to datetime date_df = res.copy() - date_df["i"] = pandas.to_datetime(res["i"]).dt.date + date_df["i"] = pd.to_datetime(res["i"]).dt.date assert str(date_df["i"].dtype) == "object" actual_res = duckdb_cursor.sql( @@ -737,21 +689,19 @@ def test_multiple_chunks_aggregate(self, pandas, duckdb_cursor): assert expected_res == actual_res - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_mixed_object_types(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_mixed_object_types(self, duckdb_cursor): + x = pd.DataFrame( { - "nested": pandas.Series( + "nested": pd.Series( data=[{"a": 1, "b": 2}, [5, 4, 3], {"key": [1, 2, 3], "value": ["a", "b", "c"]}], dtype="object" ), } ) res = duckdb_cursor.sql("select * from x").df() - assert res["nested"].dtype == np.dtype("object") + assert is_string_dtype(res["nested"].dtype) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_struct_deeply_nested_in_struct(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_struct_deeply_nested_in_struct(self, duckdb_cursor): + x = pd.DataFrame( [ { # STRUCT(b STRUCT(x VARCHAR, y VARCHAR)) @@ -768,9 +718,8 @@ def test_struct_deeply_nested_in_struct(self, pandas, duckdb_cursor): res = duckdb_cursor.sql("select * from x").fetchall() assert res == [({"b": {"x": "A", "y": "B"}},), ({"b": {"x": "A"}},)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): - x = pandas.DataFrame( + def test_struct_deeply_nested_in_list(self, duckdb_cursor): + x = pd.DataFrame( { "a": [ [ @@ -787,16 +736,14 @@ def test_struct_deeply_nested_in_list(self, pandas, duckdb_cursor): res = duckdb_cursor.sql("select * from x").fetchall() assert res == [([{"x": "A", "y": "B"}, {"x": "A"}],)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_analyze_sample_too_small(self, pandas, duckdb_cursor): + def test_analyze_sample_too_small(self, duckdb_cursor): data = [1 for _ in range(9)] + [[1, 2, 3]] + [1 for _ in range(9991)] - x = pandas.DataFrame({"a": pandas.Series(data=data)}) + x = pd.DataFrame({"a": pd.Series(data=data)}) with pytest.raises(duckdb.InvalidInputException, match="Failed to cast value: Unimplemented type for cast"): res = duckdb_cursor.sql("select * from x").df() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): - decimals = pandas.DataFrame( + def test_numeric_decimal_zero_fractional(self, duckdb_cursor): + decimals = pd.DataFrame( data={ "0": [ Decimal("0.00"), @@ -827,8 +774,7 @@ def test_numeric_decimal_zero_fractional(self, pandas, duckdb_cursor): assert conversion == reference - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): + def test_numeric_decimal_incompatible(self, duckdb_cursor): reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( VALUES @@ -841,13 +787,11 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): ) tbl(a, b, c); """ duckdb_cursor.execute(reference_query) - x = pandas.DataFrame( + x = pd.DataFrame( { - "0": ConvertStringToDecimal(["5", "12.0", "-123.0", "-234234.0", None, "1.234"], pandas), - "1": ConvertStringToDecimal([5002340, 13, "-12.0000000005", 7453324234, None, "-324234234"], pandas), - "2": ConvertStringToDecimal( - [-234234234234, "324234234.00000005", -128, 345345, 0, "1324234359"], pandas - ), + "0": ConvertStringToDecimal(["5", "12.0", "-123.0", "-234234.0", None, "1.234"]), + "1": ConvertStringToDecimal([5002340, 13, "-12.0000000005", 7453324234, None, "-324234234"]), + "2": ConvertStringToDecimal([-234234234234, "324234234.00000005", -128, 345345, 0, "1324234359"]), } ) reference = duckdb_cursor.sql("select * from tbl").fetchall() @@ -857,11 +801,9 @@ def test_numeric_decimal_incompatible(self, pandas, duckdb_cursor): print(reference) print(conversion) - @pytest.mark.parametrize( - "pandas", [NumpyPandas(), ArrowPandas()] - ) # result: [('1E-28',), ('10000000000000000000000000.0',)] - def test_numeric_decimal_combined(self, pandas, duckdb_cursor): - decimals = pandas.DataFrame( + # result: [('1E-28',), ('10000000000000000000000000.0',)] + def test_numeric_decimal_combined(self, duckdb_cursor): + decimals = pd.DataFrame( data={"0": [Decimal("0.0000000000000000000000000001"), Decimal("10000000000000000000000000.0")]} ) reference_query = """ @@ -879,9 +821,8 @@ def test_numeric_decimal_combined(self, pandas, duckdb_cursor): print(conversion) # result: [('1234.0',), ('123456789.0',), ('1234567890123456789.0',), ('0.1234567890123456789',)] - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): - decimals = pandas.DataFrame( + def test_numeric_decimal_varying_sizes(self, duckdb_cursor): + decimals = pd.DataFrame( data={ "0": [ Decimal("1234.0"), @@ -907,14 +848,13 @@ def test_numeric_decimal_varying_sizes(self, pandas, duckdb_cursor): print(reference) print(conversion) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): + def test_numeric_decimal_fallback_to_double(self, duckdb_cursor): # The widths of these decimal values are bigger than the max supported width for DECIMAL data = [ Decimal("1.234567890123456789012345678901234567890123456789"), Decimal("123456789012345678901234567890123456789012345678.0"), ] - decimals = pandas.DataFrame(data={"0": data}) + decimals = pd.DataFrame(data={"0": data}) reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( VALUES @@ -928,8 +868,7 @@ def test_numeric_decimal_fallback_to_double(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): + def test_numeric_decimal_double_mixed(self, duckdb_cursor): data = [ Decimal("1.234"), Decimal("1.234567891234567890123456789012345678901234567890123456789"), @@ -940,7 +879,7 @@ def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): Decimal("1232354.000000000000000000000000000035"), Decimal("123.5e300"), ] - decimals = pandas.DataFrame(data={"0": data}) + decimals = pd.DataFrame(data={"0": data}) reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( VALUES @@ -960,10 +899,9 @@ def test_numeric_decimal_double_mixed(self, pandas, duckdb_cursor): assert conversion == reference assert isinstance(conversion[0][0], float) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_numeric_decimal_out_of_range(self, pandas, duckdb_cursor): + def test_numeric_decimal_out_of_range(self, duckdb_cursor): data = [Decimal("1.234567890123456789012345678901234567"), Decimal("123456789012345678901234567890123456.0")] - decimals = pandas.DataFrame(data={"0": data}) + decimals = pd.DataFrame(data={"0": data}) reference_query = """ CREATE TABLE tbl AS SELECT * FROM ( VALUES diff --git a/tests/fast/pandas/test_df_recursive_nested.py b/tests/fast/pandas/test_df_recursive_nested.py index 871132ae..c3971cf6 100644 --- a/tests/fast/pandas/test_df_recursive_nested.py +++ b/tests/fast/pandas/test_df_recursive_nested.py @@ -1,5 +1,4 @@ -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb from duckdb import Value @@ -21,39 +20,35 @@ def create_reference_query(): class TestDFRecursiveNested: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_list_of_structs(self, duckdb_cursor, pandas): + def test_list_of_structs(self, duckdb_cursor): data = [[{"a": 5}, NULL, {"a": NULL}], NULL, [{"a": 5}, NULL, {"a": NULL}]] reference_query = create_reference_query() - df = pandas.DataFrame([{"a": data}]) + df = pd.DataFrame([{"a": data}]) check_equal(duckdb_cursor, df, reference_query, Value(data, "STRUCT(a INTEGER)[]")) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_list_of_map(self, duckdb_cursor, pandas): + def test_list_of_map(self, duckdb_cursor): # LIST(MAP(VARCHAR, VARCHAR)) data = [[{5: NULL}, NULL, {}], NULL, [NULL, {3: NULL, 2: "a", 4: NULL}, {"a": 1, "b": 2, "c": 3}]] reference_query = create_reference_query() print(reference_query) - df = pandas.DataFrame([{"a": data}]) + df = pd.DataFrame([{"a": data}]) check_equal(duckdb_cursor, df, reference_query, Value(data, "MAP(VARCHAR, VARCHAR)[][]")) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_recursive_list(self, duckdb_cursor, pandas): + def test_recursive_list(self, duckdb_cursor): # LIST(LIST(LIST(LIST(INTEGER)))) data = [[[[3, NULL, 5], NULL], NULL, [[5, -20, NULL]]], NULL, [[[NULL]], [[]], NULL]] reference_query = create_reference_query() - df = pandas.DataFrame([{"a": data}]) + df = pd.DataFrame([{"a": data}]) check_equal(duckdb_cursor, df, reference_query, Value(data, "INTEGER[][][][]")) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_recursive_struct(self, duckdb_cursor, pandas): + def test_recursive_struct(self, duckdb_cursor): # STRUCT(STRUCT(STRUCT(LIST))) data = { "A": {"a": {"1": [1, 2, 3]}, "b": NULL, "c": {"1": NULL}}, "B": {"a": {"1": [1, NULL, 3]}, "b": NULL, "c": {"1": NULL}}, } reference_query = create_reference_query() - df = pandas.DataFrame([{"a": data}]) + df = pd.DataFrame([{"a": data}]) check_equal( duckdb_cursor, df, @@ -89,8 +84,7 @@ def test_recursive_struct(self, duckdb_cursor, pandas): ), ) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_recursive_map(self, duckdb_cursor, pandas): + def test_recursive_map(self, duckdb_cursor): # MAP( # MAP( # INTEGER, @@ -106,13 +100,12 @@ def test_recursive_map(self, duckdb_cursor, pandas): "value": [1, 2], } reference_query = create_reference_query() - df = pandas.DataFrame([{"a": data}]) + df = pd.DataFrame([{"a": data}]) check_equal( duckdb_cursor, df, reference_query, Value(data, "MAP(MAP(INTEGER, MAP(INTEGER, VARCHAR)), INTEGER)") ) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_recursive_stresstest(self, duckdb_cursor, pandas): + def test_recursive_stresstest(self, duckdb_cursor): data = [ { "a": { @@ -134,7 +127,7 @@ def test_recursive_stresstest(self, duckdb_cursor, pandas): } ] reference_query = create_reference_query() - df = pandas.DataFrame([{"a": data}]) + df = pd.DataFrame([{"a": data}]) duckdb_type = """ STRUCT( a MAP( diff --git a/tests/fast/pandas/test_implicit_pandas_scan.py b/tests/fast/pandas/test_implicit_pandas_scan.py index 76f2c200..af3a8758 100644 --- a/tests/fast/pandas/test_implicit_pandas_scan.py +++ b/tests/fast/pandas/test_implicit_pandas_scan.py @@ -1,43 +1,27 @@ # simple DB API testcase import pandas as pd -import pytest -from conftest import ArrowPandas, NumpyPandas -from packaging.version import Version import duckdb -numpy_nullable_df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val4", "CoL2": 17}]) - -try: - from pandas.compat import pa_version_under7p0 - - pyarrow_dtypes_enabled = not pa_version_under7p0 -except Exception: - pyarrow_dtypes_enabled = False - -if Version(pd.__version__) >= Version("2.0.0") and pyarrow_dtypes_enabled: - pyarrow_df = numpy_nullable_df.convert_dtypes(dtype_backend="pyarrow") -else: - # dtype_backend is not supported in pandas < 2.0.0 - pyarrow_df = numpy_nullable_df - class TestImplicitPandasScan: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_local_pandas_scan(self, duckdb_cursor, pandas): + def test_local_pandas_scan(self, duckdb_cursor): con = duckdb.connect() - df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) # noqa: F841 + df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) # noqa: F841 r1 = con.execute("select * from df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val3" assert r1["CoL2"][0] == 1.05 assert r1["CoL2"][1] == 17 - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_global_pandas_scan(self, duckdb_cursor, pandas): + def test_global_pandas_scan(self, duckdb_cursor): + """Test that DuckDB can scan a module-level DataFrame variable.""" con = duckdb.connect() - r1 = con.execute(f"select * from {pandas.backend}_df").fetchdf() + # Create a global-scope dataframe for this test + global test_global_df + test_global_df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val4", "CoL2": 17}]) + r1 = con.execute("select * from test_global_df").fetchdf() assert r1["COL1"][0] == "val1" assert r1["COL1"][1] == "val4" assert r1["CoL2"][0] == 1.05 diff --git a/tests/fast/pandas/test_import_cache.py b/tests/fast/pandas/test_import_cache.py index eb1c8fb8..1b3a98ee 100644 --- a/tests/fast/pandas/test_import_cache.py +++ b/tests/fast/pandas/test_import_cache.py @@ -1,29 +1,38 @@ +import importlib.util + +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb -@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) -def test_import_cache_explicit_dtype(pandas): - df = pandas.DataFrame( # noqa: F841 +@pytest.mark.parametrize( + "string_dtype", + [ + "python", + pytest.param( + "pyarrow", marks=pytest.mark.skipif(not importlib.util.find_spec("pyarrow"), reason="pyarrow not installed") + ), + ], +) +def test_import_cache_explicit_dtype(string_dtype): + df = pd.DataFrame( # noqa: F841 { "id": [1, 2, 3], - "value": pandas.Series(["123.123", pandas.NaT, pandas.NA], dtype=pandas.StringDtype(storage="python")), + "value": pd.Series(["123.123", pd.NaT, pd.NA], dtype=pd.StringDtype(storage=string_dtype)), } ) con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df["value"][1] is None - assert result_df["value"][2] is None + assert pd.isna(result_df["value"][1]) + assert pd.isna(result_df["value"][2]) -@pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) -def test_import_cache_implicit_dtype(pandas): - df = pandas.DataFrame({"id": [1, 2, 3], "value": pandas.Series(["123.123", pandas.NaT, pandas.NA])}) # noqa: F841 +def test_import_cache_implicit_dtype(): + df = pd.DataFrame({"id": [1, 2, 3], "value": pd.Series(["123.123", pd.NaT, pd.NA])}) # noqa: F841 con = duckdb.connect() result_df = con.query("select id, value from df").df() - assert result_df["value"][1] is None - assert result_df["value"][2] is None + assert pd.isna(result_df["value"][1]) + assert pd.isna(result_df["value"][2]) diff --git a/tests/fast/pandas/test_issue_1767.py b/tests/fast/pandas/test_issue_1767.py index 48d3e852..1677001e 100644 --- a/tests/fast/pandas/test_issue_1767.py +++ b/tests/fast/pandas/test_issue_1767.py @@ -1,22 +1,20 @@ #!/usr/bin/env python -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb # Join from pandas not matching identical strings #1767 class TestIssue1767: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_unicode_join_pandas(self, duckdb_cursor, pandas): - A = pandas.DataFrame({"key": ["a", "п"]}) - B = pandas.DataFrame({"key": ["a", "п"]}) + def test_unicode_join_pandas(self, duckdb_cursor): + A = pd.DataFrame({"key": ["a", "п"]}) + B = pd.DataFrame({"key": ["a", "п"]}) con = duckdb.connect(":memory:") arrow = con.register("A", A).register("B", B) q = arrow.query("""SELECT key FROM "A" FULL JOIN "B" USING ("key") ORDER BY key""") result = q.df() d = {"key": ["a", "п"]} - df = pandas.DataFrame(data=d) - pandas.testing.assert_frame_equal(result, df) + df = pd.DataFrame(data=d) + pd.testing.assert_frame_equal(result, df, check_dtype=False) diff --git a/tests/fast/pandas/test_limit.py b/tests/fast/pandas/test_limit.py index 51c4a382..2fb6c769 100644 --- a/tests/fast/pandas/test_limit.py +++ b/tests/fast/pandas/test_limit.py @@ -1,13 +1,11 @@ -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb class TestLimitPandas: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_limit_df(self, duckdb_cursor, pandas): - df_in = pandas.DataFrame( + def test_limit_df(self, duckdb_cursor): + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -15,9 +13,8 @@ def test_limit_df(self, duckdb_cursor, pandas): limit_df = duckdb.limit(df_in, 2) assert len(limit_df.execute().fetchall()) == 2 - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_aggregate_df(self, duckdb_cursor, pandas): - df_in = pandas.DataFrame( + def test_aggregate_df(self, duckdb_cursor): + df_in = pd.DataFrame( { "numbers": [1, 2, 2, 2], } diff --git a/tests/fast/pandas/test_new_string_type.py b/tests/fast/pandas/test_new_string_type.py new file mode 100644 index 00000000..bd13d53a --- /dev/null +++ b/tests/fast/pandas/test_new_string_type.py @@ -0,0 +1,20 @@ +import pandas as pd +import pytest +from packaging.version import Version + +import duckdb + + +@pytest.mark.skipif( + Version(pd.__version__) < Version("3.0"), reason="Pandas < 3.0 doesn't have the new string type yet" +) +def test_new_str_type_pandas_3_0(): + df = pd.DataFrame({"s": ["DuckDB"]}) # noqa: F841 + duckdb.sql("select * from df") + + +@pytest.mark.skipif(Version(pd.__version__) >= Version("3.0"), reason="Pandas >= 3.0 has the new string type") +def test_new_str_type_pandas_lt_3_0(): + pd.options.future.infer_string = True + df = pd.DataFrame({"s": ["DuckDB"]}) # noqa: F841 + duckdb.sql("select * from df") diff --git a/tests/fast/pandas/test_pandas_arrow.py b/tests/fast/pandas/test_pandas_arrow.py index 0cb1f00d..ed387d52 100644 --- a/tests/fast/pandas/test_pandas_arrow.py +++ b/tests/fast/pandas/test_pandas_arrow.py @@ -2,16 +2,15 @@ import numpy as np import pytest -from conftest import pandas_supports_arrow_backend import duckdb pd = pytest.importorskip("pandas", "2.0.0") +pytest.importorskip("pyarrow") from pandas.api.types import is_integer_dtype # noqa: E402 -@pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") class TestPandasArrow: def test_pandas_arrow(self, duckdb_cursor): pd = pytest.importorskip("pandas") diff --git a/tests/fast/pandas/test_pandas_na.py b/tests/fast/pandas/test_pandas_na.py index 6462c298..166fc21e 100644 --- a/tests/fast/pandas/test_pandas_na.py +++ b/tests/fast/pandas/test_pandas_na.py @@ -1,8 +1,9 @@ import platform import numpy as np +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas +from conftest import is_string_dtype import duckdb @@ -10,27 +11,25 @@ def assert_nullness(items, null_indices): for i in range(len(items)): if i in null_indices: - assert items[i] is None + assert pd.isna(items[i]) else: - assert items[i] is not None + assert not pd.isna(items[i]) @pytest.mark.skipif(platform.system() == "Emscripten", reason="Pandas interaction is broken in Pyodide 3.11") class TestPandasNA: @pytest.mark.parametrize("rows", [100, duckdb.__standard_vector_size__, 5000, 1000000]) - @pytest.mark.parametrize("pd", [NumpyPandas(), ArrowPandas()]) - def test_pandas_string_null(self, duckdb_cursor, rows, pd): - df: pd.DataFrame = pd.DataFrame(index=np.arange(rows)) + def test_pandas_string_null(self, duckdb_cursor, rows): + df = pd.DataFrame(index=np.arange(rows)) df["string_column"] = pd.Series(dtype="string") e_df_rel = duckdb_cursor.from_df(df) assert e_df_rel.types == ["VARCHAR"] roundtrip = e_df_rel.df() - assert roundtrip["string_column"].dtype == "object" + assert is_string_dtype(roundtrip["string_column"].dtype) expected = pd.DataFrame({"string_column": [None for _ in range(rows)]}) - pd.testing.assert_frame_equal(expected, roundtrip) + pd.testing.assert_frame_equal(expected, roundtrip, check_dtype=False) def test_pandas_na(self, duckdb_cursor): - pd = pytest.importorskip("pandas", minversion="1.0.0", reason="Support for pandas.NA has not been added yet") # DataFrame containing a single pd.NA df = pd.DataFrame(pd.Series([pd.NA])) @@ -74,7 +73,9 @@ def test_pandas_na(self, duckdb_cursor): } ) assert str(nan_df["a"].dtype) == "float64" - assert str(na_df["a"].dtype) == "object" # pd.NA values turn the column into 'object' + # pd.NA values turn the column into 'object' in Pandas 2.x + # In Pandas 3.0+, it may be different but we just check it's not float64 + assert str(na_df["a"].dtype) != "float64" nan_result = duckdb_cursor.execute("select * from nan_df").df() na_result = duckdb_cursor.execute("select * from na_df").df() diff --git a/tests/fast/pandas/test_pandas_types.py b/tests/fast/pandas/test_pandas_types.py index 7510cb28..6335f2ee 100644 --- a/tests/fast/pandas/test_pandas_types.py +++ b/tests/fast/pandas/test_pandas_types.py @@ -56,7 +56,7 @@ def test_pandas_numeric(self): # c=type2 # .. data = {} - for letter, dtype in zip(string.ascii_lowercase, data_types): + for letter, dtype in zip(string.ascii_lowercase, data_types, strict=False): data[letter] = base_df.a.astype(dtype) df = pd.DataFrame.from_dict(data) # noqa: F841 @@ -65,7 +65,7 @@ def test_pandas_numeric(self): # Verify that the types in the out_df are correct # TODO: we don't support outputting pandas specific types (i.e UInt64) # noqa: TD002, TD003 - for letter, item in zip(string.ascii_lowercase, data_types): + for letter, item in zip(string.ascii_lowercase, data_types, strict=False): column_name = letter assert str(out_df[column_name].dtype) == item.lower() diff --git a/tests/fast/pandas/test_pandas_unregister.py b/tests/fast/pandas/test_pandas_unregister.py index ab83eb42..c89ae320 100644 --- a/tests/fast/pandas/test_pandas_unregister.py +++ b/tests/fast/pandas/test_pandas_unregister.py @@ -1,16 +1,15 @@ import gc import tempfile +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb class TestPandasUnregister: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_pandas_unregister1(self, duckdb_cursor, pandas): - df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) + def test_pandas_unregister1(self, duckdb_cursor): + df = pd.DataFrame([[1, 2, 3], [4, 5, 6]]) connection = duckdb.connect(":memory:") connection.register("dataframe", df) @@ -22,13 +21,12 @@ def test_pandas_unregister1(self, duckdb_cursor, pandas): connection.execute("DROP VIEW dataframe;") connection.execute("DROP VIEW IF EXISTS dataframe;") - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_pandas_unregister2(self, duckdb_cursor, pandas): + def test_pandas_unregister2(self, duckdb_cursor): with tempfile.NamedTemporaryFile() as tmp: db = tmp.name connection = duckdb.connect(db) - df = pandas.DataFrame([[1, 2, 3], [4, 5, 6]]) + df = pd.DataFrame([[1, 2, 3], [4, 5, 6]]) connection.register("dataframe", df) connection.unregister("dataframe") # Attempting to unregister. diff --git a/tests/fast/pandas/test_parallel_pandas_scan.py b/tests/fast/pandas/test_parallel_pandas_scan.py index 9ac7b738..7e04a933 100644 --- a/tests/fast/pandas/test_parallel_pandas_scan.py +++ b/tests/fast/pandas/test_parallel_pandas_scan.py @@ -2,13 +2,12 @@ import datetime import numpy -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb -def run_parallel_queries(main_table, left_join_table, expected_df, pandas, iteration_count=5): +def run_parallel_queries(main_table, left_join_table, expected_df, iteration_count=5): for _i in range(iteration_count): output_df = None sql = """ @@ -28,7 +27,7 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera duckdb_conn.register("main_table", main_table) duckdb_conn.register("left_join_table", left_join_table) output_df = duckdb_conn.execute(sql).fetchdf() - pandas.testing.assert_frame_equal(expected_df, output_df) + pd.testing.assert_frame_equal(expected_df, output_df, check_dtype=False) print(output_df) except Exception as err: print(err) @@ -37,67 +36,59 @@ def run_parallel_queries(main_table, left_join_table, expected_df, pandas, itera class TestParallelPandasScan: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_numeric_scan(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": 3}]) - left_join_table = pandas.DataFrame([{"join_column": 3, "other_column": 4}]) - run_parallel_queries(main_table, left_join_table, left_join_table, pandas) + def test_parallel_numeric_scan(self, duckdb_cursor): + main_table = pd.DataFrame([{"join_column": 3}]) + left_join_table = pd.DataFrame([{"join_column": 3, "other_column": 4}]) + run_parallel_queries(main_table, left_join_table, left_join_table) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_ascii_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": "text"}]) - left_join_table = pandas.DataFrame([{"join_column": "text", "other_column": "more text"}]) - run_parallel_queries(main_table, left_join_table, left_join_table, pandas) + def test_parallel_ascii_text(self, duckdb_cursor): + main_table = pd.DataFrame([{"join_column": "text"}]) + left_join_table = pd.DataFrame([{"join_column": "text", "other_column": "more text"}]) + run_parallel_queries(main_table, left_join_table, left_join_table) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": "mühleisen"}]) - left_join_table = pandas.DataFrame([{"join_column": "mühleisen", "other_column": "höhöhö"}]) - run_parallel_queries(main_table, left_join_table, left_join_table, pandas) + def test_parallel_unicode_text(self, duckdb_cursor): + main_table = pd.DataFrame([{"join_column": "mühleisen"}]) + left_join_table = pd.DataFrame([{"join_column": "mühleisen", "other_column": "höhöhö"}]) + run_parallel_queries(main_table, left_join_table, left_join_table) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_complex_unicode_text(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": "鴨"}]) - left_join_table = pandas.DataFrame([{"join_column": "鴨", "other_column": "數據庫"}]) - run_parallel_queries(main_table, left_join_table, left_join_table, pandas) + def test_parallel_complex_unicode_text(self, duckdb_cursor): + main_table = pd.DataFrame([{"join_column": "鴨"}]) + left_join_table = pd.DataFrame([{"join_column": "鴨", "other_column": "數據庫"}]) + run_parallel_queries(main_table, left_join_table, left_join_table) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_emojis(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) - left_join_table = pandas.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": "🦆🍞🦆"}]) - run_parallel_queries(main_table, left_join_table, left_join_table, pandas) + def test_parallel_emojis(self, duckdb_cursor): + main_table = pd.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️"}]) + left_join_table = pd.DataFrame([{"join_column": "🤦🏼‍♂️ L🤦🏼‍♂️R 🤦🏼‍♂️", "other_column": "🦆🍞🦆"}]) + run_parallel_queries(main_table, left_join_table, left_join_table) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_numeric_object(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({"join_column": pandas.Series([3], dtype="Int8")}) - left_join_table = pandas.DataFrame( - {"join_column": pandas.Series([3], dtype="Int8"), "other_column": pandas.Series([4], dtype="Int8")} + def test_parallel_numeric_object(self, duckdb_cursor): + main_table = pd.DataFrame({"join_column": pd.Series([3], dtype="Int8")}) + left_join_table = pd.DataFrame( + {"join_column": pd.Series([3], dtype="Int8"), "other_column": pd.Series([4], dtype="Int8")} ) - expected_df = pandas.DataFrame( + expected_df = pd.DataFrame( {"join_column": numpy.array([3], dtype=numpy.int8), "other_column": numpy.array([4], dtype=numpy.int8)} ) - run_parallel_queries(main_table, left_join_table, expected_df, pandas) + run_parallel_queries(main_table, left_join_table, expected_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_timestamp(self, duckdb_cursor, pandas): - main_table = pandas.DataFrame({"join_column": [pandas.Timestamp("20180310T11:17:54Z")]}) - left_join_table = pandas.DataFrame( + def test_parallel_timestamp(self, duckdb_cursor): + main_table = pd.DataFrame({"join_column": [pd.Timestamp("20180310T11:17:54Z")]}) + left_join_table = pd.DataFrame( { - "join_column": [pandas.Timestamp("20180310T11:17:54Z")], - "other_column": [pandas.Timestamp("20190310T11:17:54Z")], + "join_column": [pd.Timestamp("20180310T11:17:54Z")], + "other_column": [pd.Timestamp("20190310T11:17:54Z")], } ) - expected_df = pandas.DataFrame( + expected_df = pd.DataFrame( { "join_column": numpy.array([datetime.datetime(2018, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), "other_column": numpy.array([datetime.datetime(2019, 3, 10, 11, 17, 54)], dtype="datetime64[ns]"), } ) - run_parallel_queries(main_table, left_join_table, expected_df, pandas) + run_parallel_queries(main_table, left_join_table, expected_df) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_parallel_empty(self, duckdb_cursor, pandas): - df_empty = pandas.DataFrame({"A": []}) + def test_parallel_empty(self, duckdb_cursor): + df_empty = pd.DataFrame({"A": []}) duckdb_conn = duckdb.connect() duckdb_conn.execute("PRAGMA threads=4") duckdb_conn.execute("PRAGMA verify_parallelism") diff --git a/tests/fast/pandas/test_pyarrow_projection_pushdown.py b/tests/fast/pandas/test_pyarrow_projection_pushdown.py index 87f49f04..ca7bc905 100644 --- a/tests/fast/pandas/test_pyarrow_projection_pushdown.py +++ b/tests/fast/pandas/test_pyarrow_projection_pushdown.py @@ -1,5 +1,4 @@ import pytest -from conftest import pandas_supports_arrow_backend import duckdb @@ -8,7 +7,6 @@ _ = pytest.importorskip("pandas", "2.0.0") -@pytest.mark.skipif(not pandas_supports_arrow_backend(), reason="pandas does not support the 'pyarrow' backend") class TestArrowDFProjectionPushdown: def test_projection_pushdown_no_filter(self, duckdb_cursor): duckdb_conn = duckdb.connect() diff --git a/tests/fast/pandas/test_stride.py b/tests/fast/pandas/test_stride.py index cbe23cfd..65204ea8 100644 --- a/tests/fast/pandas/test_stride.py +++ b/tests/fast/pandas/test_stride.py @@ -57,7 +57,9 @@ def test_stride_timedelta(self, duckdb_cursor): ] } ) - pd.testing.assert_frame_equal(roundtrip, expected) + # DuckDB INTERVAL type stores in microseconds, so output is always timedelta64[us] + # Check values match without strict dtype comparison + pd.testing.assert_frame_equal(roundtrip, expected, check_dtype=False) def test_stride_fp64(self, duckdb_cursor): expected_df = pd.DataFrame(np.arange(20, dtype="float64").reshape(5, 4), columns=["a", "b", "c", "d"]) diff --git a/tests/fast/pandas/test_timestamp.py b/tests/fast/pandas/test_timestamp.py index 81651634..c6d080b8 100644 --- a/tests/fast/pandas/test_timestamp.py +++ b/tests/fast/pandas/test_timestamp.py @@ -65,7 +65,9 @@ def test_timestamp_timedelta(self): } ) df_from_duck = duckdb.from_df(df).df() - assert df_from_duck.equals(df) + # DuckDB INTERVAL type stores in microseconds, so output is always timedelta64[us] + # Check values match without strict dtype comparison + pd.testing.assert_frame_equal(df_from_duck, df, check_dtype=False) @pytest.mark.xfail( condition=platform.system() == "Emscripten" and os.environ.get("TZ") != "UTC", diff --git a/tests/fast/relational_api/test_rapi_aggregations.py b/tests/fast/relational_api/test_rapi_aggregations.py index ffb7e303..9253d541 100644 --- a/tests/fast/relational_api/test_rapi_aggregations.py +++ b/tests/fast/relational_api/test_rapi_aggregations.py @@ -35,33 +35,33 @@ def test_any_value(self, table): result = table.order("id, t").any_value("v").execute().fetchall() expected = [(1,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( table.order("id, t").any_value("v", groups="id", projected_columns="id").order("id").execute().fetchall() ) expected = [(1, 1), (2, 11), (3, 5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_arg_max(self, table): result = table.arg_max("t", "v").execute().fetchall() expected = [(-1,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.arg_max("t", "v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 3), (2, -1), (3, -2)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_arg_min(self, table): result = table.arg_min("t", "v").execute().fetchall() expected = [(0,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.arg_min("t", "v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 4), (3, 0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_avg(self, table): result = table.avg("v").execute().fetchall() @@ -78,41 +78,41 @@ def test_bit_and(self, table): result = table.bit_and("v").execute().fetchall() expected = [(0,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.bit_and("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 0), (2, 10), (3, 5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bit_or(self, table): result = table.bit_or("v").execute().fetchall() expected = [(-1,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.bit_or("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 3), (2, 11), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bit_xor(self, table): result = table.bit_xor("v").execute().fetchall() expected = [(-7,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.bit_xor("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 1), (3, -6)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bitstring_agg(self, table): result = table.bitstring_agg("v").execute().fetchall() expected = [("1011001000011",)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.bitstring_agg("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, "0011000000000"), (2, "0000000000011"), (3, "1000001000000")] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) with pytest.raises(duckdb.InvalidInputException): table.bitstring_agg("v", min="1") with pytest.raises(duckdb.InvalidTypeException): @@ -122,156 +122,156 @@ def test_bool_and(self, table): result = table.bool_and("v::BOOL").execute().fetchall() expected = [(True,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.bool_and("t::BOOL", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, True), (2, True), (3, False)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bool_or(self, table): result = table.bool_or("v::BOOL").execute().fetchall() expected = [(True,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.bool_or("v::BOOL", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, True), (2, True), (3, True)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_count(self, table): result = table.count("*").execute().fetchall() expected = [(8,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.count("*", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 3), (2, 2), (3, 3)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_value_counts(self, table): result = table.value_counts("v").execute().fetchall() expected = [(None, 0), (-1, 1), (1, 2), (2, 1), (5, 1), (10, 1), (11, 1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.value_counts("v", groups="v").order("v").execute().fetchall() expected = [(-1, 1), (1, 2), (2, 1), (5, 1), (10, 1), (11, 1), (None, 0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_favg(self, table): result = [round(r[0], 2) for r in table.favg("f").execute().fetchall()] expected = [5.12] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], round(r[1], 2)) for r in table.favg("f", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.25), (2, 5.24), (3, 9.92)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_first(self, table): result = table.first("v").execute().fetchall() expected = [(1,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.first("v", "id", "id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_last(self, table): result = table.last("v").execute().fetchall() expected = [(None,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.last("v", "id", "id").order("id").execute().fetchall() expected = [(1, 2), (2, 11), (3, None)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_fsum(self, table): result = [round(r[0], 2) for r in table.fsum("f").execute().fetchall()] expected = [40.99] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], round(r[1], 2)) for r in table.fsum("f", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.75), (2, 10.49), (3, 29.75)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_geomean(self, table): result = [round(r[0], 2) for r in table.geomean("f").execute().fetchall()] expected = [0.67] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], round(r[1], 2)) for r in table.geomean("f", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.05), (2, 0.65), (3, 9.52)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_histogram(self, table): result = table.histogram("v").execute().fetchall() expected = [({-1: 1, 1: 2, 2: 1, 5: 1, 10: 1, 11: 1},)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.histogram("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, {1: 2, 2: 1}), (2, {10: 1, 11: 1}), (3, {-1: 1, 5: 1})] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_list(self, table): result = table.list("v").execute().fetchall() expected = [([1, 1, 2, 10, 11, -1, 5, None],)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.list("v", groups="id order by t asc", projected_columns="id").order("id").execute().fetchall() expected = [(1, [1, 1, 2]), (2, [10, 11]), (3, [-1, 5, None])] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_max(self, table): result = table.max("v").execute().fetchall() expected = [(11,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.max("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 11), (3, 5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_min(self, table): result = table.min("v").execute().fetchall() expected = [(-1,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.min("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_product(self, table): result = table.product("v").execute().fetchall() expected = [(-1100,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.product("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 2), (2, 110), (3, -5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_string_agg(self, table): result = table.string_agg("s", sep="/").execute().fetchall() expected = [("h/e/l/l/o/,/wor/ld",)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( table.string_agg("s", sep="/", groups="id order by t asc", projected_columns="id") .order("id") @@ -280,17 +280,17 @@ def test_string_agg(self, table): ) expected = [(1, "h/e/l"), (2, "l/o"), (3, ",/wor/ld")] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_sum(self, table): result = table.sum("v").execute().fetchall() expected = [(29,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.sum("v", groups="id", projected_columns="id").execute().fetchall() expected = [(1, 4), (2, 21), (3, 4)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) # TODO: Approximate aggregate functions # noqa: TD002, TD003 @@ -299,35 +299,35 @@ def test_median(self, table): result = table.median("v").execute().fetchall() expected = [(2.0,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.median("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1.0), (2, 10.5), (3, 2.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_mode(self, table): result = table.mode("v").execute().fetchall() expected = [(1,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.mode("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_quantile_cont(self, table): result = table.quantile_cont("v").execute().fetchall() expected = [(2.0,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [[round(x, 2) for x in r[0]] for r in table.quantile_cont("v", q=[0.1, 0.5]).execute().fetchall()] expected = [[0.2, 2.0]] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = table.quantile_cont("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1.0), (2, 10.5), (3, 2.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], [round(x, 2) for x in r[1]]) for r in table.quantile_cont("v", q=[0.2, 0.5], groups="id", projected_columns="id") @@ -337,82 +337,225 @@ def test_quantile_cont(self, table): ] expected = [(1, [1.0, 1.0]), (2, [10.2, 10.5]), (3, [0.2, 2.0])] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.parametrize("f", ["quantile_disc", "quantile"]) def test_quantile_disc(self, table, f): result = getattr(table, f)("v").execute().fetchall() expected = [(2,)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = getattr(table, f)("v", q=[0.2, 0.5]).execute().fetchall() expected = [([1, 2],)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = getattr(table, f)("v", groups="id", projected_columns="id").order("id").execute().fetchall() expected = [(1, 1), (2, 10), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( getattr(table, f)("v", q=[0.2, 0.8], groups="id", projected_columns="id").order("id").execute().fetchall() ) expected = [(1, [1, 2]), (2, [10, 11]), (3, [-1, 5])] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_std_pop(self, table): result = [round(r[0], 2) for r in table.stddev_pop("v").execute().fetchall()] expected = [4.36] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], round(r[1], 2)) for r in table.stddev_pop("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.47), (2, 0.5), (3, 3.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.parametrize("f", ["stddev_samp", "stddev", "std"]) def test_std_samp(self, table, f): result = [round(r[0], 2) for r in getattr(table, f)("v").execute().fetchall()] expected = [4.71] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], round(r[1], 2)) for r in getattr(table, f)("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.58), (2, 0.71), (3, 4.24)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_var_pop(self, table): result = [round(r[0], 2) for r in table.var_pop("v").execute().fetchall()] expected = [18.98] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], round(r[1], 2)) for r in table.var_pop("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.22), (2, 0.25), (3, 9.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.parametrize("f", ["var_samp", "variance", "var"]) def test_var_samp(self, table, f): result = [round(r[0], 2) for r in getattr(table, f)("v").execute().fetchall()] expected = [22.14] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], round(r[1], 2)) for r in getattr(table, f)("v", groups="id", projected_columns="id").order("id").execute().fetchall() ] expected = [(1, 0.33), (2, 0.5), (3, 18.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_describe(self, table): assert table.describe().fetchall() is not None + + +class TestRAPIAggregationsColumnEscaping: + """Test that aggregate functions properly escape column names that need quoting.""" + + def test_reserved_keyword_column_name(self, duckdb_cursor): + # Column name "select" is a reserved SQL keyword + rel = duckdb_cursor.sql('select 1 as "select", 2 as "order"') + result = rel.sum("select").fetchall() + assert result == [(1,)] + + result = rel.avg("order").fetchall() + assert result == [(2.0,)] + + def test_column_name_with_space(self, duckdb_cursor): + rel = duckdb_cursor.sql('select 10 as "my column"') + result = rel.sum("my column").fetchall() + assert result == [(10,)] + + def test_column_name_with_quotes(self, duckdb_cursor): + # Column name containing a double quote + rel = duckdb_cursor.sql('select 5 as "col""name"') + result = rel.sum('col"name').fetchall() + assert result == [(5,)] + + def test_qualified_column_name(self, duckdb_cursor): + # Qualified column name like table.column + rel = duckdb_cursor.sql("select 42 as value") + # When using qualified names, they should be properly escaped + result = rel.sum("value").fetchall() + assert result == [(42,)] + + +class TestRAPIAggregationsExpressionPassthrough: + """Test that aggregate functions correctly pass through SQL expressions without escaping.""" + + def test_cast_expression(self, duckdb_cursor): + # Cast expressions should pass through without being quoted + rel = duckdb_cursor.sql("select 1 as v, 0 as f") + result = rel.bool_and("v::BOOL").fetchall() + assert result == [(True,)] + + result = rel.bool_or("f::BOOL").fetchall() + assert result == [(False,)] + + def test_star_expression(self, duckdb_cursor): + # Star (*) should pass through for count + rel = duckdb_cursor.sql("select 1 as a union all select 2") + result = rel.count("*").fetchall() + assert result == [(2,)] + + def test_arithmetic_expression(self, duckdb_cursor): + # Arithmetic expressions should pass through + rel = duckdb_cursor.sql("select 10 as a, 5 as b") + result = rel.sum("a + b").fetchall() + assert result == [(15,)] + + def test_function_expression(self, duckdb_cursor): + # Function calls should pass through + rel = duckdb_cursor.sql("select -5 as v") + result = rel.sum("abs(v)").fetchall() + assert result == [(5,)] + + def test_case_expression(self, duckdb_cursor): + # CASE expressions should pass through + rel = duckdb_cursor.sql("select 1 as v union all select 2 union all select 3") + result = rel.sum("case when v > 1 then v else 0 end").fetchall() + assert result == [(5,)] + + +class TestRAPIAggregationsWithInvalidInput: + """Test that only expression can be used.""" + + def test_injection_with_semicolon_is_neutralized(self, duckdb_cursor): + # Semicolon injection fails to parse as expression, gets quoted as identifier + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.BinderException, match="not found in FROM clause"): + rel.sum("v; drop table agg; --").fetchall() + + def test_injection_with_union_is_neutralized(self, duckdb_cursor): + # UNION fails to parse as single expression, gets quoted + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.BinderException, match="not found in FROM clause"): + rel.sum("v union select * from agg").fetchall() + + def test_subquery_is_contained(self, duckdb_cursor): + # Subqueries are valid expressions - they're contained within the aggregate + # and cannot break out of the expression context + rel = duckdb_cursor.sql("select 1 as v") + # This executes sum((select 1)) = sum(1) = 1 - contained, not an injection + result = rel.sum("(select 1)").fetchall() + assert result == [(1,)] + + def test_injection_closing_paren_is_neutralized(self, duckdb_cursor): + # Adding a closing paren fails to parse, gets quoted + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.BinderException, match="not found in FROM clause"): + rel.sum("v) from agg; drop table agg; --").fetchall() + + def test_comment_is_harmless(self, duckdb_cursor): + # SQL comments are stripped during parsing, so "v -- comment" parses as just "v" + rel = duckdb_cursor.sql("select 1 as v") + result = rel.sum("v -- this is ignored").fetchall() + assert result == [(1,)] + + def test_empty_expression_rejected(self, duckdb_cursor): + # Empty or whitespace-only expressions should be rejected + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.ParserException): + rel.sum("").fetchall() + + def test_whitespace_only_expression_rejected(self, duckdb_cursor): + # Whitespace-only expressions should be rejected + rel = duckdb_cursor.sql("select 1 as v") + with pytest.raises(duckdb.ParserException): + rel.sum(" ").fetchall() + + +class TestRAPIStringAggSeparatorEscaping: + """Test that string_agg separator is properly escaped as a string literal.""" + + def test_simple_separator(self, duckdb_cursor): + rel = duckdb_cursor.sql("select 'a' as s union all select 'b' union all select 'c'") + result = rel.string_agg("s", ",").fetchall() + assert result == [("a,b,c",)] + + def test_separator_with_single_quote(self, duckdb_cursor): + # Separator containing a single quote should be properly escaped + rel = duckdb_cursor.sql("select 'a' as s union all select 'b'") + result = rel.string_agg("s", "','").fetchall() + assert result == [("a','b",)] + + def test_separator_with_special_chars(self, duckdb_cursor): + rel = duckdb_cursor.sql("select 'x' as s union all select 'y'") + result = rel.string_agg("s", " | ").fetchall() + assert result == [("x | y",)] + + def test_separator_injection_attempt(self, duckdb_cursor): + # Attempt to inject via separator - should be safely quoted as string literal + rel = duckdb_cursor.sql("select 'a' as s union all select 'b'") + # This should NOT execute the injection - separator becomes a literal string + result = rel.string_agg("s", "'); drop table agg; --").fetchall() + assert result == [("a'); drop table agg; --b",)] diff --git a/tests/fast/relational_api/test_rapi_close.py b/tests/fast/relational_api/test_rapi_close.py index 969e2792..f8233aa2 100644 --- a/tests/fast/relational_api/test_rapi_close.py +++ b/tests/fast/relational_api/test_rapi_close.py @@ -1,3 +1,5 @@ +from decimal import Decimal + import pytest import duckdb @@ -27,7 +29,7 @@ def test_close_conn_rel(self, duckdb_cursor): with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.arg_min("", "") with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): - rel.fetch_arrow_table() + rel.to_arrow_table() with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): rel.avg("") with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): @@ -182,5 +184,6 @@ def test_del_conn(self, duckdb_cursor): con.execute("INSERT INTO items VALUES ('jeans', 20.0, 1), ('hammer', 42.2, 2)") rel = con.table("items") del con - with pytest.raises(duckdb.ConnectionException, match="Connection has already been closed"): - print(rel) + # Relation keeps the connection alive via connection_owner + res = rel.fetchall() + assert res == [("jeans", Decimal("20.00"), 1), ("hammer", Decimal("42.20"), 2)] diff --git a/tests/fast/relational_api/test_rapi_windows.py b/tests/fast/relational_api/test_rapi_windows.py index 28d533b7..093829a8 100644 --- a/tests/fast/relational_api/test_rapi_windows.py +++ b/tests/fast/relational_api/test_rapi_windows.py @@ -34,7 +34,7 @@ def test_row_number(self, table): result = table.row_number("over ()").execute().fetchall() expected = list(range(1, 9)) assert len(result) == len(expected) - assert all(r[0] == e for r, e in zip(result, expected)) + assert all(r[0] == e for r, e in zip(result, expected, strict=False)) result = table.row_number("over (partition by id order by t asc)", "id, v, t").order("id").execute().fetchall() expected = [ (1, 1, 1, 1), @@ -47,34 +47,34 @@ def test_row_number(self, table): (3, None, 10, 3), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_rank(self, table): result = table.rank("over ()").execute().fetchall() expected = [1] * 8 assert len(result) == len(expected) - assert all(r[0] == e for r, e in zip(result, expected)) + assert all(r[0] == e for r, e in zip(result, expected, strict=False)) result = table.rank("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [(1, 1, 1), (1, 1, 1), (1, 2, 3), (2, 10, 1), (2, 11, 2), (3, -1, 1), (3, 5, 2), (3, None, 3)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.parametrize("f", ["dense_rank", "rank_dense"]) def test_dense_rank(self, table, f): result = getattr(table, f)("over ()").execute().fetchall() expected = [1] * 8 assert len(result) == len(expected) - assert all(r[0] == e for r, e in zip(result, expected)) + assert all(r[0] == e for r, e in zip(result, expected, strict=False)) result = getattr(table, f)("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [(1, 1, 1), (1, 1, 1), (1, 2, 2), (2, 10, 1), (2, 11, 2), (3, -1, 1), (3, 5, 2), (3, None, 3)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_percent_rank(self, table): result = table.percent_rank("over ()").execute().fetchall() expected = [0.0] * 8 assert len(result) == len(expected) - assert all(r[0] == e for r, e in zip(result, expected)) + assert all(r[0] == e for r, e in zip(result, expected, strict=False)) result = table.percent_rank("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [ (1, 1, 0.0), @@ -87,13 +87,13 @@ def test_percent_rank(self, table): (3, None, 1.0), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_cume_dist(self, table): result = table.cume_dist("over ()").execute().fetchall() expected = [1.0] * 8 assert len(result) == len(expected) - assert all(r[0] == e for r, e in zip(result, expected)) + assert all(r[0] == e for r, e in zip(result, expected, strict=False)) result = table.cume_dist("over (partition by id order by v asc)", "id, v").order("id").execute().fetchall() expected = [ (1, 1, 2 / 3), @@ -106,13 +106,13 @@ def test_cume_dist(self, table): (3, None, 1.0), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_ntile(self, table): result = table.n_tile("over (order by v)", 3, "v").execute().fetchall() expected = [(-1, 1), (1, 1), (1, 1), (2, 2), (5, 2), (10, 2), (11, 3), (None, 3)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_lag(self, table): result = ( @@ -132,7 +132,7 @@ def test_lag(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( table.lag("v", "over (partition by id order by t asc)", default_value="-1", projected_columns="id, v, t") .order("id") @@ -150,7 +150,7 @@ def test_lag(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( table.lag("v", "over (partition by id order by t asc)", offset=2, projected_columns="id, v, t") .order("id") @@ -168,7 +168,7 @@ def test_lag(self, table): (3, None, 10, 5), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_lead(self, table): result = ( @@ -188,7 +188,7 @@ def test_lead(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( table.lead("v", "over (partition by id order by t asc)", default_value="-1", projected_columns="id, v, t") .order("id") @@ -206,7 +206,7 @@ def test_lead(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( table.lead("v", "over (partition by id order by t asc)", offset=2, projected_columns="id, v, t") .order("id") @@ -224,7 +224,7 @@ def test_lead(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_first_value(self, table): result = ( @@ -244,7 +244,7 @@ def test_first_value(self, table): (3, None, 10, 5), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_last_value(self, table): result = ( @@ -268,7 +268,7 @@ def test_last_value(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_nth_value(self, table): result = ( @@ -288,7 +288,7 @@ def test_nth_value(self, table): (3, None, 10, -1), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( table.nth_value("v", "over (partition by id order by t asc)", offset=4, projected_columns="id, v, t") .order("id") @@ -306,7 +306,7 @@ def test_nth_value(self, table): (3, None, 10, None), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) # agg functions within win def test_any_value(self, table): @@ -318,7 +318,7 @@ def test_any_value(self, table): ) expected = [(1, 1), (1, 1), (1, 1), (2, 11), (2, 11), (3, 5), (3, 5), (3, 5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_arg_max(self, table): result = ( @@ -329,7 +329,7 @@ def test_arg_max(self, table): ) expected = [(1, 3), (1, 3), (1, 3), (2, -1), (2, -1), (3, -2), (3, -2), (3, -2)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_arg_min(self, table): result = ( @@ -340,7 +340,7 @@ def test_arg_min(self, table): ) expected = [(1, 2), (1, 2), (1, 2), (2, 4), (2, 4), (3, 0), (3, 0), (3, 0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_avg(self, table): result = [ @@ -359,7 +359,7 @@ def test_avg(self, table): ] expected = [(1, 1.0), (1, 1.0), (1, 1.33), (2, 11.0), (2, 10.5), (3, 5.0), (3, 2.0), (3, 2.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bit_and(self, table): result = ( @@ -374,7 +374,7 @@ def test_bit_and(self, table): ) expected = [(1, 1), (1, 1), (1, 0), (2, 11), (2, 10), (3, 5), (3, 5), (3, 5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bit_or(self, table): result = ( @@ -389,7 +389,7 @@ def test_bit_or(self, table): ) expected = [(1, 1), (1, 1), (1, 3), (2, 11), (2, 11), (3, 5), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bit_xor(self, table): result = ( @@ -404,7 +404,7 @@ def test_bit_xor(self, table): ) expected = [(1, 1), (1, 0), (1, 2), (2, 11), (2, 1), (3, 5), (3, -6), (3, -6)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bitstring_agg(self, table): with pytest.raises(duckdb.BinderException, match="Could not retrieve required statistics"): @@ -436,7 +436,7 @@ def test_bitstring_agg(self, table): (3, "1000001000000"), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bool_and(self, table): result = ( @@ -447,7 +447,7 @@ def test_bool_and(self, table): ) expected = [(1, True), (1, True), (1, True), (2, True), (2, True), (3, False), (3, False), (3, False)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_bool_or(self, table): result = ( @@ -458,7 +458,7 @@ def test_bool_or(self, table): ) expected = [(1, True), (1, True), (1, True), (2, True), (2, True), (3, True), (3, True), (3, True)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_count(self, table): result = ( @@ -473,7 +473,7 @@ def test_count(self, table): ) expected = [(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (3, 1), (3, 2), (3, 3)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_favg(self, table): result = [ @@ -489,7 +489,7 @@ def test_favg(self, table): ] expected = [(1, 0.21), (1, 0.38), (1, 0.25), (2, 10.45), (2, 5.24), (3, 9.87), (3, 11.59), (3, 9.92)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_fsum(self, table): result = [ @@ -505,7 +505,7 @@ def test_fsum(self, table): ] expected = [(1, 0.21), (1, 0.75), (1, 0.75), (2, 10.45), (2, 10.49), (3, 9.87), (3, 23.19), (3, 29.75)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.skip(reason="geomean is not supported from a windowing context") def test_geomean(self, table): @@ -533,7 +533,7 @@ def test_histogram(self, table): (3, {-1: 1, 5: 1}), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_list(self, table): result = ( @@ -557,7 +557,7 @@ def test_list(self, table): (3, [5, -1, None]), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_max(self, table): result = ( @@ -572,7 +572,7 @@ def test_max(self, table): ) expected = [(1, 1), (1, 1), (1, 2), (2, 11), (2, 11), (3, 5), (3, 5), (3, 5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_min(self, table): result = ( @@ -587,7 +587,7 @@ def test_min(self, table): ) expected = [(1, 1), (1, 1), (1, 1), (2, 11), (2, 10), (3, 5), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_product(self, table): result = ( @@ -602,7 +602,7 @@ def test_product(self, table): ) expected = [(1, 1), (1, 1), (1, 2), (2, 11), (2, 110), (3, 5), (3, -5), (3, -5)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_string_agg(self, table): result = ( @@ -618,7 +618,7 @@ def test_string_agg(self, table): ) expected = [(1, "e"), (1, "e/h"), (1, "e/h/l"), (2, "o"), (2, "o/l"), (3, "wor"), (3, "wor/,"), (3, "wor/,/ld")] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_sum(self, table): result = ( @@ -633,7 +633,7 @@ def test_sum(self, table): ) expected = [(1, 1), (1, 2), (1, 4), (2, 11), (2, 21), (3, 5), (3, 4), (3, 4)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_median(self, table): result = ( @@ -648,7 +648,7 @@ def test_median(self, table): ) expected = [(1, 1.0), (1, 1.0), (1, 1.0), (2, 11.0), (2, 10.5), (3, 5.0), (3, 2.0), (3, 2.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_mode(self, table): result = ( @@ -663,7 +663,7 @@ def test_mode(self, table): ) expected = [(1, 2), (1, 2), (1, 1), (2, 10), (2, 10), (3, None), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_quantile_cont(self, table): result = ( @@ -678,7 +678,7 @@ def test_quantile_cont(self, table): ) expected = [(1, 2.0), (1, 1.5), (1, 1.0), (2, 10.0), (2, 10.5), (3, None), (3, -1.0), (3, 2.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = [ (r[0], [round(x, 2) for x in r[1]] if r[1] is not None else None) for r in table.quantile_cont( @@ -702,7 +702,7 @@ def test_quantile_cont(self, table): (3, [0.2, 2.0]), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.parametrize("f", ["quantile_disc", "quantile"]) def test_quantile_disc(self, table, f): @@ -718,7 +718,7 @@ def test_quantile_disc(self, table, f): ) expected = [(1, 2), (1, 1), (1, 1), (2, 10), (2, 10), (3, None), (3, -1), (3, -1)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) result = ( getattr(table, f)( "v", @@ -741,7 +741,7 @@ def test_quantile_disc(self, table, f): (3, [-1, 5]), ] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_stddev_pop(self, table): result = [ @@ -757,7 +757,7 @@ def test_stddev_pop(self, table): ] expected = [(1, 0.0), (1, 0.5), (1, 0.47), (2, 0.0), (2, 0.5), (3, None), (3, 0.0), (3, 3.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.parametrize("f", ["stddev_samp", "stddev", "std"]) def test_stddev_samp(self, table, f): @@ -774,7 +774,7 @@ def test_stddev_samp(self, table, f): ] expected = [(1, None), (1, 0.71), (1, 0.58), (2, None), (2, 0.71), (3, None), (3, None), (3, 4.24)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) def test_var_pop(self, table): result = [ @@ -790,7 +790,7 @@ def test_var_pop(self, table): ] expected = [(1, 0.0), (1, 0.25), (1, 0.22), (2, 0.0), (2, 0.25), (3, None), (3, 0.0), (3, 9.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) @pytest.mark.parametrize("f", ["var_samp", "variance", "var"]) def test_var_samp(self, table, f): @@ -807,4 +807,4 @@ def test_var_samp(self, table, f): ] expected = [(1, None), (1, 0.5), (1, 0.33), (2, None), (2, 0.5), (3, None), (3, None), (3, 18.0)] assert len(result) == len(expected) - assert all(r == e for r, e in zip(result, expected)) + assert all(r == e for r, e in zip(result, expected, strict=False)) diff --git a/tests/fast/spark/test_spark_functions_numeric.py b/tests/fast/spark/test_spark_functions_numeric.py index 98966548..ef24c676 100644 --- a/tests/fast/spark/test_spark_functions_numeric.py +++ b/tests/fast/spark/test_spark_functions_numeric.py @@ -294,7 +294,7 @@ def test_corr(self, spark): a = range(N) b = [2 * x for x in range(N)] # Have to use a groupby to test this as agg is not yet implemented without - df = spark.createDataFrame(zip(a, b, ["group1"] * N), ["a", "b", "g"]) + df = spark.createDataFrame(zip(a, b, ["group1"] * N, strict=False), ["a", "b", "g"]) res = df.groupBy("g").agg(sf.corr("a", "b").alias("c")).collect() assert pytest.approx(res[0].c) == 1 diff --git a/tests/fast/spark/test_spark_to_csv.py b/tests/fast/spark/test_spark_to_csv.py index 10e0028c..5003a20b 100644 --- a/tests/fast/spark/test_spark_to_csv.py +++ b/tests/fast/spark/test_spark_to_csv.py @@ -2,8 +2,9 @@ import datetime import os +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas, getTimeSeriesData +from conftest import getTimeSeriesData from spark_namespace import USE_ACTUAL_SPARK from duckdb import InvalidInputException, read_csv @@ -33,17 +34,15 @@ def df(spark): return dataframe -@pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) -def pandas_df_ints(request, spark): - pandas = request.param - dataframe = pandas.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) +@pytest.fixture +def pandas_df_ints(spark): + dataframe = pd.DataFrame({"a": [5, 3, 23, 2], "b": [45, 234, 234, 2]}) return dataframe -@pytest.fixture(params=[NumpyPandas(), ArrowPandas()]) -def pandas_df_strings(request, spark): - pandas = request.param - dataframe = pandas.DataFrame({"a": ["string1", "string2", "string3"]}) +@pytest.fixture +def pandas_df_strings(spark): + dataframe = pd.DataFrame({"a": ["string1", "string2", "string3"]}) return dataframe @@ -69,10 +68,9 @@ def test_to_csv_sep(self, pandas_df_ints, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name, sep=",") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_na_rep(self, pandas, spark, tmp_path): + def test_to_csv_na_rep(self, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 - pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) + pandas_df = pd.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -81,10 +79,9 @@ def test_to_csv_na_rep(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name, nullValue="test") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_header(self, pandas, spark, tmp_path): + def test_to_csv_header(self, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 - pandas_df = pandas.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) + pandas_df = pd.DataFrame({"a": [5, None, 23, 2], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -93,11 +90,10 @@ def test_to_csv_header(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name) assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_quotechar(self, pandas, spark, tmp_path): + def test_to_csv_quotechar(self, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 - pandas_df = pandas.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) + pandas_df = pd.DataFrame({"a": ["'a,b,c'", None, "hello", "bye"], "b": [45, 234, 234, 2]}) df = spark.createDataFrame(pandas_df) @@ -106,10 +102,9 @@ def test_to_csv_quotechar(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name, sep=",", quote="'") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_escapechar(self, pandas, spark, tmp_path): + def test_to_csv_escapechar(self, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 - pandas_df = pandas.DataFrame( + pandas_df = pd.DataFrame( { "c_bool": [True, False], "c_float": [1.0, 3.2], @@ -124,12 +119,11 @@ def test_to_csv_escapechar(self, pandas, spark, tmp_path): csv_rel = spark.read.csv(temp_file_name, quote='"', escape="!") assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_date_format(self, pandas, spark, tmp_path): + def test_to_csv_date_format(self, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 - pandas_df = pandas.DataFrame(getTimeSeriesData()) + pandas_df = pd.DataFrame(getTimeSeriesData()) dt_index = pandas_df.index - pandas_df = pandas.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) + pandas_df = pd.DataFrame({"A": dt_index, "B": dt_index.shift(1)}, index=dt_index) df = spark.createDataFrame(pandas_df) @@ -139,11 +133,10 @@ def test_to_csv_date_format(self, pandas, spark, tmp_path): assert df.collect() == csv_rel.collect() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_to_csv_timestamp_format(self, pandas, spark, tmp_path): + def test_to_csv_timestamp_format(self, spark, tmp_path): temp_file_name = os.path.join(tmp_path, "temp_file.csv") # noqa: PTH118 data = [datetime.time(hour=23, minute=1, second=34, microsecond=234345)] - pandas_df = pandas.DataFrame({"0": pandas.Series(data=data, dtype="object")}) + pandas_df = pd.DataFrame({"0": pd.Series(data=data, dtype="object")}) df = spark.createDataFrame(pandas_df) diff --git a/tests/fast/spark/test_spark_types.py b/tests/fast/spark/test_spark_types.py index 3950ea4d..3f8f1cfd 100644 --- a/tests/fast/spark/test_spark_types.py +++ b/tests/fast/spark/test_spark_types.py @@ -28,6 +28,7 @@ StringType, StructField, StructType, + TimeNSType, TimeNTZType, TimestampMillisecondNTZType, TimestampNanosecondNTZType, @@ -129,5 +130,6 @@ def test_all_types_schema(self, spark): True, ), StructField("map", MapType(StringType(), StringType(), True), True), + StructField("time_ns", TimeNSType(), True), ] ) diff --git a/tests/fast/spark/test_spark_union_by_name.py b/tests/fast/spark/test_spark_union_by_name.py index bec539a2..2e4e998a 100644 --- a/tests/fast/spark/test_spark_union_by_name.py +++ b/tests/fast/spark/test_spark_union_by_name.py @@ -52,3 +52,18 @@ def test_union_by_name_allow_missing_cols(self, df1, df2): Row(name="Jeff", id=None), ] assert res == expected + + def test_union_by_name_allow_missing_cols_rev(self, df1, df2): + rel = df2.drop("id").unionByName(df1, allowMissingColumns=True) + res = rel.collect() + expected = [ + Row(name="James", id=None), + Row(name="Maria", id=None), + Row(name="Jen", id=None), + Row(name="Jeff", id=None), + Row(name="James", id=34), + Row(name="Michael", id=56), + Row(name="Robert", id=30), + Row(name="Maria", id=24), + ] + assert res == expected diff --git a/tests/fast/test_all_types.py b/tests/fast/test_all_types.py index c4ba0e55..07dc5f70 100644 --- a/tests/fast/test_all_types.py +++ b/tests/fast/test_all_types.py @@ -551,16 +551,16 @@ def test_arrow(self, cur_type): conn = duckdb.connect() if cur_type in replacement_values: - arrow_table = conn.execute("select " + replacement_values[cur_type]).fetch_arrow_table() + arrow_table = conn.execute("select " + replacement_values[cur_type]).to_arrow_table() else: - arrow_table = conn.execute(f'select "{cur_type}" from test_all_types()').fetch_arrow_table() + arrow_table = conn.execute(f'select "{cur_type}" from test_all_types()').to_arrow_table() if cur_type in enum_types: - round_trip_arrow_table = conn.execute("select * from arrow_table").fetch_arrow_table() + round_trip_arrow_table = conn.execute("select * from arrow_table").to_arrow_table() result_arrow = conn.execute("select * from arrow_table").fetchall() result_roundtrip = conn.execute("select * from round_trip_arrow_table").fetchall() assert recursive_equality(result_arrow, result_roundtrip) else: - round_trip_arrow_table = conn.execute("select * from arrow_table").fetch_arrow_table() + round_trip_arrow_table = conn.execute("select * from arrow_table").to_arrow_table() assert arrow_table.equals(round_trip_arrow_table, check_metadata=True) @pytest.mark.parametrize("cur_type", all_types) diff --git a/tests/fast/test_case_alias.py b/tests/fast/test_case_alias.py index d1afb4d8..f99b994e 100644 --- a/tests/fast/test_case_alias.py +++ b/tests/fast/test_case_alias.py @@ -1,15 +1,13 @@ -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb class TestCaseAlias: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_case_alias(self, duckdb_cursor, pandas): + def test_case_alias(self, duckdb_cursor): con = duckdb.connect(":memory:") - df = pandas.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) + df = pd.DataFrame([{"COL1": "val1", "CoL2": 1.05}, {"COL1": "val3", "CoL2": 17}]) r1 = con.from_df(df).query("df", "select * from df").df() assert r1["COL1"][0] == "val1" diff --git a/tests/fast/test_expression_implicit_conversion.py b/tests/fast/test_expression_implicit_conversion.py new file mode 100644 index 00000000..d1da498b --- /dev/null +++ b/tests/fast/test_expression_implicit_conversion.py @@ -0,0 +1,282 @@ +"""Tests that all types in _ExpressionLike are accepted by the implicit conversion path. + +pybind11 registers implicit conversions so that any C++ method taking +``const DuckDBPyExpression &`` silently accepts Python scalars. The stubs +declare these as ``_ExpressionLike``. This file verifies every type in that +union actually works at runtime. + +Key semantics: +- ``str`` always becomes a **ColumnExpression** (column reference), never a string constant. +- ``bytes`` is decoded as UTF-8 by pybind11 and also becomes a ColumnExpression. +- All other types become **ConstantExpression** via ``TransformPythonValue``. +""" + +import datetime +import decimal +import platform +import uuid + +import pytest + +import duckdb +from duckdb import ( + CaseExpression, + CoalesceOperator, + ColumnExpression, + FunctionExpression, +) + +pytestmark = pytest.mark.skipif( + platform.system() == "Emscripten", + reason="Extensions are not supported on Emscripten", +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def rel(): + """A one-row relation with columns of various types.""" + con = duckdb.connect() + r = con.sql( + """ + SELECT + 42 AS i, + 3.14 AS f, + 'hello' AS s, + TRUE AS b, + DATE '2024-01-15' AS d, + TIMESTAMP '2024-01-15 10:30:00' AS ts, + TIME '10:30:00' AS t, + INTERVAL 5 DAY AS iv, + 1.23::DECIMAL(18,2) AS dec, + 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::UUID AS u + """ + ) + yield r + con.close() + + +# --------------------------------------------------------------------------- +# Constant types: these become ConstantExpression via TransformPythonValue. +# Each entry maps (value, compatible_column) so the == comparison is valid. +# --------------------------------------------------------------------------- + +CONSTANT_VALUES = { + "int": (42, "i"), + "float": (3.14, "f"), + "bool": (True, "b"), + "None": (None, "i"), # NULL compares with anything + "date": (datetime.date(2024, 1, 15), "d"), + "datetime": (datetime.datetime(2024, 1, 15, 10, 30), "ts"), + "time": (datetime.time(10, 30), "t"), + "timedelta": (datetime.timedelta(days=5), "iv"), + "Decimal": (decimal.Decimal("1.23"), "dec"), + "UUID": (uuid.UUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), "u"), +} + + +# --------------------------------------------------------------------------- +# 1. Binary operator with constant types: col == +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("value", "column"), + list(CONSTANT_VALUES.values()), + ids=list(CONSTANT_VALUES.keys()), +) +def test_binary_operator_constant_rhs(rel, value, column): + """Expression == should work for every constant type.""" + expr = ColumnExpression(column) == value + result = rel.select(expr).fetchall() + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# 2. Binary operator with str: str becomes a ColumnExpression (column ref) +# --------------------------------------------------------------------------- + + +def test_binary_operator_str_rhs(rel): + """Str on the RHS becomes a ColumnExpression (column reference).""" + # ColumnExpression("i") == "i" → column i == column i → True + expr = ColumnExpression("i") == "i" + result = rel.select(expr).fetchall() + assert result == [(True,)] + + +# --------------------------------------------------------------------------- +# 3. Reflected operators: + col +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value", + [1, 1.0, decimal.Decimal("1")], + ids=["int", "float", "Decimal"], +) +def test_reflected_operator_lhs(rel, value): + """ + Expression should work via __radd__.""" + expr = value + ColumnExpression("i") + result = rel.select(expr).fetchall() + assert len(result) == 1 + + +# --------------------------------------------------------------------------- +# 4. Expression.isin() / isnotin() with mixed scalar types +# --------------------------------------------------------------------------- + + +def test_isin_with_scalars(rel): + expr = ColumnExpression("i").isin(42, 99, None) + result = rel.select(expr).fetchall() + assert result == [(True,)] + + +def test_isnotin_with_scalars(rel): + expr = ColumnExpression("i").isnotin(1, 2, 3) + result = rel.select(expr).fetchall() + assert result == [(True,)] + + +# --------------------------------------------------------------------------- +# 5. Expression.between() with scalar bounds +# --------------------------------------------------------------------------- + + +def test_between_with_int_scalars(rel): + expr = ColumnExpression("i").between(0, 100) + result = rel.select(expr).fetchall() + assert result == [(True,)] + + +def test_between_with_date_scalars(rel): + expr = ColumnExpression("d").between(datetime.date(2024, 1, 1), datetime.date(2024, 12, 31)) + result = rel.select(expr).fetchall() + assert result == [(True,)] + + +# --------------------------------------------------------------------------- +# 6. CaseExpression / when / otherwise with scalar values +# Note: str values become column refs, so we use int/None scalars here. +# --------------------------------------------------------------------------- + + +def test_case_expression_with_scalars(rel): + case = CaseExpression(ColumnExpression("i") == 42, 1) + case = case.otherwise(0) + result = rel.select(case).fetchall() + assert result == [(1,)] + + +def test_when_otherwise_with_scalars(rel): + case = CaseExpression(ColumnExpression("i") == 0, 0) + case = case.when(ColumnExpression("i") == 42, 42) + case = case.otherwise(None) + result = rel.select(case).fetchall() + assert result == [(42,)] + + +# --------------------------------------------------------------------------- +# 7. CoalesceOperator with scalars +# --------------------------------------------------------------------------- + + +def test_coalesce_with_scalars(rel): + expr = CoalesceOperator(None, None, 42) + result = rel.select(expr).fetchall() + assert result == [(42,)] + + +# --------------------------------------------------------------------------- +# 8. FunctionExpression with scalar args +# --------------------------------------------------------------------------- + + +def test_function_expression_with_scalars(rel): + expr = FunctionExpression("greatest", ColumnExpression("i"), 99) + result = rel.select(expr).fetchall() + assert result == [(99,)] + + +# --------------------------------------------------------------------------- +# 9. Relation.sort() with str (column reference) +# --------------------------------------------------------------------------- + + +def test_sort_with_string(): + con = duckdb.connect() + rel = con.sql("SELECT * FROM (VALUES (2, 'b'), (1, 'a'), (3, 'c')) t(x, y)") + result = rel.sort("x").fetchall() + assert result == [(1, "a"), (2, "b"), (3, "c")] + + +# --------------------------------------------------------------------------- +# 10. Relation.select() with constant scalars +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "value", + [ + 42, + 3.14, + True, + None, + datetime.date(2024, 1, 15), + datetime.datetime(2024, 1, 15, 10, 30), + datetime.time(10, 30), + datetime.timedelta(days=5), + decimal.Decimal("1.23"), + uuid.UUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), + ], + ids=[ + "int", + "float", + "bool", + "None", + "date", + "datetime", + "time", + "timedelta", + "Decimal", + "UUID", + ], +) +def test_select_with_constant(rel, value): + """rel.select() should produce a one-row result.""" + result = rel.select(value).fetchall() + assert len(result) == 1 + + +def test_select_with_string(rel): + """rel.select() selects a column by name.""" + result = rel.select("i").fetchall() + assert result == [(42,)] + + +# --------------------------------------------------------------------------- +# 11. Relation.project() with scalar +# --------------------------------------------------------------------------- + + +def test_project_with_scalar(rel): + result = rel.project(42).fetchall() + assert result == [(42,)] + + +# --------------------------------------------------------------------------- +# 12. Relation.aggregate() with scalar in list +# --------------------------------------------------------------------------- + + +def test_aggregate_with_scalar(): + con = duckdb.connect() + rel = con.sql("SELECT * FROM (VALUES (1), (2), (3)) t(a)") + # A bare int as an aggregate expression is accepted (non-aggregate, one per row) + result = rel.aggregate([5]).fetchall() + assert len(result) == 3 + assert all(row == (5,) for row in result) diff --git a/tests/fast/test_filesystem.py b/tests/fast/test_filesystem.py index f9f08266..758a243e 100644 --- a/tests/fast/test_filesystem.py +++ b/tests/fast/test_filesystem.py @@ -1,8 +1,8 @@ import logging import sys +from collections.abc import Callable from pathlib import Path, PurePosixPath from shutil import copyfileobj -from typing import Callable import pytest @@ -57,8 +57,7 @@ def add_file(fs, filename=FILENAME): class TestPythonFilesystem: def test_unregister_non_existent_filesystem(self, duckdb_cursor: DuckDBPyConnection): - with pytest.raises(InvalidInputException): - duckdb_cursor.unregister_filesystem("fake") + duckdb_cursor.unregister_filesystem("fake") def test_memory_filesystem(self, duckdb_cursor: DuckDBPyConnection, memory: fsspec.AbstractFileSystem): duckdb_cursor.register_filesystem(memory) diff --git a/tests/fast/test_insert.py b/tests/fast/test_insert.py index c5de1589..6eeabd67 100644 --- a/tests/fast/test_insert.py +++ b/tests/fast/test_insert.py @@ -1,13 +1,11 @@ -import pytest -from conftest import ArrowPandas, NumpyPandas +import pandas as pd import duckdb class TestInsert: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_insert(self, pandas): - test_df = pandas.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) + def test_insert(self): + test_df = pd.DataFrame({"i": [1, 2, 3], "j": ["one", "two", "three"]}) # connect to an in-memory temporary database conn = duckdb.connect() # get a cursor @@ -18,7 +16,7 @@ def test_insert(self, pandas): rel.insert([2, "two"]) rel.insert([3, "three"]) rel_a3 = cursor.table("test").project("CAST(i as BIGINT)i, j").to_df() - pandas.testing.assert_frame_equal(rel_a3, test_df) + pd.testing.assert_frame_equal(rel_a3, test_df) def test_insert_with_schema(self, duckdb_cursor): duckdb_cursor.sql("create schema not_main") diff --git a/tests/fast/test_map.py b/tests/fast/test_map.py index 336b2775..2209fe1b 100644 --- a/tests/fast/test_map.py +++ b/tests/fast/test_map.py @@ -2,8 +2,8 @@ from datetime import date, timedelta from typing import NoReturn +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb @@ -17,15 +17,13 @@ def evil1(df): class TestMap: - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_evil_map(self, duckdb_cursor, pandas): + def test_evil_map(self, duckdb_cursor): testrel = duckdb.values([1, 2]) rel = testrel.map(evil1, schema={"i": str}) with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): rel.df() - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_map(self, duckdb_cursor, pandas): + def test_map(self, duckdb_cursor): testrel = duckdb.values([1, 2]) conn = duckdb_cursor conn.execute("CREATE TABLE t (a integer)") @@ -57,16 +55,16 @@ def evil5(df) -> NoReturn: raise TypeError def return_dataframe(df): - return pandas.DataFrame({"A": [1]}) + return pd.DataFrame({"A": [1]}) def return_big_dataframe(df): - return pandas.DataFrame({"A": [1] * 5000}) + return pd.DataFrame({"A": [1] * 5000}) def return_none(df) -> None: return None def return_empty_df(df): - return pandas.DataFrame() + return pd.DataFrame() with pytest.raises(duckdb.InvalidInputException, match="Expected 1 columns from UDF, got 2"): print(testrel.map(evil1).df()) @@ -93,14 +91,14 @@ def return_empty_df(df): with pytest.raises(TypeError): print(testrel.map().df()) - testrel.map(return_dataframe).df().equals(pandas.DataFrame({"A": [1]})) + testrel.map(return_dataframe).df().equals(pd.DataFrame({"A": [1]})) with pytest.raises( duckdb.InvalidInputException, match="UDF returned more than 2048 rows, which is not allowed" ): testrel.map(return_big_dataframe).df() - empty_rel.map(return_dataframe).df().equals(pandas.DataFrame({"A": []})) + empty_rel.map(return_dataframe).df().equals(pd.DataFrame({"A": []})) with pytest.raises(duckdb.InvalidInputException, match="No return value from Python function"): testrel.map(return_none).df() @@ -118,18 +116,20 @@ def return_with_no_modification(df): # in this case we assume the returned type should be the same as the input type duckdb_cursor.values([b"1234"]).map(return_with_no_modification).fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_isse_3237(self, duckdb_cursor, pandas): + def test_isse_3237(self, duckdb_cursor): def process(rel): def mapper(x): dates = x["date"].to_numpy("datetime64[us]") days = x["days_to_add"].to_numpy("int") - x["result1"] = pandas.Series( - [pandas.to_datetime(y[0]).date() + timedelta(days=y[1].item()) for y in zip(dates, days)], + x["result1"] = pd.Series( + [pd.to_datetime(y[0]).date() + timedelta(days=y[1].item()) for y in zip(dates, days, strict=False)], dtype="datetime64[us]", ) - x["result2"] = pandas.Series( - [pandas.to_datetime(y[0]).date() + timedelta(days=-y[1].item()) for y in zip(dates, days)], + x["result2"] = pd.Series( + [ + pd.to_datetime(y[0]).date() + timedelta(days=-y[1].item()) + for y in zip(dates, days, strict=False) + ], dtype="datetime64[us]", ) return x @@ -140,8 +140,8 @@ def mapper(x): rel = rel.project("*, IF(ABS(one) > ABS(two), one, two) as three") return rel - df = pandas.DataFrame( - {"date": pandas.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), "days_to_add": [1, 2]} + df = pd.DataFrame( + {"date": pd.Series([date(2000, 1, 1), date(2000, 1, 2)], dtype="datetime64[us]"), "days_to_add": [1, 2]} ) rel = duckdb.from_df(df) rel = process(rel) @@ -172,10 +172,9 @@ def does_nothing(df): ): rel.fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_explicit_schema_name_mismatch(self, pandas): + def test_explicit_schema_name_mismatch(self): def renames_column(df): - return pandas.DataFrame({"a": df["i"]}) + return pd.DataFrame({"a": df["i"]}) con = duckdb.connect() rel = con.sql("select i from range(10) tbl(i)") @@ -183,8 +182,7 @@ def renames_column(df): with pytest.raises(duckdb.InvalidInputException, match=re.escape("UDF column name mismatch")): rel.fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_explicit_schema_error(self, pandas): + def test_explicit_schema_error(self): def no_op(df): return df @@ -196,8 +194,7 @@ def no_op(df): ): rel.map(no_op, schema=[int]) - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_returns_non_dataframe(self, pandas): + def test_returns_non_dataframe(self): def returns_series(df): return df.loc[:, "i"] @@ -205,17 +202,14 @@ def returns_series(df): rel = con.sql("select i, i as j from range(10) tbl(i)") with pytest.raises( duckdb.InvalidInputException, - match=re.escape( - "Expected the UDF to return an object of type 'pandas.DataFrame', found " - "'' instead" - ), + match=r"Expected the UDF to return an object of type 'pandas\.DataFrame', found " + r"'' instead", ): rel = rel.map(returns_series) - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_explicit_schema_columncount_mismatch(self, pandas): + def test_explicit_schema_columncount_mismatch(self): def returns_subset(df): - return pandas.DataFrame({"i": df.loc[:, "i"]}) + return pd.DataFrame({"i": df.loc[:, "i"]}) con = duckdb.connect() rel = con.sql("select i, i as j from range(10) tbl(i)") @@ -225,14 +219,13 @@ def returns_subset(df): ): rel.fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas()]) - def test_pyarrow_df(self, pandas): + def test_pyarrow_df(self): # PyArrow backed dataframes only exist on pandas >= 2.0.0 pytest.importorskip("pandas", "2.0.0") def basic_function(df): # Create a pyarrow backed dataframe - df = pandas.DataFrame({"a": [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend="pyarrow") + df = pd.DataFrame({"a": [5, 3, 2, 1, 2]}).convert_dtypes(dtype_backend="pyarrow") return df con = duckdb.connect() diff --git a/tests/fast/test_multithread.py b/tests/fast/test_multithread.py index dfefb918..fec0ed12 100644 --- a/tests/fast/test_multithread.py +++ b/tests/fast/test_multithread.py @@ -4,8 +4,8 @@ from pathlib import Path import numpy as np +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb @@ -25,11 +25,10 @@ def everything_succeeded(results: list[bool]): class DuckDBThreaded: - def __init__(self, duckdb_insert_thread_count, thread_function, pandas) -> None: + def __init__(self, duckdb_insert_thread_count, thread_function) -> None: self.duckdb_insert_thread_count = duckdb_insert_thread_count self.threads = [] self.thread_function = thread_function - self.pandas = pandas def multithread_test(self, result_verification=everything_succeeded): duckdb_conn = duckdb.connect() @@ -38,9 +37,7 @@ def multithread_test(self, result_verification=everything_succeeded): # Create all threads for i in range(self.duckdb_insert_thread_count): self.threads.append( - threading.Thread( - target=self.thread_function, args=(duckdb_conn, queue, self.pandas), name="duckdb_thread_" + str(i) - ) + threading.Thread(target=self.thread_function, args=(duckdb_conn, queue), name="duckdb_thread_" + str(i)) ) # Record for every thread if they succeeded or not @@ -58,7 +55,7 @@ def multithread_test(self, result_verification=everything_succeeded): assert result_verification(thread_results) -def execute_query_same_connection(duckdb_conn, queue, pandas): +def execute_query_same_connection(duckdb_conn, queue): try: duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)") queue.put(False) @@ -66,7 +63,7 @@ def execute_query_same_connection(duckdb_conn, queue, pandas): queue.put(True) -def execute_query(duckdb_conn, queue, pandas): +def execute_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -76,7 +73,7 @@ def execute_query(duckdb_conn, queue, pandas): queue.put(False) -def insert_runtime_error(duckdb_conn, queue, pandas): +def insert_runtime_error(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -86,7 +83,7 @@ def insert_runtime_error(duckdb_conn, queue, pandas): queue.put(True) -def execute_many_query(duckdb_conn, queue, pandas): +def execute_many_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -119,7 +116,7 @@ def execute_many_query(duckdb_conn, queue, pandas): queue.put(False) -def fetchone_query(duckdb_conn, queue, pandas): +def fetchone_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -129,7 +126,7 @@ def fetchone_query(duckdb_conn, queue, pandas): queue.put(False) -def fetchall_query(duckdb_conn, queue, pandas): +def fetchall_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -139,7 +136,7 @@ def fetchall_query(duckdb_conn, queue, pandas): queue.put(False) -def conn_close(duckdb_conn, queue, pandas): +def conn_close(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -149,7 +146,7 @@ def conn_close(duckdb_conn, queue, pandas): queue.put(False) -def fetchnp_query(duckdb_conn, queue, pandas): +def fetchnp_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -159,7 +156,7 @@ def fetchnp_query(duckdb_conn, queue, pandas): queue.put(False) -def fetchdf_query(duckdb_conn, queue, pandas): +def fetchdf_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -169,7 +166,7 @@ def fetchdf_query(duckdb_conn, queue, pandas): queue.put(False) -def fetchdf_chunk_query(duckdb_conn, queue, pandas): +def fetchdf_chunk_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -179,27 +176,27 @@ def fetchdf_chunk_query(duckdb_conn, queue, pandas): queue.put(False) -def fetch_arrow_query(duckdb_conn, queue, pandas): +def arrow_table_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_arrow_table() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").to_arrow_table() queue.put(True) except Exception: queue.put(False) -def fetch_record_batch_query(duckdb_conn, queue, pandas): +def fetch_record_batch_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: - duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").fetch_record_batch() + duckdb_conn.execute("select i from (values (42), (84), (NULL), (128)) tbl(i)").to_arrow_reader() queue.put(True) except Exception: queue.put(False) -def transaction_query(duckdb_conn, queue, pandas): +def transaction_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") @@ -214,11 +211,11 @@ def transaction_query(duckdb_conn, queue, pandas): queue.put(False) -def df_append(duckdb_conn, queue, pandas): +def df_append(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) + df = pd.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: duckdb_conn.append("T", df) queue.put(True) @@ -226,10 +223,10 @@ def df_append(duckdb_conn, queue, pandas): queue.put(False) -def df_register(duckdb_conn, queue, pandas): +def df_register(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) + df = pd.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: duckdb_conn.register("T", df) queue.put(True) @@ -237,10 +234,10 @@ def df_register(duckdb_conn, queue, pandas): queue.put(False) -def df_unregister(duckdb_conn, queue, pandas): +def df_unregister(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) + df = pd.DataFrame(np.random.randint(0, 100, size=15), columns=["A"]) try: duckdb_conn.register("T", df) duckdb_conn.unregister("T") @@ -249,7 +246,7 @@ def df_unregister(duckdb_conn, queue, pandas): queue.put(False) -def arrow_register_unregister(duckdb_conn, queue, pandas): +def arrow_register_unregister(duckdb_conn, queue): # Get a new connection pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() @@ -262,7 +259,7 @@ def arrow_register_unregister(duckdb_conn, queue, pandas): queue.put(False) -def table(duckdb_conn, queue, pandas): +def table(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") @@ -273,7 +270,7 @@ def table(duckdb_conn, queue, pandas): queue.put(False) -def view(duckdb_conn, queue, pandas): +def view(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE T ( i INTEGER)") @@ -285,7 +282,7 @@ def view(duckdb_conn, queue, pandas): queue.put(False) -def values(duckdb_conn, queue, pandas): +def values(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -295,7 +292,7 @@ def values(duckdb_conn, queue, pandas): queue.put(False) -def from_query(duckdb_conn, queue, pandas): +def from_query(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() try: @@ -305,10 +302,10 @@ def from_query(duckdb_conn, queue, pandas): queue.put(False) -def from_df(duckdb_conn, queue, pandas): +def from_df(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() - df = pandas.DataFrame(["bla", "blabla"] * 10, columns=["A"]) # noqa: F841 + df = pd.DataFrame(["bla", "blabla"] * 10, columns=["A"]) # noqa: F841 try: duckdb_conn.execute("select * from df").fetchall() queue.put(True) @@ -316,7 +313,7 @@ def from_df(duckdb_conn, queue, pandas): queue.put(False) -def from_arrow(duckdb_conn, queue, pandas): +def from_arrow(duckdb_conn, queue): # Get a new connection pa = pytest.importorskip("pyarrow") duckdb_conn = duckdb.connect() @@ -328,7 +325,7 @@ def from_arrow(duckdb_conn, queue, pandas): queue.put(False) -def from_csv_auto(duckdb_conn, queue, pandas): +def from_csv_auto(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() filename = str(Path(__file__).parent / "data" / "integers.csv") @@ -339,7 +336,7 @@ def from_csv_auto(duckdb_conn, queue, pandas): queue.put(False) -def from_parquet(duckdb_conn, queue, pandas): +def from_parquet(duckdb_conn, queue): # Get a new connection duckdb_conn = duckdb.connect() filename = str(Path(__file__).parent / "data" / "binary_string.parquet") @@ -350,7 +347,7 @@ def from_parquet(duckdb_conn, queue, pandas): queue.put(False) -def description(_, queue, __): +def description(_, queue): # Get a new connection duckdb_conn = duckdb.connect() duckdb_conn.execute("CREATE TABLE test (i bool, j TIME, k VARCHAR)") @@ -364,7 +361,7 @@ def description(_, queue, __): queue.put(False) -def cursor(duckdb_conn, queue, pandas): +def cursor(duckdb_conn, queue): # Get a new connection cx = duckdb_conn.cursor() try: @@ -375,136 +372,111 @@ def cursor(duckdb_conn, queue, pandas): class TestDuckMultithread: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_execute(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, execute_query, pandas) + def test_execute(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, execute_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_execute_many(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, execute_many_query, pandas) + def test_execute_many(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, execute_many_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fetchone(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, fetchone_query, pandas) + def test_fetchone(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, fetchone_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fetchall(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, fetchall_query, pandas) + def test_fetchall(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, fetchall_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_close(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, conn_close, pandas) + def test_close(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, conn_close) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fetchnp(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, fetchnp_query, pandas) + def test_fetchnp(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, fetchnp_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fetchdf(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, fetchdf_query, pandas) + def test_fetchdf(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, fetchdf_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fetchdfchunk(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, fetchdf_chunk_query, pandas) + def test_fetchdfchunk(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, fetchdf_chunk_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fetcharrow(self, duckdb_cursor, pandas): + def test_fetcharrow(self, duckdb_cursor): pytest.importorskip("pyarrow") - duck_threads = DuckDBThreaded(10, fetch_arrow_query, pandas) + duck_threads = DuckDBThreaded(10, arrow_table_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_fetch_record_batch(self, duckdb_cursor, pandas): + def test_fetch_record_batch(self, duckdb_cursor): pytest.importorskip("pyarrow") - duck_threads = DuckDBThreaded(10, fetch_record_batch_query, pandas) + duck_threads = DuckDBThreaded(10, fetch_record_batch_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_transaction(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, transaction_query, pandas) + def test_transaction(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, transaction_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_df_append(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, df_append, pandas) + def test_df_append(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, df_append) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_df_register(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, df_register, pandas) + def test_df_register(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, df_register) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_df_unregister(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, df_unregister, pandas) + def test_df_unregister(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, df_unregister) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_arrow_register_unregister(self, duckdb_cursor, pandas): + def test_arrow_register_unregister(self, duckdb_cursor): pytest.importorskip("pyarrow") - duck_threads = DuckDBThreaded(10, arrow_register_unregister, pandas) + duck_threads = DuckDBThreaded(10, arrow_register_unregister) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_table(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, table, pandas) + def test_table(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, table) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_view(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, view, pandas) + def test_view(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, view) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_values(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, values, pandas) + def test_values(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, values) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_from_query(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, from_query, pandas) + def test_from_query(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, from_query) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_from_DF(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, from_df, pandas) + def test_from_DF(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, from_df) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_from_arrow(self, duckdb_cursor, pandas): + def test_from_arrow(self, duckdb_cursor): pytest.importorskip("pyarrow") - duck_threads = DuckDBThreaded(10, from_arrow, pandas) + duck_threads = DuckDBThreaded(10, from_arrow) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_from_csv_auto(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, from_csv_auto, pandas) + def test_from_csv_auto(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, from_csv_auto) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_from_parquet(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, from_parquet, pandas) + def test_from_parquet(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, from_parquet) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_description(self, duckdb_cursor, pandas): - duck_threads = DuckDBThreaded(10, description, pandas) + def test_description(self, duckdb_cursor): + duck_threads = DuckDBThreaded(10, description) duck_threads.multithread_test() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_cursor(self, duckdb_cursor, pandas): + def test_cursor(self, duckdb_cursor): def only_some_succeed(results: list[bool]) -> bool: if not any(result for result in results): return False return not all(result for result in results) - duck_threads = DuckDBThreaded(10, cursor, pandas) + duck_threads = DuckDBThreaded(10, cursor) duck_threads.multithread_test(only_some_succeed) diff --git a/tests/fast/test_parameter_list.py b/tests/fast/test_parameter_list.py index 22413999..6d101bcb 100644 --- a/tests/fast/test_parameter_list.py +++ b/tests/fast/test_parameter_list.py @@ -1,5 +1,5 @@ +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb @@ -12,10 +12,9 @@ def test_bool(self, duckdb_cursor): res = conn.execute("select count(*) from bool_table where a =?", [True]) assert res.fetchone()[0] == 1 - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_exception(self, duckdb_cursor, pandas): + def test_exception(self, duckdb_cursor): conn = duckdb.connect() - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } diff --git a/tests/fast/test_profiler.py b/tests/fast/test_profiler.py new file mode 100644 index 00000000..9e023b0e --- /dev/null +++ b/tests/fast/test_profiler.py @@ -0,0 +1,47 @@ +import pytest + +import duckdb +from duckdb.query_graph import ProfilingInfo + + +@pytest.fixture(scope="session") +def profiling_connection(): + con = duckdb.connect() + con.enable_profiling() + con.execute("SELECT 42;").fetchall() + yield con + con.close() + + +class TestProfiler: + def test_profiler_matches_expected_format(self, profiling_connection, tmp_path_factory): + # Test String returned + profiling_info = ProfilingInfo(profiling_connection) + profiling_info_json = profiling_info.to_json() + assert isinstance(profiling_info_json, str) + + # Test expected metrics are there and profiling is json loadable + profiling_dict = profiling_info.to_pydict() + expected_keys = { + "query_name", + "total_bytes_written", + "total_bytes_read", + "system_peak_temp_dir_size", + "system_peak_buffer_memory", + "rows_returned", + "result_set_size", + "latency", + "cumulative_rows_scanned", + "cumulative_cardinality", + "cpu_time", + "extra_info", + "blocked_thread_time", + "children", + } + assert expected_keys.issubset(profiling_dict.keys()) + + def test_profiler_html_output(self, profiling_connection, tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("profiler", numbered=True) + profiling_info = ProfilingInfo(profiling_connection) + # Test HTML execution works, nothing to assert! + profiling_info.to_html(output_file=f"{tmp_dir}/profiler_output.html") diff --git a/tests/fast/test_relation.py b/tests/fast/test_relation.py index 32349c68..bc7039fa 100644 --- a/tests/fast/test_relation.py +++ b/tests/fast/test_relation.py @@ -2,13 +2,11 @@ import datetime import gc import os -import platform import tempfile import numpy as np import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb from duckdb import ColumnExpression @@ -39,10 +37,9 @@ def test_csv_auto(self): csv_rel = duckdb.from_csv_auto(temp_file_name) assert df_rel.execute().fetchall() == csv_rel.execute().fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_relation_view(self, duckdb_cursor, pandas): + def test_relation_view(self, duckdb_cursor): def create_view(duckdb_cursor) -> None: - df_in = pandas.DataFrame({"numbers": [1, 2, 3, 4, 5]}) + df_in = pd.DataFrame({"numbers": [1, 2, 3, 4, 5]}) rel = duckdb_cursor.query("select * from df_in") rel.to_view("my_view") @@ -536,15 +533,6 @@ def test_relation_print(self): 1024, 2048, 5000, - 1000000, - pytest.param( - 10000000, - marks=pytest.mark.skipif( - condition=platform.system() == "Emscripten", - reason="Emscripten/Pyodide builds run out of memory at this scale, and error might not " - "thrown reliably", - ), - ), ], ) def test_materialized_relation(self, duckdb_cursor, num_rows): diff --git a/tests/fast/test_relation_dependency_leak.py b/tests/fast/test_relation_dependency_leak.py index 659e1c28..db83ff1c 100644 --- a/tests/fast/test_relation_dependency_leak.py +++ b/tests/fast/test_relation_dependency_leak.py @@ -1,6 +1,7 @@ import os import numpy as np +import pandas as pd import pytest try: @@ -9,67 +10,61 @@ can_run = True except ImportError: can_run = False -from conftest import ArrowPandas, NumpyPandas psutil = pytest.importorskip("psutil") -def check_memory(function_to_check, pandas, duckdb_cursor): +def check_memory(function_to_check, duckdb_cursor): process = psutil.Process(os.getpid()) mem_usage = process.memory_info().rss / (10**9) for __ in range(100): - function_to_check(pandas, duckdb_cursor) + function_to_check(duckdb_cursor) cur_mem_usage = process.memory_info().rss / (10**9) # This seems a good empirical value assert cur_mem_usage / 3 < mem_usage -def from_df(pandas, duckdb_cursor): - df = pandas.DataFrame({"x": np.random.rand(1_000_000)}) +def from_df(duckdb_cursor): + df = pd.DataFrame({"x": np.random.rand(1_000_000)}) return duckdb_cursor.from_df(df) -def from_arrow(pandas, duckdb_cursor): +def from_arrow(duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) arrow_table = pa.Table.from_arrays([data], ["a"]) duckdb_cursor.from_arrow(arrow_table) -def arrow_replacement(pandas, duckdb_cursor): +def arrow_replacement(duckdb_cursor): data = pa.array(np.random.rand(1_000_000), type=pa.float32()) arrow_table = pa.Table.from_arrays([data], ["a"]) # noqa: F841 duckdb_cursor.query("select sum(a) from arrow_table").fetchall() -def pandas_replacement(pandas, duckdb_cursor): - df = pandas.DataFrame({"x": np.random.rand(1_000_000)}) # noqa: F841 +def pandas_replacement(duckdb_cursor): + df = pd.DataFrame({"x": np.random.rand(1_000_000)}) # noqa: F841 duckdb_cursor.query("select sum(x) from df").fetchall() class TestRelationDependencyMemoryLeak: - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_from_arrow_leak(self, pandas, duckdb_cursor): + def test_from_arrow_leak(self, duckdb_cursor): if not can_run: return - check_memory(from_arrow, pandas, duckdb_cursor) + check_memory(from_arrow, duckdb_cursor) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_from_df_leak(self, pandas, duckdb_cursor): - check_memory(from_df, pandas, duckdb_cursor) + def test_from_df_leak(self, duckdb_cursor): + check_memory(from_df, duckdb_cursor) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_arrow_replacement_scan_leak(self, pandas, duckdb_cursor): + def test_arrow_replacement_scan_leak(self, duckdb_cursor): if not can_run: return - check_memory(arrow_replacement, pandas, duckdb_cursor) + check_memory(arrow_replacement, duckdb_cursor) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_pandas_replacement_scan_leak(self, pandas, duckdb_cursor): - check_memory(pandas_replacement, pandas, duckdb_cursor) + def test_pandas_replacement_scan_leak(self, duckdb_cursor): + check_memory(pandas_replacement, duckdb_cursor) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_relation_view_leak(self, pandas, duckdb_cursor): - rel = from_df(pandas, duckdb_cursor) + def test_relation_view_leak(self, duckdb_cursor): + rel = from_df(duckdb_cursor) rel.create_view("bla") duckdb_cursor.unregister("bla") assert rel.query("bla", "select count(*) from bla").fetchone()[0] == 1_000_000 diff --git a/tests/fast/test_replacement_scan.py b/tests/fast/test_replacement_scan.py index 1e76d1d5..cafee85d 100644 --- a/tests/fast/test_replacement_scan.py +++ b/tests/fast/test_replacement_scan.py @@ -25,30 +25,26 @@ def using_sql(con, to_scan, object_name): # Fetch methods -def fetch_polars(rel): +def polars_from_rel(rel): return rel.pl() -def fetch_df(rel): +def df_from_rel(rel): return rel.df() -def fetch_arrow(rel): - return rel.fetch_arrow_table() +def arrow_table_from_rel(rel): + return rel.to_arrow_table() -def fetch_arrow_table(rel): - return rel.fetch_arrow_table() - - -def fetch_arrow_record_batch(rel: duckdb.DuckDBPyRelation): +def arrow_reader_from_rel(rel: duckdb.DuckDBPyRelation): # Note: this has to executed first, otherwise we'll create a deadlock # Because it will try to execute the input at the same time as executing the relation # On the same connection (that's the core of the issue) - return rel.execute().fetch_record_batch() + return rel.execute().to_arrow_reader() -def fetch_relation(rel): +def rel_from_rel(rel): return rel @@ -94,7 +90,7 @@ def test_parquet_replacement(self): @pytest.mark.parametrize("get_relation", [using_table, using_sql]) @pytest.mark.parametrize( "fetch_method", - [fetch_polars, fetch_df, fetch_arrow, fetch_arrow_table, fetch_arrow_record_batch, fetch_relation], + [polars_from_rel, df_from_rel, arrow_table_from_rel, arrow_reader_from_rel, rel_from_rel], ) @pytest.mark.parametrize("object_name", ["tbl", "table", "select", "update"]) def test_table_replacement_scans(self, duckdb_cursor, get_relation, fetch_method, object_name): @@ -314,7 +310,6 @@ def test_cte_with_joins(self, duckdb_cursor): res = rel.fetchall() assert res == [(2, 2, 2)] - @pytest.mark.xfail(reason="Bug in DuckDB core (MRE at #19154)") def test_same_name_cte(self, duckdb_cursor): query = """ WITH df AS ( @@ -469,7 +464,8 @@ def test_replacement_disabled(self): with pytest.raises(duckdb.CatalogException, match="Table with name df does not exist!"): create_relation(con, "select * from df") with pytest.raises( - duckdb.InvalidInputException, match="Cannot change enable_external_access setting while database is running" + duckdb.InvalidInputException, + match="Invalid Input Error: Cannot enable external access while database is running", ): con.execute("set enable_external_access=true") diff --git a/tests/fast/test_result.py b/tests/fast/test_result.py index 4210a437..9180f8dc 100644 --- a/tests/fast/test_result.py +++ b/tests/fast/test_result.py @@ -21,9 +21,9 @@ def test_result_closed(self, duckdb_cursor): with pytest.raises(duckdb.InvalidInputException, match="result closed"): res.fetchnumpy() with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): - res.fetch_arrow_table() + res.to_arrow_table() with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): - res.fetch_arrow_reader(1) + res.to_arrow_reader(1) def test_result_describe_types(self, duckdb_cursor): connection = duckdb.connect("") diff --git a/tests/fast/test_runtime_error.py b/tests/fast/test_runtime_error.py index 9f1975a0..62bf7589 100644 --- a/tests/fast/test_runtime_error.py +++ b/tests/fast/test_runtime_error.py @@ -1,5 +1,5 @@ +import pandas as pd import pytest -from conftest import ArrowPandas, NumpyPandas import duckdb @@ -31,7 +31,7 @@ def test_arrow_error(self): con = duckdb.connect() con.execute("create table tbl as select 'hello' i") with pytest.raises(duckdb.ConversionException): - con.execute("select i::int from tbl").fetch_arrow_table() + con.execute("select i::int from tbl").to_arrow_table() def test_register_error(self): con = duckdb.connect() @@ -43,28 +43,27 @@ def test_arrow_fetch_table_error(self): pytest.importorskip("pyarrow") con = duckdb.connect() - arrow_object = con.execute("select 1").fetch_arrow_table() + arrow_object = con.execute("select 1").to_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() with pytest.raises(duckdb.InvalidInputException, match="There is no query result"): - res.fetch_arrow_table() + res.to_arrow_table() def test_arrow_record_batch_reader_error(self): pytest.importorskip("pyarrow") con = duckdb.connect() - arrow_object = con.execute("select 1").fetch_arrow_table() + arrow_object = con.execute("select 1").to_arrow_table() arrow_relation = con.from_arrow(arrow_object) res = arrow_relation.execute() res.close() with pytest.raises(duckdb.ProgrammingError, match="There is no query result"): - res.fetch_arrow_reader(1) + res.to_arrow_reader(1) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_relation_cache_fetchall(self, pandas): + def test_relation_cache_fetchall(self): conn = duckdb.connect() - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -78,10 +77,9 @@ def test_relation_cache_fetchall(self, pandas): # so the dependency of 'x' on 'df_in' is not registered in 'rel' rel.fetchall() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_relation_cache_execute(self, pandas): + def test_relation_cache_execute(self): conn = duckdb.connect() - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -92,10 +90,9 @@ def test_relation_cache_execute(self, pandas): with pytest.raises(duckdb.ProgrammingError, match="Table with name df_in does not exist"): rel.execute() - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_relation_query_error(self, pandas): + def test_relation_query_error(self): conn = duckdb.connect() - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -106,10 +103,9 @@ def test_relation_query_error(self, pandas): with pytest.raises(duckdb.CatalogException, match="Table with name df_in does not exist"): rel.query("bla", "select * from bla") - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_conn_broken_statement_error(self, pandas): + def test_conn_broken_statement_error(self): conn = duckdb.connect() - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -128,11 +124,10 @@ def test_conn_prepared_statement_error(self): ): conn.execute("select * from integers where a =? and b=?", [1]) - @pytest.mark.parametrize("pandas", [NumpyPandas(), ArrowPandas()]) - def test_closed_conn_exceptions(self, pandas): + def test_closed_conn_exceptions(self): conn = duckdb.connect() conn.close() - df_in = pandas.DataFrame( + df_in = pd.DataFrame( { "numbers": [1, 2, 3, 4, 5], } @@ -190,7 +185,7 @@ def test_missing_result_from_conn_exceptions(self): conn.fetch_df_chunk() with no_result_set(): - conn.fetch_arrow_table() + conn.to_arrow_table() with no_result_set(): - conn.fetch_record_batch() + conn.to_arrow_reader() diff --git a/tests/fast/test_type.py b/tests/fast/test_type.py index d8145166..17cec9e6 100644 --- a/tests/fast/test_type.py +++ b/tests/fast/test_type.py @@ -1,5 +1,4 @@ import sys -from typing import Optional, Union import pytest @@ -138,7 +137,7 @@ def test_implicit_convert_from_builtin_type(self): res = duckdb.list_type(list[dict[str, dict[list[str], str]]]) assert str(res.child) == "MAP(VARCHAR, MAP(VARCHAR[], VARCHAR))[]" - res = duckdb.list_type(list[Union[str, int]]) + res = duckdb.list_type(list[str | int]) assert str(res.child) == "UNION(u1 VARCHAR, u2 BIGINT)[]" def test_implicit_convert_from_numpy(self, duckdb_cursor): @@ -227,21 +226,21 @@ def test_hash_method(self): # NOTE: we can support this, but I don't think going through hoops for an outdated version of python is worth it @pytest.mark.skipif(sys.version_info < (3, 9), reason="python3.7 does not store Optional[..] in a recognized way") def test_optional(self): - type = DuckDBPyType(Optional[str]) + type = DuckDBPyType(str | None) assert type == "VARCHAR" - type = DuckDBPyType(Optional[Union[int, bool]]) + type = DuckDBPyType(int | bool | None) assert type == "UNION(u1 BIGINT, u2 BOOLEAN)" - type = DuckDBPyType(Optional[list[int]]) + type = DuckDBPyType(list[int] | None) assert type == "BIGINT[]" - type = DuckDBPyType(Optional[dict[int, str]]) + type = DuckDBPyType(dict[int, str] | None) assert type == "MAP(BIGINT, VARCHAR)" - type = DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) + type = DuckDBPyType(dict[int | None, str | None] | None) assert type == "MAP(BIGINT, VARCHAR)" - type = DuckDBPyType(Optional[dict[Optional[int], Optional[str]]]) + type = DuckDBPyType(dict[int | None, str | None] | None) assert type == "MAP(BIGINT, VARCHAR)" - type = DuckDBPyType(Optional[Union[Optional[str], Optional[bool]]]) + type = DuckDBPyType(str | None | bool) assert type == "UNION(u1 VARCHAR, u2 BOOLEAN)" - type = DuckDBPyType(Union[str, None]) + type = DuckDBPyType(str | None) assert type == "VARCHAR" @pytest.mark.skipif(sys.version_info < (3, 10), reason="'str | None' syntax requires Python 3.10 or higher") diff --git a/tests/fast/test_type_conversion.py b/tests/fast/test_type_conversion.py new file mode 100644 index 00000000..9bc2e6d2 --- /dev/null +++ b/tests/fast/test_type_conversion.py @@ -0,0 +1,136 @@ +"""Regression tests for Python-to-DuckDB type conversion bugs. + +Issue #115: Float conversion error with UNION containing float +Issue #171: Dictionary key case sensitivity not respected for parameter bindings +Issue #330: Integers >64-bit lose precision via double conversion +""" + +import numpy as np +import pytest + +import duckdb +from duckdb.sqltypes import BIGINT, DOUBLE, FLOAT, HUGEINT, UHUGEINT, VARCHAR, DuckDBPyType + + +class TestIssue115FloatToUnion: + """HandleDouble should use DefaultCastAs for unknown target types like UNION.""" + + def test_udf_float_to_union_type(self): + conn = duckdb.connect() + conn.create_function( + "return_float", + lambda: 1.5, + return_type=duckdb.union_type({"u1": VARCHAR, "u2": BIGINT, "u3": DOUBLE}), + ) + result = conn.sql("SELECT return_float()").fetchone()[0] + assert result == 1.5 + + def test_udf_float_to_ambiguous_union_type(self): + """UNION with duplicate DOUBLE members (from np.float64 and float) must not raise ambiguity error.""" + conn = duckdb.connect() + conn.create_function( + "return_float", + lambda: 1.5, + return_type=duckdb.union_type({"u1": VARCHAR, "u2": BIGINT, "u3": DOUBLE, "u4": FLOAT, "u5": DOUBLE}), + ) + result = conn.sql("SELECT return_float()").fetchone()[0] + assert result == 1.5 + + def test_udf_dict_with_float_in_union_struct(self): + """Original repro from issue #115 with ambiguous UNION members.""" + conn = duckdb.connect() + + arr = [{"a": 1, "b": 1.2}, {"a": 3, "b": 2.4}] + + def test(): + return arr + + return_type = DuckDBPyType(list[dict[str, str | int | np.float64 | np.float32 | float]]) + conn.create_function("test", test, return_type=return_type) + result = conn.sql("SELECT test()").fetchone()[0] + assert len(result) == 2 + assert result[0]["b"] == pytest.approx(1.2) + assert result[1]["b"] == pytest.approx(2.4) + + def test_udf_int_to_ambiguous_union_type(self): + """HandleBigint default branch: int into UNION with duplicate BIGINT members.""" + conn = duckdb.connect() + conn.create_function( + "return_int", + lambda: 42, + return_type=duckdb.union_type({"u1": VARCHAR, "u2": BIGINT, "u3": BIGINT}), + ) + result = conn.sql("SELECT return_int()").fetchone()[0] + assert result == 42 + + def test_udf_string_to_ambiguous_union_type(self): + """HandleString default branch: str into UNION with duplicate VARCHAR members.""" + conn = duckdb.connect() + conn.create_function( + "return_str", + lambda: "hello", + return_type=duckdb.union_type({"u1": VARCHAR, "u2": BIGINT, "u3": VARCHAR}), + ) + result = conn.sql("SELECT return_str()").fetchone()[0] + assert result == "hello" + + +class TestIssue171DictKeyCaseSensitivity: + """Dict keys differing only by case must preserve their individual values.""" + + def test_case_sensitive_dict_keys(self): + result = duckdb.execute("SELECT ?", [{"Key": "first", "key": "second"}]).fetchone()[0] + assert result["Key"] == "first" + assert result["key"] == "second" + + def test_case_sensitive_dict_keys_three_variants(self): + result = duckdb.execute("SELECT ?", [{"abc": 1, "ABC": 2, "Abc": 3}]).fetchone()[0] + assert result["abc"] == 1 + assert result["ABC"] == 2 + assert result["Abc"] == 3 + + +class TestIssue330LargeIntegerPrecision: + """Integers >64-bit must not lose precision via double conversion.""" + + # --- Parameter binding path (TryTransformPythonNumeric) --- + + def test_param_hugeint_large(self): + """Value with >52 significant bits must not lose precision.""" + value = (2**128 - 1) // 15 * 7 # 0x77777777777777777777777777777777 + result = duckdb.execute("SELECT ?::HUGEINT", [value]).fetchone()[0] + assert result == value + + def test_param_uhugeint_max(self): + """2**128-1 must not overflow when cast to UHUGEINT.""" + value = 2**128 - 1 + result = duckdb.execute("SELECT ?::UHUGEINT", [value]).fetchone()[0] + assert result == value + + def test_param_auto_sniff(self): + """2**64 without explicit cast should sniff as HUGEINT, not lose precision.""" + value = 2**64 + result = duckdb.execute("SELECT ?", [value]).fetchone()[0] + assert result == value + + def test_param_negative_hugeint_no_regression(self): + """Negative overflow path (already correct) must not regress.""" + value = -(2**64) + result = duckdb.execute("SELECT ?::HUGEINT", [value]).fetchone()[0] + assert result == value + + # --- UDF return path (TransformPythonObjectInternal template) --- + + def test_udf_return_large_hugeint(self): + value = (2**128 - 1) // 15 * 7 + conn = duckdb.connect() + conn.create_function("big_hugeint", lambda: value, return_type=HUGEINT) + result = conn.sql("SELECT big_hugeint()").fetchone()[0] + assert result == value + + def test_udf_return_large_uhugeint(self): + value = 2**128 - 1 + conn = duckdb.connect() + conn.create_function("big_uhugeint", lambda: value, return_type=UHUGEINT) + result = conn.sql("SELECT big_uhugeint()").fetchone()[0] + assert result == value diff --git a/tests/fast/test_variant.py b/tests/fast/test_variant.py new file mode 100644 index 00000000..f935d291 --- /dev/null +++ b/tests/fast/test_variant.py @@ -0,0 +1,203 @@ +import numpy as np +import pytest + +import duckdb + + +class TestVariantFetchall: + """Tests for fetchall/fetchone with VARIANT columns (should all pass).""" + + def test_integer(self): + result = duckdb.sql("SELECT 42::VARIANT AS v").fetchone() + assert result[0] == 42 + + def test_string(self): + result = duckdb.sql("SELECT 'hello'::VARIANT AS v").fetchone() + assert result[0] == "hello" + + def test_boolean(self): + result = duckdb.sql("SELECT true::VARIANT AS v").fetchone() + assert result[0] is True + + def test_double(self): + result = duckdb.sql("SELECT 3.14::DOUBLE::VARIANT AS v").fetchone() + assert abs(result[0] - 3.14) < 1e-10 + + def test_null(self): + result = duckdb.sql("SELECT NULL::VARIANT AS v").fetchone() + assert result[0] is None + + def test_list(self): + result = duckdb.sql("SELECT [1, 2, 3]::VARIANT AS v").fetchone() + assert result[0] == [1, 2, 3] + + def test_struct(self): + result = duckdb.sql("SELECT {'a': 1, 'b': 2}::VARIANT AS v").fetchone() + assert result[0] == {"a": 1, "b": 2} + + def test_nested_struct(self): + result = duckdb.sql("SELECT {'x': {'y': 42}}::VARIANT AS v").fetchone() + assert result[0] == {"x": {"y": 42}} + + def test_map(self): + result = duckdb.sql("SELECT MAP {'key1': 'val1', 'key2': 'val2'}::VARIANT AS v").fetchone() + val = result[0] + # VARIANT converts maps to a list of key/value structs + assert val == [{"key": "key1", "value": "val1"}, {"key": "key2", "value": "val2"}] + + def test_multiple_rows_mixed_types(self): + result = duckdb.sql(""" + SELECT * FROM ( + VALUES (42::VARIANT), ('hello'::VARIANT), (true::VARIANT), ([1,2]::VARIANT) + ) AS t(v) + """).fetchall() + assert result[0][0] == 42 + assert result[1][0] == "hello" + assert result[2][0] is True + assert result[3][0] == [1, 2] + + def test_variant_from_table(self): + con = duckdb.connect() + con.execute("CREATE TABLE t (v VARIANT)") + con.execute("INSERT INTO t VALUES (42::VARIANT), ('hello'::VARIANT)") + result = con.execute("SELECT * FROM t").fetchall() + assert result[0][0] == 42 + assert result[1][0] == "hello" + + def test_variant_as_map_key(self): + """The original repro that motivated VARIANT support.""" + result = duckdb.sql(""" + SELECT MAP {42::VARIANT: 'answer'} AS m + """).fetchone() + # MAP with VARIANT keys is returned as a struct with key/value arrays + assert result[0] == {"key": [42], "value": ["answer"]} + + +class TestVariantFetchNumpy: + """Tests for fetchnumpy with VARIANT columns.""" + + def test_single_row(self): + result = duckdb.sql("SELECT 42::VARIANT AS v").fetchnumpy() + assert result["v"][0] == 42 + + def test_multiple_rows(self): + """Exercises chunk_offset > 0 — this was broken by Bug A/B.""" + result = duckdb.sql(""" + SELECT * FROM ( + VALUES (1::VARIANT), (2::VARIANT), (3::VARIANT) + ) AS t(v) + """).fetchnumpy() + values = list(result["v"]) + assert values == [1, 2, 3] + + def test_null_handling(self): + result = duckdb.sql(""" + SELECT * FROM ( + VALUES (42::VARIANT), (NULL::VARIANT), (99::VARIANT) + ) AS t(v) + """).fetchnumpy() + arr = result["v"] + assert arr[0] == 42 + assert arr[1] is np.ma.masked or arr[1] is None + assert arr[2] == 99 + + def test_mixed_types(self): + result = duckdb.sql(""" + SELECT * FROM ( + VALUES (42::VARIANT), ('hello'::VARIANT), (true::VARIANT) + ) AS t(v) + """).fetchnumpy() + values = list(result["v"]) + assert values[0] == 42 + assert values[1] == "hello" + assert values[2] is True + + +class TestVariantFetchDF: + """Tests for Pandas df() with VARIANT columns (goes through numpy).""" + + def test_basic(self): + df = duckdb.sql("SELECT 42::VARIANT AS v").df() + assert df["v"].iloc[0] == 42 + + def test_multiple_types(self): + df = duckdb.sql(""" + SELECT * FROM ( + VALUES (42::VARIANT), ('hello'::VARIANT), (true::VARIANT) + ) AS t(v) + """).df() + assert df["v"].iloc[0] == 42 + assert df["v"].iloc[1] == "hello" + assert df["v"].iloc[2] is True + + def test_null_handling(self): + df = duckdb.sql(""" + SELECT * FROM ( + VALUES (42::VARIANT), (NULL::VARIANT), (99::VARIANT) + ) AS t(v) + """).df() + assert df["v"].iloc[0] == 42 + assert df["v"].iloc[2] == 99 + + +class TestVariantArrow: + """Tests for Arrow/Polars — blocked on DuckDB core Arrow support.""" + + @pytest.mark.xfail(strict=True, reason="Arrow export for VARIANT not yet supported in DuckDB core") + def test_to_arrow_table(self): + duckdb.sql("SELECT 42::VARIANT AS v").arrow() + + @pytest.mark.xfail(strict=True, reason="Arrow export for VARIANT not yet supported in DuckDB core") + def test_fetch_arrow_reader(self): + duckdb.sql("SELECT 42::VARIANT AS v").fetch_arrow_reader() + + @pytest.mark.xfail(strict=True, reason="Polars uses Arrow, which doesn't support VARIANT yet") + def test_polars(self): + duckdb.sql("SELECT 42::VARIANT AS v").pl() + + +class TestVariantIngestion: + """Tests for Python → DuckDB VARIANT ingestion.""" + + def test_insert_with_params(self): + con = duckdb.connect() + con.execute("CREATE TABLE t (v VARIANT)") + con.execute("INSERT INTO t VALUES ($1::VARIANT)", [42]) + result = con.execute("SELECT * FROM t").fetchone() + assert result[0] == 42 + + +class TestVariantType: + """Tests for VARIANT in the type system.""" + + def test_type_from_string(self): + t = duckdb.type("VARIANT") + assert t.id == "variant" + + def test_variant_constant(self): + from duckdb.sqltypes import VARIANT + + assert VARIANT is not None + assert VARIANT.id == "variant" + + def test_children_raises(self): + t = duckdb.type("VARIANT") + with pytest.raises(duckdb.InvalidInputException, match="not nested"): + _ = t.children + + def test_sqltypes_variant(self): + from duckdb.sqltypes import VARIANT + + assert VARIANT.id == "variant" + + +class TestVariantPySpark: + """Tests for PySpark VARIANT type mapping.""" + + def test_variant_converts_to_variant_type(self): + from duckdb.experimental.spark.sql.type_utils import convert_type + from duckdb.experimental.spark.sql.types import VariantType + + t = duckdb.type("VARIANT") + spark_type = convert_type(t) + assert isinstance(spark_type, VariantType) diff --git a/tests/fast/types/test_time_ns.py b/tests/fast/types/test_time_ns.py new file mode 100644 index 00000000..5a27f612 --- /dev/null +++ b/tests/fast/types/test_time_ns.py @@ -0,0 +1,75 @@ +import datetime + +import pytest + +from duckdb import ConversionException, sqltypes + + +def test_time_ns_select(duckdb_cursor): + duckdb_cursor.execute("SELECT TIME_NS '1992-09-20 11:30:00.123456'") + result = duckdb_cursor.fetchone()[0] + assert result + assert isinstance(result, datetime.time) + + +@pytest.mark.xfail( + raises=ConversionException, + reason="Conversion Error: Unimplemented type for cast (TIME -> TIME_NS)", +) +def test_time_ns_insert(duckdb_cursor): + """This tests that datetime.time values can be inserted as TIME_NS.""" + duckdb_cursor.execute("SELECT TIME_NS '1992-09-20 11:30:00.123456'") + result1 = duckdb_cursor.fetchone()[0] + duckdb_cursor.execute("CREATE OR REPLACE TEMP TABLE time_ns_test (time_ns_col TIME_NS)") + duckdb_cursor.execute("INSERT INTO time_ns_test VALUES (?)", [result1]) + duckdb_cursor.execute("SELECT time_ns_col FROM time_ns_test") + result2 = duckdb_cursor.fetchone()[0] + assert isinstance(result2, datetime.time) + assert result1 == result2 + + +def test_time_insert(duckdb_cursor): + """This tests that datetime.time values are casted to TIME when needed.""" + duckdb_cursor.execute("SELECT TIME_NS '1992-09-20 11:30:00.123456'") + result1 = duckdb_cursor.fetchone()[0] + duckdb_cursor.execute("CREATE OR REPLACE TEMP TABLE time_test (time_col TIME)") + duckdb_cursor.execute("INSERT INTO time_test VALUES (?)", [result1]) + duckdb_cursor.execute("SELECT time_col FROM time_test") + result2 = duckdb_cursor.fetchone()[0] + assert isinstance(result2, datetime.time) + assert result1 == result2 + + +def test_time_ns_arrow_roundtrip(duckdb_cursor): + pa = pytest.importorskip("pyarrow") + + # Get a time_ns in an arrow table + arrow_table = duckdb_cursor.execute("SELECT TIME_NS '12:34:56.123456789' AS time_ns_col").to_arrow_table() + + value = arrow_table.column("time_ns_col")[0] + assert isinstance(value, pa.lib.Time64Scalar) + + # Roundtrip back into duckdb and assert the column's type is TIME_NS + duckdb_cursor.execute("CREATE OR REPLACE TEMP TABLE time_ns_test AS SELECT * FROM arrow_table") + col_type = duckdb_cursor.execute("SELECT time_ns_col FROM time_ns_test").description[0][1] + assert col_type == sqltypes.TIME_NS + + +def test_time_ns_pandas_roundtrip(duckdb_cursor): + """Test that we can roundtrip using Pandas.""" + pytest.importorskip("pandas") + df = duckdb_cursor.execute("SELECT TIME_NS '12:34:56.123456789' AS time_ns_col").df() + assert df["time_ns_col"].dtype == "object" + duckdb_cursor.execute("CREATE OR REPLACE TEMP TABLE time_ns_test AS SELECT * FROM df") + col_type = duckdb_cursor.execute("SELECT time_ns_col FROM time_ns_test").description[0][1] + assert col_type == sqltypes.TIME + + +def test_time_pandas_roundtrip(duckdb_cursor): + """For Pandas, creating a table using CREATE .... AS SELECT FROM df, will create TIME_NS cols by default.""" + pytest.importorskip("pandas") + df = duckdb_cursor.execute("SELECT TIME '12:34:56.123456789' AS time_col").df() + assert df["time_col"].dtype == "object" + duckdb_cursor.execute("CREATE OR REPLACE TEMP TABLE time_test AS SELECT * FROM df") + col_type = duckdb_cursor.execute("SELECT time_col FROM time_test").description[0][1] + assert col_type == sqltypes.TIME diff --git a/tests/fast/udf/test_null_filtering.py b/tests/fast/udf/test_null_filtering.py index 8bf2ce73..33ae208c 100644 --- a/tests/fast/udf/test_null_filtering.py +++ b/tests/fast/udf/test_null_filtering.py @@ -180,14 +180,14 @@ class TestUDFNullFiltering: ) @pytest.mark.parametrize("udf_type", ["arrow", "native"]) def test_null_filtering(self, duckdb_cursor, table_data: dict, test_type: Candidate, udf_type): - null_count = sum([1 for x in list(zip(*table_data.values())) if any(y is None for y in x)]) + null_count = sum([1 for x in list(zip(*table_data.values(), strict=False)) if any(y is None for y in x)]) row_count = len(table_data) table_data = { key: [None if not x else test_type.variant_one if x == "x" else test_type.variant_two for x in value] for key, value in table_data.items() } - tuples = list(zip(*table_data.values())) + tuples = list(zip(*table_data.values(), strict=False)) query = construct_query(tuples) parameters = construct_parameters(tuples, test_type.type) rel = duckdb_cursor.sql(query + " t(a, b, c)", params=parameters) @@ -210,7 +210,7 @@ def create_parameters(table_data, dbtype): result = duckdb_cursor.sql(query).fetchall() expected_output = [ - (t[0],) if not any(x is None for x in t) else (None,) for t in list(zip(*table_data.values())) + (t[0],) if not any(x is None for x in t) else (None,) for t in list(zip(*table_data.values(), strict=False)) ] assert result == expected_output assert len(result) == row_count diff --git a/tests/fast/udf/test_scalar_arrow.py b/tests/fast/udf/test_scalar_arrow.py index e3f18344..3d64ec51 100644 --- a/tests/fast/udf/test_scalar_arrow.py +++ b/tests/fast/udf/test_scalar_arrow.py @@ -149,7 +149,7 @@ def return_struct(col): """ select {'a': 5, 'b': 'test', 'c': [5,3,2]} """ - ).fetch_arrow_table() + ).to_arrow_table() con = duckdb.connect() struct_type = con.struct_type({"a": BIGINT, "b": VARCHAR, "c": con.list_type(BIGINT)}) diff --git a/tests/slow/test_materialized_relation.py b/tests/slow/test_materialized_relation.py new file mode 100644 index 00000000..69008adc --- /dev/null +++ b/tests/slow/test_materialized_relation.py @@ -0,0 +1,52 @@ +import platform + +import pytest + + +class TestMaterializedRelationSlow: + @pytest.mark.parametrize( + "num_rows", + [ + 1000000, + pytest.param( + 10000000, + marks=pytest.mark.skipif( + condition=platform.system() == "Emscripten", + reason="Emscripten/Pyodide builds run out of memory at this scale, and error might not " + "thrown reliably", + ), + ), + ], + ) + def test_materialized_relation(self, duckdb_cursor, num_rows): + # Anything that is not a SELECT statement becomes a materialized relation, so we use `CALL` + query = f"call repeat_row(42, 'test', 'this is a long string', true, num_rows={num_rows})" + rel = duckdb_cursor.sql(query) + res = rel.fetchone() + assert res is not None + + res = rel.fetchmany(num_rows) + assert len(res) == num_rows - 1 + + res = rel.fetchmany(5) + assert len(res) == 0 + res = rel.fetchmany(5) + assert len(res) == 0 + res = rel.fetchone() + assert res is None + + rel.execute() + res = rel.fetchone() + assert res is not None + + res = rel.fetchall() + assert len(res) == num_rows - 1 + res = rel.fetchall() + assert len(res) == num_rows + + rel = duckdb_cursor.sql(query) + projection = rel.select("column0") + assert projection.fetchall() == [(42,) for _ in range(num_rows)] + + filtered = rel.filter("column1 != 'test'") + assert filtered.fetchall() == []