Skip to content

Commit e093e30

Browse files
committed
feat: Add FFI_TableProviderFactory support
This wraps the new FFI_TableProviderFactory APIs in datafusion-ffi.
1 parent d322b7b commit e093e30

File tree

6 files changed

+188
-1
lines changed

6 files changed

+188
-1
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pyarrow as pa
21+
import pytest
22+
from datafusion import SessionContext
23+
from datafusion_ffi_example import MyTableProviderFactory
24+
25+
26+
def test_table_provider_factory_ffi() -> None:
27+
ctx = SessionContext()
28+
table = MyTableProviderFactory()
29+
30+
ctx.register_table_factory("MY_FORMAT", table)
31+
32+
# Create a new external table
33+
ctx.sql("""
34+
CREATE EXTERNAL TABLE
35+
foo
36+
STORED AS my_format
37+
LOCATION '';
38+
""")
39+
40+
# Query the pre-populated table
41+
result = ctx.sql("SELECT * FROM foo;").collect()
42+
assert len(result) == 2
43+
assert result[0].num_columns == 2

examples/datafusion-ffi-example/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP
2222
use crate::scalar_udf::IsNullUDF;
2323
use crate::table_function::MyTableFunction;
2424
use crate::table_provider::MyTableProvider;
25+
use crate::table_provider_factory::MyTableProviderFactory;
2526
use crate::window_udf::MyRankUDF;
2627

2728
pub(crate) mod aggregate_udf;
2829
pub(crate) mod catalog_provider;
2930
pub(crate) mod scalar_udf;
3031
pub(crate) mod table_function;
3132
pub(crate) mod table_provider;
33+
pub(crate) mod table_provider_factory;
3234
pub(crate) mod utils;
3335
pub(crate) mod window_udf;
3436

@@ -37,6 +39,7 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
3739
pyo3_log::init();
3840

3941
m.add_class::<MyTableProvider>()?;
42+
m.add_class::<MyTableProviderFactory>()?;
4043
m.add_class::<MyTableFunction>()?;
4144
m.add_class::<MyCatalogProvider>()?;
4245
m.add_class::<MyCatalogProviderList>()?;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use async_trait::async_trait;
21+
use datafusion_catalog::{Session, TableProvider, TableProviderFactory};
22+
use datafusion_common::error::Result as DataFusionResult;
23+
use datafusion_expr::CreateExternalTable;
24+
use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
25+
use pyo3::types::PyCapsule;
26+
use pyo3::{Bound, PyAny, PyResult, Python, pyclass, pymethods};
27+
28+
use crate::catalog_provider;
29+
use crate::utils::ffi_logical_codec_from_pycapsule;
30+
31+
#[derive(Debug)]
32+
pub(crate) struct ExampleTableProviderFactory {}
33+
34+
impl ExampleTableProviderFactory {
35+
fn new() -> Self {
36+
Self {}
37+
}
38+
}
39+
40+
#[async_trait]
41+
impl TableProviderFactory for ExampleTableProviderFactory {
42+
async fn create(
43+
&self,
44+
_state: &dyn Session,
45+
_cmd: &CreateExternalTable,
46+
) -> DataFusionResult<Arc<dyn TableProvider>> {
47+
Ok(catalog_provider::my_table())
48+
}
49+
}
50+
51+
#[pyclass(
52+
name = "MyTableProviderFactory",
53+
module = "datafusion_ffi_example",
54+
subclass
55+
)]
56+
#[derive(Debug)]
57+
pub struct MyTableProviderFactory {
58+
inner: Arc<ExampleTableProviderFactory>,
59+
}
60+
61+
impl Default for MyTableProviderFactory {
62+
fn default() -> Self {
63+
let inner = Arc::new(ExampleTableProviderFactory::new());
64+
Self { inner }
65+
}
66+
}
67+
68+
#[pymethods]
69+
impl MyTableProviderFactory {
70+
#[new]
71+
pub fn new() -> Self {
72+
Self::default()
73+
}
74+
75+
pub fn __datafusion_table_provider_factory__<'py>(
76+
&self,
77+
py: Python<'py>,
78+
codec: Bound<PyAny>,
79+
) -> PyResult<Bound<'py, PyCapsule>> {
80+
let name = cr"datafusion_table_provider_factory".into();
81+
let codec = ffi_logical_codec_from_pycapsule(codec)?;
82+
let factory = Arc::clone(&self.inner) as Arc<dyn TableProviderFactory + Send>;
83+
let factory = FFI_TableProviderFactory::new_with_ffi_codec(factory, None, codec);
84+
85+
PyCapsule::new(py, factory, Some(name))
86+
}
87+
}

