diff --git a/README.md b/README.md index e7d935f..783db87 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,24 @@ filter.raw_query = { e.load_filtered_policy(filter) ``` +## Using an Existing MongoDB Client + +If you already have a MongoDB client instance in your application, you can reuse it: + +```python +from pymongo import MongoClient +import casbin_pymongo_adapter +import casbin + +# Create or use your existing MongoDB client +mongo_client = MongoClient('mongodb://localhost:27017/') + +# Pass the client to the adapter +adapter = casbin_pymongo_adapter.Adapter(client=mongo_client, db_name="casbin") + +e = casbin.Enforcer('path/to/model.conf', adapter, True) +``` + ## Async Example ```python @@ -73,6 +91,23 @@ e = casbin.AsyncEnforcer('path/to/model.conf', adapter) await e.load_policy() ``` +### Using an Existing AsyncMongoClient + +```python +from pymongo import AsyncMongoClient +from casbin_pymongo_adapter.asynchronous import Adapter +import casbin + +# Create or use your existing AsyncMongoClient +mongo_client = AsyncMongoClient('mongodb://localhost:27017/') + +# Pass the client to the adapter +adapter = Adapter(client=mongo_client, db_name="casbin") +e = casbin.AsyncEnforcer('path/to/model.conf', adapter) + +await e.load_policy() +``` + ### Getting Help diff --git a/casbin_pymongo_adapter/adapter.py b/casbin_pymongo_adapter/adapter.py index 9cb4beb..4e4d4c7 100644 --- a/casbin_pymongo_adapter/adapter.py +++ b/casbin_pymongo_adapter/adapter.py @@ -9,22 +9,47 @@ class Adapter(persist.Adapter): def __init__( self, - uri, - dbname, + uri=None, + dbname=None, collection="casbin_rule", filtered=False, + client=None, + db_name=None, ): """Create an adapter for Mongodb Args: - uri (str): This should be the same requiement as pymongo Client's 'uri' parameter. + uri (str, optional): This should be the same requiement as pymongo Client's 'uri' parameter. See https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient. - dbname (str): Database to store policy. + Required if client is not provided. + dbname (str, optional): Database to store policy. Required if client is not provided. collection (str, optional): Collection of the choosen database. Defaults to "casbin_rule". filtered (bool, optional): Whether to use filtered query. Defaults to False. + client (MongoClient, optional): An existing MongoClient instance to reuse. If provided, uri is ignored. + db_name (str, optional): Database name to use with the provided client. Takes precedence over dbname. + + Note: + When both client and uri are provided, client takes precedence and uri is ignored. """ - client = MongoClient(uri) - db = client[dbname] + # Support both db_name and dbname for backward compatibility + database_name = db_name if db_name is not None else dbname + + if client is not None: + # Use the provided client + if database_name is None: + raise ValueError( + "db_name or dbname must be provided when using an existing client" + ) + mongo_client = client + else: + # Create a new client from URI + if uri is None: + raise ValueError("uri must be provided when client is not specified") + if database_name is None: + raise ValueError("dbname must be provided when client is not specified") + mongo_client = MongoClient(uri) + + db = mongo_client[database_name] self._collection = db[collection] self._filtered = filtered diff --git a/casbin_pymongo_adapter/asynchronous/adapter.py b/casbin_pymongo_adapter/asynchronous/adapter.py index b124397..34e8501 100644 --- a/casbin_pymongo_adapter/asynchronous/adapter.py +++ b/casbin_pymongo_adapter/asynchronous/adapter.py @@ -10,22 +10,47 @@ class Adapter(AsyncAdapter): def __init__( self, - uri, - dbname, + uri=None, + dbname=None, collection="casbin_rule", filtered=False, + client=None, + db_name=None, ): """Create an adapter for Mongodb Args: - uri (str): This should be the same requiement as pymongo Client's 'uri' parameter. + uri (str, optional): This should be the same requiement as pymongo Client's 'uri' parameter. See https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient. - dbname (str): Database to store policy. + Required if client is not provided. + dbname (str, optional): Database to store policy. Required if client is not provided. collection (str, optional): Collection of the choosen database. Defaults to "casbin_rule". filtered (bool, optional): Whether to use filtered query. Defaults to False. + client (AsyncMongoClient, optional): An existing AsyncMongoClient instance to reuse. If provided, uri is ignored. + db_name (str, optional): Database name to use with the provided client. Takes precedence over dbname. + + Note: + When both client and uri are provided, client takes precedence and uri is ignored. """ - client = AsyncMongoClient(uri) - db = client[dbname] + # Support both db_name and dbname for backward compatibility + database_name = db_name if db_name is not None else dbname + + if client is not None: + # Use the provided client + if database_name is None: + raise ValueError( + "db_name or dbname must be provided when using an existing client" + ) + mongo_client = client + else: + # Create a new client from URI + if uri is None: + raise ValueError("uri must be provided when client is not specified") + if database_name is None: + raise ValueError("dbname must be provided when client is not specified") + mongo_client = AsyncMongoClient(uri) + + db = mongo_client[database_name] self._collection = db[collection] self._filtered = filtered