Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def system_prompt(self, prompt: str | list[dict], state: State) -> str:
@hookimpl
def provide_channels(self, message_handler: MessageHandler) -> list[Channel]:
from bub.channels.cli import CliChannel
from bub.channels.telegram import TelegramChannel
from bub.channels.telegram import TelegramChannel, TelegramSettings

return [
TelegramChannel(on_receive=message_handler),
CliChannel(on_receive=message_handler, agent=self._get_agent()),
]
channels: list[Channel] = []
for bot_config in TelegramSettings.bot_configs():
channels.append(TelegramChannel(on_receive=message_handler, bot_config=bot_config))
channels.append(CliChannel(on_receive=message_handler, agent=self._get_agent()))
return channels

@hookimpl
async def on_error(self, stage: str, error: Exception, message: Envelope | None) -> None:
Expand Down
18 changes: 12 additions & 6 deletions src/bub/builtin/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,22 @@ async def append(self, tape: str, entry: TapeEntry) -> None:
@contextlib.asynccontextmanager
async def fork(self, tape: str, merge_back: bool = True) -> AsyncGenerator[None, None]:
store = InMemoryTapeStore()
token = current_store.set(store)
tape_token = current_fork_tape.set(tape)
reset_token = current_tape_was_reset.set(False)
# Save/restore instead of ContextVar.reset(token) to avoid
# "Token was created in a different Context" when cleanup
# runs in a different asyncio Task (e.g. cancellation, TaskGroup).
prev_store = current_store.get(_empty_store)
prev_fork_tape = current_fork_tape.get()
prev_was_reset = current_tape_was_reset.get()
current_store.set(store)
current_fork_tape.set(tape)
current_tape_was_reset.set(False)
try:
yield
finally:
was_reset = current_tape_was_reset.get()
current_store.reset(token)
current_fork_tape.reset(tape_token)
current_tape_was_reset.reset(reset_token)
current_store.set(prev_store)
current_fork_tape.set(prev_fork_tape)
current_tape_was_reset.set(prev_was_reset)
if merge_back:
if was_reset:
await self._parent.reset(tape)
Expand Down
87 changes: 72 additions & 15 deletions src/bub/channels/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, ClassVar

from loguru import logger
from pydantic import Field
from pydantic import BaseModel, Field
from pydantic_settings import SettingsConfigDict
from telegram import Bot, Message, Update
from telegram.ext import Application, CommandHandler, ContextTypes, filters
Expand All @@ -22,22 +22,74 @@
from bub.utils import exclude_none


class BotConfig(BaseModel):
"""Configuration for a single Telegram bot instance."""

name: str = Field(default="", description="Unique bot name used as channel name suffix.")
token: str = Field(..., description="Telegram bot token.")
allow_users: str | None = Field(default=None, description="Comma-separated allowed user IDs.")
allow_chats: str | None = Field(default=None, description="Comma-separated allowed chat IDs.")
proxy: str | None = Field(
default=None,
description="Optional proxy URL for connecting to Telegram API.",
)


@config(name="telegram")
class TelegramSettings(Settings):
model_config = SettingsConfigDict(env_prefix="BUB_TELEGRAM_", extra="ignore", env_file=".env")

token: str = Field(default="", description="Telegram bot token.")
token: str = Field(default="", description="Telegram bot token (backward compat, single-bot mode).")
bots: str = Field(
default="",
description="JSON array of bot configs for multi-bot mode, e.g. '[{\"name\":\"personal\",\"token\":\"xxx\"}]'.",
)
allow_users: str | None = Field(
default=None, description="Comma-separated list of allowed Telegram user IDs, or empty for no restriction."
default=None, description="Comma-separated list of allowed Telegram user IDs (single-bot mode)."
)
allow_chats: str | None = Field(
default=None, description="Comma-separated list of allowed Telegram chat IDs, or empty for no restriction."
default=None, description="Comma-separated list of allowed Telegram chat IDs (single-bot mode)."
)
proxy: str | None = Field(
default=None,
description="Optional proxy URL for connecting to Telegram API, e.g. 'http://user:pass@host:port' or 'socks5://host:port'.",
description="Optional proxy URL for connecting to Telegram API (single-bot mode).",
)

@staticmethod
def bot_configs() -> list[BotConfig]:
"""Return the list of bot configurations, supporting both single and multi-bot modes."""
settings = ensure_config(TelegramSettings)
if settings.bots:
try:
import json as _json
raw = _json.loads(settings.bots)
if not isinstance(raw, list):
logger.warning("telegram settings: BUB_TELEGRAM_BOTS is not a JSON array, falling back to single-bot")
return TelegramSettings._single_bot_config(settings)
configs = [BotConfig(**item) for item in raw]
if not configs:
return []
return configs
except Exception as exc:
logger.warning("telegram settings: failed to parse BUB_TELEGRAM_BOTS: %s", exc)
return TelegramSettings._single_bot_config(settings)
return TelegramSettings._single_bot_config(settings)

