From 00369591e175edc0dc3930f07574ed93d3249176 Mon Sep 17 00:00:00 2001 From: wf-yamaday Date: Sat, 9 Aug 2025 01:08:10 +0900 Subject: [PATCH 1/7] chore: support pymongo>=4.13.0 Support AsyncMongoClient --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8e03b91..4e71db0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ casbin>=0.8.4 -pymongo>=3.10.1 \ No newline at end of file +pymongo>=4.13.0 From f72959a1ea53d4520647be396c10a2b01fbc154d Mon Sep 17 00:00:00 2001 From: wf-yamaday Date: Sat, 9 Aug 2025 01:10:12 +0900 Subject: [PATCH 2/7] refactor: move CasbinRule into modules --- casbin_pymongo_adapter/_rule.py | 34 +++++++++++++++++++++++++++++ casbin_pymongo_adapter/adapter.py | 36 +------------------------------ tests/test_adapter.py | 2 +- 3 files changed, 36 insertions(+), 36 deletions(-) create mode 100644 casbin_pymongo_adapter/_rule.py diff --git a/casbin_pymongo_adapter/_rule.py b/casbin_pymongo_adapter/_rule.py new file mode 100644 index 0000000..a32d70d --- /dev/null +++ b/casbin_pymongo_adapter/_rule.py @@ -0,0 +1,34 @@ +class CasbinRule: + """ + CasbinRule model + """ + + def __init__( + self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None + ): + self.ptype = ptype + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.v4 = v4 + self.v5 = v5 + + def dict(self): + d = {"ptype": self.ptype} + + for value in dir(self): + if ( + getattr(self, value) is not None + and value.startswith("v") + and value[1:].isnumeric() + ): + d[value] = getattr(self, value) + + return d + + def __str__(self): + return ", ".join(self.dict().values()) + + def __repr__(self): + return ''.format(str(self)) diff --git a/casbin_pymongo_adapter/adapter.py b/casbin_pymongo_adapter/adapter.py index 6829f5c..45e2ebc 100644 --- a/casbin_pymongo_adapter/adapter.py +++ b/casbin_pymongo_adapter/adapter.py @@ -1,41 +1,7 @@ from casbin import persist from pymongo import MongoClient - -class CasbinRule: - """ - CasbinRule model - """ - - def __init__( - self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None - ): - self.ptype = ptype - self.v0 = v0 - self.v1 = v1 - self.v2 = v2 - self.v3 = v3 - self.v4 = v4 - self.v5 = v5 - - def dict(self): - d = {"ptype": self.ptype} - - for value in dir(self): - if ( - getattr(self, value) is not None - and value.startswith("v") - and value[1:].isnumeric() - ): - d[value] = getattr(self, value) - - return d - - def __str__(self): - return ", ".join(self.dict().values()) - - def __repr__(self): - return ''.format(str(self)) +from ._rule import CasbinRule class Adapter(persist.Adapter): diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 19771cd..2fb5e85 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,5 +1,5 @@ from casbin_pymongo_adapter.adapter import Adapter -from casbin_pymongo_adapter.adapter import CasbinRule +from casbin_pymongo_adapter._rule import CasbinRule from pymongo import MongoClient from unittest import TestCase import casbin From 2bbe87586636f1b54a9bc21eb292abd48f080334 Mon Sep 17 00:00:00 2001 From: wf-yamaday Date: Sat, 9 Aug 2025 01:25:09 +0900 Subject: [PATCH 3/7] refactor: move get_fixture function to helper module --- tests/helper.py | 9 +++++++++ tests/test_adapter.py | 9 +-------- 2 files changed, 10 insertions(+), 8 deletions(-) create mode 100644 tests/helper.py diff --git a/tests/helper.py b/tests/helper.py new file mode 100644 index 0000000..f23fafc --- /dev/null +++ b/tests/helper.py @@ -0,0 +1,9 @@ +import os + + +def get_fixture(path): + """ + get model path + """ + dir_path = os.path.split(os.path.realpath(__file__))[0] + "/" + return os.path.abspath(dir_path + path) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 2fb5e85..70afd06 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -3,15 +3,8 @@ from pymongo import MongoClient from unittest import TestCase import casbin -import os - -def get_fixture(path): - """ - get model path - """ - dir_path = os.path.split(os.path.realpath(__file__))[0] + "/" - return os.path.abspath(dir_path + path) +from tests.helper import get_fixture def get_enforcer(): From 1f571d1bc6fd3bee4284c3f32a78e4ba2055772f Mon Sep 17 00:00:00 2001 From: wf-yamaday Date: Sat, 9 Aug 2025 02:01:04 +0900 Subject: [PATCH 4/7] feat: implements AsyncEnforcer --- .../asynchronous/__init__.py | 0 .../asynchronous/adapter.py | 140 +++++++++++ tests/asynchronous/test_async_adapter.py | 223 ++++++++++++++++++ 3 files changed, 363 insertions(+) create mode 100644 casbin_pymongo_adapter/asynchronous/__init__.py create mode 100644 casbin_pymongo_adapter/asynchronous/adapter.py create mode 100644 tests/asynchronous/test_async_adapter.py diff --git a/casbin_pymongo_adapter/asynchronous/__init__.py b/casbin_pymongo_adapter/asynchronous/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/casbin_pymongo_adapter/asynchronous/adapter.py b/casbin_pymongo_adapter/asynchronous/adapter.py new file mode 100644 index 0000000..68e99b5 --- /dev/null +++ b/casbin_pymongo_adapter/asynchronous/adapter.py @@ -0,0 +1,140 @@ +from casbin import persist +from casbin.persist.adapters.asyncio.adapter import AsyncAdapter +from pymongo import AsyncMongoClient + +from .._rule import CasbinRule + + +class Adapter(AsyncAdapter): + """the interface for Casbin adapters.""" + + def __init__(self, uri, dbname, collection="casbin_rule"): + """Create an adapter for Mongodb + + Args: + uri (str): This should be the same requiement as pymongo Client's 'uri' parameter. + See https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient. + dbname (str): Database to store policy. + collection (str, optional): Collection of the choosen database. Defaults to "casbin_rule". + """ + client = AsyncMongoClient(uri) + db = client[dbname] + self._collection = db[collection] + + async def load_policy(self, model): + """Implementing add Interface for casbin. Load all policy rules from mongodb + + Args: + model (CasbinRule): CasbinRule object + """ + + async for line in self._collection.find(): + if "ptype" not in line: + continue + rule = CasbinRule(line["ptype"]) + for key, value in line.items(): + setattr(rule, key, value) + + persist.load_policy_line(str(rule), model) + + async def _save_policy_line(self, ptype, rule): + line = CasbinRule(ptype=ptype) + for index, value in enumerate(rule): + setattr(line, f"v{index}", value) + await self._collection.insert_one(line.dict()) + + def _find_policy_lines(self, ptype, rule): + line = CasbinRule(ptype=ptype) + for index, value in enumerate(rule): + setattr(line, f"v{index}", value) + return self._collection.find(line.dict()) + + async def _delete_policy_lines(self, ptype, rule): + line = CasbinRule(ptype=ptype) + for index, value in enumerate(rule): + setattr(line, f"v{index}", value) + + # if rule is empty, do nothing + # else find all given rules and delete them + if len(line.dict()) == 0: + return 0 + else: + line_dict = line.dict() + line_dict_keys_len = len(line_dict) + to_delete = [ + result["_id"] + async for result in self._collection.find(line_dict) + if line_dict_keys_len == len(result.keys()) - 1 + ] + results = await self._collection.delete_many({"_id": {"$in": to_delete}}) + return results.deleted_count + + async def save_policy(self, model) -> bool: + """Implement add Interface for casbin. Save the policy in mongodb + + Args: + model (Class Model): Casbin Model which loads from .conf file usually. + + Returns: + bool: True if succeed + """ + for sec in ["p", "g"]: + if sec not in model.model.keys(): + continue + for ptype, ast in model.model[sec].items(): + for rule in ast.policy: + await self._save_policy_line(ptype, rule) + return True + + async def add_policy(self, sec, ptype, rule): + """Add policy rules to mongodb + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rule (CasbinRule): Casbin rule will be added + + Returns: + bool: True if succeed else False + """ + await self._save_policy_line(ptype, rule) + return True + + async def remove_policy(self, sec, ptype, rule): + """Remove policy rules in mongodb(rules duplicate are also removed) + + Args: + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rule (CasbinRule): Casbin rule if it is exactly same as will be removed. + + Returns: + Number: Number of policies be removed + """ + deleted_count = await self._delete_policy_lines(ptype, rule) + return deleted_count > 0 + + async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): + """Remove policy rules taht match the filter from the storage. + This is part of the Auto-Save feature. + + Args: + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rule (CasbinRule): Casbin rule will be removed + field_index (int): The policy index at which the filed_values begins filtering. Its range is [0, 5] + field_values(List[str]): A list of rules to filter policy which starts from + + Returns: + bool: True if succeed else False + """ + if not (0 <= field_index <= 5): + return False + if not (1 <= field_index + len(field_values) <= 6): + return False + query = { + f"v{index + field_index}": value + for index, value in enumerate(field_values) + if value != "" + } + query["ptype"] = ptype + results = await self._collection.delete_many(query) + return results.deleted_count > 0 diff --git a/tests/asynchronous/test_async_adapter.py b/tests/asynchronous/test_async_adapter.py new file mode 100644 index 0000000..96363f9 --- /dev/null +++ b/tests/asynchronous/test_async_adapter.py @@ -0,0 +1,223 @@ +from casbin_pymongo_adapter.asynchronous.adapter import Adapter +from casbin_pymongo_adapter._rule import CasbinRule +from pymongo import AsyncMongoClient +from unittest import IsolatedAsyncioTestCase +import casbin + +from tests.helper import get_fixture + + +async def get_enforcer(): + adapter = Adapter("mongodb://localhost:27017", "casbin_test") + e = casbin.AsyncEnforcer(get_fixture("rbac_model.conf"), adapter) + model = e.get_model() + + model.clear_policy() + model.add_policy("p", "p", ["alice", "data1", "read"]) + model.add_policy("p", "p", ["bob", "data2", "write"]) + model.add_policy("p", "p", ["data2_admin", "data2", "read"]) + model.add_policy("p", "p", ["data2_admin", "data2", "write"]) + model.add_policy("g", "g", ["alice", "data2_admin"]) + await adapter.save_policy(model) + + e = casbin.AsyncEnforcer(get_fixture("rbac_model.conf"), adapter) + await e.load_policy() + return e + + +async def clear_db(dbname): + client = AsyncMongoClient("mongodb://localhost:27017") + await client.drop_database(dbname) + + +class TestConfig(IsolatedAsyncioTestCase): + """ + unittest + """ + + async def asyncSetUp(self): + await clear_db("casbin_test") + + async def asyncTearDown(self): + await clear_db("casbin_test") + + async def test_enforcer_basic(self): + """ + test policy + """ + e = await get_enforcer() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + async def test_add_policy(self): + """ + test add_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test add_policy after insert 2 rules + await adapter.add_policy(sec="p", ptype="p", rule=("alice", "data1", "write")) + await adapter.add_policy(sec="p", ptype="p", rule=("bob", "data2", "read")) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + async def test_remove_policy(self): + """ + test remove_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test remove_policy after delete a role definition + result = await adapter.remove_policy( + sec="g", ptype="g", rule=("alice", "data2_admin") + ) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertFalse(e.enforce("alice", "data2", "read")) + self.assertFalse(e.enforce("alice", "data2", "write")) + self.assertTrue(result) + + async def test_remove_policy_no_remove_when_rule_is_incomplete(self): + adapter = Adapter("mongodb://localhost:27017", "casbin_test") + e = casbin.AsyncEnforcer(get_fixture("rbac_with_resources_roles.conf"), adapter) + + await adapter.add_policy(sec="p", ptype="p", rule=("alice", "data1", "write")) + await adapter.add_policy(sec="p", ptype="p", rule=("alice", "data1", "read")) + await adapter.add_policy(sec="p", ptype="p", rule=("bob", "data2", "read")) + await adapter.add_policy( + sec="p", ptype="p", rule=("data_group_admin", "data_group", "write") + ) + await adapter.add_policy(sec="g", ptype="g", rule=("alice", "data_group_admin")) + await adapter.add_policy(sec="g", ptype="g2", rule=("data2", "data_group")) + + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test remove_policy doesn't remove when given an incomplete policy + result = await adapter.remove_policy( + sec="p", ptype="p", rule=("alice", "data1") + ) + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + self.assertFalse(result) + + async def test_save_policy(self): + """ + test save_policy + """ + + e = await get_enforcer() + self.assertFalse(e.enforce("alice", "data4", "read")) + + model = e.get_model() + model.clear_policy() + + model.add_policy("p", "p", ("alice", "data4", "read")) + + adapter = e.get_adapter() + await adapter.save_policy(model) + + self.assertTrue(e.enforce("alice", "data4", "read")) + + async def test_remove_filtered_policy(self): + """ + test remove_filtered_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + result = await adapter.remove_filtered_policy( + "g", "g", 6, "alice", "data2_admin" + ) + await e.load_policy() + self.assertFalse(result) + + result = await adapter.remove_filtered_policy( + "g", "g", 0, *[f"v{i}" for i in range(7)] + ) + await e.load_policy() + self.assertFalse(result) + + result = await adapter.remove_filtered_policy( + "g", "g", 0, "alice", "data2_admin" + ) + await e.load_policy() + self.assertTrue(result) + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertFalse(e.enforce("alice", "data2", "read")) + self.assertFalse(e.enforce("alice", "data2", "write")) + + def test_str(self): + """ + test __str__ function + """ + rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") + self.assertEqual(rule.__str__(), "p, alice, data1, read") + + def test_dict(self): + """ + test __str__ function + """ + rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") + self.assertEqual( + rule.dict(), {"ptype": "p", "v0": "alice", "v1": "data1", "v2": "read"} + ) + + def test_repr(self): + """ + test __repr__ function + """ + adapter = Adapter("mongodb://localhost:27017", "casbin_test") + rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") + self.assertEqual(repr(rule), '') + # adapter.save_policy(rule) + # self.assertRegex(repr(rule), r'') From a751acd95e4271b5f9ece98303e393866de8dfe6 Mon Sep 17 00:00:00 2001 From: wf-yamaday Date: Sat, 9 Aug 2025 02:04:55 +0900 Subject: [PATCH 5/7] test: remove unnecessary test cases --- tests/asynchronous/test_async_adapter.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/asynchronous/test_async_adapter.py b/tests/asynchronous/test_async_adapter.py index 96363f9..c29b82b 100644 --- a/tests/asynchronous/test_async_adapter.py +++ b/tests/asynchronous/test_async_adapter.py @@ -211,13 +211,3 @@ def test_dict(self): self.assertEqual( rule.dict(), {"ptype": "p", "v0": "alice", "v1": "data1", "v2": "read"} ) - - def test_repr(self): - """ - test __repr__ function - """ - adapter = Adapter("mongodb://localhost:27017", "casbin_test") - rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") - self.assertEqual(repr(rule), '') - # adapter.save_policy(rule) - # self.assertRegex(repr(rule), r'') From f65ddd54948e35e07b6672d1db0d1782affea485 Mon Sep 17 00:00:00 2001 From: wf-yamaday Date: Sat, 9 Aug 2025 02:06:37 +0900 Subject: [PATCH 6/7] rename: rename test file --- tests/asynchronous/{test_async_adapter.py => test_adapter.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/asynchronous/{test_async_adapter.py => test_adapter.py} (100%) diff --git a/tests/asynchronous/test_async_adapter.py b/tests/asynchronous/test_adapter.py similarity index 100% rename from tests/asynchronous/test_async_adapter.py rename to tests/asynchronous/test_adapter.py From 8ad08c0b55defc35773a5a64de197e87ad43e840 Mon Sep 17 00:00:00 2001 From: wf-yamaday Date: Sat, 9 Aug 2025 02:10:19 +0900 Subject: [PATCH 7/7] remove: remove dead code _find_policy_lines not called --- casbin_pymongo_adapter/asynchronous/adapter.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/casbin_pymongo_adapter/asynchronous/adapter.py b/casbin_pymongo_adapter/asynchronous/adapter.py index 68e99b5..b223a22 100644 --- a/casbin_pymongo_adapter/asynchronous/adapter.py +++ b/casbin_pymongo_adapter/asynchronous/adapter.py @@ -43,12 +43,6 @@ async def _save_policy_line(self, ptype, rule): setattr(line, f"v{index}", value) await self._collection.insert_one(line.dict()) - def _find_policy_lines(self, ptype, rule): - line = CasbinRule(ptype=ptype) - for index, value in enumerate(rule): - setattr(line, f"v{index}", value) - return self._collection.find(line.dict()) - async def _delete_policy_lines(self, ptype, rule): line = CasbinRule(ptype=ptype) for index, value in enumerate(rule):