diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..ff04eb4500 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -786,14 +786,18 @@ def _delete_sync(): structure_deleted = cursor.rowcount + # Remove messages that are no longer referenced by any branch. + orphans_deleted = self._cleanup_orphaned_messages_sync(conn) + conn.commit() - return usage_deleted, structure_deleted + return usage_deleted, structure_deleted, orphans_deleted - usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync) + usage_deleted, structure_deleted, orphans_deleted = await asyncio.to_thread(_delete_sync) self._logger.info( - f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501 + f"Deleted branch '{branch_id}': {structure_deleted} message entries, " + f"{usage_deleted} usage entries, {orphans_deleted} orphaned messages" ) async def list_branches(self) -> list[dict[str, Any]]: diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index ad4b5c4d86..c8fc770ecb 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -694,6 +694,49 @@ async def test_branch_deletion_with_force(): session.close() +async def test_delete_branch_removes_orphaned_messages(): + """Regression: delete_branch must remove messages exclusive to that branch. + + Messages shared with another branch must be kept; messages that exist only + in the deleted branch must be cleaned up from the messages table. + """ + session = AdvancedSQLiteSession(session_id="orphan_delete_test", create_tables=True) + + # Two messages on main branch. + await session.add_items([{"role": "user", "content": "main msg 1"}]) + await session.add_items([{"role": "user", "content": "main msg 2"}]) + + # Branch from turn 1; this branch shares main msg 1 and adds its own message. + await session.create_branch_from_turn(1, "side_branch") + await session.add_items([{"role": "user", "content": "branch only msg"}]) + + # Confirm all three messages are in the messages table before deletion. + with session._locked_connection() as conn: + total_before = conn.execute( + f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?", + (session.session_id,), + ).fetchone()[0] + assert total_before == 3 + + await session.switch_to_branch("main") + await session.delete_branch("side_branch") + + # The branch-only message must be gone; the shared message must remain. + with session._locked_connection() as conn: + total_after = conn.execute( + f"SELECT COUNT(*) FROM {session.messages_table} WHERE session_id = ?", + (session.session_id,), + ).fetchone()[0] + assert total_after == 2, "branch-only messages must be deleted when the branch is removed" + + # Main branch items must still be intact. + await session.switch_to_branch("main") + main_items = await session.get_items() + assert len(main_items) == 2 + + session.close() + + async def test_get_items_with_parameters(): """Test get_items with new parameters (include_inactive, branch_id).""" session_id = "get_items_params_test"