Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyrit/prompt_converter/add_image_to_video_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import contextlib
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -177,7 +178,8 @@ async def _add_image_to_video(self, image_path: str, output_path: str) -> str:
# Release everything
cap.release()
output_video.release()
cv2.destroyAllWindows()
with contextlib.suppress(cv2.error):
cv2.destroyAllWindows() # Not available in headless OpenCV builds
if azure_storage_flag:
os.remove(local_temp_path)

Expand Down
62 changes: 40 additions & 22 deletions tests/integration/memory/test_azure_sql_memory_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from sqlalchemy.exc import SQLAlchemyError

from pyrit.identifiers import ScorerIdentifier
from pyrit.identifiers import ComponentIdentifier
from pyrit.memory import AzureSQLMemory
from pyrit.memory.memory_models import (
AttackResultEntry,
Expand All @@ -37,23 +37,38 @@ def generate_test_id() -> str:
return str(uuid4())[:8]


def get_test_scorer_identifier(**kwargs) -> ScorerIdentifier:
def get_test_scorer_identifier(**kwargs) -> ComponentIdentifier:
"""
Returns a test ScorerIdentifier for use in integration tests.
Returns a test ComponentIdentifier for use in integration tests.

Args:
**kwargs: Optional overrides for ScorerIdentifier fields.
**kwargs: Optional overrides for ComponentIdentifier fields.

Returns:
ScorerIdentifier: A test scorer identifier with all required fields.
ComponentIdentifier: A test scorer identifier with all required fields.
"""
return ScorerIdentifier(
return ComponentIdentifier(
class_name=kwargs.get("class_name", "TestScorer"),
class_module=kwargs.get("class_module", "tests.integration.memory.test_azure_sql_memory_integration"),
class_description=kwargs.get("class_description", "Test scorer for integration testing"),
identifier_type=kwargs.get("identifier_type", "instance"),
scorer_type=kwargs.get("scorer_type", "true_false"),
system_prompt_template=kwargs.get("system_prompt_template"),
params={
"class_description": kwargs.get("class_description", "Test scorer for integration testing"),
"identifier_type": kwargs.get("identifier_type", "instance"),
"scorer_type": kwargs.get("scorer_type", "true_false"),
"system_prompt_template": kwargs.get("system_prompt_template"),
},
)


def get_test_attack_identifier() -> ComponentIdentifier:
"""
Returns a test ComponentIdentifier for attack results in integration tests.

Returns:
ComponentIdentifier: A test attack identifier.
"""
return ComponentIdentifier(
class_name="test_attack",
class_module="tests.integration.memory.test_azure_sql_memory_integration",
)


Expand Down Expand Up @@ -264,19 +279,19 @@ async def test_get_attack_results_by_harm_categories(azuresql_instance: AzureSQL
result1 = AttackResult(
conversation_id=conversation_ids[0],
objective="Test objective 1",
attack_identifier={"name": "test_attack"},
attack_identifier=get_test_attack_identifier(),
outcome=AttackOutcome.SUCCESS,
)
result2 = AttackResult(
conversation_id=conversation_ids[1],
objective="Test objective 2",
attack_identifier={"name": "test_attack"},
attack_identifier=get_test_attack_identifier(),
outcome=AttackOutcome.SUCCESS,
)
result3 = AttackResult(
conversation_id=conversation_ids[2],
objective="Test objective 3",
attack_identifier={"name": "test_attack"},
attack_identifier=get_test_attack_identifier(),
outcome=AttackOutcome.FAILURE,
)

Expand Down Expand Up @@ -350,19 +365,19 @@ async def test_get_attack_results_by_labels(azuresql_instance: AzureSQLMemory):
result1 = AttackResult(
conversation_id=conversation_ids[0],
objective="Test objective 1",
attack_identifier={"name": "test_attack"},
attack_identifier=get_test_attack_identifier(),
outcome=AttackOutcome.SUCCESS,
)
result2 = AttackResult(
conversation_id=conversation_ids[1],
objective="Test objective 2",
attack_identifier={"name": "test_attack"},
attack_identifier=get_test_attack_identifier(),
outcome=AttackOutcome.SUCCESS,
)
result3 = AttackResult(
conversation_id=conversation_ids[2],
objective="Test objective 3",
attack_identifier={"name": "test_attack"},
attack_identifier=get_test_attack_identifier(),
outcome=AttackOutcome.FAILURE,
)

Expand Down Expand Up @@ -394,13 +409,13 @@ async def test_scenario_result_scorer_identifier_roundtrip(azuresql_instance: Az
"""
Integration test for storing and retrieving objective_scorer_identifier in ScenarioResult.

Verifies that ScorerIdentifier is correctly serialized to JSON when stored
and deserialized back to ScorerIdentifier when retrieved from Azure SQL.
Verifies that ComponentIdentifier is correctly serialized to JSON when stored
and deserialized back to ComponentIdentifier when retrieved from Azure SQL.
"""
test_id = generate_test_id()

with cleanup_scenario_data(azuresql_instance, test_id):
# Create a ScorerIdentifier with various fields
# Create a ComponentIdentifier with various fields
scorer_identifier = get_test_scorer_identifier(
scorer_type="true_false",
system_prompt_template="Test prompt template for {objective}",
Expand All @@ -426,9 +441,12 @@ async def test_scenario_result_scorer_identifier_roundtrip(azuresql_instance: Az

retrieved = results[0]
assert retrieved.objective_scorer_identifier is not None
assert isinstance(retrieved.objective_scorer_identifier, ScorerIdentifier)
assert retrieved.objective_scorer_identifier.scorer_type == "true_false"
assert retrieved.objective_scorer_identifier.system_prompt_template == "Test prompt template for {objective}"
assert isinstance(retrieved.objective_scorer_identifier, ComponentIdentifier)
assert retrieved.objective_scorer_identifier.params["scorer_type"] == "true_false"
assert (
retrieved.objective_scorer_identifier.params["system_prompt_template"]
== "Test prompt template for {objective}"
)
assert retrieved.objective_scorer_identifier.class_name == scorer_identifier.class_name
assert retrieved.objective_scorer_identifier.hash == scorer_identifier.hash

Expand Down
4 changes: 2 additions & 2 deletions tests/integration/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from sqlalchemy import inspect

from pyrit.identifiers import AttackIdentifier
from pyrit.identifiers import ComponentIdentifier
from pyrit.memory import MemoryInterface, SQLiteMemory
from pyrit.models import Message, MessagePiece
from pyrit.prompt_target import PromptChatTarget, limit_requests_per_minute
Expand Down Expand Up @@ -49,7 +49,7 @@ def set_system_prompt(
*,
system_prompt: str,
conversation_id: str,
attack_identifier: Optional[AttackIdentifier] = None,
attack_identifier: Optional[ComponentIdentifier] = None,
labels: Optional[dict[str, str]] = None,
) -> None:
self.system_prompt = system_prompt
Expand Down
12 changes: 10 additions & 2 deletions tests/integration/score/test_hitl_gradio_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,16 @@
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import MessagePiece, Score
from pyrit.score import HumanInTheLoopScorerGradio
from pyrit.ui.rpc import RPCAlreadyRunningException
from pyrit.ui.rpc_client import RPCClient, RPCClientStoppedException

try:
from pyrit.ui.rpc import RPCAlreadyRunningException
from pyrit.ui.rpc_client import RPCClient, RPCClientStoppedException

_rpyc_available = True
except (ImportError, ModuleNotFoundError):
_rpyc_available = False

pytestmark = pytest.mark.skipif(not _rpyc_available, reason="rpyc not installed")


def if_gradio_installed():
Expand Down