@staticmethod
def _single_bot_config(settings: TelegramSettings) -> list[BotConfig]:
"""Build a single BotConfig from the legacy single-bot settings."""
if settings.token:
return [
BotConfig(
name="",
token=settings.token,
allow_users=settings.allow_users,
allow_chats=settings.allow_chats,
proxy=settings.proxy,
)
]
return []


NO_ACCESS_MESSAGE = "You are not allowed to chat with me. Please deploy your own instance of Bub."

Expand Down Expand Up @@ -146,35 +198,39 @@ def _extract_media_items(metadata: dict[str, Any]) -> list[MediaItem]:


class TelegramChannel(Channel):
name = "telegram"
_app: Application

def __init__(self, on_receive: MessageHandler) -> None:
def __init__(self, on_receive: MessageHandler, bot_config: BotConfig) -> None:
self._on_receive = on_receive
self._settings = ensure_config(TelegramSettings)
self._allow_users = {uid.strip() for uid in (self._settings.allow_users or "").split(",") if uid.strip()}
self._allow_chats = {cid.strip() for cid in (self._settings.allow_chats or "").split(",") if cid.strip()}
self._config = bot_config
self._allow_users = {uid.strip() for uid in (self._config.allow_users or "").split(",") if uid.strip()}
self._allow_chats = {cid.strip() for cid in (self._config.allow_chats or "").split(",") if cid.strip()}
self._parser = TelegramMessageParser(bot_getter=lambda: self._app.bot)
self._typing_tasks: dict[str, asyncio.Task] = {}

@property
def name(self) -> str:
return f"telegram-{self._config.name}" if self._config.name else "telegram"

@property
def enabled(self) -> bool:
return bool(self._settings.token)
return bool(self._config.token)

@property
def needs_debounce(self) -> bool:
return True

async def start(self, stop_event: asyncio.Event) -> None:
proxy = self._settings.proxy
proxy = self._config.proxy
logger.info(
"telegram.start allow_users_count={} allow_chats_count={} proxy_enabled={}",
"telegram.start channel={} allow_users_count={} allow_chats_count={} proxy_enabled={}",
self.name,
len(self._allow_users),
len(self._allow_chats),
bool(proxy),
)
get_updates_request = HTTPXRequest(read_timeout=30, proxy=proxy)
builder = Application.builder().token(self._settings.token).get_updates_request(get_updates_request)
builder = Application.builder().token(self._config.token).get_updates_request(get_updates_request)
if proxy:
builder = builder.proxy(proxy)
self._app = builder.build()
Expand All @@ -187,6 +243,7 @@ async def start(self, stop_event: asyncio.Event) -> None:
if updater is None:
return
await updater.start_polling(drop_pending_updates=True, allowed_updates=["message"])
logger.info("telegram.start polling channel={}", self.name)
logger.info("telegram.start polling")

async def stop(self) -> None:
Expand All @@ -201,7 +258,7 @@ async def stop(self) -> None:
with contextlib.suppress(asyncio.CancelledError):
await task
self._typing_tasks.clear()
logger.info("telegram.stopped")
logger.info("telegram.stopped channel={}", self.name)

async def send(self, message: ChannelMessage) -> None:
chat_id = message.chat_id
Expand Down
7 changes: 6 additions & 1 deletion tests/test_builtin_hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def __init__(self, on_receive, agent) -> None:
class DummyTelegramChannel:
name = "telegram"

def __init__(self, on_receive) -> None:
def __init__(self, on_receive, bot_config=None) -> None:
self.on_receive = on_receive

@property
Expand All @@ -218,6 +218,11 @@ def enabled(self) -> bool:

monkeypatch.setattr(bub.channels.cli, "CliChannel", DummyCliChannel)
monkeypatch.setattr(bub.channels.telegram, "TelegramChannel", DummyTelegramChannel)
monkeypatch.setattr(
bub.channels.telegram.TelegramSettings,
"bot_configs",
staticmethod(lambda: [bub.channels.telegram.BotConfig(token="test")]),
)

