Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def evaluate(
) -> None:
agent_class = get_agent_class_from_agent_name(agent)
classified_articles = await Classifier(
agent=agent_class(len(sources) > 1 if sources else False, prompt, model=model),
agent=agent_class(prompt, model=model),
score=score_threshold,
relevant_articles=relevant_articles,
non_relevant_articles=non_relevant_articles,
Expand Down
3 changes: 1 addition & 2 deletions src/lightman_ai/ai/base/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class BaseAgent(ABC):

def __init__(
self,
multiple_sources: bool,
system_prompt: str,
model: str | None = None,
logger: logging.Logger | None = None,
Expand All @@ -27,7 +26,7 @@ def __init__(
model=agent_model,
output_type=SelectedArticlesList,
system_prompt=system_prompt,
instructions=MERGE_ARTICLES_FROM_DIFFERENT_SOURCES if multiple_sources else None,
instructions=MERGE_ARTICLES_FROM_DIFFERENT_SOURCES,
)
self.logger = logger or logging.getLogger("lightman")
self.logger.info("Selected %s's %s model", self, selected_model)
Expand Down
10 changes: 6 additions & 4 deletions src/lightman_ai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ async def lightman(
if not sources:
raise NoSourcesError

tasks = [_get_articles_from_source(source, start_date) for source in sources]
articles_lists = await asyncio.gather(*tasks)

articles = ArticlesList()
for source in sources:
articles += await _get_articles_from_source(source, start_date)
for articles_list in articles_lists:
articles += articles_list

multiple_sources = len(sources) > 1
agent_class = get_agent_class_from_agent_name(agent)
agent_instance = agent_class(multiple_sources, prompt, model, logger=logger)
agent_instance = agent_class(prompt, model, logger=logger)

classified_articles = await _classify_articles(
articles=articles,
Expand Down
28 changes: 19 additions & 9 deletions src/lightman_ai/sources/bleeping_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from httpx import AsyncClient
from lightman_ai.article.models import Article, ArticlesList
from lightman_ai.sources.base import BaseSource
from lightman_ai.sources.exceptions import IncompleteArticleFromSourceError, MalformedSourceResponseError
from lightman_ai.sources.exceptions import (
IncompleteArticleFromSourceError,
MalformedSourceResponseError,
NoArticlesError,
SourceError,
)
from pydantic import ValidationError

logger = logging.getLogger("lightman")
Expand All @@ -24,14 +29,19 @@ class BleepingComputerSource(BaseSource):
@override
async def get_articles(self, date: datetime | None = None) -> ArticlesList:
"""Return the articles that are present in BleepingComputer feed."""
logger.info("Downloading articles from %s", BLEEPING_COMPUTER_URL)
feed = await self.get_feed()
articles = self._xml_to_list_of_articles(feed)
logger.info("Articles properly downloaded and parsed.")
if date:
return ArticlesList.get_articles_from_date_onwards(articles=articles, start_date=date)
else:
return ArticlesList(articles=articles)
try:
logger.info("Downloading articles from %s", BLEEPING_COMPUTER_URL)
feed = await self.get_feed()
articles = self._xml_to_list_of_articles(feed)
logger.info("Articles properly downloaded and parsed.")
if not articles:
raise NoArticlesError
if date:
return ArticlesList.get_articles_from_date_onwards(articles=articles, start_date=date)
else:
return ArticlesList(articles=articles)
except Exception as e:
raise SourceError("Could not download articles from BleepingComputer") from e

async def get_feed(self) -> str:
"""Retrieve the BleepingComputer RSS Feed."""
Expand Down
8 changes: 8 additions & 0 deletions src/lightman_ai/sources/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,11 @@ class MalformedSourceResponseError(BaseSourceError):

class IncompleteArticleFromSourceError(MalformedSourceResponseError):
"""Exception for when all the mandatory fields could not be retrieved from an article."""


class SourceError(BaseSourceError):
"""Exception for when something went wrong while downloading or parsing the articles from source."""


class NoArticlesError(BaseSourceError):
"""Exception for when no articles where found after the download was successful."""
23 changes: 14 additions & 9 deletions src/lightman_ai/sources/the_hacker_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from httpx import AsyncClient
from lightman_ai.article.models import Article, ArticlesList
from lightman_ai.sources.base import BaseSource
from lightman_ai.sources.exceptions import MalformedSourceResponseError
from lightman_ai.sources.exceptions import MalformedSourceResponseError, NoArticlesError, SourceError
from pydantic import ValidationError

logger = logging.getLogger("lightman")
Expand All @@ -25,14 +25,19 @@ class TheHackerNewsSource(BaseSource):
@override
async def get_articles(self, date: datetime | None = None) -> ArticlesList:
"""Return the articles that are present in THN feed."""
logger.info("Downloading articles from %s", THN_URL)
feed = await self.get_feed()
articles = self._xml_to_list_of_articles(feed)
logger.info("Articles properly downloaded and parsed.")
if date:
return ArticlesList.get_articles_from_date_onwards(articles=articles, start_date=date)
else:
return ArticlesList(articles=articles)
try:
logger.info("Downloading articles from %s", THN_URL)
feed = await self.get_feed()
articles = self._xml_to_list_of_articles(feed)
if not articles:
raise NoArticlesError
logger.info("Articles properly downloaded and parsed.")
if date:
return ArticlesList.get_articles_from_date_onwards(articles=articles, start_date=date)
else:
return ArticlesList(articles=articles)
except Exception as e:
raise SourceError("Could not download articles from THN source") from e

async def get_feed(self) -> str:
"""Retrieve the TheHackerNews' RSS Feed."""
Expand Down
4 changes: 2 additions & 2 deletions tests/ai/base/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestBaseAgent:
@patch("lightman_ai.ai.base.agent.Agent")
async def test__get_prompt_result(self, m_agent: Mock, test_prompt: str, thn_news: ArticlesList) -> None:
"""Check that we receive an instance of `SelectedArticlesList` when running the method."""
agent = FakeAgent(False, test_prompt)
agent = FakeAgent(test_prompt)

with patch("tests.ai.base.test_agent.FakeAgent.run_prompt") as m_run_prompt:
await agent.run_prompt(str(thn_news))
Expand All @@ -32,7 +32,7 @@ async def test__get_prompt_result(self, m_agent: Mock, test_prompt: str, thn_new

@patch("lightman_ai.ai.base.agent.Agent")
async def test_agent_is_intantiated_with_model_when_set(self, m_agent: Mock, test_prompt: str) -> None:
agent = FakeAgent(False, test_prompt, model="my model")
agent = FakeAgent(test_prompt, model="my model")
await agent.run_prompt("")

assert m_agent.call_count == 1
Expand Down
2 changes: 1 addition & 1 deletion tests/ai/gemini/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class TestGeminiAgent:
agent = GeminiAgent(False, system_prompt="Test system prompt")
agent = GeminiAgent(system_prompt="Test system prompt")

async def test__run_prompt(self, test_prompt: str) -> None:
"""Test that we can run a prompt and receive a SelectedArticlesList."""
Expand Down
2 changes: 1 addition & 1 deletion tests/ai/openai/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class TestAgent:
agent = OpenAIAgent(False, system_prompt="Test system prompt")
agent = OpenAIAgent(system_prompt="Test system prompt")

async def test__run_prompt(self, test_prompt: str) -> None:
"""Test that we can run a prompt and receive a SelectedArticlesList."""
Expand Down
13 changes: 12 additions & 1 deletion tests/sources/test_bleeping_computer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import UTC, datetime
from unittest.mock import patch

import pytest
from lightman_ai.article.models import ArticlesList
from lightman_ai.sources.bleeping_computer import BleepingComputerSource
from lightman_ai.sources.exceptions import MalformedSourceResponseError
from lightman_ai.sources.exceptions import MalformedSourceResponseError, SourceError
from tests.conftest import patch_httpx_client_get


Expand Down Expand Up @@ -100,3 +101,13 @@ def test_xml_to_list_of_articles_validation_error(self) -> None:
</rss>""" # no title

assert not BleepingComputerSource()._xml_to_list_of_articles(xml)

async def test_get_articles_raises_when_no_articles_in_feed(self, bc_xml: str) -> None:
with (
patch_httpx_client_get(bc_xml),
patch(
"lightman_ai.sources.bleeping_computer.BleepingComputerSource._xml_to_list_of_articles", return_value=[]
),
pytest.raises(SourceError),
):
await BleepingComputerSource().get_articles()
11 changes: 10 additions & 1 deletion tests/sources/test_the_hacker_news.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import UTC, datetime
from unittest.mock import patch

import pytest
from lightman_ai.article.models import ArticlesList
from lightman_ai.sources.exceptions import MalformedSourceResponseError
from lightman_ai.sources.exceptions import MalformedSourceResponseError, SourceError
from lightman_ai.sources.the_hacker_news import TheHackerNewsSource
from tests.conftest import patch_httpx_client_get

Expand Down Expand Up @@ -126,3 +127,11 @@ def test_xml_to_list_of_articles_cleans_description(self) -> None:

assert len(articles) == 1
assert articles[0].description == "Test description with spaces"

async def test_get_articles_raises_when_no_articles_in_feed(self, thn_xml: str) -> None:
with (
patch_httpx_client_get(thn_xml),
patch("lightman_ai.sources.the_hacker_news.TheHackerNewsSource._xml_to_list_of_articles", return_value=[]),
pytest.raises(SourceError),
):
await TheHackerNewsSource().get_articles()
38 changes: 37 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from datetime import datetime
from unittest.mock import Mock, call, patch
from unittest.mock import AsyncMock, Mock, call, patch
from zoneinfo import ZoneInfo

import pytest
Expand All @@ -9,6 +9,7 @@
from lightman_ai import cli
from lightman_ai.article.models import PrimarySelectedArticle
from lightman_ai.core.config import FileConfig, PromptConfig
from lightman_ai.sources.utils import SOURCE_CHOICES
from tests.conftest import patch_config_file


Expand Down Expand Up @@ -408,3 +409,38 @@ def test_regular_output_with_no_articles(

assert result.exit_code == 0
assert "No relevant articles found." in result.output

@patch("lightman_ai.cli.load_dotenv")
@patch("lightman_ai.cli.FileConfig.get_config_from_file")
@patch("lightman_ai.cli.PromptConfig.get_config_from_file")
def test_exit_code_is_not_zero_if_a_source_fails(self, m_prompt: Mock, m_config: Mock, m_load_dotenv: Mock) -> None:
"""Test that exit code is 0 when all sources fail and no articles are found.

Proves that:
- Sources attempted to retrieve news (httpx was called for each source)
- The AI agent was never invoked (no point classifying zero articles)
"""
runner = CliRunner()
m_prompt.return_value = PromptConfig({"eval": "eval prompt"})
m_config.return_value = FileConfig()

with (
patch("httpx.AsyncClient.get") as mock_get,
patch("pydantic_ai.Agent.run", new_callable=AsyncMock) as mock_agent_run,
patch_config_file(),
):
mock_get.side_effect = Exception("Network error")
result = runner.invoke(
cli.run,
[
"--agent",
"openai",
"--prompt",
"eval",
"--dry-run",
],
)

assert result.exit_code != 0
assert mock_get.call_count == len(SOURCE_CHOICES)
assert mock_agent_run.call_count == 0
28 changes: 26 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightman_ai.core.sentry import configure_sentry
from lightman_ai.exceptions import NoSourcesError
from lightman_ai.main import _create_service_desk_issues, _get_articles_from_source, lightman
from lightman_ai.sources.exceptions import SourceError
from lightman_ai.sources.utils import SOURCE_CHOICES
from tests.conftest import patch_httpx_client_get, patch_multiple_responses
from tests.utils import patch_agent, patch_get_articles_from_xml
Expand Down Expand Up @@ -116,7 +117,7 @@ async def test_lightman_and_service_desk_publish(self, test_prompt: str, thn_xml
assert relevant_article_1.title in called_titles
assert relevant_article_2.title in called_titles

async def test_lightman_no_publish_if_dry_run(self, test_prompt: str, thn_xml: str) -> None:
async def test_lightman_no_publish_if_dry_run(self, test_prompt: str, thn_xml: str, bc_xml: str) -> None:
now = datetime.now(UTC)
relevant_article_1 = PrimarySelectedArticle(
title="article 2", link="https://article2.com", why_is_relevant="a", relevance_score=8, published_at=now
Expand All @@ -129,7 +130,7 @@ async def test_lightman_no_publish_if_dry_run(self, test_prompt: str, thn_xml: s
)
agent_response = SelectedArticlesList(articles=[relevant_article_1, relevant_article_2, not_relevant_article])
with (
patch_httpx_client_get(thn_xml),
patch_multiple_responses([thn_xml, bc_xml]),
patch_agent(agent_response),
patch("lightman_ai.main.ServiceDeskIntegration.from_env") as mock_service_desk_env,
):
Expand Down Expand Up @@ -162,6 +163,29 @@ async def test_lightman_raises_error_when_sources_is_none(self) -> None:
dry_run=True,
)

async def test_lightman_fails_when_one_source_raises_exception(self, test_prompt: str, thn_xml: str) -> None:
"""Test that execution fails when one source raises an exception during download."""
with (
patch("httpx.AsyncClient.get") as mock_get,
patch("pydantic_ai.Agent.run", new_callable=AsyncMock) as mock_agent_run,
):
mock_get.side_effect = [
Mock(text=thn_xml, **{"raise_for_status.return_value": None}),
Exception("Network error: Connection timeout"),
]

with pytest.raises(SourceError):
await lightman(
agent="openai",
prompt=test_prompt,
sources=SOURCE_CHOICES,
score_threshold=8,
dry_run=True,
)

assert mock_get.call_count == len(SOURCE_CHOICES)
mock_agent_run.assert_not_called()


class TestCreateServiceDeskIssues:
"""Tests for the _create_service_desk_issues function."""
Expand Down