Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions drivers/python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,25 @@ SET search_path = ag_catalog, "$user", public;
* Make sure to give your non-superuser db account proper permissions to the graph schemas and corresponding objects
* Make sure to initiate the Apache Age python driver with the ```load_from_plugins``` parameter. This parameter tries to
load the Apache Age extension from the PostgreSQL plugins directory located at ```$libdir/plugins/age```. Example:
```python.
```python
ag = age.connect(host='localhost', port=5432, user='dbuser', password='strong_password',
dbname=postgres, load_from_plugins=True, graph='graph_name)
dbname='postgres', load_from_plugins=True, graph='graph_name')
```

### Managed PostgreSQL Usage (Azure, AWS RDS, etc.)
* On managed PostgreSQL services where the AGE extension is loaded server-side via ```shared_preload_libraries```,
the ```LOAD 'age'``` command may fail because the binary is not at the expected file path. Use the ```skip_load```
parameter to skip the ```LOAD``` statement while still performing all other setup:
```python
ag = age.connect(host='myserver.postgres.database.azure.com', port=5432,
user='dbuser', password='strong_password',
dbname='postgres', skip_load=True, graph='graph_name')
```
* **Connection pools:** If you manage connections externally (e.g. via ```psycopg_pool.ConnectionPool```),
you can call ```setUpAge()``` with ```skip_load=True``` on each pooled connection:
```python
from age.age import setUpAge
setUpAge(conn, 'graph_name', skip_load=True)
```

### License
Expand Down
4 changes: 2 additions & 2 deletions drivers/python/age/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ def version():


def connect(dsn=None, graph=None, connection_factory=None, cursor_factory=ClientCursor, load_from_plugins=False,
**kwargs):
skip_load=False, **kwargs):

dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs)

ag = Age()
ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory, cursor_factory=cursor_factory,
load_from_plugins=load_from_plugins, **kwargs)
load_from_plugins=load_from_plugins, skip_load=skip_load, **kwargs)
return ag

# Dummy ResultHandler
Expand Down
22 changes: 15 additions & 7 deletions drivers/python/age/age.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,20 @@ def load(self, data: bytes | bytearray | memoryview) -> Any | None:
return parseAgeValue(data_bytes.decode('utf-8'))


def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False):
def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False, skip_load:bool=False):
if skip_load and load_from_plugins:
raise ValueError(
"skip_load=True and load_from_plugins=True are contradictory. "
"Set skip_load=False to load the extension from the plugins path, "
"or remove load_from_plugins to skip loading entirely."
)

with conn.cursor() as cursor:
if load_from_plugins:
cursor.execute("LOAD '$libdir/plugins/age';")
else:
cursor.execute("LOAD 'age';")
if not skip_load:
if load_from_plugins:
cursor.execute("LOAD '$libdir/plugins/age';")
else:
cursor.execute("LOAD 'age';")

cursor.execute("SET search_path = ag_catalog, '$user', public;")

Expand Down Expand Up @@ -333,9 +341,9 @@ def __init__(self):

# Connect to PostgreSQL Server and establish session and type extension environment.
def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor,
load_from_plugins:bool=False, **kwargs):
load_from_plugins:bool=False, skip_load:bool=False, **kwargs):
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connection_factory is accepted and forwarded through the public API, but it is never used: Age.connect() ignores it when calling psycopg.connect(...). Either pass it through (if supported by psycopg) or remove/deprecate it to avoid a misleading no-op parameter.

Suggested change
load_from_plugins:bool=False, skip_load:bool=False, **kwargs):
load_from_plugins:bool=False, skip_load:bool=False, **kwargs):
if connection_factory is not None:
raise TypeError(
"connection_factory is not supported by Age.connect(); "
"use psycopg.connect() options supported by this driver instead."
)

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connection_factory is a pre-existing parameter in the codebase (not introduced by this PR). Addressing it would be scope creep — happy to open a separate issue if needed.

conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs)
setUpAge(conn, graph, load_from_plugins)
setUpAge(conn, graph, load_from_plugins, skip_load=skip_load)
self.connection = conn
self.graphName = graph
return self
Expand Down
71 changes: 71 additions & 0 deletions drivers/python/test_age_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from age.models import Vertex
import unittest
import unittest.mock
import decimal
import age
import argparse
Expand All @@ -28,6 +29,76 @@
TEST_GRAPH_NAME = "test_graph"


class TestSetUpAge(unittest.TestCase):
"""Unit tests for setUpAge() skip_load parameter — no DB required."""

def _make_mock_conn(self):
mock_conn = unittest.mock.MagicMock()
mock_cursor = unittest.mock.MagicMock()
mock_conn.cursor.return_value.__enter__ = unittest.mock.Mock(return_value=mock_cursor)
mock_conn.cursor.return_value.__exit__ = unittest.mock.Mock(return_value=False)
mock_conn.adapters = unittest.mock.MagicMock()
mock_type_info = unittest.mock.MagicMock()
mock_type_info.oid = 1
mock_type_info.array_oid = 2
return mock_conn, mock_cursor, mock_type_info

def test_skip_load_true_does_not_execute_load(self):
"""When skip_load=True, LOAD 'age' must not be executed."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.setUpAge(mock_conn, "test_graph", skip_load=True)
mock_cursor.execute.assert_called_once_with(
"SET search_path = ag_catalog, '$user', public;"
)

def test_skip_load_false_executes_load(self):
"""When skip_load=False (default), LOAD 'age' must be executed."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.setUpAge(mock_conn, "test_graph", skip_load=False)
mock_cursor.execute.assert_any_call("LOAD 'age';")

def test_skip_load_with_load_from_plugins(self):
"""When skip_load=False and load_from_plugins=True, LOAD from plugins path."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.setUpAge(mock_conn, "test_graph", load_from_plugins=True, skip_load=False)
mock_cursor.execute.assert_any_call("LOAD '$libdir/plugins/age';")

def test_skip_load_true_still_sets_search_path(self):
"""When skip_load=True, search_path must still be set."""
mock_conn, mock_cursor, mock_type_info = self._make_mock_conn()
with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \
unittest.mock.patch("age.age.checkGraphCreated"):
age.age.setUpAge(mock_conn, "test_graph", skip_load=True)
mock_cursor.execute.assert_any_call(
"SET search_path = ag_catalog, '$user', public;"
)

def test_contradictory_skip_load_and_load_from_plugins_raises(self):
"""skip_load=True + load_from_plugins=True must raise ValueError."""
mock_conn, _, _ = self._make_mock_conn()
with self.assertRaises(ValueError):
age.age.setUpAge(mock_conn, "test_graph", load_from_plugins=True, skip_load=True)

def test_connect_forwards_skip_load_to_setup(self):
"""age.connect(skip_load=True) must forward skip_load through the full call chain."""
with unittest.mock.patch("age.age.psycopg.connect") as mock_psycopg, \
unittest.mock.patch("age.age.setUpAge") as mock_setup:
mock_psycopg.return_value = unittest.mock.MagicMock()
age.connect(dsn="host=localhost", graph="test_graph", skip_load=True)
mock_setup.assert_called_once()
_, kwargs = mock_setup.call_args
self.assertTrue(
kwargs.get("skip_load", False),
"skip_load must be forwarded from age.connect() to setUpAge()"
)


class TestAgeBasic(unittest.TestCase):
ag = None
args: argparse.Namespace = argparse.Namespace(
Expand Down