diff --git a/drivers/python/README.md b/drivers/python/README.md index e64f9de67..f4fa43919 100644 --- a/drivers/python/README.md +++ b/drivers/python/README.md @@ -89,9 +89,25 @@ SET search_path = ag_catalog, "$user", public; * Make sure to give your non-superuser db account proper permissions to the graph schemas and corresponding objects * Make sure to initiate the Apache Age python driver with the ```load_from_plugins``` parameter. This parameter tries to load the Apache Age extension from the PostgreSQL plugins directory located at ```$libdir/plugins/age```. Example: - ```python. + ```python ag = age.connect(host='localhost', port=5432, user='dbuser', password='strong_password', - dbname=postgres, load_from_plugins=True, graph='graph_name) + dbname='postgres', load_from_plugins=True, graph='graph_name') + ``` + +### Managed PostgreSQL Usage (Azure, AWS RDS, etc.) +* On managed PostgreSQL services where the AGE extension is loaded server-side via ```shared_preload_libraries```, + the ```LOAD 'age'``` command may fail because the binary is not at the expected file path. Use the ```skip_load``` + parameter to skip the ```LOAD``` statement while still performing all other setup: + ```python + ag = age.connect(host='myserver.postgres.database.azure.com', port=5432, + user='dbuser', password='strong_password', + dbname='postgres', skip_load=True, graph='graph_name') + ``` +* **Connection pools:** If you manage connections externally (e.g. via ```psycopg_pool.ConnectionPool```), + you can call ```setUpAge()``` with ```skip_load=True``` on each pooled connection: + ```python + from age.age import setUpAge + setUpAge(conn, 'graph_name', skip_load=True) ``` ### License diff --git a/drivers/python/age/__init__.py b/drivers/python/age/__init__.py index 685f0fe74..caee6a43c 100644 --- a/drivers/python/age/__init__.py +++ b/drivers/python/age/__init__.py @@ -25,13 +25,13 @@ def version(): def connect(dsn=None, graph=None, connection_factory=None, cursor_factory=ClientCursor, load_from_plugins=False, - **kwargs): + skip_load=False, **kwargs): dsn = conninfo.make_conninfo('' if dsn is None else dsn, **kwargs) ag = Age() ag.connect(dsn=dsn, graph=graph, connection_factory=connection_factory, cursor_factory=cursor_factory, - load_from_plugins=load_from_plugins, **kwargs) + load_from_plugins=load_from_plugins, skip_load=skip_load, **kwargs) return ag # Dummy ResultHandler diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py index fad1f27b1..5feba1bd7 100644 --- a/drivers/python/age/age.py +++ b/drivers/python/age/age.py @@ -137,12 +137,20 @@ def load(self, data: bytes | bytearray | memoryview) -> Any | None: return parseAgeValue(data_bytes.decode('utf-8')) -def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False): +def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=False, skip_load:bool=False): + if skip_load and load_from_plugins: + raise ValueError( + "skip_load=True and load_from_plugins=True are contradictory. " + "Set skip_load=False to load the extension from the plugins path, " + "or remove load_from_plugins to skip loading entirely." + ) + with conn.cursor() as cursor: - if load_from_plugins: - cursor.execute("LOAD '$libdir/plugins/age';") - else: - cursor.execute("LOAD 'age';") + if not skip_load: + if load_from_plugins: + cursor.execute("LOAD '$libdir/plugins/age';") + else: + cursor.execute("LOAD 'age';") cursor.execute("SET search_path = ag_catalog, '$user', public;") @@ -333,9 +341,9 @@ def __init__(self): # Connect to PostgreSQL Server and establish session and type extension environment. def connect(self, graph:str=None, dsn:str=None, connection_factory=None, cursor_factory=ClientCursor, - load_from_plugins:bool=False, **kwargs): + load_from_plugins:bool=False, skip_load:bool=False, **kwargs): conn = psycopg.connect(dsn, cursor_factory=cursor_factory, **kwargs) - setUpAge(conn, graph, load_from_plugins) + setUpAge(conn, graph, load_from_plugins, skip_load=skip_load) self.connection = conn self.graphName = graph return self diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py index f904fb9e3..fe4b91d1b 100644 --- a/drivers/python/test_age_py.py +++ b/drivers/python/test_age_py.py @@ -16,6 +16,7 @@ from age.models import Vertex import unittest +import unittest.mock import decimal import age import argparse @@ -28,6 +29,76 @@ TEST_GRAPH_NAME = "test_graph" +class TestSetUpAge(unittest.TestCase): + """Unit tests for setUpAge() skip_load parameter — no DB required.""" + + def _make_mock_conn(self): + mock_conn = unittest.mock.MagicMock() + mock_cursor = unittest.mock.MagicMock() + mock_conn.cursor.return_value.__enter__ = unittest.mock.Mock(return_value=mock_cursor) + mock_conn.cursor.return_value.__exit__ = unittest.mock.Mock(return_value=False) + mock_conn.adapters = unittest.mock.MagicMock() + mock_type_info = unittest.mock.MagicMock() + mock_type_info.oid = 1 + mock_type_info.array_oid = 2 + return mock_conn, mock_cursor, mock_type_info + + def test_skip_load_true_does_not_execute_load(self): + """When skip_load=True, LOAD 'age' must not be executed.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.setUpAge(mock_conn, "test_graph", skip_load=True) + mock_cursor.execute.assert_called_once_with( + "SET search_path = ag_catalog, '$user', public;" + ) + + def test_skip_load_false_executes_load(self): + """When skip_load=False (default), LOAD 'age' must be executed.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.setUpAge(mock_conn, "test_graph", skip_load=False) + mock_cursor.execute.assert_any_call("LOAD 'age';") + + def test_skip_load_with_load_from_plugins(self): + """When skip_load=False and load_from_plugins=True, LOAD from plugins path.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.setUpAge(mock_conn, "test_graph", load_from_plugins=True, skip_load=False) + mock_cursor.execute.assert_any_call("LOAD '$libdir/plugins/age';") + + def test_skip_load_true_still_sets_search_path(self): + """When skip_load=True, search_path must still be set.""" + mock_conn, mock_cursor, mock_type_info = self._make_mock_conn() + with unittest.mock.patch("age.age.TypeInfo.fetch", return_value=mock_type_info), \ + unittest.mock.patch("age.age.checkGraphCreated"): + age.age.setUpAge(mock_conn, "test_graph", skip_load=True) + mock_cursor.execute.assert_any_call( + "SET search_path = ag_catalog, '$user', public;" + ) + + def test_contradictory_skip_load_and_load_from_plugins_raises(self): + """skip_load=True + load_from_plugins=True must raise ValueError.""" + mock_conn, _, _ = self._make_mock_conn() + with self.assertRaises(ValueError): + age.age.setUpAge(mock_conn, "test_graph", load_from_plugins=True, skip_load=True) + + def test_connect_forwards_skip_load_to_setup(self): + """age.connect(skip_load=True) must forward skip_load through the full call chain.""" + with unittest.mock.patch("age.age.psycopg.connect") as mock_psycopg, \ + unittest.mock.patch("age.age.setUpAge") as mock_setup: + mock_psycopg.return_value = unittest.mock.MagicMock() + age.connect(dsn="host=localhost", graph="test_graph", skip_load=True) + mock_setup.assert_called_once() + _, kwargs = mock_setup.call_args + self.assertTrue( + kwargs.get("skip_load", False), + "skip_load must be forwarded from age.connect() to setUpAge()" + ) + + class TestAgeBasic(unittest.TestCase): ag = None args: argparse.Namespace = argparse.Namespace(