diff --git a/examples/extension.py b/examples/extension.py new file mode 100644 index 0000000..6addf63 --- /dev/null +++ b/examples/extension.py @@ -0,0 +1,45 @@ +from matrix import Extension, Context + +extension = Extension("math") + + +@extension.group("math", description="Math Group") +async def math_group(ctx: Context): + pass + + +@math_group.command() +async def add(ctx: Context, a: int, b: int): + await ctx.reply(f"**{a} + {b} = {a + b}**") + + +@math_group.command() +async def subtract(ctx: Context, a: int, b: int): + await ctx.reply(f"{a} - {b} = {a - b}") + + +@math_group.command() +async def multiply(ctx: Context, a: int, b: int): + await ctx.reply(f"{a} x {b} = {a * b}") + + +@math_group.command() +async def divide(ctx: Context, a: int, b: int): + await ctx.reply(f"{a} รท {b} = {a / b}") + + +@divide.error(ZeroDivisionError) +async def divide_error(ctx: Context, error): + await ctx.reply(f"Divide error: {error}") + + +""" +from matrix import Bot +from math_extension import extension as math_extension + +bot = Bot(config="config.yaml") + + +bot.load_extension(math_extension) +bot.start() +""" diff --git a/matrix/__init__.py b/matrix/__init__.py index b4369ec..ef00fe2 100644 --- a/matrix/__init__.py +++ b/matrix/__init__.py @@ -15,6 +15,7 @@ from .help import HelpCommand from .checks import cooldown from .room import Room +from .extension import Extension __all__ = [ "Bot", @@ -26,4 +27,5 @@ "HelpCommand", "cooldown", "Room", + "Extension", ] diff --git a/matrix/bot.py b/matrix/bot.py index feef107..db083a4 100644 --- a/matrix/bot.py +++ b/matrix/bot.py @@ -3,77 +3,30 @@ import asyncio import logging -from collections import defaultdict -from typing import ( - Any, - Dict, - List, - Type, - Union, - Optional, - Callable, - Coroutine, -) - -from nio import ( - AsyncClient, - Event, - MatrixRoom, - RoomMessageText, - RoomMemberEvent, - TypingNoticeEvent, - ReactionEvent, -) +from typing import Union, Optional + +from nio import AsyncClient, Event, MatrixRoom from .room import Room from .group import Group from .config import Config from .context import Context -from .command import Command +from .extension import Extension +from .registry import Registry from .help import HelpCommand, DefaultHelpCommand from .scheduler import Scheduler +from .errors import AlreadyRegisteredError, CommandNotFoundError, CheckError -from .errors import ( - GroupAlreadyRegisteredError, - AlreadyRegisteredError, - CommandNotFoundError, - CheckError, -) - -Callback = Callable[..., Coroutine[Any, Any, Any]] -GroupCallable = Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group] -ErrorCallback = Callable[[Exception], Coroutine] -CommandErrorCallback = Callable[["Context", Exception], Coroutine[Any, Any, Any]] - -class Bot: +class Bot(Registry): """ The base class defining a Matrix bot. This class manages the connection to a Matrix homeserver, listens for events, and dispatches them to registered handlers. It also supports a command system with decorators for easy registration. - - :param config: Configuration for Matrix client settings - :type config: Config - - :raises TypeError: If an event or command handler is not a coroutine. - :raises ValueError: If an unknown event name - :raises AlreadyRegisteredError: If a new command is already registered. """ - EVENT_MAP: Dict[str, Type[Event]] = { - "on_typing": TypingNoticeEvent, - "on_message": RoomMessageText, - "on_react": ReactionEvent, - "on_member_join": RoomMemberEvent, - "on_member_leave": RoomMemberEvent, - "on_member_invite": RoomMemberEvent, - "on_member_ban": RoomMemberEvent, - "on_member_kick": RoomMemberEvent, - "on_member_change": RoomMemberEvent, - } - def __init__( self, *, config: Union[Config, str], help: Optional[HelpCommand] = None ) -> None: @@ -84,263 +37,86 @@ def __init__( else: raise TypeError("config must be a Config instance or a config file path") + super().__init__(self.__class__.__name__, prefix=self.config.prefix) + self.client: AsyncClient = AsyncClient(self.config.homeserver) + self.extensions: dict[str, Extension] = {} + self.scheduler: Scheduler = Scheduler() self.log: logging.Logger = logging.getLogger(__name__) - self.prefix: str = self.config.prefix self.start_at: float | None = None # unix timestamp - self.commands: Dict[str, Command] = {} - self.checks: List[Callback] = [] - self.scheduler = Scheduler() - - self._handlers: Dict[Type[Event], List[Callback]] = defaultdict(list) - self._on_error: Optional[ErrorCallback] = None - self._error_handlers: dict[type[Exception], ErrorCallback] = {} - self._command_error_handlers: dict[type[Exception], CommandErrorCallback] = {} - self.help: HelpCommand = help or DefaultHelpCommand(prefix=self.prefix) self.register_command(self.help) self.client.add_event_callback(self._on_event, Event) self._auto_register_events() - def check(self, func: Callback) -> None: - """ - Register a check callback - - :param func: The check callback - :type func: Callback - - :raises TypeError: If the function is not a coroutine. - """ - if not inspect.iscoroutinefunction(func): - raise TypeError("Checks must be coroutine") + def get_room(self, room_id: str) -> Room: + """Retrieve a Room instance based on the room_id.""" + matrix_room = self.client.rooms[room_id] + return Room(matrix_room=matrix_room, client=self.client) - self.checks.append(func) + def load_extension(self, extension: Extension) -> None: + self.log.debug(f"Loading extension: '{extension.name}'") - def event( - self, - func: Optional[Callback] = None, - *, - event_spec: Union[str, Type[Event], None] = None, - ) -> Union[Callback, Callable[[Callback], Callback]]: - """ - Decorator to register a coroutine as an event handler. - - Can be used with or without arguments: - - - Without arguments, registers based on coroutine name - lookup in ``EVENT_MAP``:: - - @bot.event - async def on_message(room, event): - ... - - - With an explicit event type or event name:: - - @bot.event(event_spec=RoomMemberEvent) - async def handle_member(room, event): - ... - - @bot.event(event_spec="on_member_join") - async def welcome(room, event): - ... - - :param func: The coroutine function to register (used when decorator - is applied without parentheses). - :type func: coroutine function, optional - :param event_spec: The event to register for, either as a string key - matching ``EVENT_MAP`` or a specific event class. If omitted, - the event type is inferred from the coroutine function's name. - :type event_spec: str or type or None, optional - :raises TypeError: If the decorated function is not a coroutine. - :raises ValueError: If the event name or string is unknown. - :return: Decorator that registers the event handler. - :rtype: Callable[[Callable[..., Awaitable[None]]], - Callable[..., Awaitable[None]]] - """ + if extension.name in self.extensions: + raise AlreadyRegisteredError(extension) - def wrapper(f: Callback) -> Callback: - if not inspect.iscoroutinefunction(f): - raise TypeError("Event handlers must be coroutines") - - if event_spec: - if isinstance(event_spec, str): - event_type = self.EVENT_MAP.get(event_spec) - if event_type is None: - raise ValueError(f"Unknown event string: {event_spec}") - else: - event_type = event_spec + for cmd in extension._commands.values(): + if isinstance(cmd, Group): + self.register_group(cmd) else: - event_type = self.EVENT_MAP.get(f.__name__) - if event_type is None: - raise ValueError(f"Unknown event name: {f.__name__}") + self.register_command(cmd) - self._handlers[event_type].append(f) - self.log.debug( - "registered event %s for %s", f.__name__, event_type.__name__ - ) - return f - - if func is None: - return wrapper - - return wrapper(func) - - def command( - self, - name: Optional[str] = None, - *, - description: Optional[str] = None, - parent: Optional[str] = None, - usage: Optional[str] = None, - cooldown: Optional[tuple[int, float]] = None, - ) -> Callable[[Callback], Command]: - """ - Decorator to register a coroutine function as a command handler. + for event_type, handlers in extension._event_handlers.items(): + self._event_handlers[event_type].extend(handlers) - The command name defaults to the function name unless - explicitly provided. + self._checks.extend(extension._checks) + self._error_handlers.update(extension._error_handlers) + self._command_error_handlers.update(extension._command_error_handlers) - ## Example - - ```python - @bot.command(description="Returns pong!") - async def ping(ctx): - await ctx.reply("Pong!") - ``` - """ - - def wrapper(func: Callback) -> Command: - cmd = Command( - func, - name=name, - description=description, - prefix=self.prefix, - parent=parent, - usage=usage, - cooldown=cooldown, + for job in extension._scheduler.jobs: + self.scheduler.scheduler.add_job( + job.func, + trigger=job.trigger, + name=job.name, ) - return self.register_command(cmd) - - return wrapper - def group( - self, - name: Optional[str] = None, - *, - description: Optional[str] = None, - parent: Optional[str] = None, - usage: Optional[str] = None, - cooldown: Optional[tuple[int, float]] = None, - ) -> GroupCallable: - """Decorator to register a coroutine function as a group handler. + self.extensions[extension.name] = extension + extension.load() + self.log.debug("loaded extension '%s'", extension.name) - The group name defaults to the function name unless - explicitly provided. + def unload_extension(self, ext_name: str) -> None: + self.log.debug("Unloading extension: '%s'", ext_name) - ## Example + extension = self.extensions.pop(ext_name, None) + if extension is None: + raise ValueError(f"No extension named '{ext_name}' is loaded") - ```python - @bot.group(description="Group of mathematical commands") - async def math(ctx): - await ctx.reply("You called !math") + for cmd_name in extension._commands: + self._commands.pop(cmd_name, None) + for event_type, handlers in extension._event_handlers.items(): + for handler in handlers: + self._event_handlers[event_type].remove(handler) - @math.command() - async def add(ctx, a: int, b: int): - await ctx.reply(f"{a} + {b} = {a + b}") + for check in extension._checks: + self._checks.remove(check) + for exc_type in extension._error_handlers: + self._error_handlers.pop(exc_type, None) - @math.command() - async def subtract(ctx, a: int, b: int): - await ctx.reply(f"{a} - {b} = {a - b}") - ``` - """ + for exc_type in extension._command_error_handlers: + self._command_error_handlers.pop(exc_type, None) - def wrapper(func: Callback) -> Group: - group = Group( - func, - name=name, - description=description, - prefix=self.prefix, - parent=parent, - usage=usage, - cooldown=cooldown, - ) - return self.register_group(group) - - return wrapper - - def schedule(self, cron: str) -> Callable[..., Callback]: - """ - Decorator to register a coroutine function as a scheduled task. + for job in extension._scheduler.jobs: + bot_job = next((j for j in self.scheduler.jobs if j.func is job.func), None) + if bot_job: + bot_job.remove() - The cron string defines the schedule for the task. - - :param cron: The cron string defining the schedule. - :type cron: str - :raises TypeError: If the decorated function is not a coroutine. - :return: Decorator that registers the scheduled task. - :rtype: Callback - """ - - def wrapper(f: Callback) -> Callback: - if not inspect.iscoroutinefunction(f): - raise TypeError("Scheduled tasks must be coroutines") - - self.scheduler.schedule(cron, f) - self.log.debug("registered scheduled task %s for cron %s", f.__name__, cron) - return f - - return wrapper - - def register_command(self, cmd: Command) -> Command: - if cmd in self.commands: - raise AlreadyRegisteredError(cmd) - - self.commands[cmd.name] = cmd - self.log.debug("command '%s' registered", cmd) - - return cmd - - def register_group(self, group: Group) -> Group: - if group in self.commands: - raise GroupAlreadyRegisteredError(group) - - self.commands[group.name] = group - self.log.debug("group '%s' registered", group) - return group - - def error(self, exception: Optional[type[Exception]] = None) -> Callable: - """ - Decorator to register a custom error handler for commands. - - :param exception: The specific exception type to handle. - :type exception: Optional[Exception] - - :return: A decorator that registers the given coroutine as - an error handler. - :rtype: Callable - """ - - def wrapper(func: ErrorCallback) -> Callable: - if not inspect.iscoroutinefunction(func): - raise TypeError("The error handler must be a coroutine.") - - if exception: - self._error_handlers[exception] = func - else: - self._on_error = func - return func - - return wrapper - - def get_room(self, room_id: str) -> Room: - """Retrieve a Room instance based on the room_id.""" - matrix_room = self.client.rooms[room_id] - return Room(matrix_room=matrix_room, client=self.client) + extension.unload() + self.log.debug("unloaded extension '%s'", ext_name) def _auto_register_events(self) -> None: for attr in dir(self): @@ -369,7 +145,7 @@ async def _on_event(self, room: MatrixRoom, event: Event) -> None: async def _dispatch(self, room: MatrixRoom, event: Event) -> None: """Internal type-based fan-out plus optional command handling.""" - for event_type, funcs in self._handlers.items(): + for event_type, funcs in self._event_handlers.items(): if isinstance(event, event_type): for func in funcs: await func(room, event) @@ -379,28 +155,44 @@ async def _process_commands(self, room: MatrixRoom, event: Event) -> None: ctx = await self._build_context(room, event) if ctx.command: - for check in self.checks: + for check in self._checks: if not await check(ctx): raise CheckError(ctx.command, check) await ctx.command(ctx) async def _build_context(self, matrix_room: MatrixRoom, event: Event) -> Context: - """Builds the base context and extracts the command from the event""" room = self.get_room(matrix_room.room_id) ctx = Context(bot=self, room=room, event=event) + prefix: str | None = None - if not self.prefix or not ctx.body.startswith(self.prefix): + if self.prefix is not None and ctx.body.startswith(self.prefix): + prefix = self.prefix + else: + prefix = next( + ( + cmd.prefix + for cmd in self._commands.values() + if cmd.prefix is not None and ctx.body.startswith(cmd.prefix) + ), + self.config.prefix, + ) + + if prefix is None or not ctx.body.startswith(prefix): return ctx - if parts := ctx.body[len(self.prefix) :].split(): + if parts := ctx.body[len(prefix) :].split(): cmd_name = parts[0] - cmd = self.commands.get(cmd_name) + cmd = self._commands.get(cmd_name) + + if cmd and cmd.prefix and not ctx.body.startswith(cmd.prefix): + return ctx + + if not cmd: + raise CommandNotFoundError(cmd_name) - if not cmd: - raise CommandNotFoundError(cmd_name) + ctx.command = cmd - ctx.command = cmd return ctx async def on_message(self, room: MatrixRoom, event: Event) -> None: diff --git a/matrix/context.py b/matrix/context.py index 1fe15f4..3d1651a 100644 --- a/matrix/context.py +++ b/matrix/context.py @@ -28,7 +28,6 @@ def __init__(self, bot: "Bot", room: Room, event: Event): self.sender: str = event.sender # Command metadata - self.prefix: str = bot.prefix self.command: Optional[Command] = None self.subcommand: Optional[Command] = None self._args: List[str] = shlex.split(self.body) diff --git a/matrix/errors.py b/matrix/errors.py index 18f3890..98e9c18 100644 --- a/matrix/errors.py +++ b/matrix/errors.py @@ -4,6 +4,7 @@ if TYPE_CHECKING: from .command import Command # pragma: no cover from .group import Group # pragma: no cover + from .extension import Extension Callback = Callable[..., Coroutine[Any, Any, Any]] @@ -12,6 +13,17 @@ class MatrixError(Exception): pass +class RegistryError(MatrixError): + pass + + +class AlreadyRegisteredError(RegistryError): + def __init__(self, entry: "Command | Group | Extension"): + super().__init__( + f"{entry.__class__.__name__} '{entry.name}' is already registered" + ) + + class CommandError(MatrixError): pass @@ -21,7 +33,7 @@ def __init__(self, cmd: str): super().__init__(f"Command with name '{cmd}' not found") -class AlreadyRegisteredError(CommandError): +class CommandAlreadyRegisteredError(CommandError): def __init__(self, cmd: "Command"): super().__init__(f"Command '{cmd}' is already registered") @@ -40,11 +52,6 @@ class GroupError(CommandError): pass -class GroupAlreadyRegisteredError(GroupError): - def __init__(self, group: "Group"): - super().__init__(f"Group '{group}' is already registered") - - class ConfigError(MatrixError): def __init__(self, error: str): super().__init__(f"Missing required configuration: '{error}'") diff --git a/matrix/extension.py b/matrix/extension.py new file mode 100644 index 0000000..5fe9ec7 --- /dev/null +++ b/matrix/extension.py @@ -0,0 +1,56 @@ +import logging +import inspect + +from typing import Any, Callable, Coroutine, Optional +from matrix.registry import Registry + +logger = logging.getLogger(__name__) + + +class Extension(Registry): + def __init__(self, name: str, prefix: Optional[str] = None) -> None: + super().__init__(name, prefix=prefix) + self._on_load: Optional[Callable] = None + self._on_unload: Optional[Callable] = None + + def load(self) -> None: + if self._on_load: + self._on_load() + + def on_load(self, func: Callable) -> Callable: + """Decorator to register a function to be called after this extension + is loaded into the bot. + + ## Example + + ```python + @extension.on_load + def setup(): + print("extension loaded") + ``` + """ + if inspect.iscoroutinefunction(func): + raise TypeError("on_load handler must not be a coroutine") + self._on_load = func + return func + + def unload(self) -> None: + if self._on_unload: + self._on_unload() + + def on_unload(self, func: Callable) -> Callable: + """Decorator to register a function to be called before this extension + is unloaded from the bot. + + ## Example + + ```python + @extension.on_unload + def teardown(): + print("extension unloaded") + ``` + """ + if inspect.iscoroutinefunction(func): + raise TypeError("on_unload handler must not be a coroutine") + self._on_unload = func + return func diff --git a/matrix/registry.py b/matrix/registry.py new file mode 100644 index 0000000..2a9683b --- /dev/null +++ b/matrix/registry.py @@ -0,0 +1,333 @@ +import inspect +import logging + +from collections import defaultdict +from typing import Any, Callable, Coroutine, Optional, Type, Union, Dict, List + +from nio import ( + Event, + ReactionEvent, + RoomMemberEvent, + RoomMessageText, + TypingNoticeEvent, +) + +from matrix.group import Group +from matrix.command import Command +from matrix.scheduler import Scheduler +from matrix.context import Context +from matrix.errors import AlreadyRegisteredError + +logger = logging.getLogger(__name__) + +Callback = Callable[..., Coroutine[Any, Any, Any]] +GroupCallable = Callable[[Callable[..., Coroutine[Any, Any, Any]]], Group] +ErrorCallback = Callable[[Exception], Coroutine] +CommandErrorCallback = Callable[[Context, Exception], Coroutine[Any, Any, Any]] + + +class Registry: + """ + Base class providing shared registration behaviour for Bot and Extension. + + Handles registration of commands, groups, events, checks, schedules, + and error handlers. Subclasses must initialize the required attributes + defined below, either directly or via ``super().__init__()``. + """ + + EVENT_MAP: dict[str, Type[Event]] = { + "on_typing": TypingNoticeEvent, + "on_message": RoomMessageText, + "on_react": ReactionEvent, + "on_member_join": RoomMemberEvent, + "on_member_leave": RoomMemberEvent, + "on_member_invite": RoomMemberEvent, + "on_member_ban": RoomMemberEvent, + "on_member_kick": RoomMemberEvent, + "on_member_change": RoomMemberEvent, + } + + def __init__(self, name: str, prefix: Optional[str] = None): + self.name = name + self.prefix = prefix + self.log = logging.getLogger(__name__) + + self._commands: Dict[str, Command] = {} + self._checks: List[Callback] = [] + self._scheduler: Scheduler = Scheduler() + + self._event_handlers: Dict[Type[Event], List[Callback]] = defaultdict(list) + self._on_error: Optional[ErrorCallback] = None + self._error_handlers: Dict[type[Exception], ErrorCallback] = {} + self._command_error_handlers: Dict[type[Exception], CommandErrorCallback] = {} + + @property + def commands(self) -> Dict[str, Command]: + return self._commands + + def command( + self, + name: Optional[str] = None, + *, + description: Optional[str] = None, + usage: Optional[str] = None, + cooldown: Optional[tuple[int, float]] = None, + ) -> Callable[[Callback], Command]: + """Decorator to register a coroutine function as a command handler. + + The command name defaults to the function name unless + explicitly provided. + + ## Example + + ```python + @bot.command(description="Returns pong!") + async def ping(ctx): + await ctx.reply("Pong!") + ``` + """ + + def wrapper(func: Callback) -> Command: + cmd = Command( + func, + name=name, + description=description, + prefix=self.prefix, + usage=usage, + cooldown=cooldown, + ) + return self.register_command(cmd) + + return wrapper + + def register_command(self, cmd: Command) -> Command: + """Register a Command instance directly. + + Prefer the :meth:`command` decorator for typical use. This method + is useful when constructing a ``Command`` object manually or when + loading commands from an extension. + """ + if cmd.name in self._commands: + raise AlreadyRegisteredError(cmd) + + self._commands[cmd.name] = cmd + logger.debug("command '%s' registered on %s", cmd, type(self).__name__) + + return cmd + + def group( + self, + name: Optional[str] = None, + *, + description: Optional[str] = None, + usage: Optional[str] = None, + cooldown: Optional[tuple[int, float]] = None, + ) -> Callable[[Callback], Group]: + """Decorator to register a coroutine function as a command group. + + A group acts as a parent command that can have subcommands attached + to it via its own ``@group.command()`` decorator. The group name + defaults to the function name unless explicitly provided. + + ## Example + + ```python + @bot.group(description="Group of mathematical commands") + async def math(ctx): + await ctx.reply("You called !math") + + @math.command() + async def add(ctx, a: int, b: int): + await ctx.reply(f"{a} + {b} = {a + b}") + + @math.command() + async def subtract(ctx, a: int, b: int): + await ctx.reply(f"{a} - {b} = {a - b}") + ``` + """ + + def wrapper(func: Callback) -> Group: + grp = Group( + func, + name=name, + description=description, + prefix=self.prefix, + usage=usage, + cooldown=cooldown, + ) + return self.register_group(grp) + + return wrapper + + def register_group(self, group: Group) -> Group: + """Register a Group instance directly. + + Prefer the :meth:`group` decorator for typical use. This method + is useful when constructing a ``Group`` object manually or when + loading groups from an extension. + """ + if group.name in self._commands: + raise AlreadyRegisteredError(group) + + self._commands[group.name] = group + logger.debug("group '%s' registered on %s", group, type(self).__name__) + + return group + + def event( + self, + func: Optional[Callback] = None, + *, + event_spec: Union[str, Type[Event], None] = None, + ) -> Union[Callback, Callable[[Callback], Callback]]: + """Decorator to register a coroutine as an event handler. + + Can be used with or without arguments. Without arguments, the event + type is inferred from the function name via ``EVENT_MAP``. Multiple + handlers for the same event type are supported and called in + registration order. + + ## Example + + ```python + @bot.event + async def on_message(room, event): + ... + + @bot.event(event_spec=RoomMemberEvent) + async def handle_member(room, event): + ... + + @bot.event(event_spec="on_member_join") + async def welcome(room, event): + ... + ``` + """ + + def wrapper(f: Callback) -> Callback: + if not inspect.iscoroutinefunction(f): + raise TypeError("Event handlers must be coroutines") + + if event_spec: + if isinstance(event_spec, str): + event_type = self.EVENT_MAP.get(event_spec) + if event_type is None: + raise ValueError(f"Unknown event string: {event_spec}") + else: + event_type = event_spec + else: + event_type = self.EVENT_MAP.get(f.__name__) + if event_type is None: + raise ValueError(f"Unknown event name: {f.__name__}") + + return self.register_event(event_type, f) + + if func is None: + return wrapper + return wrapper(func) + + def register_event(self, event_type: Type[Event], callback: Callback) -> Callback: + """Register an event handler directly for a given event type. + + Prefer the :meth:`event` decorator for typical use. This method + is useful when loading event handlers from an extension. + """ + self._event_handlers[event_type].append(callback) + logger.debug( + "registered event %s for %s", callback.__name__, event_type.__name__ + ) + return callback + + def check(self, func: Callback) -> Callback: + """Register a global check that must pass before any command is invoked. + + The check receives the current :class:`Context` and must return a + boolean. If any check returns ``False``, a :class:`CheckError` is + raised and the command is not executed. + + ## Example + + ```python + @bot.check + async def is_not_banned(ctx): + return ctx.sender not in banned_users + ``` + """ + if not inspect.iscoroutinefunction(func): + raise TypeError("Checks must be coroutines") + + self._checks.append(func) + logger.debug("registered check '%s' on %s", func.__name__, type(self).__name__) + + return func + + def schedule(self, cron: str) -> Callable[[Callback], Callback]: + """Decorator to register a coroutine as a scheduled task. + + When used on an extension, scheduled tasks are merged into the + bot's scheduler when the extension is loaded. + + ## Example + + ```python + @bot.schedule("0 9 * * *") + async def morning_message(): + await room.send("Good morning!") + ``` + """ + + def wrapper(f: Callback) -> Callback: + if not inspect.iscoroutinefunction(f): + raise TypeError("Scheduled tasks must be coroutines") + + self._scheduler.schedule(cron, f) + logger.debug( + "scheduled '%s' for cron '%s' on %s", + f.__name__, + cron, + type(self).__name__, + ) + + return f + + return wrapper + + def error( + self, exception: Optional[type[Exception]] = None + ) -> Callable[[ErrorCallback], ErrorCallback]: + """Decorator to register an error handler. + + If an exception type is provided, the handler is only invoked for + that specific exception. If omitted, the handler acts as a generic + fallback for any unhandled error. + + ## Example + + ```python + @bot.error(ValueError) + async def on_value_error(error): + await room.send(f"Bad value: {error}") + + @bot.error() + async def on_any_error(error): + await room.send(f"Something went wrong: {error}") + ``` + """ + + def wrapper(func: ErrorCallback) -> ErrorCallback: + if not inspect.iscoroutinefunction(func): + raise TypeError("Error handlers must be coroutines") + + if exception: + self._error_handlers[exception] = func + else: + self._on_error = func + logger.debug( + "registered error handler '%s' on %s", + func.__name__, + type(self).__name__, + ) + + return func + + return wrapper diff --git a/matrix/scheduler.py b/matrix/scheduler.py index 1af3e98..2516f6f 100644 --- a/matrix/scheduler.py +++ b/matrix/scheduler.py @@ -1,6 +1,7 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.cron import CronTrigger -from typing import Any, Callable, Coroutine +from apscheduler.job import Job +from typing import cast, Any, Callable, Coroutine Callback = Callable[..., Coroutine[Any, Any, Any]] @@ -13,6 +14,10 @@ def __init__(self) -> None: """ self.scheduler = AsyncIOScheduler() + @property + def jobs(self) -> list[Job]: + return cast(list[Job], self.scheduler.get_jobs()) + def _parse_cron(self, cron: str) -> dict: """ Parse a cron string into a dictionary suitable for CronTrigger. diff --git a/pyproject.toml b/pyproject.toml index 43d9e14..8e45597 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,4 +44,10 @@ packages = ["matrix"] license-files = [] [tool.setuptools_scm] -write_to = "matrix/_version.py" \ No newline at end of file +write_to = "matrix/_version.py" + +[tool.coverage.run] +omit = [ + "matrix/_version.py", + "matrix/__init__.py", +] \ No newline at end of file diff --git a/tests/test_bot.py b/tests/test_bot.py index 8f98402..8ab4651 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -5,6 +5,7 @@ from matrix.bot import Bot from matrix.config import Config +from matrix.extension import Extension from matrix.errors import ( CheckError, CommandNotFoundError, @@ -85,8 +86,8 @@ async def handler1(room, event): async def handler2(room, event): called.append("h2") - bot._handlers[RoomMessageText].append(handler1) - bot._handlers[RoomMessageText].append(handler2) + bot._event_handlers[RoomMessageText].append(handler1) + bot._event_handlers[RoomMessageText].append(handler2) event = RoomMessageText.from_dict( { @@ -410,5 +411,322 @@ async def test_scheduled_task_in_scheduler(bot): async def scheduled_task(): pass - job_names = list(map(lambda j: j.name, bot.scheduler.scheduler.get_jobs())) + job_names = list(map(lambda j: j.name, bot._scheduler.jobs)) assert "scheduled_task" in job_names, "Scheduled task not found in scheduler" + + +@pytest.fixture +def extension() -> Extension: + ext = Extension(name="test_ext", prefix="!") + + @ext.command() + async def hello(ctx): + pass + + return ext + + +def test_load_extension_with_valid_extension__expect_extension_in_registry( + bot: Bot, extension: Extension +): + bot.load_extension(extension) + + assert "test_ext" in bot.extensions + + +def test_load_extension_with_duplicate_extension__expect_already_registered_error( + bot: Bot, extension: Extension +): + bot.load_extension(extension) + + with pytest.raises(AlreadyRegisteredError): + bot.load_extension(extension) + + +def test_load_extension_with_commands__expect_commands_in_bot( + bot: Bot, extension: Extension +): + bot.load_extension(extension) + + assert "hello" in bot.commands + + +def test_load_extension_with_group__expect_group_in_bot(bot: Bot): + ext = Extension(name="math_ext", prefix="!") + + @ext.group() + async def math(ctx): + pass + + bot.load_extension(ext) + + assert "math" in bot.commands + + +def test_load_extension_with_event_handlers__expect_handlers_in_bot(bot: Bot): + ext = Extension(name="events_ext") + + @ext.event + async def on_message(room, event): + pass + + bot.load_extension(ext) + + assert on_message in bot._event_handlers[RoomMessageText] + + +def test_load_extension_with_multiple_event_handlers__expect_all_handlers_in_bot( + bot: Bot, +): + ext = Extension(name="multi_events_ext") + + @ext.event + async def on_message(room, event): + pass + + @ext.event(event_spec="on_message") + async def on_message_two(room, event): + pass + + bot.load_extension(ext) + + assert on_message in bot._event_handlers[RoomMessageText] + assert on_message_two in bot._event_handlers[RoomMessageText] + + +def test_load_extension_with_checks__expect_checks_in_bot(bot: Bot): + ext = Extension(name="checks_ext") + + @ext.check + async def only_admins(ctx): + return True + + bot.load_extension(ext) + + assert only_admins in bot._checks + + +def test_load_extension_with_error_handlers__expect_handlers_in_bot(bot: Bot): + ext = Extension(name="errors_ext") + + @ext.error(ValueError) + async def on_value_error(error): + pass + + bot.load_extension(ext) + + assert ValueError in bot._error_handlers + assert bot._error_handlers[ValueError] is on_value_error + + +def test_load_extension_with_scheduled_tasks__expect_jobs_in_bot_scheduler(bot: Bot): + ext = Extension(name="scheduler_ext") + + @ext.schedule("* * * * *") + async def periodic_task(): + pass + + bot.load_extension(ext) + + job_names = [j.name for j in bot.scheduler.jobs] + assert "periodic_task" in job_names + + +def test_load_extension_does_not_add_jobs_to_extension_scheduler(bot: Bot): + ext = Extension(name="scheduler_ext") + + @ext.schedule("* * * * *") + async def periodic_task(): + pass + + initial_bot_job_count = len(bot.scheduler.jobs) + bot.load_extension(ext) + + assert len(bot.scheduler.jobs) == initial_bot_job_count + 1 + + +def test_load_extension_logs_loading(bot: Bot, extension: Extension): + bot.load_extension(extension) + + bot.log.debug.assert_any_call("loaded extension '%s'", extension.name) + + +def test_load_extension_with_empty_extension__expect_no_commands_added(bot: Bot): + ext = Extension(name="empty_ext") + initial_command_count = len(bot.commands) + + bot.load_extension(ext) + + assert len(bot.commands) == initial_command_count + + +@pytest.fixture +def loaded_extension(bot: Bot) -> Extension: + ext = Extension(name="loaded_ext", prefix="!") + + @ext.command() + async def hello(ctx): + pass + + bot.load_extension(ext) + return ext + + +def test_unload_extension_with_valid_name__expect_extension_removed_from_registry( + bot: Bot, loaded_extension: Extension +): + bot.unload_extension(loaded_extension.name) + + assert loaded_extension.name not in bot.extensions + + +def test_unload_extension_with_unknown_name__expect_value_error(bot: Bot): + with pytest.raises(ValueError): + bot.unload_extension("nonexistent_ext") + + +def test_unload_extension_with_commands__expect_commands_removed_from_bot( + bot: Bot, loaded_extension: Extension +): + bot.unload_extension(loaded_extension.name) + + assert "hello" not in bot.commands + + +def test_unload_extension_with_group__expect_group_removed_from_bot(bot: Bot): + ext = Extension(name="group_ext", prefix="!") + + @ext.group() + async def math(ctx): + pass + + bot.load_extension(ext) + bot.unload_extension(ext.name) + + assert "math" not in bot.commands + + +def test_unload_extension_with_event_handlers__expect_handlers_removed_from_bot( + bot: Bot, +): + ext = Extension(name="events_ext") + + @ext.event + async def on_message(room, event): + pass + + bot.load_extension(ext) + bot.unload_extension(ext.name) + + assert on_message not in bot._event_handlers[RoomMessageText] + + +def test_unload_extension_with_multiple_event_handlers__expect_all_handlers_removed( + bot: Bot, +): + ext = Extension(name="multi_events_ext") + + @ext.event + async def on_message(room, event): + pass + + @ext.event(event_spec="on_message") + async def on_message_two(room, event): + pass + + bot.load_extension(ext) + bot.unload_extension(ext.name) + + assert on_message not in bot._event_handlers[RoomMessageText] + assert on_message_two not in bot._event_handlers[RoomMessageText] + + +def test_unload_extension_does_not_remove_other_extension_handlers(bot: Bot): + ext_a = Extension(name="ext_a") + ext_b = Extension(name="ext_b") + + @ext_a.event + async def on_message(room, event): + pass + + @ext_b.event(event_spec="on_message") + async def on_message_b(room, event): + pass + + bot.load_extension(ext_a) + bot.load_extension(ext_b) + bot.unload_extension(ext_a.name) + + assert on_message not in bot._event_handlers[RoomMessageText] + assert on_message_b in bot._event_handlers[RoomMessageText] + + +def test_unload_extension_with_checks__expect_checks_removed_from_bot(bot: Bot): + ext = Extension(name="checks_ext") + + @ext.check + async def only_admins(ctx): + return True + + bot.load_extension(ext) + bot.unload_extension(ext.name) + + assert only_admins not in bot._checks + + +def test_unload_extension_with_error_handlers__expect_handlers_removed_from_bot( + bot: Bot, +): + ext = Extension(name="errors_ext") + + @ext.error(ValueError) + async def on_value_error(error): + pass + + bot.load_extension(ext) + bot.unload_extension(ext.name) + + assert ValueError not in bot._error_handlers + + +def test_unload_extension_with_scheduled_tasks__expect_jobs_removed_from_bot_scheduler( + bot: Bot, +): + ext = Extension(name="scheduler_ext") + + @ext.schedule("* * * * *") + async def periodic_task(): + pass + + bot.load_extension(ext) + bot.unload_extension(ext.name) + + job_names = [j.name for j in bot.scheduler.jobs] + assert "periodic_task" not in job_names + + +def test_unload_extension_does_not_remove_other_extension_jobs(bot: Bot): + ext_a = Extension(name="scheduler_ext_a") + ext_b = Extension(name="scheduler_ext_b") + + @ext_a.schedule("* * * * *") + async def task_a(): + pass + + @ext_b.schedule("* * * * *") + async def task_b(): + pass + + bot.load_extension(ext_a) + bot.load_extension(ext_b) + bot.unload_extension(ext_a.name) + + job_names = [j.name for j in bot.scheduler.jobs] + assert "task_a" not in job_names + assert "task_b" in job_names + + +def test_unload_extension_logs_unloading(bot: Bot, loaded_extension: Extension): + bot.unload_extension(loaded_extension.name) + + bot.log.debug.assert_any_call("unloaded extension '%s'", loaded_extension.name) diff --git a/tests/test_context.py b/tests/test_context.py index cccf5d7..c5056eb 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -64,7 +64,6 @@ def test_context_initialization__expect_correct_properties(context, bot, room, e assert context.event is event assert context.body == "!echo hello world" assert context.sender == "@user:matrix.org" - assert context.prefix == "!" assert context.command is None assert context.subcommand is None assert context._args == ["!echo", "hello", "world"] diff --git a/tests/test_extension.py b/tests/test_extension.py new file mode 100644 index 0000000..8b96013 --- /dev/null +++ b/tests/test_extension.py @@ -0,0 +1,159 @@ +import pytest + +from matrix.extension import Extension + + +@pytest.fixture +def extension() -> Extension: + return Extension(name="test_ext", prefix="!") + + +def test_init_with_name_and_prefix__expect_attributes_set(): + ext = Extension(name="math", prefix="!") + + assert ext.name == "math" + assert ext.prefix == "!" + + +def test_init_with_name_only__expect_prefix_is_none(): + ext = Extension(name="math") + + assert ext.prefix is None + + +def test_init__expect_on_load_is_none(extension: Extension): + assert extension._on_load is None + + +def test_init__expect_on_unload_is_none(extension: Extension): + assert extension._on_unload is None + + +def test_init__expect_empty_commands(extension: Extension): + assert extension.commands == {} + + +def test_init__expect_empty_event_handlers(extension: Extension): + assert extension._event_handlers == {} + + +def test_init__expect_empty_error_handlers(extension: Extension): + assert extension._error_handlers == {} + + +def test_init__expect_empty_checks(extension: Extension): + assert extension._checks == [] + + +def test_on_load_with_sync_function__expect_handler_registered(extension: Extension): + @extension.on_load + def setup(): + pass + + assert extension._on_load is setup + + +def test_on_load_with_coroutine__expect_type_error(extension: Extension): + with pytest.raises(TypeError): + + @extension.on_load + async def setup(): + pass + + +def test_on_load_returns_the_original_function__expect_same_reference( + extension: Extension, +): + def setup(): + pass + + result = extension.on_load(setup) + + assert result is setup + + +def test_on_load_overwrites_previous_handler__expect_latest_handler( + extension: Extension, +): + @extension.on_load + def first(): + pass + + @extension.on_load + def second(): + pass + + assert extension._on_load is second + + +def test_load_with_registered_handler__expect_handler_called(extension: Extension): + called = [] + + @extension.on_load + def setup(): + called.append(True) + + extension.load() + + assert called == [True] + + +def test_load_with_no_handler__expect_no_error(extension: Extension): + extension.load() + + +def test_on_unload_with_sync_function__expect_handler_registered(extension: Extension): + @extension.on_unload + def teardown(): + pass + + assert extension._on_unload is teardown + + +def test_on_unload_with_coroutine__expect_type_error(extension: Extension): + with pytest.raises(TypeError): + + @extension.on_unload + async def teardown(): + pass + + +def test_on_unload_returns_the_original_function__expect_same_reference( + extension: Extension, +): + def teardown(): + pass + + result = extension.on_unload(teardown) + + assert result is teardown + + +def test_on_unload_overwrites_previous_handler__expect_latest_handler( + extension: Extension, +): + @extension.on_unload + def first(): + pass + + @extension.on_unload + def second(): + pass + + assert extension._on_unload is second + + +def test_unload_with_registered_handler__expect_handler_called(extension: Extension): + called = [] + + @extension.on_unload + def teardown(): + called.append(True) + + extension.unload() + + assert called == [True] + + +def test_unload_with_no_handler__expect_no_error(extension: Extension): + extension.unload() diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..56f326a --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,362 @@ +import pytest + +from nio import RoomMessageText, RoomMemberEvent, TypingNoticeEvent, ReactionEvent + +from matrix.registry import Registry +from matrix.command import Command +from matrix.group import Group +from matrix.errors import AlreadyRegisteredError + + +@pytest.fixture +def registry() -> Registry: + return Registry(name="test", prefix="!") + + +async def _dummy(ctx): + pass + + +async def _dummy_event(room, event): + pass + + +async def _dummy_error(error): + pass + + +async def _dummy_check(ctx): + return True + + +def test_register_command_with_decorator__expect_command_in_registry( + registry: Registry, +): + @registry.command(description="A test command") + async def ping(ctx): + pass + + assert "ping" in registry.commands + + +def test_register_command_with_custom_name__expect_custom_name_in_registry( + registry: Registry, +): + @registry.command(name="pong") + async def ping(ctx): + pass + + assert "pong" in registry.commands + assert "ping" not in registry.commands + + +def test_register_command_with_duplicate_name__expect_already_registered_error( + registry: Registry, +): + @registry.command() + async def ping(ctx): + pass + + with pytest.raises(AlreadyRegisteredError): + + @registry.command(name="ping") + async def ping2(ctx): + pass + + +def test_register_command_returns_command_instance__expect_command_type( + registry: Registry, +): + @registry.command() + async def ping(ctx): + pass + + assert isinstance(registry.commands["ping"], Command) + + +def test_register_command_directly_with_valid_command__expect_command_in_registry( + registry: Registry, +): + cmd = Command(_dummy, name="direct", prefix="!") + registry.register_command(cmd) + + assert "direct" in registry.commands + + +def test_register_command_directly_with_duplicate__expect_already_registered_error( + registry: Registry, +): + cmd = Command(_dummy, name="dupe", prefix="!") + registry.register_command(cmd) + + with pytest.raises(AlreadyRegisteredError): + registry.register_command(cmd) + + +def test_register_group_with_decorator__expect_group_in_registry(registry: Registry): + @registry.group(description="A test group") + async def math(ctx): + pass + + assert "math" in registry.commands + + +def test_register_group_with_custom_name__expect_custom_name_in_registry( + registry: Registry, +): + @registry.group(name="utils") + async def utility(ctx): + pass + + assert "utils" in registry.commands + assert "utility" not in registry.commands + + +def test_register_group_with_duplicate_name__expect_already_registered_error( + registry: Registry, +): + @registry.group() + async def math(ctx): + pass + + with pytest.raises(AlreadyRegisteredError): + + @registry.group(name="math") + async def math2(ctx): + pass + + +def test_register_group_returns_group_instance__expect_group_type(registry: Registry): + @registry.group() + async def math(ctx): + pass + + assert isinstance(registry.commands["math"], Group) + + +# --------------------------------------------------------------------------- +# event() +# --------------------------------------------------------------------------- + + +def test_register_event_by_function_name__expect_handler_registered(registry: Registry): + @registry.event + async def on_message(room, event): + pass + + assert on_message in registry._event_handlers[RoomMessageText] + + +def test_register_event_with_string_spec__expect_handler_registered(registry: Registry): + @registry.event(event_spec="on_typing") + async def handle_typing(room, event): + pass + + assert handle_typing in registry._event_handlers[TypingNoticeEvent] + + +def test_register_event_with_type_spec__expect_handler_registered(registry: Registry): + @registry.event(event_spec=RoomMemberEvent) + async def handle_member(room, event): + pass + + assert handle_member in registry._event_handlers[RoomMemberEvent] + + +def test_register_event_with_unknown_name__expect_value_error(registry: Registry): + with pytest.raises(ValueError): + + @registry.event + async def on_unknown_event(room, event): + pass + + +def test_register_event_with_unknown_string_spec__expect_value_error( + registry: Registry, +): + with pytest.raises(ValueError): + + @registry.event(event_spec="on_nonexistent") + async def handler(room, event): + pass + + +def test_register_event_with_non_coroutine__expect_type_error(registry: Registry): + with pytest.raises(TypeError): + + @registry.event + def on_message(room, event): # not async + pass + + +def test_register_multiple_handlers_for_same_event__expect_all_registered( + registry: Registry, +): + @registry.event + async def on_message(room, event): + pass + + @registry.event(event_spec="on_message") + async def on_message_two(room, event): + pass + + assert len(registry._event_handlers[RoomMessageText]) == 2 + + +# --------------------------------------------------------------------------- +# register_event() +# --------------------------------------------------------------------------- + + +def test_register_event_directly_with_valid_handler__expect_handler_in_registry( + registry: Registry, +): + registry.register_event(ReactionEvent, _dummy_event) + + assert _dummy_event in registry._event_handlers[ReactionEvent] + + +# --------------------------------------------------------------------------- +# check() +# --------------------------------------------------------------------------- + + +def test_register_check_with_coroutine__expect_check_in_list(registry: Registry): + registry.check(_dummy_check) + + assert _dummy_check in registry._checks + + +def test_register_check_with_non_coroutine__expect_type_error(registry: Registry): + def sync_check(ctx): + return True + + with pytest.raises(TypeError): + registry.check(sync_check) + + +def test_register_check_as_decorator__expect_check_in_list(registry: Registry): + @registry.check + async def only_admins(ctx): + return True + + assert only_admins in registry._checks + + +def test_register_multiple_checks__expect_all_checks_in_list(registry: Registry): + @registry.check + async def check_one(ctx): + return True + + @registry.check + async def check_two(ctx): + return True + + assert check_one in registry._checks + assert check_two in registry._checks + + +def test_register_schedule_with_valid_cron__expect_job_in_scheduler(registry: Registry): + @registry.schedule("0 9 * * *") + async def morning_task(): + pass + + jobs = registry._scheduler.jobs + assert any(j.func is morning_task for j in jobs) + + +def test_register_schedule_with_non_coroutine__expect_type_error(registry: Registry): + with pytest.raises(TypeError): + + @registry.schedule("0 9 * * *") + def not_async(): + pass + + +def test_register_multiple_schedules__expect_all_jobs_in_scheduler(registry: Registry): + @registry.schedule("0 9 * * *") + async def morning(): + pass + + @registry.schedule("0 18 * * *") + async def evening(): + pass + + funcs = [j.func for j in registry._scheduler.jobs] + assert morning in funcs + assert evening in funcs + + +def test_register_error_handler_with_exception_type__expect_handler_in_dict( + registry: Registry, +): + @registry.error(ValueError) + async def on_value_error(error): + pass + + assert registry._error_handlers[ValueError] is on_value_error + + +def test_register_generic_error_handler__expect_on_error_set(registry: Registry): + @registry.error() + async def on_any_error(error): + pass + + assert registry._on_error is on_any_error + + +def test_register_error_handler_with_non_coroutine__expect_type_error( + registry: Registry, +): + with pytest.raises(TypeError): + + @registry.error(ValueError) + def sync_handler(error): + pass + + +def test_register_multiple_typed_error_handlers__expect_all_in_dict(registry: Registry): + @registry.error(ValueError) + async def on_value_error(error): + pass + + @registry.error(RuntimeError) + async def on_runtime_error(error): + pass + + assert ValueError in registry._error_handlers + assert RuntimeError in registry._error_handlers + + +def test_register_error_handler_overwrites_previous_handler__expect_latest_handler( + registry: Registry, +): + @registry.error(ValueError) + async def first_handler(error): + pass + + @registry.error(ValueError) + async def second_handler(error): + pass + + assert registry._error_handlers[ValueError] is second_handler + + +def test_commands_property_with_empty_registry__expect_empty_dict(registry: Registry): + assert registry.commands == {} + + +def test_commands_property_reflects_registered_commands__expect_correct_entries( + registry: Registry, +): + @registry.command() + async def foo(ctx): + pass + + @registry.group() + async def bar(ctx): + pass + + assert "foo" in registry.commands + assert "bar" in registry.commands + assert len(registry.commands) == 2