diff --git a/src/askui/reporting.py b/src/askui/reporting.py index cc8be97a..78c521f9 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -3,8 +3,10 @@ import base64 import io import json +import logging import platform import random +import shutil import sys from abc import ABC, abstractmethod from datetime import datetime, timezone @@ -18,6 +20,8 @@ from askui.utils.annotated_image import AnnotatedImage +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from PIL import Image @@ -188,6 +192,70 @@ def generate(self) -> None: NULL_REPORTER = NullReporter() +class ReporterErrorHandler(Reporter): + """A reporter that handles errors by logging them and skipping the reporter.""" + + def __init__(self, reporter: Reporter) -> None: + self._reporter = reporter + self._error_occurred = False + + def _log_on_exception(self, error: Exception) -> None: + error_message = getattr(error, "message", str(error)) + logger.exception( + "Skipping the usage of reporter %s due to the following exception: %s", + self._reporter, + error_message, + ) + self._error_occurred = True + + @override + def add_message( + self, + role: str, + content: Union[str, dict[str, Any], list[Any]], + image: Optional[Image.Image | list[Image.Image] | AnnotatedImage] = None, + ) -> None: + if self._error_occurred: + logger.debug("Skipping reporter due to previous error") + return + try: + self._reporter.add_message(role, content, image) + except Exception as e: # noqa: BLE001 + self._log_on_exception(e) + + @override + def add_usage_summary(self, usage: UsageSummary) -> None: + if self._error_occurred: + logger.debug("Skipping reporter due to previous error") + return + try: + self._reporter.add_usage_summary(usage) + except Exception as e: # noqa: BLE001 + self._log_on_exception(e) + + @override + def add_cache_execution_statistics( + self, original_usage: dict[str, int | None] + ) -> None: + if self._error_occurred: + logger.debug("Skipping reporter due to previous error") + return + try: + self._reporter.add_cache_execution_statistics(original_usage) + except Exception as e: # noqa: BLE001 + self._log_on_exception(e) + + @override + def generate(self) -> None: + if self._error_occurred: + logger.debug("Skipping reporter due to previous error") + return + try: + self._reporter.generate() + except Exception as e: # noqa: BLE001 + self._log_on_exception(e) + + class CompositeReporter(Reporter): """A reporter that combines multiple reporters. @@ -200,7 +268,9 @@ class CompositeReporter(Reporter): """ def __init__(self, reporters: list[Reporter] | None = None) -> None: - self._reporters = reporters or [] + self._reporters = [ + ReporterErrorHandler(reporter) for reporter in reporters or [] + ] @override def add_message( @@ -243,6 +313,10 @@ class SystemInfo(TypedDict): class SimpleHtmlReporter(Reporter): """A reporter that generates HTML reports with conversation logs and system information. + Messages are streamed to a temporary file as they arrive so that base64-encoded + screenshots are never held in memory all at once. The final report is assembled + as a single self-contained HTML file on `generate()`. + Args: report_dir (str, optional): Directory where reports will be saved. Defaults to `reports`. @@ -250,7 +324,7 @@ class SimpleHtmlReporter(Reporter): def __init__(self, report_dir: str = "reports") -> None: self.report_dir = Path(report_dir) - self.messages: list[dict[str, Any]] = [] + self._temp_messages_file: Path | None = None self.system_info = self._collect_system_info() self.usage_summary: UsageSummary | None = None self.cache_original_usage: dict[str, int | None] | None = None @@ -276,6 +350,54 @@ def _format_content(self, content: Union[str, dict[str, Any], list[Any]]) -> str return json.dumps(content, indent=2) return str(content) + def _get_temp_messages_file(self) -> Path: + """Return the path to the temporary messages file, creating it if needed.""" + if self._temp_messages_file is None or not self._temp_messages_file.exists(): + self.report_dir.mkdir(parents=True, exist_ok=True) + _report_ts = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S_%f") + self._temp_messages_file = ( + self.report_dir / f"AskUI_report_{_report_ts}.tmp" + ) + return self._temp_messages_file + + _MESSAGE_ROW_TEMPLATE = Template( + '' + '{{ ts_str }} UTC' + '{{ role }}' + '' + "{% if is_json %}" + '
' + '
{{ content }}
' + "
" + "{% else %}" + "{{ content }}" + "{% endif %}" + "{% for image in images %}" + '
Message image' + "{% endfor %}" + "" + "\n" + ) + + def _render_message_row( + self, + timestamp: datetime, + role: str, + content: str, + is_json: bool, + images: list[str], + ) -> str: + """Render a single conversation message as an HTML table row.""" + return self._MESSAGE_ROW_TEMPLATE.render( + role_lower=role.lower(), + ts_str=timestamp.strftime("%H:%M:%S.%f")[:-3], + role=role, + content=content, + is_json=is_json, + images=images, + ) + @override def add_message( self, @@ -283,22 +405,27 @@ def add_message( content: Union[str, dict[str, Any], list[Any]], image: Optional[Image.Image | list[Image.Image] | AnnotatedImage] = None, ) -> None: - """Add a message to the report.""" - # Track start time from first message + """Add a message to the report. + + The rendered HTML row is written directly to a temporary file so that + base64 image data is not accumulated in memory during long runs. + """ if self._start_time is None: self._start_time = datetime.now(tz=timezone.utc) _images = normalize_to_pil_images(image) _content = truncate_base64_images(content) - message = { - "timestamp": datetime.now(tz=timezone.utc), - "role": role, - "content": self._format_content(_content), - "is_json": isinstance(_content, (dict, list)), - "images": [self._image_to_base64(img) for img in _images], - } - self.messages.append(message) + timestamp = datetime.now(tz=timezone.utc) + formatted_content = self._format_content(_content) + is_json = isinstance(_content, (dict, list)) + image_b64s = [self._image_to_base64(img) for img in _images] + + row_html = self._render_message_row( + timestamp, role, formatted_content, is_json, image_b64s + ) + with self._get_temp_messages_file().open(mode="a", encoding="utf-8") as f: + f.write(row_html) @override def add_usage_summary(self, usage: UsageSummary) -> None: @@ -331,8 +458,11 @@ def generate(self) -> None: - System information - All collected messages with their content and images - Syntax-highlighted JSON content + + Message rows are streamed from a temporary file so that the full set of + base64 images is never held in memory simultaneously. """ - template_str = """ + _HEADER_TEMPLATE = """ @@ -1137,39 +1267,9 @@ def generate(self) -> None: Role Content - {% for msg in messages %} - - {{ msg.timestamp.strftime('%H:%M:%S.%f')[:-3] }} UTC - - - {{ msg.role }} - - - - {% if msg.is_json %} -
-
{{ msg.content }}
-
- {% else %} - {{ msg.content }} - {% endif %} - {% for image in msg.images %} -
- Message image - {% endfor %} - - - {% endfor %} - - - - - """ - template = Template(template_str) + _FOOTER = " \n \n \n \n " # Calculate execution time end_time = datetime.now(tz=timezone.utc) @@ -1198,9 +1298,8 @@ def _format_conversation_duration( ).total_seconds() ) - html = template.render( + header_html = Template(_HEADER_TEMPLATE).render( timestamp=end_time, - messages=self.messages, system_info=self.system_info, usage_summary=self.usage_summary, cache_original_usage=self.cache_original_usage, @@ -1209,11 +1308,26 @@ def _format_conversation_duration( ) report_path = ( - self.report_dir / f"report_{datetime.now(tz=timezone.utc):%Y%m%d%H%M%S%f}" + self.report_dir / f"report_{end_time:%Y%m%d%H%M%S%f}" f"{random.randint(0, 1000):03}.html" ) self.report_dir.mkdir(parents=True, exist_ok=True) - report_path.write_text(html, encoding="utf-8") + + with report_path.open(mode="w", encoding="utf-8") as out: + out.write(header_html) + try: + if ( + self._temp_messages_file is not None + and self._temp_messages_file.exists() + ): + with self._temp_messages_file.open( + mode="r", encoding="utf-8" + ) as tmp: + shutil.copyfileobj(tmp, out) + self._temp_messages_file.unlink() + self._temp_messages_file = None + finally: + out.write(_FOOTER) class AllureReporter(Reporter):