diff --git a/AGENTS.md b/AGENTS.md index ad23b2d..11d8271 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -180,7 +180,8 @@ uv run mypy statemachine/ tests/ - **Formatter/Linter:** ruff (line length 99, target Python 3.9) - **Rules:** pycodestyle, pyflakes, isort, pyupgrade, flake8-comprehensions, flake8-bugbear, flake8-pytest-style -- **Imports:** single-line, sorted by isort +- **Imports:** single-line, sorted by isort. **Always prefer top-level imports** — only use + lazy (in-function) imports when strictly necessary to break circular dependencies - **Docstrings:** Google convention - **Naming:** PascalCase for classes, snake_case for functions/methods, UPPER_SNAKE_CASE for constants - **Type hints:** used throughout; `TYPE_CHECKING` for circular imports @@ -188,13 +189,19 @@ uv run mypy statemachine/ tests/ ## Design principles -- **Follow SOLID principles.** In particular: +- **Use GRASP/SOLID patterns to guide decisions.** When refactoring or designing, explicitly + apply patterns like Information Expert, Single Responsibility, and Law of Demeter to decide + where logic belongs — don't just pick a convenient location. + - **Information Expert (GRASP):** Place logic in the module/class that already has the + knowledge it needs. If a method computes a result, it should signal or return it rather + than forcing another method to recompute the same thing. - **Law of Demeter:** Methods should depend only on the data they need, not on the objects that contain it. Pass the specific value (e.g., a `Future`) rather than the parent object (e.g., `TriggerData`) — this reduces coupling and removes the need for null-checks on intermediate accessors. - **Single Responsibility:** Each module, class, and function should have one clear reason - to change. + to change. Functions and types belong in the module that owns their domain (e.g., + event-name helpers belong in `event.py`, not in `factory.py`). - **Interface Segregation:** Depend on narrow interfaces. If a helper only needs one field from a dataclass, accept that field directly. - **Decouple infrastructure from domain:** Modules like `signature.py` and `dispatcher.py` are diff --git a/statemachine/engines/async_.py b/statemachine/engines/async_.py index 836dca5..e7b2905 100644 --- a/statemachine/engines/async_.py +++ b/statemachine/engines/async_.py @@ -402,6 +402,7 @@ async def processing_loop( # noqa: C901 # Spawn invoke handlers for states entered during this macrostep. await self._invoke_manager.spawn_pending_async() + self._check_root_final_state() # Phase 2: remaining internal events while not self.internal_queue.is_empty(): # pragma: no cover diff --git a/statemachine/engines/base.py b/statemachine/engines/base.py index 990d530..049f0e4 100644 --- a/statemachine/engines/base.py +++ b/statemachine/engines/base.py @@ -99,6 +99,7 @@ def __init__(self, sm: "StateChart"): self._macrostep_count: int = 0 self._microstep_count: int = 0 self._log_id = f"[{type(sm).__name__}]" + self._root_parallel_final_pending: "State | None" = None def empty(self): # pragma: no cover return self.external_queue.is_empty() @@ -614,6 +615,8 @@ def _handle_final_state(self, target: State, on_entry_result: list): BoundEvent(f"done.state.{grandparent.id}", _sm=self.sm, internal=True).put( *donedata_args, **donedata_kwargs ) + if grandparent.parent is None: + self._root_parallel_final_pending = grandparent def _enter_states( # noqa: C901 self, @@ -908,6 +911,29 @@ def add_ancestor_states_to_enter( default_history_content, ) + def _check_root_final_state(self): + """SCXML spec: terminate when the root configuration is final. + + For top-level parallel states, the machine terminates when all child + regions have reached their final states — equivalent to the SCXML + algorithm's ``isInFinalState(scxml_element)`` check. + + Uses a flag set by ``_handle_final_state`` (Information Expert) to + avoid re-scanning top-level states on every macrostep. The flag is + deferred because ``done.state`` events queued by ``_handle_final_state`` + may trigger transitions that exit the parallel, so we verify the + parallel is still in the configuration before terminating. + """ + state = self._root_parallel_final_pending + if state is None: + return + self._root_parallel_final_pending = None + # A done.state transition may have exited the parallel; verify it's + # still in the configuration before terminating. + if state in self.sm.configuration and self.is_in_final_state(state): + self._invoke_manager.cancel_all() + self.running = False + def is_in_final_state(self, state: State) -> bool: if state.is_compound: return any(s.final and s in self.sm.configuration for s in state.states) diff --git a/statemachine/engines/sync.py b/statemachine/engines/sync.py index d2f9734..6c85650 100644 --- a/statemachine/engines/sync.py +++ b/statemachine/engines/sync.py @@ -118,6 +118,7 @@ def processing_loop(self, caller_future=None): # noqa: C901 # Spawn invoke handlers for states entered during this macrostep. self._invoke_manager.spawn_pending_sync() + self._check_root_final_state() # Process remaining internal events before external events. # Note: the macrostep loop above already drains the internal queue, diff --git a/statemachine/event.py b/statemachine/event.py index cd55c8c..91c9805 100644 --- a/statemachine/event.py +++ b/statemachine/event.py @@ -16,6 +16,24 @@ from .transition_list import TransitionList +def _expand_event_id(key: str) -> str: + """Apply naming conventions for special event prefixes. + + Converts underscore-based Python attribute names to their dot-separated + event equivalents. Returns a space-separated string so ``Events.add()`` + registers both forms. + """ + if key.startswith("done_invoke_"): + suffix = key[len("done_invoke_") :] + return f"{key} done.invoke.{suffix}" + if key.startswith("done_state_"): + suffix = key[len("done_state_") :] + return f"{key} done.state.{suffix}" + if key.startswith("error_"): + return f"{key} {key.replace('_', '.')}" + return key + + _event_data_kwargs = { "event_data", "machine", diff --git a/statemachine/factory.py b/statemachine/factory.py index b7f71ba..d470e3b 100644 --- a/statemachine/factory.py +++ b/statemachine/factory.py @@ -9,6 +9,7 @@ from .callbacks import CallbackPriority from .callbacks import CallbackSpecList from .event import Event +from .event import _expand_event_id from .exceptions import InvalidDefinition from .graph import disconnected_states from .graph import iterate_states @@ -271,29 +272,13 @@ def add_from_attributes(cls, attrs): # noqa: C901 if isinstance(value, State): cls.add_state(key, value) elif isinstance(value, (Transition, TransitionList)): - event_id = key - if key.startswith("error_"): - event_id = f"{key} {key.replace('_', '.')}" - elif key.startswith("done_invoke_"): - suffix = key[len("done_invoke_") :] - event_id = f"{key} done.invoke.{suffix}" - elif key.startswith("done_state_"): - suffix = key[len("done_state_") :] - event_id = f"{key} done.state.{suffix}" + event_id = _expand_event_id(key) cls.add_event(event=Event(transitions=value, id=event_id, name=key)) elif isinstance(value, (Event,)): if value._has_real_id: event_id = value.id - elif key.startswith("error_"): - event_id = f"{key} {key.replace('_', '.')}" - elif key.startswith("done_invoke_"): - suffix = key[len("done_invoke_") :] - event_id = f"{key} done.invoke.{suffix}" - elif key.startswith("done_state_"): - suffix = key[len("done_state_") :] - event_id = f"{key} done.state.{suffix}" else: - event_id = key + event_id = _expand_event_id(key) new_event = Event( transitions=value._transitions, id=event_id, diff --git a/statemachine/state.py b/statemachine/state.py index 3a85a89..32c436f 100644 --- a/statemachine/state.py +++ b/statemachine/state.py @@ -4,11 +4,13 @@ from typing import Dict from typing import Generator from typing import List +from typing import cast from weakref import ref from .callbacks import CallbackGroup from .callbacks import CallbackPriority from .callbacks import CallbackSpecList +from .event import _expand_event_id from .exceptions import InvalidDefinition from .exceptions import StateMachineError from .i18n import _ @@ -32,8 +34,10 @@ def __call__(self, *states: "State", **kwargs): class _ToState(_TransitionBuilder): - def __call__(self, *states: "State | None", **kwargs): - transitions = TransitionList(Transition(self._state, state, **kwargs) for state in states) + def __call__(self, *states: "State | NestedStateFactory | None", **kwargs): + transitions = TransitionList( + Transition(self._state, cast("State | None", state), **kwargs) for state in states + ) self._state.transitions.add_transitions(transitions) return transitions @@ -43,11 +47,12 @@ def any(self, **kwargs): """Create transitions from all non-final states (reversed).""" return self.__call__(AnyState(), **kwargs) - def __call__(self, *states: "State", **kwargs): + def __call__(self, *states: "State | NestedStateFactory", **kwargs): transitions = TransitionList() for origin in states: - transition = Transition(origin, self._state, **kwargs) - origin.transitions.add_transitions(transition) + state = cast(State, origin) + transition = Transition(state, self._state, **kwargs) + state.transitions.add_transitions(transition) transitions.add_transitions(transition) return transitions @@ -78,7 +83,7 @@ def __new__( # type: ignore [misc] value._set_id(key) states.append(value) elif isinstance(value, TransitionList): - value.add_event(key) + value.add_event(_expand_event_id(key)) elif callable(value): callbacks[key] = value @@ -87,7 +92,7 @@ def __new__( # type: ignore [misc] ) @classmethod - def to(cls, *args: "State", **kwargs) -> "_ToState": # pragma: no cover + def to(cls, *args: "State | NestedStateFactory", **kwargs) -> "_ToState": # pragma: no cover """Create transitions to the given target states. .. note: This method is only a type hint for mypy. The actual implementation belongs to the :ref:`State` class. @@ -95,7 +100,9 @@ def to(cls, *args: "State", **kwargs) -> "_ToState": # pragma: no cover return _ToState(State()) @classmethod - def from_(cls, *args: "State", **kwargs) -> "_FromState": # pragma: no cover + def from_( # pragma: no cover + cls, *args: "State | NestedStateFactory", **kwargs + ) -> "_FromState": """Create transitions from the given target states (reversed). .. note: This method is only a type hint for mypy. The actual implementation belongs to the :ref:`State` class. diff --git a/tests/examples/ai_shell_machine.py b/tests/examples/ai_shell_machine.py new file mode 100644 index 0000000..5fc8c4a --- /dev/null +++ b/tests/examples/ai_shell_machine.py @@ -0,0 +1,568 @@ +""" +AI Shell -- coding assistant +============================= + +A feature-rich coding assistant powered by python-statemachine. + +A standalone interactive CLI that uses the OpenAI SDK for LLM calls with +tool_use. Demonstrates **parallel states**, **compound states**, +**HistoryState**, **eventless transitions**, **In() guards**, +**done.state**, **error.execution**, **invoke**, and **raise_()** — all +working together in a practical application. + +.. warning:: + + This example grants an LLM the ability to read files, list directories, + and execute shell commands — which can be very useful for exploring a + codebase, running tests, or automating tasks. However, the actual behavior + depends on the prompts you send and the model you use, and unintended + actions (e.g., deleting files or exposing credentials) are possible. + + **Use at your own risk.** This code is provided for educational and + demonstration purposes only. The authors and contributors of + python-statemachine accept no liability for any damage or data loss. + Consider running it in an isolated environment (e.g., a container or + virtual machine) and avoid using elevated privileges. + +Usage:: + + # Standalone (installs deps from PyPI) + OPENAI_API_KEY=sk-... uv run examples/ai_shell.py + + # From the repo (uses local statemachine) + OPENAI_API_KEY=sk-... uv run --with openai python examples/ai_shell.py + + # Debug mode — shows engine macro/micro step log on stderr + OPENAI_API_KEY=sk-... uv run --with openai python examples/ai_shell.py -v + +""" +# /// script +# requires-python = ">=3.9" +# dependencies = [ +# "openai", +# "python-statemachine", +# ] +# /// + +import itertools +import json +import logging +import os +import random +import subprocess +import sys +import threading + +from statemachine import HistoryState +from statemachine import State +from statemachine import StateChart + +if "-v" in sys.argv or "--verbose" in sys.argv: + logging.basicConfig(level=logging.DEBUG, format="%(name)s %(message)s", stream=sys.stderr) + +# --------------------------------------------------------------------------- +# Tool definitions (OpenAI function calling format) +# --------------------------------------------------------------------------- + +TOOLS = [ + { + "type": "function", + "function": { + "name": "read_file", + "description": ( + "Read the contents of a file at the given path. " + "Returns the file contents (truncated to 10 000 characters)." + ), + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "Absolute or relative file path."}, + }, + "required": ["path"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "list_files", + "description": "List files and directories at the given path.", + "parameters": { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path. Defaults to '.' (current directory).", + }, + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "run_command", + "description": ( + "Run a shell command and return its stdout and stderr. " + "Commands are executed with a 30-second timeout." + ), + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute.", + }, + }, + "required": ["command"], + }, + }, + }, +] + +SYSTEM_PROMPT = ( + "You are a helpful coding assistant. You can read files, list directory contents, " + "and run shell commands to help the user with their tasks. Be concise and practical. " + "You also have tools to introspect the state machine that powers this shell — use them " + "when the user asks about the current state, allowed transitions, or other metadata." +) + +MAX_FILE_CHARS = 10_000 +COMMAND_TIMEOUT = 30 +MAX_RETRIES = 3 + +# --------------------------------------------------------------------------- +# Spinner animation +# --------------------------------------------------------------------------- + +SPINNER_CHARS = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏" + +SPINNER_MESSAGES = [ + "thinking...", + "contemplating...", + "cooking something up...", + "making something special...", + "crunching the data...", + "pondering...", + "culminating...", + "brewing ideas...", + "connecting the dots...", + "almost there...", +] + + +class Spinner: + """Animated terminal spinner shown while the LLM is working.""" + + def __init__(self): + self._stop = threading.Event() + self._thread: "threading.Thread | None" = None + + def __enter__(self): + self._stop.clear() + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + return self + + def __exit__(self, *args): + self._stop.set() + if self._thread is not None: + self._thread.join(timeout=2) + + def _run(self): + messages = SPINNER_MESSAGES[:] + random.shuffle(messages) + msg_cycle = itertools.cycle(messages) + char_cycle = itertools.cycle(SPINNER_CHARS) + msg = next(msg_cycle) + tick = 0 + while not self._stop.wait(timeout=0.08): + char = next(char_cycle) + if tick > 0 and tick % 30 == 0: + msg = next(msg_cycle) + line = f" {char} {msg}" + print(f"\r{line:<50}", end="", flush=True) + tick += 1 + print(f"\r{'':50}\r", end="", flush=True) + + +# --------------------------------------------------------------------------- +# Tool execution +# --------------------------------------------------------------------------- + + +def _tool_read_file(input_data: dict) -> str: + path = input_data["path"] + try: + with open(path) as f: + content = f.read(MAX_FILE_CHARS + 1) + if len(content) > MAX_FILE_CHARS: + content = content[:MAX_FILE_CHARS] + "\n... (truncated)" + return content + except OSError as e: + return f"Error reading file: {e}" + + +def _tool_list_files(input_data: dict) -> str: + path = input_data.get("path", ".") + try: + entries = sorted(os.listdir(path)) + return "\n".join(entries) + except OSError as e: + return f"Error listing directory: {e}" + + +def _tool_run_command(input_data: dict) -> str: + command = input_data["command"] + try: + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=COMMAND_TIMEOUT, + ) + output = "" + if result.stdout: + output += result.stdout + if result.stderr: + output += ("" if not output else "\n") + f"stderr: {result.stderr}" + if result.returncode != 0: + output += f"\n(exit code {result.returncode})" + return output or "(no output)" + except subprocess.TimeoutExpired: + return f"Error: command timed out after {COMMAND_TIMEOUT}s" + except OSError as e: + return f"Error running command: {e}" + + +TOOL_HANDLERS = { + "read_file": _tool_read_file, + "list_files": _tool_list_files, + "run_command": _tool_run_command, +} + + +# --------------------------------------------------------------------------- +# State machine introspection tools +# --------------------------------------------------------------------------- + + +def _tool_sm_configuration(sm, input_data: dict) -> str: + states = sorted(sm.configuration_values) + return json.dumps({"active_states": states}) + + +def _tool_sm_enabled_events(sm, input_data: dict) -> str: + events = sorted({e.name for e in sm.enabled_events()}) + return json.dumps({"enabled_events": events}) + + +def _tool_sm_macrostep_count(sm, input_data: dict) -> str: + return json.dumps({"macrostep_count": sm._engine._macrostep_count}) + + +def _tool_sm_states(sm, input_data: dict) -> str: + all_states = sorted(sm.states_map.keys()) + return json.dumps({"all_states": all_states}) + + +SM_TOOL_HANDLERS = { + "sm_configuration": _tool_sm_configuration, + "sm_enabled_events": _tool_sm_enabled_events, + "sm_macrostep_count": _tool_sm_macrostep_count, + "sm_states": _tool_sm_states, +} + +SM_TOOLS = [ + { + "type": "function", + "function": { + "name": "sm_configuration", + "description": ( + "Get the current active states (configuration) of the state machine. " + "Returns which states are currently active." + ), + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "sm_enabled_events", + "description": ( + "List events (transitions) that can be triggered from the current " + "state machine configuration, considering guard conditions." + ), + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "sm_macrostep_count", + "description": ( + "Get the current macrostep counter of the state machine engine. " + "A macrostep is the full processing cycle for one external event." + ), + "parameters": {"type": "object", "properties": {}}, + }, + }, + { + "type": "function", + "function": { + "name": "sm_states", + "description": ( + "List all states defined in the state machine, including nested states " + "inside compound and parallel states." + ), + "parameters": {"type": "object", "properties": {}}, + }, + }, +] + + +def execute_tool(name: str, input_data: dict, sm=None) -> str: + sm_handler = SM_TOOL_HANDLERS.get(name) + if sm_handler is not None: + return sm_handler(sm, input_data) + handler = TOOL_HANDLERS.get(name) + if handler is None: + return f"Unknown tool: {name}" + return handler(input_data) + + +# --------------------------------------------------------------------------- +# State machine +# --------------------------------------------------------------------------- + +GOODBYE_WORDS = {"bye", "exit", "quit"} + + +class AIShell(StateChart): + """An agentic coding assistant as a StateChart. + + Demonstrates parallel states, compound states, HistoryState, eventless + transitions, In() guards, done.state, error.execution, invoke, and + raise_() — all in a practical application. + + States:: + + session (Parallel) + ├── conversation (Compound) + │ ├── idle (initial) + │ ├── processing (Compound) + │ │ ├── thinking (initial, invoke) ← API call + spinner + │ │ ├── using_tool (invoke) ← tool execution + │ │ ├── done (final) + │ │ └── h = HistoryState(deep) ← for error retry + │ ├── responding + │ ├── recovering ← error.execution handler + │ └── conversation_ended (final) + └── context_tracker (Compound) + ├── fresh (initial) + ├── active (≥4 messages) + ├── deep (≥20 messages, shows warning) + └── tracker_done (final) + + """ + + error_on_execution = True + + # --- Top-level parallel state: two independent regions --- + + class session(State.Parallel): + class conversation(State.Compound): + idle = State("Idle", initial=True) + + class processing(State.Compound): + thinking = State("Thinking", initial=True) + using_tool = State("Using Tool") + done = State("Done", final=True) + h = HistoryState(type="deep") + + # Invoke results route automatically + done_invoke_thinking = thinking.to( + using_tool, cond="has_tool_calls" + ) | thinking.to(done) + done_invoke_using_tool = using_tool.to(thinking) + + responding = State("Responding") + recovering = State("Recovering") + conversation_ended = State("Ended", final=True) + + # Named events + user_message = idle.to(processing, cond="is_not_goodbye") | idle.to( + conversation_ended, cond="is_goodbye" + ) + done_state_processing = processing.to(responding) + error_execution = processing.to(recovering) + + # Eventless transitions + responding.to(idle) + recovering.to(processing.h, cond="can_retry") + recovering.to(idle, cond="cannot_retry") + + class context_tracker(State.Compound): + fresh = State("Fresh", initial=True) + active = State("Active") + deep = State("Deep") + tracker_done = State(final=True) + + # Eventless: track conversation depth + fresh.to(active, cond="is_active_context") + active.to(deep, cond="is_deep_context") + + # Eventless + In() guard: follow conversation end + fresh.to(tracker_done, cond="In('conversation_ended')") + active.to(tracker_done, cond="In('conversation_ended')") + deep.to(tracker_done, cond="In('conversation_ended')") + + # --- Initialization --- + + def __init__(self): + from openai import OpenAI # type: ignore[import-not-found] + + self.client = OpenAI() + self.messages: list = [{"role": "system", "content": SYSTEM_PROMPT}] + self._last_text: str = "" + self._retries: int = 0 + self._ready = threading.Event() + super().__init__() + + # --- Guards --- + + def is_goodbye(self, text="", **kwargs) -> bool: + return text.strip().lower() in GOODBYE_WORDS + + def is_not_goodbye(self, text="", **kwargs) -> bool: + return not self.is_goodbye(text=text) + + def can_retry(self, **kwargs) -> bool: + return self._retries < MAX_RETRIES + + def cannot_retry(self, **kwargs) -> bool: + return not self.can_retry() + + def is_active_context(self, **kwargs) -> bool: + return len(self.messages) >= 5 + + def is_deep_context(self, **kwargs) -> bool: + return len(self.messages) >= 20 + + # --- Callbacks --- + + def on_user_message(self, text, **kwargs): + """Append the user's message to conversation history.""" + self.messages.append({"role": "user", "content": text}) + + def has_tool_calls(self, data=None, **kwargs) -> bool: + """Guard: check if the API response contains tool calls.""" + return bool(getattr(data, "tool_calls", None)) + + def on_invoke_thinking(self, **kwargs): + """Call the OpenAI API with a spinner animation. Returns the message.""" + with Spinner(): + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=self.messages, + tools=TOOLS + SM_TOOLS, + ) + + message = response.choices[0].message + self.messages.append(message) + + if not message.tool_calls: + self._last_text = message.content or "" + + return message + + def on_invoke_using_tool(self, data, **kwargs): + """Execute tool calls from the API response.""" + for call in data.tool_calls: + args = json.loads(call.function.arguments) + print(f" [tool] {call.function.name}({json.dumps(args)})") + result = execute_tool(call.function.name, args, sm=self) + self.messages.append( + { + "role": "tool", + "tool_call_id": call.id, + "content": result, + } + ) + + def on_enter_responding(self, **kwargs): + """Print the assistant's text response.""" + if self._last_text: + print(f"\n{self._last_text}") + self._last_text = "" + + def on_enter_idle(self, **kwargs): + """Reset retry counter and signal readiness when returning to idle.""" + self._retries = 0 + self._ready.set() + + def on_enter_recovering(self, **kwargs): + """Handle API errors with retry logic (via error.execution).""" + self._retries += 1 + if self._retries < MAX_RETRIES: + print(f"\n [error] API call failed, retrying ({self._retries}/{MAX_RETRIES})...") + else: + print(f"\n [error] API call failed after {MAX_RETRIES} attempts. Giving up.") + + def on_enter_deep(self, **kwargs): + """Warn when conversation context is getting long.""" + print(" [context] Conversation is getting long — responses may degrade.") + + def on_enter_conversation_ended(self, **kwargs): + print("\nGoodbye!") + + +# --------------------------------------------------------------------------- +# Main loop +# --------------------------------------------------------------------------- + + +def _check_openai(): + """Return True if the openai package is available.""" + try: + import openai # noqa: F401 + + return True + except ImportError: + return False + + +def main(): + if not _check_openai(): + print("This example requires the 'openai' package.") + print("Install it with: pip install openai") + return + + print("AI Shell") + print("A coding assistant powered by python-statemachine + OpenAI.") + print("Type 'bye', 'exit', or 'quit' to end. Ctrl+C to interrupt.") + if "-v" in sys.argv or "--verbose" in sys.argv: + print("Debug mode enabled — engine log is written to stderr.\n") + else: + print("Tip: run with -v to see engine macro/micro step debug log.\n") + + try: + sm = AIShell() + except Exception as e: + sys.exit(f"Error initializing: {e}") + + while not sm.is_terminated: + sm._ready.wait() + sm._ready.clear() + try: + text = input("> ") + except (EOFError, KeyboardInterrupt): + print() + break + if text.strip(): + sm.send("user_message", text=text) + + +if __name__ == "__main__" and "sphinx" not in sys.modules: # pragma: no cover + main() diff --git a/tests/examples/statechart_compound_machine.py b/tests/examples/statechart_compound_machine.py index 613172f..805437c 100644 --- a/tests/examples/statechart_compound_machine.py +++ b/tests/examples/statechart_compound_machine.py @@ -38,7 +38,7 @@ class rivendell(State.Compound): destination = State("Quest continues", final=True) depart_shire = shire.to(wilderness) - arrive_rivendell = wilderness.to(rivendell) # type: ignore[arg-type] + arrive_rivendell = wilderness.to(rivendell) done_state_rivendell = rivendell.to(destination) diff --git a/tests/examples/statechart_delayed_machine.py b/tests/examples/statechart_delayed_machine.py index f3f5f4d..d4e9eb0 100644 --- a/tests/examples/statechart_delayed_machine.py +++ b/tests/examples/statechart_delayed_machine.py @@ -100,7 +100,7 @@ class siege(State.Compound): city_falls = State("Minas Tirith has fallen!", final=True) # External event to kick off the quest - start = idle.to(quest) # type: ignore[arg-type] + start = idle.to(quest) # Eventless transitions -- checked automatically each macrostep quest.to(rohan_rides, cond="In('rohan_reached')") diff --git a/tests/examples/statechart_parallel_machine.py b/tests/examples/statechart_parallel_machine.py index d5ba271..0dd2110 100644 --- a/tests/examples/statechart_parallel_machine.py +++ b/tests/examples/statechart_parallel_machine.py @@ -44,7 +44,7 @@ class gandalfs_defense(State.Compound): ride_to_gondor = rohan.to(gondor) peace = State("Peace in Middle-earth", final=True) - done_state_war = war.to(peace) # type: ignore[arg-type] + done_state_war = war.to(peace) # %% diff --git a/tests/test_statechart_compound.py b/tests/test_statechart_compound.py index 3908d8d..49757ae 100644 --- a/tests/test_statechart_compound.py +++ b/tests/test_statechart_compound.py @@ -266,6 +266,70 @@ def on_enter_troubled(self): await sm_runner.send(sm, "darken") assert log == ["entered troubled times"] + async def test_done_state_inside_compound(self, sm_runner): + """done_state_* bare transition inside a compound body registers done.state.* event.""" + + class InnerDoneState(StateChart): + class outer(State.Compound): + class inner(State.Compound): + start = State(initial=True) + end = State(final=True) + + finish = start.to(end) + + after_inner = State(final=True) + done_state_inner = inner.to(after_inner) + + victory = State(final=True) + done_state_outer = outer.to(victory) + + sm = await sm_runner.start(InnerDoneState) + assert "start" in sm.configuration_values + + await sm_runner.send(sm, "finish") + assert {"victory"} == set(sm.configuration_values) + + async def test_done_invoke_inside_compound(self, sm_runner): + """done_invoke_* bare transition inside a compound registers done.invoke.* event.""" + + class InvokeInCompound(StateChart): + class wrapper(State.Compound): + loading = State(initial=True, invoke=lambda: 42) + loaded = State(final=True) + + done_invoke_loading = loading.to(loaded) + + done = State(final=True) + done_state_wrapper = wrapper.to(done) + + sm = await sm_runner.start(InvokeInCompound) + await sm_runner.sleep(0.15) + await sm_runner.processing_loop(sm) + assert {"done"} == set(sm.configuration_values) + + async def test_error_execution_inside_compound(self, sm_runner): + """error_execution inside a compound body registers error.execution event.""" + + def raise_error(): + raise RuntimeError("boom") + + class ErrorInCompound(StateChart): + class active(State.Compound): + ok = State(initial=True) + failing = State() + + trigger = ok.to(failing, on=raise_error) + + errored = State() + error_execution = failing.to(errored) + + done = State(final=True) + finish = active.to(done) + + sm = await sm_runner.start(ErrorInCompound) + await sm_runner.send(sm, "trigger") + assert "errored" in sm.configuration_values + def test_compound_state_name_attribute(self): """The name= kwarg in class syntax sets the state name.""" diff --git a/tests/test_statechart_parallel.py b/tests/test_statechart_parallel.py index 835451d..6e87d42 100644 --- a/tests/test_statechart_parallel.py +++ b/tests/test_statechart_parallel.py @@ -82,10 +82,16 @@ async def test_exit_parallel_exits_all_regions(self, sm_runner): class WarWithExit(StateChart): class war(State.Parallel): class front_a(State.Compound): - fighting = State(initial=True, final=True) + fighting = State(initial=True) + won = State(final=True) + + win_a = fighting.to(won) class front_b(State.Compound): - holding = State(initial=True, final=True) + holding = State(initial=True) + held = State(final=True) + + hold_b = holding.to(held) peace = State(final=True) truce = war.to(peace) @@ -191,3 +197,79 @@ async def test_transition_within_compound_inside_parallel( vals = set(sm.configuration_values) assert "mount_doom" in vals assert "ranger" in vals # other regions unchanged + + async def test_top_level_parallel_terminates_when_all_children_final(self, sm_runner): + """A root parallel terminates when all regions reach final states.""" + + class Session(StateChart): + class session(State.Parallel): + class ui(State.Compound): + active = State(initial=True) + closed = State(final=True) + + close_ui = active.to(closed) + + class backend(State.Compound): + running = State(initial=True) + stopped = State(final=True) + + stop_backend = running.to(stopped) + + sm = await sm_runner.start(Session) + assert sm.is_terminated is False + + await sm_runner.send(sm, "close_ui") + assert sm.is_terminated is False # one region still active + + await sm_runner.send(sm, "stop_backend") + assert sm.is_terminated is True + + async def test_top_level_parallel_done_state_fires_before_termination(self, sm_runner): + """done.state fires and transitions before root-final check terminates.""" + + class Session(StateChart): + class session(State.Parallel): + class ui(State.Compound): + active = State(initial=True) + closed = State(final=True) + + close_ui = active.to(closed) + + class backend(State.Compound): + running = State(initial=True) + stopped = State(final=True) + + stop_backend = running.to(stopped) + + finished = State(final=True) + done_state_session = session.to(finished) + + sm = await sm_runner.start(Session) + await sm_runner.send(sm, "close_ui") + await sm_runner.send(sm, "stop_backend") + # done.state.session fires, transitions to finished, then terminates + assert {"finished"} == set(sm.configuration_values) + assert sm.is_terminated is True + + async def test_top_level_parallel_not_terminated_when_one_region_pending(self, sm_runner): + """Machine keeps running when only one region reaches final.""" + + class Session(StateChart): + class session(State.Parallel): + class ui(State.Compound): + active = State(initial=True) + closed = State(final=True) + + close_ui = active.to(closed) + + class backend(State.Compound): + running = State(initial=True) + stopped = State(final=True) + + stop_backend = running.to(stopped) + + sm = await sm_runner.start(Session) + await sm_runner.send(sm, "close_ui") + assert sm.is_terminated is False + assert "closed" in sm.configuration_values + assert "running" in sm.configuration_values