python/datafusion/catalog.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,15 @@ def kind(self) -> str:
243243
return self._inner.kind
244244

245245

246+
class TableProviderFactoryExportable(Protocol):
247+
"""Type hint for object that has __datafusion_table_provider_factory__ PyCapsule.
248+
249+
https://docs.rs/datafusion/latest/datafusion/catalog/trait.TableProviderFactory.html
250+
"""
251+
252+
def __datafusion_table_provider_factory__(self, session: Any) -> object: ...
253+
254+
246255
class CatalogProviderList(ABC):
247256
"""Abstract class for defining a Python based Catalog Provider List."""
248257

python/datafusion/context.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
CatalogProviderExportable,
3838
CatalogProviderList,
3939
CatalogProviderListExportable,
40+
TableProviderFactoryExportable,
4041
)
4142
from datafusion.dataframe import DataFrame
4243
from datafusion.expr import sort_list_to_raw_sort_list
@@ -830,6 +831,21 @@ def deregister_table(self, name: str) -> None:
830831
"""Remove a table from the session."""
831832
self.ctx.deregister_table(name)
832833

834+
def register_table_factory(
835+
self, format: str, factory: TableProviderFactoryExportable
836+
) -> None:
837+
"""Register a :py:class:`~datafusion.TableProviderFactoryExportable`
838+
with this context.
839+
840+
The registered factory can be reference from SQL DDL statements executed
841+
against this context.
842+
843+
Args:
844+
format: The value to be used in `STORED AS ${format}` clause.
845+
factory: A PyCapsule that implements TableProviderFactoryExportable"
846+
"""
847+
self.ctx.register_table_factory(format, factory)
848+
833849
def catalog_names(self) -> set[str]:
834850
"""Returns the list of catalogs in this context."""
835851
return self.ctx.catalog_names()

src/context.rs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use arrow::pyarrow::FromPyArrow;
2727
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
2828
use datafusion::arrow::pyarrow::PyArrowType;
2929
use datafusion::arrow::record_batch::RecordBatch;
30-
use datafusion::catalog::{CatalogProvider, CatalogProviderList};
30+
use datafusion::catalog::{CatalogProvider, CatalogProviderList, TableProviderFactory};
3131
use datafusion::common::{ScalarValue, TableReference, exec_err};
3232
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
3333
use datafusion::datasource::file_format::parquet::ParquetFormat;
@@ -51,6 +51,7 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
5151
use datafusion_ffi::catalog_provider_list::FFI_CatalogProviderList;
5252
use datafusion_ffi::execution::FFI_TaskContextProvider;
5353
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
54+
use datafusion_ffi::table_provider_factory::FFI_TableProviderFactory;
5455
use datafusion_proto::logical_plan::DefaultLogicalExtensionCodec;
5556
use object_store::ObjectStore;
5657
use pyo3::IntoPyObjectExt;
@@ -659,6 +660,34 @@ impl PySessionContext {
659660
Ok(())
660661
}
661662

663+
pub fn register_table_factory(
664+
&self,
665+
format: &str,
666+
factory: Bound<'_, PyAny>,
667+
) -> PyDataFusionResult<()> {
668+
let py = factory.py();
669+
let codec_capsule = create_logical_extension_capsule(py, self.logical_codec.as_ref())?;
670+
671+
let capsule = factory
672+
.getattr("__datafusion_table_provider_factory__")?
673+
.call1((codec_capsule,))?;
674+
let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
675+
validate_pycapsule(capsule, "datafusion_table_provider_factory")?;
676+
677+
let factory: NonNull<FFI_TableProviderFactory> = capsule
678+
.pointer_checked(Some(c_str!("datafusion_table_provider_factory")))?
679+
.cast();
680+
let factory = unsafe { factory.as_ref() };
681+
let factory: Arc<dyn TableProviderFactory> = factory.into();
682+
683+
let st = self.ctx.state_ref();
684+
let mut lock = st.write();
685+
lock.table_factories_mut()
686+
.insert(format.to_owned(), factory);
687+
688+
Ok(())
689+
}
690+
662691
pub fn register_catalog_provider_list(
663692
&self,
664693
mut provider: Bound<PyAny>,

0 commit comments

Comments
 (0)