diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index adfc50d57395..b64b0972e740 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -263,7 +263,8 @@ def print_entry(label, value): from pyarrow.lib import (ChunkedArray, RecordBatch, Table, table, concat_arrays, concat_tables, TableGroupBy, - RecordBatchReader, concat_batches) + RecordBatchReader, AsyncRecordBatchReader, + concat_batches) # Exceptions from pyarrow.lib import (ArrowCancelled, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e96a7d84696d..482f33cf70c3 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -3179,6 +3179,9 @@ cdef extern from "arrow/c/abi.h": int64_t device_id int32_t device_type + cdef struct ArrowAsyncDeviceStreamHandler: + void (*release)(ArrowAsyncDeviceStreamHandler*) noexcept nogil + cdef extern from "arrow/c/bridge.h" namespace "arrow" nogil: CStatus ExportType(CDataType&, ArrowSchema* out) CResult[shared_ptr[CDataType]] ImportType(ArrowSchema*) @@ -3225,6 +3228,42 @@ cdef extern from "arrow/c/bridge.h" namespace "arrow" nogil: CResult[shared_ptr[CRecordBatch]] ImportDeviceRecordBatch( ArrowDeviceArray*, ArrowSchema*) + # Opaque type for the async generator callable + cdef cppclass CAsyncRecordBatchGenerator_Generator \ + "arrow::AsyncGenerator": + pass + + cdef cppclass CAsyncRecordBatchGenerator \ + "arrow::AsyncRecordBatchGenerator": + shared_ptr[CSchema] schema + CDeviceAllocationType device_type + CAsyncRecordBatchGenerator_Generator generator + + CFuture[CAsyncRecordBatchGenerator] CreateAsyncDeviceStreamHandler( + ArrowAsyncDeviceStreamHandler* handler, + CExecutor* executor, + uint64_t queue_size) + + CFuture[CAsyncRecordBatchGenerator] CreateAsyncDeviceStreamHandler( + ArrowAsyncDeviceStreamHandler* handler, + CExecutor* executor) + + +cdef extern from "arrow/python/async_stream.h" namespace "arrow::py" nogil: + CFuture[CRecordBatchWithMetadata] CallAsyncGenerator( + CAsyncRecordBatchGenerator_Generator& generator) + + CFuture[CAsyncRecordBatchGenerator] RoundtripAsyncBatches( + shared_ptr[CSchema] schema, + vector[shared_ptr[CRecordBatch]] batches, + CExecutor* executor, + uint64_t queue_size) + + CFuture[CAsyncRecordBatchGenerator] RoundtripAsyncBatches( + shared_ptr[CSchema] schema, + vector[shared_ptr[CRecordBatch]] batches, + CExecutor* executor) + cdef extern from "arrow/util/byte_size.h" namespace "arrow::util" nogil: CResult[int64_t] ReferencedBufferSize(const CArray& array_data) diff --git a/python/pyarrow/ipc_async.pxi b/python/pyarrow/ipc_async.pxi new file mode 100644 index 000000000000..9f948f164171 --- /dev/null +++ b/python/pyarrow/ipc_async.pxi @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class _AsyncioCall: + """State for an async operation using asyncio.""" + + def __init__(self): + import asyncio + self._future = asyncio.get_running_loop().create_future() + + def as_awaitable(self): + return self._future + + def wakeup(self, result_or_exception): + loop = self._future.get_loop() + if isinstance(result_or_exception, BaseException): + loop.call_soon_threadsafe( + self._future.set_exception, result_or_exception) + else: + loop.call_soon_threadsafe( + self._future.set_result, result_or_exception) + + +cdef object _wrap_record_batch_or_none(CRecordBatchWithMetadata batch_with_md): + """Wrap a CRecordBatchWithMetadata as a RecordBatch, or return None at end-of-stream.""" + if batch_with_md.batch.get() == NULL: + return None + return pyarrow_wrap_batch(batch_with_md.batch) + + +cdef object _wrap_async_generator(CAsyncRecordBatchGenerator gen): + """Wrap a CAsyncRecordBatchGenerator into an AsyncRecordBatchReader.""" + cdef AsyncRecordBatchReader reader = AsyncRecordBatchReader.__new__( + AsyncRecordBatchReader) + cdef CAsyncRecordBatchGenerator* p = new CAsyncRecordBatchGenerator() + p.schema = gen.schema + p.device_type = gen.device_type + p.generator = move(gen.generator) + reader.generator.reset(p) + reader._schema = None + return reader + + +cdef class AsyncRecordBatchReader(_Weakrefable): + """Asynchronous reader for a stream of record batches. + + This class provides an async iterator interface for consuming record + batches from an asynchronous device stream. + + This interface is EXPERIMENTAL. + + Examples + -------- + >>> async for batch in reader: # doctest: +SKIP + ... process(batch) + """ + + def __init__(self): + raise TypeError( + f"Do not call {self.__class__.__name__}'s constructor directly, " + "use factory methods instead.") + + @property + def schema(self): + """ + Shared schema of the record batches in the stream. + + Returns + ------- + Schema + """ + if self._schema is None: + self._schema = pyarrow_wrap_schema(self.generator.get().schema) + return self._schema + + def __aiter__(self): + return self + + async def __anext__(self): + batch = await self._read_next_async() + if batch is None: + raise StopAsyncIteration + return batch + + async def _read_next_async(self): + call = _AsyncioCall() + self._read_next(call) + return await call.as_awaitable() + + cdef _read_next(self, call): + cdef CFuture[CRecordBatchWithMetadata] c_future + + with nogil: + c_future = CallAsyncGenerator(self.generator.get().generator) + + BindFuture(move(c_future), call.wakeup, _wrap_record_batch_or_none) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + +async def _test_roundtrip_async(schema, batches, queue_size=5): + """Test helper: create an async producer+consumer pair and return reader. + + EXPERIMENTAL: This function is intended for testing purposes only. + + Parameters + ---------- + schema : Schema + The schema of the record batches. + batches : list of RecordBatch + The record batches to produce. + queue_size : int, default 5 + Number of batches to request ahead. + + Returns + ------- + AsyncRecordBatchReader + """ + call = _AsyncioCall() + _start_roundtrip(call, schema, batches, queue_size) + return await call.as_awaitable() + + +cdef _start_roundtrip(call, Schema schema, list batches, uint64_t queue_size): + cdef: + shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) + vector[shared_ptr[CRecordBatch]] c_batches + CFuture[CAsyncRecordBatchGenerator] c_future + + for batch in batches: + c_batches.push_back((batch).sp_batch) + + with nogil: + c_future = RoundtripAsyncBatches( + c_schema, move(c_batches), GetCpuThreadPool(), queue_size) + + BindFuture(move(c_future), call.wakeup, _wrap_async_generator) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 683faa7855c5..9e71f06d7740 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -631,6 +631,14 @@ cdef class RecordBatchReader(_Weakrefable): SharedPtrNoGIL[CRecordBatchReader] reader +cdef class AsyncRecordBatchReader(_Weakrefable): + cdef: + SharedPtrNoGIL[CAsyncRecordBatchGenerator] generator + Schema _schema + + cdef _read_next(self, call) + + cdef class CacheOptions(_Weakrefable): cdef: CCacheOptions wrapped diff --git a/python/pyarrow/lib.pyx b/python/pyarrow/lib.pyx index 7e97177a6ec0..28292795ccc7 100644 --- a/python/pyarrow/lib.pyx +++ b/python/pyarrow/lib.pyx @@ -239,6 +239,9 @@ include "io.pxi" # IPC / Messaging include "ipc.pxi" +# Async IPC +include "ipc_async.pxi" + # Micro-benchmark routines include "benchmark.pxi" diff --git a/python/pyarrow/src/arrow/python/async_stream.h b/python/pyarrow/src/arrow/python/async_stream.h new file mode 100644 index 000000000000..21a2974bc25e --- /dev/null +++ b/python/pyarrow/src/arrow/python/async_stream.h @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include + +#include "arrow/c/bridge.h" +#include "arrow/record_batch.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/future.h" +#include "arrow/util/thread_pool.h" + +namespace arrow::py { + +/// \brief Call an AsyncGenerator and return the Future. +/// +/// This is needed because Cython cannot invoke std::function objects directly. +inline Future CallAsyncGenerator( + AsyncGenerator& generator) { + return generator(); +} + +/// \brief Create a roundtrip async producer+consumer pair for testing. +/// +/// Allocates an ArrowAsyncDeviceStreamHandler on the heap, calls +/// CreateAsyncDeviceStreamHandler (consumer side), then submits +/// ExportAsyncRecordBatchReader (producer side) on the given executor. +/// Returns a Future that resolves to the AsyncRecordBatchGenerator once +/// the schema is available. +inline Future RoundtripAsyncBatches( + std::shared_ptr schema, std::vector> batches, + ::arrow::internal::Executor* executor, uint64_t queue_size = 5) { + // Heap-allocate the handler so it outlives this function. + auto* handler = new ArrowAsyncDeviceStreamHandler; + std::memset(handler, 0, sizeof(ArrowAsyncDeviceStreamHandler)); + + auto fut_gen = CreateAsyncDeviceStreamHandler(handler, executor, queue_size); + + // Submit the export to the executor so it runs concurrently with the consumer. + auto submit_result = executor->Submit( + [schema = std::move(schema), batches = std::move(batches), handler]() mutable { + auto generator = MakeVectorGenerator(std::move(batches)); + return ExportAsyncRecordBatchReader(std::move(schema), std::move(generator), + DeviceAllocationType::kCPU, handler); + }); + + if (!submit_result.ok()) { + return Future::MakeFinished(submit_result.status()); + } + + return fut_gen; +} + +} // namespace arrow::py diff --git a/python/pyarrow/tests/test_ipc_async.py b/python/pyarrow/tests/test_ipc_async.py new file mode 100644 index 000000000000..3c8d3c57e96d --- /dev/null +++ b/python/pyarrow/tests/test_ipc_async.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio + +import pyarrow as pa +from pyarrow.lib import _test_roundtrip_async + + +def test_async_record_batch_reader_basic(): + schema = pa.schema([('x', pa.int64())]) + batches = [ + pa.record_batch([pa.array([1, 2, 3])], schema=schema), + pa.record_batch([pa.array([4, 5, 6])], schema=schema), + ] + + async def _test(): + reader = await _test_roundtrip_async(schema, batches) + assert isinstance(reader, pa.AsyncRecordBatchReader) + assert reader.schema == schema + + results = [] + async for batch in reader: + results.append(batch) + + assert len(results) == 2 + assert results[0].equals(batches[0]) + assert results[1].equals(batches[1]) + + asyncio.run(_test()) + + +def test_async_record_batch_reader_empty(): + schema = pa.schema([('x', pa.int64())]) + + async def _test(): + reader = await _test_roundtrip_async(schema, []) + assert reader.schema == schema + + results = [b async for b in reader] + assert len(results) == 0 + + asyncio.run(_test()) + + +def test_async_record_batch_reader_schema(): + schema = pa.schema([ + ('a', pa.float32()), + ('b', pa.utf8()), + ('c', pa.list_(pa.int32())), + ]) + batch = pa.record_batch( + [ + pa.array([1.0, 2.0], type=pa.float32()), + pa.array(['hello', 'world']), + pa.array([[1, 2], [3]], type=pa.list_(pa.int32())), + ], + schema=schema, + ) + + async def _test(): + reader = await _test_roundtrip_async(schema, [batch]) + assert reader.schema == schema + + results = [b async for b in reader] + assert len(results) == 1 + assert results[0].equals(batch) + + asyncio.run(_test()) + + +def test_async_record_batch_reader_context_manager(): + schema = pa.schema([('x', pa.int64())]) + batches = [pa.record_batch([pa.array([1, 2, 3])], schema=schema)] + + async def _test(): + reader = await _test_roundtrip_async(schema, batches) + async with reader as r: + results = [b async for b in r] + assert len(results) == 1 + assert results[0].equals(batches[0]) + + asyncio.run(_test()) + + +def test_async_record_batch_reader_many_batches(): + schema = pa.schema([('x', pa.int64())]) + batches = [ + pa.record_batch([pa.array([i])], schema=schema) + for i in range(20) + ] + + async def _test(): + reader = await _test_roundtrip_async(schema, batches, queue_size=2) + results = [b async for b in reader] + assert len(results) == 20 + for i, batch in enumerate(results): + assert batch.equals(batches[i]) + + asyncio.run(_test())