diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..e8ed72a19d 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -785,15 +785,17 @@ def _delete_sync(): ) structure_deleted = cursor.rowcount + messages_deleted = self._cleanup_orphaned_messages_sync(conn) conn.commit() - return usage_deleted, structure_deleted + return usage_deleted, structure_deleted, messages_deleted - usage_deleted, structure_deleted = await asyncio.to_thread(_delete_sync) + usage_deleted, structure_deleted, messages_deleted = await asyncio.to_thread(_delete_sync) self._logger.info( - f"Deleted branch '{branch_id}': {structure_deleted} message entries, {usage_deleted} usage entries" # noqa: E501 + f"Deleted branch '{branch_id}': {structure_deleted} message entries, " + f"{usage_deleted} usage entries, {messages_deleted} orphaned messages" ) async def list_branches(self) -> list[dict[str, Any]]: diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index ad4b5c4d86..b305c05062 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -694,6 +694,55 @@ async def test_branch_deletion_with_force(): session.close() +async def test_delete_branch_removes_branch_only_messages(): + session_id = "delete_branch_orphans_test" + session = AdvancedSQLiteSession(session_id=session_id, create_tables=True) + + await session.add_items( + [ + {"role": "user", "content": "Main question"}, + {"role": "assistant", "content": "Main answer"}, + ] + ) + await session.create_branch_from_turn(1, "branch_only") + await session.add_items( + [ + {"role": "user", "content": "Branch-only question"}, + {"role": "assistant", "content": "Branch-only answer"}, + ] + ) + + await session.delete_branch("branch_only", force=True) + + with session._locked_connection() as conn: + message_rows = conn.execute( + f""" + SELECT message_data + FROM {session.messages_table} + WHERE session_id = ? + ORDER BY id + """, + (session.session_id,), + ).fetchall() + structure_rows = conn.execute( + """ + SELECT branch_id, message_id + FROM message_structure + WHERE session_id = ? + ORDER BY message_id + """, + (session.session_id,), + ).fetchall() + + assert [json.loads(row[0])["content"] for row in message_rows] == [ + "Main question", + "Main answer", + ] + assert {row[0] for row in structure_rows} == {"main"} + + session.close() + + async def test_get_items_with_parameters(): """Test get_items with new parameters (include_inactive, branch_id).""" session_id = "get_items_params_test"