def message_handler(message) -> None:
return None
Expand Down
12 changes: 6 additions & 6 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from bub.channels.handler import BufferedMessageHandler
from bub.channels.manager import ChannelManager
from bub.channels.message import ChannelMessage
from bub.channels.telegram import BubMessageFilter, TelegramChannel, TelegramMessageParser
from bub.channels.telegram import BubMessageFilter, BotConfig, TelegramChannel, TelegramMessageParser
from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer


Expand Down Expand Up @@ -283,7 +283,7 @@ def test_channel_manager_selects_channels_by_runtime_role(
def test_channel_manager_selects_real_channel_types(load_config) -> None:
_load_channel_config(load_config, telegram_value="test-token")
cli = CliChannel.__new__(CliChannel)
telegram = TelegramChannel(lambda message: None)
telegram = TelegramChannel(lambda message: None, bot_config=BotConfig(token="test_token"))
manager = ChannelManager(
FakeFramework({"cli": cli, "telegram": telegram}),
enabled_channels=["all"],
Expand Down Expand Up @@ -752,7 +752,7 @@ def test_bub_message_filter_accepts_group_mention() -> None:
@pytest.mark.asyncio
async def test_telegram_channel_send_extracts_json_message_and_skips_blank(load_config) -> None:
_load_channel_config(load_config, telegram_value="test-token")
channel = TelegramChannel(lambda message: None)
channel = TelegramChannel(lambda message: None, bot_config=BotConfig(token="test_token"))
sent: list[tuple[str, str]] = []

async def send_message(chat_id: str, text: str) -> None:
Expand All @@ -774,7 +774,7 @@ async def test_telegram_channel_start_with_proxy_does_not_call_get_updates_proxy
fake_builder = _FakeTelegramBuilder()
monkeypatch.setattr("bub.channels.telegram.Application.builder", lambda: fake_builder)

channel = TelegramChannel(lambda message: None)
channel = TelegramChannel(lambda message: None, bot_config=BotConfig(token="test_token", proxy="http://127.0.0.1:1087"))
await channel.start(asyncio.Event())

assert fake_builder.proxy_value == "http://127.0.0.1:1087"
Expand All @@ -785,7 +785,7 @@ async def test_telegram_channel_start_with_proxy_does_not_call_get_updates_proxy
@pytest.mark.asyncio
async def test_telegram_channel_build_message_returns_command_directly(load_config) -> None:
_load_channel_config(load_config, telegram_value="test-token")
channel = TelegramChannel(lambda message: None)
channel = TelegramChannel(lambda message: None, bot_config=BotConfig(token="test_token"))
channel._parser = SimpleNamespace(parse=_async_return((",help", {"type": "text"})), get_reply=_async_return(None))

message = SimpleNamespace(chat_id=42)
Expand All @@ -803,7 +803,7 @@ async def test_telegram_channel_build_message_wraps_payload_and_disables_outboun
monkeypatch: pytest.MonkeyPatch, load_config
) -> None:
_load_channel_config(load_config, telegram_value="test-token")
channel = TelegramChannel(lambda message: None)
channel = TelegramChannel(lambda message: None, bot_config=BotConfig(token="test_token"))
parser = SimpleNamespace(
parse=_async_return(("hello", {"type": "text", "sender_id": "7"})),
get_reply=_async_return({"message": "prev", "type": "text"}),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_image_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from bub.builtin.hook_impl import BuiltinImpl
from bub.channels.message import ChannelMessage, MediaItem
from bub.channels.telegram import TelegramChannel, _extract_media_items
from bub.channels.telegram import BotConfig, TelegramChannel, _extract_media_items
from bub.framework import BubFramework

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -186,7 +186,7 @@ async def _receive_message(_message) -> None:
@pytest.mark.asyncio
async def test_telegram_build_message_extracts_media_items(monkeypatch: pytest.MonkeyPatch, load_config) -> None:
load_config("telegram:\n token: test-token")
channel = TelegramChannel(_receive_message)
channel = TelegramChannel(_receive_message, bot_config=BotConfig(token="test_token"))
photo_metadata = {
"type": "photo",
"sender_id": "7",
Expand All @@ -213,7 +213,7 @@ async def test_telegram_build_message_extracts_media_items(monkeypatch: pytest.M
@pytest.mark.asyncio
async def test_telegram_build_message_no_media_for_text(monkeypatch: pytest.MonkeyPatch, load_config) -> None:
load_config("telegram:\n token: test-token")
channel = TelegramChannel(_receive_message)
channel = TelegramChannel(_receive_message, bot_config=BotConfig(token="test_token"))
channel._parser = SimpleNamespace( # type: ignore[assignment]
parse=_async_return(("hello", {"type": "text", "sender_id": "7"})),
get_reply=_async_return(None),
Expand Down
Loading