diff --git a/casbin_pymongo_adapter/_rule.py b/casbin_pymongo_adapter/_rule.py new file mode 100644 index 0000000..a32d70d --- /dev/null +++ b/casbin_pymongo_adapter/_rule.py @@ -0,0 +1,34 @@ +class CasbinRule: + """ + CasbinRule model + """ + + def __init__( + self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None + ): + self.ptype = ptype + self.v0 = v0 + self.v1 = v1 + self.v2 = v2 + self.v3 = v3 + self.v4 = v4 + self.v5 = v5 + + def dict(self): + d = {"ptype": self.ptype} + + for value in dir(self): + if ( + getattr(self, value) is not None + and value.startswith("v") + and value[1:].isnumeric() + ): + d[value] = getattr(self, value) + + return d + + def __str__(self): + return ", ".join(self.dict().values()) + + def __repr__(self): + return ''.format(str(self)) diff --git a/casbin_pymongo_adapter/adapter.py b/casbin_pymongo_adapter/adapter.py index 6829f5c..45e2ebc 100644 --- a/casbin_pymongo_adapter/adapter.py +++ b/casbin_pymongo_adapter/adapter.py @@ -1,41 +1,7 @@ from casbin import persist from pymongo import MongoClient - -class CasbinRule: - """ - CasbinRule model - """ - - def __init__( - self, ptype=None, v0=None, v1=None, v2=None, v3=None, v4=None, v5=None - ): - self.ptype = ptype - self.v0 = v0 - self.v1 = v1 - self.v2 = v2 - self.v3 = v3 - self.v4 = v4 - self.v5 = v5 - - def dict(self): - d = {"ptype": self.ptype} - - for value in dir(self): - if ( - getattr(self, value) is not None - and value.startswith("v") - and value[1:].isnumeric() - ): - d[value] = getattr(self, value) - - return d - - def __str__(self): - return ", ".join(self.dict().values()) - - def __repr__(self): - return ''.format(str(self)) +from ._rule import CasbinRule class Adapter(persist.Adapter): diff --git a/casbin_pymongo_adapter/asynchronous/__init__.py b/casbin_pymongo_adapter/asynchronous/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/casbin_pymongo_adapter/asynchronous/adapter.py b/casbin_pymongo_adapter/asynchronous/adapter.py new file mode 100644 index 0000000..b223a22 --- /dev/null +++ b/casbin_pymongo_adapter/asynchronous/adapter.py @@ -0,0 +1,134 @@ +from casbin import persist +from casbin.persist.adapters.asyncio.adapter import AsyncAdapter +from pymongo import AsyncMongoClient + +from .._rule import CasbinRule + + +class Adapter(AsyncAdapter): + """the interface for Casbin adapters.""" + + def __init__(self, uri, dbname, collection="casbin_rule"): + """Create an adapter for Mongodb + + Args: + uri (str): This should be the same requiement as pymongo Client's 'uri' parameter. + See https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient. + dbname (str): Database to store policy. + collection (str, optional): Collection of the choosen database. Defaults to "casbin_rule". + """ + client = AsyncMongoClient(uri) + db = client[dbname] + self._collection = db[collection] + + async def load_policy(self, model): + """Implementing add Interface for casbin. Load all policy rules from mongodb + + Args: + model (CasbinRule): CasbinRule object + """ + + async for line in self._collection.find(): + if "ptype" not in line: + continue + rule = CasbinRule(line["ptype"]) + for key, value in line.items(): + setattr(rule, key, value) + + persist.load_policy_line(str(rule), model) + + async def _save_policy_line(self, ptype, rule): + line = CasbinRule(ptype=ptype) + for index, value in enumerate(rule): + setattr(line, f"v{index}", value) + await self._collection.insert_one(line.dict()) + + async def _delete_policy_lines(self, ptype, rule): + line = CasbinRule(ptype=ptype) + for index, value in enumerate(rule): + setattr(line, f"v{index}", value) + + # if rule is empty, do nothing + # else find all given rules and delete them + if len(line.dict()) == 0: + return 0 + else: + line_dict = line.dict() + line_dict_keys_len = len(line_dict) + to_delete = [ + result["_id"] + async for result in self._collection.find(line_dict) + if line_dict_keys_len == len(result.keys()) - 1 + ] + results = await self._collection.delete_many({"_id": {"$in": to_delete}}) + return results.deleted_count + + async def save_policy(self, model) -> bool: + """Implement add Interface for casbin. Save the policy in mongodb + + Args: + model (Class Model): Casbin Model which loads from .conf file usually. + + Returns: + bool: True if succeed + """ + for sec in ["p", "g"]: + if sec not in model.model.keys(): + continue + for ptype, ast in model.model[sec].items(): + for rule in ast.policy: + await self._save_policy_line(ptype, rule) + return True + + async def add_policy(self, sec, ptype, rule): + """Add policy rules to mongodb + + Args: + sec (str): Section name, 'g' or 'p' + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rule (CasbinRule): Casbin rule will be added + + Returns: + bool: True if succeed else False + """ + await self._save_policy_line(ptype, rule) + return True + + async def remove_policy(self, sec, ptype, rule): + """Remove policy rules in mongodb(rules duplicate are also removed) + + Args: + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rule (CasbinRule): Casbin rule if it is exactly same as will be removed. + + Returns: + Number: Number of policies be removed + """ + deleted_count = await self._delete_policy_lines(ptype, rule) + return deleted_count > 0 + + async def remove_filtered_policy(self, sec, ptype, field_index, *field_values): + """Remove policy rules taht match the filter from the storage. + This is part of the Auto-Save feature. + + Args: + ptype (str): Policy type, 'g', 'g2', 'p', etc. + rule (CasbinRule): Casbin rule will be removed + field_index (int): The policy index at which the filed_values begins filtering. Its range is [0, 5] + field_values(List[str]): A list of rules to filter policy which starts from + + Returns: + bool: True if succeed else False + """ + if not (0 <= field_index <= 5): + return False + if not (1 <= field_index + len(field_values) <= 6): + return False + query = { + f"v{index + field_index}": value + for index, value in enumerate(field_values) + if value != "" + } + query["ptype"] = ptype + results = await self._collection.delete_many(query) + return results.deleted_count > 0 diff --git a/requirements.txt b/requirements.txt index 8e03b91..4e71db0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ casbin>=0.8.4 -pymongo>=3.10.1 \ No newline at end of file +pymongo>=4.13.0 diff --git a/tests/asynchronous/test_adapter.py b/tests/asynchronous/test_adapter.py new file mode 100644 index 0000000..c29b82b --- /dev/null +++ b/tests/asynchronous/test_adapter.py @@ -0,0 +1,213 @@ +from casbin_pymongo_adapter.asynchronous.adapter import Adapter +from casbin_pymongo_adapter._rule import CasbinRule +from pymongo import AsyncMongoClient +from unittest import IsolatedAsyncioTestCase +import casbin + +from tests.helper import get_fixture + + +async def get_enforcer(): + adapter = Adapter("mongodb://localhost:27017", "casbin_test") + e = casbin.AsyncEnforcer(get_fixture("rbac_model.conf"), adapter) + model = e.get_model() + + model.clear_policy() + model.add_policy("p", "p", ["alice", "data1", "read"]) + model.add_policy("p", "p", ["bob", "data2", "write"]) + model.add_policy("p", "p", ["data2_admin", "data2", "read"]) + model.add_policy("p", "p", ["data2_admin", "data2", "write"]) + model.add_policy("g", "g", ["alice", "data2_admin"]) + await adapter.save_policy(model) + + e = casbin.AsyncEnforcer(get_fixture("rbac_model.conf"), adapter) + await e.load_policy() + return e + + +async def clear_db(dbname): + client = AsyncMongoClient("mongodb://localhost:27017") + await client.drop_database(dbname) + + +class TestConfig(IsolatedAsyncioTestCase): + """ + unittest + """ + + async def asyncSetUp(self): + await clear_db("casbin_test") + + async def asyncTearDown(self): + await clear_db("casbin_test") + + async def test_enforcer_basic(self): + """ + test policy + """ + e = await get_enforcer() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + async def test_add_policy(self): + """ + test add_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test add_policy after insert 2 rules + await adapter.add_policy(sec="p", ptype="p", rule=("alice", "data1", "write")) + await adapter.add_policy(sec="p", ptype="p", rule=("bob", "data2", "read")) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + async def test_remove_policy(self): + """ + test remove_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test remove_policy after delete a role definition + result = await adapter.remove_policy( + sec="g", ptype="g", rule=("alice", "data2_admin") + ) + + # reload policies from database + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertFalse(e.enforce("alice", "data2", "read")) + self.assertFalse(e.enforce("alice", "data2", "write")) + self.assertTrue(result) + + async def test_remove_policy_no_remove_when_rule_is_incomplete(self): + adapter = Adapter("mongodb://localhost:27017", "casbin_test") + e = casbin.AsyncEnforcer(get_fixture("rbac_with_resources_roles.conf"), adapter) + + await adapter.add_policy(sec="p", ptype="p", rule=("alice", "data1", "write")) + await adapter.add_policy(sec="p", ptype="p", rule=("alice", "data1", "read")) + await adapter.add_policy(sec="p", ptype="p", rule=("bob", "data2", "read")) + await adapter.add_policy( + sec="p", ptype="p", rule=("data_group_admin", "data_group", "write") + ) + await adapter.add_policy(sec="g", ptype="g", rule=("alice", "data_group_admin")) + await adapter.add_policy(sec="g", ptype="g2", rule=("data2", "data_group")) + + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + # test remove_policy doesn't remove when given an incomplete policy + result = await adapter.remove_policy( + sec="p", ptype="p", rule=("alice", "data1") + ) + await e.load_policy() + + self.assertTrue(e.enforce("alice", "data1", "write")) + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertTrue(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + self.assertFalse(result) + + async def test_save_policy(self): + """ + test save_policy + """ + + e = await get_enforcer() + self.assertFalse(e.enforce("alice", "data4", "read")) + + model = e.get_model() + model.clear_policy() + + model.add_policy("p", "p", ("alice", "data4", "read")) + + adapter = e.get_adapter() + await adapter.save_policy(model) + + self.assertTrue(e.enforce("alice", "data4", "read")) + + async def test_remove_filtered_policy(self): + """ + test remove_filtered_policy + """ + e = await get_enforcer() + adapter = e.get_adapter() + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertTrue(e.enforce("alice", "data2", "read")) + self.assertTrue(e.enforce("alice", "data2", "write")) + + result = await adapter.remove_filtered_policy( + "g", "g", 6, "alice", "data2_admin" + ) + await e.load_policy() + self.assertFalse(result) + + result = await adapter.remove_filtered_policy( + "g", "g", 0, *[f"v{i}" for i in range(7)] + ) + await e.load_policy() + self.assertFalse(result) + + result = await adapter.remove_filtered_policy( + "g", "g", 0, "alice", "data2_admin" + ) + await e.load_policy() + self.assertTrue(result) + self.assertTrue(e.enforce("alice", "data1", "read")) + self.assertFalse(e.enforce("alice", "data1", "write")) + self.assertFalse(e.enforce("bob", "data2", "read")) + self.assertTrue(e.enforce("bob", "data2", "write")) + self.assertFalse(e.enforce("alice", "data2", "read")) + self.assertFalse(e.enforce("alice", "data2", "write")) + + def test_str(self): + """ + test __str__ function + """ + rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") + self.assertEqual(rule.__str__(), "p, alice, data1, read") + + def test_dict(self): + """ + test __str__ function + """ + rule = CasbinRule(ptype="p", v0="alice", v1="data1", v2="read") + self.assertEqual( + rule.dict(), {"ptype": "p", "v0": "alice", "v1": "data1", "v2": "read"} + ) diff --git a/tests/helper.py b/tests/helper.py new file mode 100644 index 0000000..f23fafc --- /dev/null +++ b/tests/helper.py @@ -0,0 +1,9 @@ +import os + + +def get_fixture(path): + """ + get model path + """ + dir_path = os.path.split(os.path.realpath(__file__))[0] + "/" + return os.path.abspath(dir_path + path) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 19771cd..70afd06 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,17 +1,10 @@ from casbin_pymongo_adapter.adapter import Adapter -from casbin_pymongo_adapter.adapter import CasbinRule +from casbin_pymongo_adapter._rule import CasbinRule from pymongo import MongoClient from unittest import TestCase import casbin -import os - -def get_fixture(path): - """ - get model path - """ - dir_path = os.path.split(os.path.realpath(__file__))[0] + "/" - return os.path.abspath(dir_path + path) +from tests.helper import get_fixture def get_enforcer():