diff --git a/src/fireflyframework_genai/pipeline/builder.py b/src/fireflyframework_genai/pipeline/builder.py index 242fbc36..47258751 100644 --- a/src/fireflyframework_genai/pipeline/builder.py +++ b/src/fireflyframework_genai/pipeline/builder.py @@ -147,4 +147,6 @@ def _resolve_step(step: Any) -> Any: if asyncio.iscoroutinefunction(step): return CallableStep(step) - raise TypeError(f"Cannot resolve {type(step).__name__} as a pipeline step. Must be StepExecutor, agent-like, or async callable.") + raise TypeError( + f"Cannot resolve {type(step).__name__} as a pipeline step. Must be StepExecutor, agent-like, or async callable." + ) diff --git a/src/fireflyframework_genai/studio/api/assistant.py b/src/fireflyframework_genai/studio/api/assistant.py index 715abd76..b55a2b90 100644 --- a/src/fireflyframework_genai/studio/api/assistant.py +++ b/src/fireflyframework_genai/studio/api/assistant.py @@ -39,10 +39,15 @@ # Generous limit for complex multi-tool pipelines (each tool call = 1 request). _DEFAULT_REQUEST_LIMIT = 200 -_CANVAS_TOOL_NAMES = frozenset({ - "add_node", "connect_nodes", "configure_node", - "remove_node", "clear_canvas", -}) +_CANVAS_TOOL_NAMES = frozenset( + { + "add_node", + "connect_nodes", + "configure_node", + "remove_node", + "clear_canvas", + } +) def _canvas_to_dict(canvas: Any) -> dict[str, Any]: @@ -116,11 +121,13 @@ def _extract_tool_calls(result: Any) -> list[dict[str, Any]]: for part in parts: part_kind = getattr(part, "part_kind", "") if part_kind == "tool-call": - tool_calls.append({ - "tool": getattr(part, "tool_name", "unknown"), - "args": _normalize_args(getattr(part, "args", {})), - "result": None, - }) + tool_calls.append( + { + "tool": getattr(part, "tool_name", "unknown"), + "args": _normalize_args(getattr(part, "args", {})), + "result": None, + } + ) elif part_kind == "tool-return": content = getattr(part, "content", "") tool_name = getattr(part, "tool_name", "") @@ -171,9 +178,7 @@ def _process_attachments(attachments: list[dict[str, Any]]) -> str: except Exception: parts.append(f"[Attached CSV: {name} ({size} bytes) — could not decode]") elif category == "image": - parts.append( - f"[Attached image: {name} ({size} bytes, type: {att.get('type', 'unknown')})]" - ) + parts.append(f"[Attached image: {name} ({size} bytes, type: {att.get('type', 'unknown')})]") elif category == "pdf": parts.append( f"[Attached PDF: {name} ({size} bytes) — " @@ -240,18 +245,22 @@ async def _handle_chat_streaming( logger.info( "Streaming complete (%d chars, %d tool calls, canvas: %d nodes, %d edges)", - len(full_text), len(tool_calls), - len(canvas.nodes), len(canvas.edges), + len(full_text), + len(tool_calls), + len(canvas.nodes), + len(canvas.edges), ) # Send tool call details so the UI shows what happened for tc in tool_calls: - await websocket.send_json({ - "type": "tool_call", - "tool": tc["tool"], - "args": tc["args"], - "result": tc["result"], - }) + await websocket.send_json( + { + "type": "tool_call", + "tool": tc["tool"], + "args": tc["args"], + "result": tc["result"], + } + ) # Check if present_plan was called plan_call = next( @@ -260,25 +269,30 @@ async def _handle_chat_streaming( ) if plan_call and plan_call["args"]: args = plan_call["args"] - await websocket.send_json({ - "type": "plan", - "summary": args.get("summary", ""), - "steps": args.get("steps", "[]"), - "options": args.get("options", "[]"), - "question": args.get("question", ""), - }) - - await websocket.send_json({ - "type": "response_complete", - "full_text": full_text, - }) + await websocket.send_json( + { + "type": "plan", + "summary": args.get("summary", ""), + "steps": args.get("steps", "[]"), + "options": args.get("options", "[]"), + "question": args.get("question", ""), + } + ) + + await websocket.send_json( + { + "type": "response_complete", + "full_text": full_text, + } + ) # Canvas sync after tool use used_canvas_tools = any(tc["tool"] in _CANVAS_TOOL_NAMES for tc in tool_calls) if used_canvas_tools: logger.info( "Canvas tools used, sending sync (%d nodes, %d edges)", - len(canvas.nodes), len(canvas.edges), + len(canvas.nodes), + len(canvas.edges), ) await _send_canvas_sync(websocket, canvas) @@ -314,17 +328,21 @@ async def _handle_chat_blocking( logger.info( "Blocking complete (%d chars, %d tool calls, canvas: %d nodes, %d edges)", - len(full_text), len(tool_calls), - len(canvas.nodes), len(canvas.edges), + len(full_text), + len(tool_calls), + len(canvas.nodes), + len(canvas.edges), ) for tc in tool_calls: - await websocket.send_json({ - "type": "tool_call", - "tool": tc["tool"], - "args": tc["args"], - "result": tc["result"], - }) + await websocket.send_json( + { + "type": "tool_call", + "tool": tc["tool"], + "args": tc["args"], + "result": tc["result"], + } + ) plan_call = next( (tc for tc in tool_calls if tc["tool"] == "present_plan"), @@ -332,27 +350,32 @@ async def _handle_chat_blocking( ) if plan_call and plan_call["args"]: args = plan_call["args"] - await websocket.send_json({ - "type": "plan", - "summary": args.get("summary", ""), - "steps": args.get("steps", "[]"), - "options": args.get("options", "[]"), - "question": args.get("question", ""), - }) + await websocket.send_json( + { + "type": "plan", + "summary": args.get("summary", ""), + "steps": args.get("steps", "[]"), + "options": args.get("options", "[]"), + "question": args.get("question", ""), + } + ) if full_text: await websocket.send_json({"type": "token", "content": full_text}) - await websocket.send_json({ - "type": "response_complete", - "full_text": full_text, - }) + await websocket.send_json( + { + "type": "response_complete", + "full_text": full_text, + } + ) used_canvas_tools = any(tc["tool"] in _CANVAS_TOOL_NAMES for tc in tool_calls) if used_canvas_tools: logger.info( "Canvas tools used, sending sync (%d nodes, %d edges)", - len(canvas.nodes), len(canvas.edges), + len(canvas.nodes), + len(canvas.edges), ) await _send_canvas_sync(websocket, canvas) @@ -390,16 +413,20 @@ async def _run_reflexion_validation( logger.info( "Reflexion round %d: %d errors, %d warnings", - round_num, len(errors), len(warnings), + round_num, + len(errors), + len(warnings), ) # Notify the user that auto-correction is running - await websocket.send_json({ - "type": "tool_call", - "tool": "reflexion_validation", - "args": {"round": round_num, "errors": len(errors)}, - "result": f"Found {len(errors)} errors. Auto-correcting...", - }) + await websocket.send_json( + { + "type": "tool_call", + "tool": "reflexion_validation", + "args": {"round": round_num, "errors": len(errors)}, + "result": f"Found {len(errors)} errors. Auto-correcting...", + } + ) # Ask the Architect to fix the issues via its canvas tools fix_prompt = ( @@ -408,9 +435,7 @@ async def _run_reflexion_validation( + "\n".join(f" - {e}" for e in errors) ) if warnings: - fix_prompt += "\n\nWarnings (non-blocking):\n" + "\n".join( - f" - {w}" for w in warnings - ) + fix_prompt += "\n\nWarnings (non-blocking):\n" + "\n".join(f" - {w}" for w in warnings) fix_prompt += ( "\n\nFix ALL errors using configure_node, connect_nodes, or add_node as needed. " "Do NOT explain. Just fix the issues with tool calls." @@ -429,12 +454,14 @@ async def _run_reflexion_validation( # Send tool calls to frontend for tc in fix_tool_calls: - await websocket.send_json({ - "type": "tool_call", - "tool": tc["tool"], - "args": tc["args"], - "result": tc["result"], - }) + await websocket.send_json( + { + "type": "tool_call", + "tool": tc["tool"], + "args": tc["args"], + "result": tc["result"], + } + ) # Sync canvas if tools were used used_canvas = any(tc["tool"] in _CANVAS_TOOL_NAMES for tc in fix_tool_calls) @@ -456,13 +483,15 @@ async def _run_reflexion_validation( final = _validate_canvas(canvas) if not final["valid"]: remaining = final.get("errors", []) - await websocket.send_json({ - "type": "token", - "content": ( - f"\n\n[Reflexion completed with {len(remaining)} remaining issue(s). " - "Please review and address manually.]" - ), - }) + await websocket.send_json( + { + "type": "token", + "content": ( + f"\n\n[Reflexion completed with {len(remaining)} remaining issue(s). " + "Please review and address manually.]" + ), + } + ) def _validate_canvas(canvas: Any) -> dict[str, Any]: @@ -557,7 +586,9 @@ async def _handle_chat( logger.info( "Running assistant (canvas: %d nodes, %d edges, attachments: %d)", - len(canvas.nodes), len(canvas.edges), len(attachments or []), + len(canvas.nodes), + len(canvas.edges), + len(attachments or []), ) # Use blocking run() for reliable multi-turn tool execution. @@ -565,8 +596,11 @@ async def _handle_chat( # loop (tool calls → re-prompt → more tool calls → final text), # which run_stream() does not handle correctly. tool_calls = await _handle_chat_blocking( - websocket, agent, effective_message, - message_history, canvas, + websocket, + agent, + effective_message, + message_history, + canvas, ) # Reflexion triggers only on substantial builds: @@ -578,11 +612,15 @@ async def _handle_chat( if (called_validate or canvas_tool_count >= 3) and canvas.nodes: logger.info( "Reflexion triggered (validate_called=%s, canvas_tools=%d)", - called_validate, canvas_tool_count, + called_validate, + canvas_tool_count, ) try: await _run_reflexion_validation( - websocket, agent, message_history, canvas, + websocket, + agent, + message_history, + canvas, ) except Exception as val_exc: logger.warning("Reflexion validation failed: %s", val_exc) @@ -598,17 +636,16 @@ async def _handle_chat( "pipeline you are asking me to build." ) elif "rate" in err_str.lower() and ("limit" in err_str.lower() or "429" in err_str): - user_msg = ( - "The LLM provider is rate-limiting requests. " - "Please wait a moment and try again." - ) + user_msg = "The LLM provider is rate-limiting requests. Please wait a moment and try again." else: user_msg = f"Assistant error: {exc}" - await websocket.send_json({ - "type": "error", - "message": user_msg, - }) + await websocket.send_json( + { + "type": "error", + "message": user_msg, + } + ) def _build_project_context(canvas: Any = None, project_name: str = "") -> str: @@ -657,10 +694,7 @@ def _build_project_context(canvas: Any = None, project_name: str = "") -> str: for t in tools: status = "registered" summaries.append(f" - {t.name} (type={t.tool_type}, {status})") - parts.append( - "[CONTEXT] Custom tools / integrations installed:\n" - + "\n".join(summaries) - ) + parts.append("[CONTEXT] Custom tools / integrations installed:\n" + "\n".join(summaries)) else: parts.append("[CONTEXT] No custom tools or integrations installed yet.") except Exception: @@ -683,9 +717,7 @@ def _build_project_context(canvas: Any = None, project_name: str = "") -> str: settings = load_settings() if settings.model_defaults.default_model: - parts.append( - f"[CONTEXT] Default model: {settings.model_defaults.default_model}." - ) + parts.append(f"[CONTEXT] Default model: {settings.model_defaults.default_model}.") except Exception: pass @@ -704,14 +736,9 @@ def _build_project_context(canvas: Any = None, project_name: str = "") -> str: {"id": n.id, "type": n.type, "data": {"label": n.label or "", "config": n.config or {}}} for n in canvas.nodes ], - "edges": [ - {"source": e.source, "target": e.target} - for e in canvas.edges - ], + "edges": [{"source": e.source, "target": e.target} for e in canvas.edges], } - shared = build_shared_context( - project_name, canvas_dict, exclude_agent="architect" - ) + shared = build_shared_context(project_name, canvas_dict, exclude_agent="architect") if shared: parts.append(shared) except Exception: @@ -723,6 +750,7 @@ def _build_project_context(canvas: Any = None, project_name: str = "") -> str: class InferProjectNameRequest(BaseModel): message: str + class SaveHistoryRequest(BaseModel): messages: list[dict] @@ -743,25 +771,30 @@ def create_assistant_router() -> APIRouter: @router.get("/api/assistant/{project}/history") async def get_chat_history(project: str): from fireflyframework_genai.studio.assistant.history import load_chat_history + return load_chat_history(project) @router.post("/api/assistant/{project}/history") async def save_chat_history_endpoint(project: str, body: SaveHistoryRequest): from fireflyframework_genai.studio.assistant.history import save_chat_history + save_chat_history(project, body.messages) return {"status": "saved"} @router.delete("/api/assistant/{project}/history") async def delete_chat_history(project: str): from fireflyframework_genai.studio.assistant.history import clear_chat_history + clear_chat_history(project) return {"status": "cleared"} @router.post("/api/assistant/infer-project-name") async def infer_project_name(body: InferProjectNameRequest): import time + try: from pydantic_ai import Agent + agent = Agent( "openai:gpt-4.1-mini", system_prompt=( @@ -814,6 +847,7 @@ async def assistant_ws(websocket: WebSocket, project: str = Query(default="")) - if project: try: from fireflyframework_genai.studio.assistant.history import load_chat_history + saved_history = load_chat_history(project) if saved_history: logger.info("Loaded %d saved messages for project '%s'", len(saved_history), project) @@ -857,8 +891,11 @@ async def assistant_ws(websocket: WebSocket, project: str = Query(default="")) - user_message = f"{project_context}\n\n{user_message}" await _handle_chat( - websocket, agent, user_message, - message_history, canvas, + websocket, + agent, + user_message, + message_history, + canvas, attachments=chat_attachments or None, ) @@ -866,9 +903,19 @@ async def assistant_ws(websocket: WebSocket, project: str = Query(default="")) - if project: try: from fireflyframework_genai.studio.assistant.history import save_chat_history - save_chat_history(project, [ - {"role": "user", "content": user_message, "timestamp": __import__("datetime").datetime.now(__import__("datetime").UTC).isoformat()}, - ]) + + save_chat_history( + project, + [ + { + "role": "user", + "content": user_message, + "timestamp": __import__("datetime") + .datetime.now(__import__("datetime").UTC) + .isoformat(), + }, + ], + ) except Exception: pass @@ -886,29 +933,35 @@ async def assistant_ws(websocket: WebSocket, project: str = Query(default="")) - for n in sync_nodes: data = n.get("data", {}) config = {k: v for k, v in data.items() if k not in ("label", "_executionState")} - canvas.nodes.append(CanvasNode( - id=n.get("id", ""), - type=n.get("type", "pipeline_step"), - label=data.get("label", n.get("id", "")), - position=n.get("position", {"x": 0, "y": 0}), - config=config, - )) + canvas.nodes.append( + CanvasNode( + id=n.get("id", ""), + type=n.get("type", "pipeline_step"), + label=data.get("label", n.get("id", "")), + position=n.get("position", {"x": 0, "y": 0}), + config=config, + ) + ) # Track highest numeric suffix for counter m = re.search(r"(\d+)$", n.get("id", "")) if m: max_id = max(max_id, int(m.group(1))) for e in sync_edges: - canvas.edges.append(CanvasEdge( - id=e.get("id", ""), - source=e.get("source", ""), - target=e.get("target", ""), - source_handle=e.get("sourceHandle"), - target_handle=e.get("targetHandle"), - )) + canvas.edges.append( + CanvasEdge( + id=e.get("id", ""), + source=e.get("source", ""), + target=e.get("target", ""), + source_handle=e.get("sourceHandle"), + target_handle=e.get("targetHandle"), + ) + ) canvas._counter = max_id logger.info( "Canvas synced from frontend: %d nodes, %d edges (counter=%d)", - len(canvas.nodes), len(canvas.edges), max_id, + len(canvas.nodes), + len(canvas.edges), + max_id, ) else: diff --git a/src/fireflyframework_genai/studio/api/custom_tools.py b/src/fireflyframework_genai/studio/api/custom_tools.py index 063381ec..3766ae48 100644 --- a/src/fireflyframework_genai/studio/api/custom_tools.py +++ b/src/fireflyframework_genai/studio/api/custom_tools.py @@ -348,9 +348,7 @@ async def list_connectors() -> list[dict[str, Any]]: @router.post("/catalog/{connector_id}/install") async def install_connector(connector_id: str, body: dict[str, Any] | None = None) -> dict[str, Any]: """Install a pre-built connector as a custom tool.""" - connector = next( - (c for c in _CONNECTOR_CATALOG if c["id"] == connector_id), None - ) + connector = next((c for c in _CONNECTOR_CATALOG if c["id"] == connector_id), None) if connector is None: raise HTTPException(status_code=404, detail=f"Connector '{connector_id}' not found") @@ -359,6 +357,7 @@ async def install_connector(connector_id: str, body: dict[str, Any] | None = Non # Resolve credential placeholders from settings if connector.get("requires_credential"): from fireflyframework_genai.studio.settings import load_settings + settings = load_settings() cred_field = connector["requires_credential"] cred_val = getattr(settings.tool_credentials, cred_field, None) @@ -383,6 +382,7 @@ async def install_connector(connector_id: str, body: dict[str, Any] | None = Non try: tool = manager.create_runtime_tool(definition) from fireflyframework_genai.tools.registry import tool_registry + tool_registry.register(tool) except Exception: pass @@ -394,9 +394,7 @@ async def verify_connector(connector_id: str) -> dict[str, Any]: """Verify a connector's credentials by making a test API call.""" import httpx - connector = next( - (c for c in _CONNECTOR_CATALOG if c["id"] == connector_id), None - ) + connector = next((c for c in _CONNECTOR_CATALOG if c["id"] == connector_id), None) if connector is None: raise HTTPException(status_code=404, detail=f"Connector '{connector_id}' not found") @@ -408,14 +406,21 @@ async def verify_connector(connector_id: str) -> dict[str, Any]: token = "" if connector.get("requires_credential"): from fireflyframework_genai.studio.settings import load_settings + settings = load_settings() cred_field = connector["requires_credential"] cred_val = getattr(settings.tool_credentials, cred_field, None) if cred_val is None: - return {"status": "error", "message": f"Missing credential: {cred_field}. Configure it in Tool Credentials."} + return { + "status": "error", + "message": f"Missing credential: {cred_field}. Configure it in Tool Credentials.", + } token = cred_val.get_secret_value() if not token: - return {"status": "error", "message": f"Credential '{cred_field}' is empty. Configure it in Tool Credentials."} + return { + "status": "error", + "message": f"Credential '{cred_field}' is empty. Configure it in Tool Credentials.", + } # For webhook-based connectors (Discord, Teams), check the installed tool URL if verify_method == "head": @@ -481,6 +486,7 @@ async def get_tool(name: str) -> dict[str, Any]: # Include inline Python code if this is a Python tool if tool.tool_type == "python" and tool.module_path: from pathlib import Path + py_path = Path(tool.module_path) if py_path.is_file(): try: diff --git a/src/fireflyframework_genai/studio/api/graphql_api.py b/src/fireflyframework_genai/studio/api/graphql_api.py index 56e71913..75e50f10 100644 --- a/src/fireflyframework_genai/studio/api/graphql_api.py +++ b/src/fireflyframework_genai/studio/api/graphql_api.py @@ -59,9 +59,7 @@ def create_graphql_router(project_manager: ProjectManager) -> Any: import strawberry from strawberry.fastapi import GraphQLRouter except ImportError: - logger.warning( - "strawberry-graphql is not installed; GraphQL endpoint disabled" - ) + logger.warning("strawberry-graphql is not installed; GraphQL endpoint disabled") from fastapi import APIRouter # type: ignore[import-not-found] router = APIRouter() @@ -211,9 +209,7 @@ async def run_pipeline(self, project: str, input: str) -> ExecutionResult: ) except Exception as exc: duration_ms = round((time.monotonic() - start_time) * 1000, 2) - logger.exception( - "GraphQL run_pipeline failed for project '%s'", project - ) + logger.exception("GraphQL run_pipeline failed for project '%s'", project) return ExecutionResult( execution_id=execution_id, status="error", diff --git a/src/fireflyframework_genai/studio/api/oracle.py b/src/fireflyframework_genai/studio/api/oracle.py index 963715f3..72189dad 100644 --- a/src/fireflyframework_genai/studio/api/oracle.py +++ b/src/fireflyframework_genai/studio/api/oracle.py @@ -97,17 +97,20 @@ async def skip_insight(project: str, insight_id: str): @router.get("/api/oracle/{project}/chat-history") async def get_oracle_chat_history(project: str): from fireflyframework_genai.studio.assistant.history import load_oracle_history + return load_oracle_history(project) @router.post("/api/oracle/{project}/chat-history") async def save_oracle_chat_history(project: str, body: _SaveChatHistoryBody): from fireflyframework_genai.studio.assistant.history import save_oracle_history + save_oracle_history(project, body.messages) return {"status": "saved"} @router.delete("/api/oracle/{project}/chat-history") async def delete_oracle_chat_history(project: str): from fireflyframework_genai.studio.assistant.history import clear_oracle_history + clear_oracle_history(project) return {"status": "cleared"} @@ -116,9 +119,7 @@ async def delete_oracle_chat_history(project: str): # ------------------------------------------------------------------ @router.websocket("/ws/oracle") - async def oracle_ws( - websocket: WebSocket, project: str = Query(default="") - ) -> None: + async def oracle_ws(websocket: WebSocket, project: str = Query(default="")) -> None: await websocket.accept() logger.info("Oracle WebSocket connected (project=%s)", project) @@ -140,9 +141,7 @@ def _get_canvas() -> dict[str, Any]: oracle = create_oracle_agent(_get_canvas, user_name=user_name) except Exception as exc: logger.error("Failed to create Oracle agent: %s", exc) - await websocket.send_json( - {"type": "error", "message": f"Oracle unavailable: {exc}"} - ) + await websocket.send_json({"type": "error", "message": f"Oracle unavailable: {exc}"}) await websocket.close() return @@ -154,9 +153,7 @@ def _get_canvas() -> dict[str, Any]: try: message = json.loads(raw) except json.JSONDecodeError: - await websocket.send_json( - {"type": "error", "message": "Invalid JSON"} - ) + await websocket.send_json({"type": "error", "message": "Invalid JSON"}) continue action = message.get("action") @@ -172,8 +169,7 @@ def _get_canvas() -> dict[str, Any]: try: context_block = _build_shared_context_for_oracle(project, canvas_state) result = await oracle.run( - context_block - + "Analyze the current pipeline thoroughly. " + context_block + "Analyze the current pipeline thoroughly. " "Use analyze_pipeline, check_connectivity, and review_agent_setup " "to gather data. Then use suggest_improvement for each issue or " "recommendation you find. Consider the project purpose and " @@ -197,16 +193,12 @@ def _get_canvas() -> dict[str, Any]: title=insight_data.get("title", "Insight"), description=insight_data.get("description", ""), severity=insight_data.get("severity", "suggestion"), - action_instruction=insight_data.get( - "action_instruction" - ), + action_instruction=insight_data.get("action_instruction"), ) if project: add_insight(project, insight) - await websocket.send_json( - {"type": "insight", **asdict(insight)} - ) + await websocket.send_json({"type": "insight", **asdict(insight)}) # Also send any text output text_output = "" @@ -223,24 +215,19 @@ def _get_canvas() -> dict[str, Any]: except Exception as exc: logger.error("Oracle analysis failed: %s", exc, exc_info=True) - await websocket.send_json( - {"type": "error", "message": f"Analysis failed: {exc}"} - ) + await websocket.send_json({"type": "error", "message": f"Analysis failed: {exc}"}) elif action == "analyze_node": # Single node analysis node_id = message.get("node_id", "") if not node_id: - await websocket.send_json( - {"type": "error", "message": "Missing node_id"} - ) + await websocket.send_json({"type": "error", "message": "Missing node_id"}) continue try: context_block = _build_shared_context_for_oracle(project, canvas_state) result = await oracle.run( - context_block - + f"Analyze node '{node_id}' specifically. " + context_block + f"Analyze node '{node_id}' specifically. " f"Use analyze_node_config to check its configuration, " f"then suggest improvements if needed.", message_history=message_history, @@ -260,16 +247,12 @@ def _get_canvas() -> dict[str, Any]: title=insight_data.get("title", "Insight"), description=insight_data.get("description", ""), severity=insight_data.get("severity", "suggestion"), - action_instruction=insight_data.get( - "action_instruction" - ), + action_instruction=insight_data.get("action_instruction"), ) if project: add_insight(project, insight) - await websocket.send_json( - {"type": "insight", **asdict(insight)} - ) + await websocket.send_json({"type": "insight", **asdict(insight)}) text_output = "" if hasattr(result, "output"): @@ -284,20 +267,14 @@ def _get_canvas() -> dict[str, Any]: ) except Exception as exc: - logger.error( - "Oracle node analysis failed: %s", exc, exc_info=True - ) - await websocket.send_json( - {"type": "error", "message": f"Analysis failed: {exc}"} - ) + logger.error("Oracle node analysis failed: %s", exc, exc_info=True) + await websocket.send_json({"type": "error", "message": f"Analysis failed: {exc}"}) elif action == "chat": # Free-form conversational chat with The Oracle user_msg = message.get("message", "").strip() if not user_msg: - await websocket.send_json( - {"type": "error", "message": "Empty message"} - ) + await websocket.send_json({"type": "error", "message": "Empty message"}) continue try: @@ -322,9 +299,7 @@ def _get_canvas() -> dict[str, Any]: _chunk_size = 12 for i in range(0, len(full_text), _chunk_size): chunk = full_text[i : i + _chunk_size] - await websocket.send_json( - {"type": "oracle_token", "content": chunk} - ) + await websocket.send_json({"type": "oracle_token", "content": chunk}) await asyncio.sleep(0.01) # Extract any insights produced during chat @@ -339,15 +314,11 @@ def _get_canvas() -> dict[str, Any]: title=insight_data.get("title", "Insight"), description=insight_data.get("description", ""), severity=insight_data.get("severity", "suggestion"), - action_instruction=insight_data.get( - "action_instruction" - ), + action_instruction=insight_data.get("action_instruction"), ) if project: add_insight(project, insight) - await websocket.send_json( - {"type": "insight", **asdict(insight)} - ) + await websocket.send_json({"type": "insight", **asdict(insight)}) await websocket.send_json( { @@ -357,17 +328,11 @@ def _get_canvas() -> dict[str, Any]: ) except Exception as exc: - logger.error( - "Oracle chat failed: %s", exc, exc_info=True - ) - await websocket.send_json( - {"type": "error", "message": f"Oracle chat error: {exc}"} - ) + logger.error("Oracle chat failed: %s", exc, exc_info=True) + await websocket.send_json({"type": "error", "message": f"Oracle chat error: {exc}"}) else: - await websocket.send_json( - {"type": "error", "message": f"Unknown action: {action}"} - ) + await websocket.send_json({"type": "error", "message": f"Unknown action: {action}"}) except WebSocketDisconnect: logger.info("Oracle WebSocket disconnected") @@ -408,9 +373,7 @@ def _extract_oracle_insights(result: Any) -> list[dict[str, Any]]: return insights -def _build_shared_context_for_oracle( - project: str, canvas_state: dict[str, Any] -) -> str: +def _build_shared_context_for_oracle(project: str, canvas_state: dict[str, Any]) -> str: """Build cross-agent context for the Oracle using the shared builder. Replaces the old frontend-supplied context approach — context is now diff --git a/src/fireflyframework_genai/studio/api/project_api.py b/src/fireflyframework_genai/studio/api/project_api.py index f6fc924f..5675e27c 100644 --- a/src/fireflyframework_genai/studio/api/project_api.py +++ b/src/fireflyframework_genai/studio/api/project_api.py @@ -58,6 +58,7 @@ def _store_execution(record: dict[str, Any]) -> None: oldest_key = next(iter(_executions)) del _executions[oldest_key] + # --------------------------------------------------------------------------- # Request / response models # --------------------------------------------------------------------------- @@ -183,13 +184,15 @@ async def run_pipeline_async(name: str, body: RunRequest) -> dict[str, Any]: graph_model = _load_graph_model(project_manager, name) execution_id = str(uuid.uuid4()) - _store_execution({ - "execution_id": execution_id, - "project": name, - "status": "running", - "result": None, - "duration_ms": None, - }) + _store_execution( + { + "execution_id": execution_id, + "project": name, + "status": "running", + "result": None, + "duration_ms": None, + } + ) async def _run_in_background() -> None: start_time = time.monotonic() @@ -197,18 +200,22 @@ async def _run_in_background() -> None: engine = compile_graph(graph_model) result = await engine.run(inputs=body.input) duration_ms = round((time.monotonic() - start_time) * 1000, 2) - _executions[execution_id].update({ - "status": "completed", - "result": result, - "duration_ms": duration_ms, - }) + _executions[execution_id].update( + { + "status": "completed", + "result": result, + "duration_ms": duration_ms, + } + ) except Exception as exc: duration_ms = round((time.monotonic() - start_time) * 1000, 2) - _executions[execution_id].update({ - "status": "failed", - "result": str(exc), - "duration_ms": duration_ms, - }) + _executions[execution_id].update( + { + "status": "failed", + "result": str(exc), + "duration_ms": duration_ms, + } + ) logger.exception("Async pipeline execution failed for project '%s'", name) asyncio.create_task(_run_in_background()) @@ -270,13 +277,15 @@ async def upload_file(name: str, file: UploadFile) -> dict[str, Any]: duration_ms = round((time.monotonic() - start_time) * 1000, 2) - _store_execution({ - "execution_id": execution_id, - "project": name, - "status": "completed", - "result": result, - "duration_ms": duration_ms, - }) + _store_execution( + { + "execution_id": execution_id, + "project": name, + "status": "completed", + "result": result, + "duration_ms": duration_ms, + } + ) return { "result": result, diff --git a/src/fireflyframework_genai/studio/api/projects.py b/src/fireflyframework_genai/studio/api/projects.py index e1eb4ee9..0aeece6b 100644 --- a/src/fireflyframework_genai/studio/api/projects.py +++ b/src/fireflyframework_genai/studio/api/projects.py @@ -143,6 +143,7 @@ async def save_pipeline(project_name: str, pipeline_name: str, body: SavePipelin # Create a version history entry for every save try: from fireflyframework_genai.studio.versioning import ProjectVersioning + project_dir = manager._safe_path(project_name) versioning = ProjectVersioning(project_dir) node_count = len(body.graph.get("nodes", [])) @@ -169,6 +170,7 @@ def create_versioning_router(project_manager: ProjectManager) -> APIRouter: @router.get("/{name}/history") async def get_project_history(name: str): from fireflyframework_genai.studio.versioning import ProjectVersioning + project_dir = project_manager._safe_path(name) if not project_dir.exists(): raise HTTPException(status_code=404, detail=f"Project '{name}' not found") @@ -178,6 +180,7 @@ async def get_project_history(name: str): @router.post("/{name}/restore") async def restore_project_version(name: str, body: dict): from fireflyframework_genai.studio.versioning import ProjectVersioning + project_dir = project_manager._safe_path(name) if not project_dir.exists(): raise HTTPException(status_code=404, detail=f"Project '{name}' not found") @@ -191,6 +194,7 @@ async def restore_project_version(name: str, body: dict): @router.post("/{name}/bookmark") async def bookmark_project_version(name: str, body: dict): from fireflyframework_genai.studio.versioning import ProjectVersioning + project_dir = project_manager._safe_path(name) if not project_dir.exists(): raise HTTPException(status_code=404, detail=f"Project '{name}' not found") diff --git a/src/fireflyframework_genai/studio/api/settings.py b/src/fireflyframework_genai/studio/api/settings.py index 7b76cea0..819f9f78 100644 --- a/src/fireflyframework_genai/studio/api/settings.py +++ b/src/fireflyframework_genai/studio/api/settings.py @@ -244,11 +244,7 @@ async def list_services() -> list[dict]: for field in ("password", "connection_url", "api_key", "token"): val = getattr(sc, field, None) if val is not None: - secret_val = ( - val.get_secret_value() - if hasattr(val, "get_secret_value") - else str(val) - ) + secret_val = val.get_secret_value() if hasattr(val, "get_secret_value") else str(val) entry[field] = "***" if secret_val else "" else: entry[field] = None @@ -295,13 +291,9 @@ async def delete_service(service_id: str) -> dict: """Delete a service credential by ID.""" settings = load_settings(path) before = len(settings.service_credentials) - settings.service_credentials = [ - s for s in settings.service_credentials if s.id != service_id - ] + settings.service_credentials = [s for s in settings.service_credentials if s.id != service_id] if len(settings.service_credentials) == before: - raise HTTPException( - status_code=404, detail=f"Service '{service_id}' not found" - ) + raise HTTPException(status_code=404, detail=f"Service '{service_id}' not found") save_settings(settings, path) return {"status": "deleted", "id": service_id} @@ -309,13 +301,9 @@ async def delete_service(service_id: str) -> dict: async def test_service(service_id: str) -> dict: """Test connectivity for a service credential.""" settings = load_settings(path) - sc = next( - (s for s in settings.service_credentials if s.id == service_id), None - ) + sc = next((s for s in settings.service_credentials if s.id == service_id), None) if not sc: - raise HTTPException( - status_code=404, detail=f"Service '{service_id}' not found" - ) + raise HTTPException(status_code=404, detail=f"Service '{service_id}' not found") # Basic connectivity test based on service type try: @@ -338,14 +326,10 @@ async def test_service(service_id: str) -> dict: return {"status": "error", "message": "No token set"} # Database/queue services - check host is set - if sc.host or ( - sc.connection_url and sc.connection_url.get_secret_value() - ): + if sc.host or (sc.connection_url and sc.connection_url.get_secret_value()): return { "status": "ok", - "message": ( - f"Connection details configured for {sc.service_type}" - ), + "message": (f"Connection details configured for {sc.service_type}"), } return { "status": "error", diff --git a/src/fireflyframework_genai/studio/api/smith.py b/src/fireflyframework_genai/studio/api/smith.py index fa7760e6..c0ae9bb0 100644 --- a/src/fireflyframework_genai/studio/api/smith.py +++ b/src/fireflyframework_genai/studio/api/smith.py @@ -93,28 +93,33 @@ def create_smith_router() -> APIRouter: @router.get("/api/smith/{project}/history") async def get_smith_history(project: str): from fireflyframework_genai.studio.assistant.history import load_smith_history + return load_smith_history(project) @router.post("/api/smith/{project}/history") async def save_smith_history_endpoint(project: str, body: _SaveHistoryBody): from fireflyframework_genai.studio.assistant.history import save_smith_history + save_smith_history(project, body.messages) return {"status": "saved"} @router.delete("/api/smith/{project}/history") async def delete_smith_history(project: str): from fireflyframework_genai.studio.assistant.history import clear_smith_history + clear_smith_history(project) return {"status": "cleared"} @router.get("/api/smith/{project}/files") async def get_smith_files(project: str): from fireflyframework_genai.studio.assistant.history import load_smith_files + return load_smith_files(project) @router.post("/api/smith/{project}/files") async def save_smith_files_endpoint(project: str, body: _SaveFilesBody): from fireflyframework_genai.studio.assistant.history import save_smith_files + save_smith_files(project, body.files) return {"status": "saved"} @@ -123,9 +128,7 @@ async def save_smith_files_endpoint(project: str, body: _SaveFilesBody): # ------------------------------------------------------------------ @router.websocket("/ws/smith") - async def smith_ws( - websocket: WebSocket, project: str = Query(default="") - ) -> None: + async def smith_ws(websocket: WebSocket, project: str = Query(default="")) -> None: await websocket.accept() logger.info("Smith WebSocket connected (project=%s)", project) @@ -144,44 +147,43 @@ async def smith_ws( try: data = json.loads(raw) except json.JSONDecodeError: - await websocket.send_json( - {"type": "error", "message": "Invalid JSON"} - ) + await websocket.send_json({"type": "error", "message": "Invalid JSON"}) continue action = data.get("action", "") if action == "generate": await _handle_generate( - websocket, data, canvas_state, message_history, + websocket, + data, + canvas_state, + message_history, project, ) elif action == "chat": await _handle_chat( - websocket, data, canvas_state, message_history, - pending_commands, project, + websocket, + data, + canvas_state, + message_history, + pending_commands, + project, ) elif action == "sync_canvas": await _handle_sync_canvas(websocket, data, canvas_state) elif action == "execute": await _handle_execute(websocket, data) elif action == "approve_command": - await _handle_approve_command( - websocket, data, pending_commands - ) + await _handle_approve_command(websocket, data, pending_commands) else: - await websocket.send_json( - {"type": "error", "message": f"Unknown action: {action}"} - ) + await websocket.send_json({"type": "error", "message": f"Unknown action: {action}"}) except WebSocketDisconnect: logger.info("Smith WebSocket disconnected") except Exception as exc: logger.exception("Smith WebSocket error") with contextlib.suppress(Exception): - await websocket.send_json( - {"type": "error", "message": str(exc)} - ) + await websocket.send_json({"type": "error", "message": str(exc)}) return router @@ -221,9 +223,7 @@ async def _handle_generate( graph = data.get("graph", canvas_state) # Notify the frontend that generation has started - await websocket.send_json( - {"type": "smith_token", "content": "Generating code from pipeline...\n"} - ) + await websocket.send_json({"type": "smith_token", "content": "Generating code from pipeline...\n"}) # Pass user name so Smith can personalise responses user_name = "" @@ -237,14 +237,14 @@ async def _handle_generate( build_shared_context, ) - shared_context = build_shared_context( - project, canvas_state, exclude_agent="smith" - ) + shared_context = build_shared_context(project, canvas_state, exclude_agent="smith") except Exception: pass result = await generate_code_with_smith( - graph, settings_dict, user_name=user_name, + graph, + settings_dict, + user_name=user_name, shared_context=shared_context, ) @@ -281,9 +281,7 @@ async def _handle_generate( except Exception as exc: logger.error("Smith generation failed: %s", exc, exc_info=True) - await websocket.send_json( - {"type": "error", "message": f"Code generation failed: {exc}"} - ) + await websocket.send_json({"type": "error", "message": f"Code generation failed: {exc}"}) async def _handle_chat( @@ -302,9 +300,7 @@ async def _handle_chat( """ user_msg = data.get("message", "").strip() if not user_msg: - await websocket.send_json( - {"type": "error", "message": "Empty message"} - ) + await websocket.send_json({"type": "error", "message": "Empty message"}) return try: @@ -329,18 +325,13 @@ async def _handle_chat( build_shared_context, ) - shared = build_shared_context( - project, canvas_state, exclude_agent="smith" - ) + shared = build_shared_context(project, canvas_state, exclude_agent="smith") if shared: context_parts.append(shared) except Exception: pass if canvas_state.get("nodes"): - context_parts.append( - "[CURRENT PIPELINE STATE]\n" - + json.dumps(canvas_state, indent=2) - ) + context_parts.append("[CURRENT PIPELINE STATE]\n" + json.dumps(canvas_state, indent=2)) context_parts.append(user_msg) effective_message = "\n\n".join(context_parts) @@ -379,13 +370,9 @@ async def _handle_chat( _chunk_size = 80 for i in range(0, len(chat_text), _chunk_size): chunk = chat_text[i : i + _chunk_size] - await websocket.send_json( - {"type": "smith_token", "content": chunk} - ) + await websocket.send_json({"type": "smith_token", "content": chunk}) - combined_code = "\n\n".join( - f"# --- {f['path']} ---\n{f['content']}" for f in extracted_files - ) + combined_code = "\n\n".join(f"# --- {f['path']} ---\n{f['content']}" for f in extracted_files) await websocket.send_json( { @@ -399,9 +386,7 @@ async def _handle_chat( _chunk_size = 80 for i in range(0, len(full_text), _chunk_size): chunk = full_text[i : i + _chunk_size] - await websocket.send_json( - {"type": "smith_token", "content": chunk} - ) + await websocket.send_json({"type": "smith_token", "content": chunk}) await websocket.send_json( { @@ -423,15 +408,11 @@ async def _handle_chat( ) # Check for pending approvals in tool return parts - await _check_pending_approvals( - websocket, result, pending_commands - ) + await _check_pending_approvals(websocket, result, pending_commands) except Exception as exc: logger.error("Smith chat failed: %s", exc, exc_info=True) - await websocket.send_json( - {"type": "error", "message": f"Smith chat error: {exc}"} - ) + await websocket.send_json({"type": "error", "message": f"Smith chat error: {exc}"}) async def _handle_sync_canvas( @@ -465,9 +446,7 @@ async def _handle_execute( """ code = data.get("code", "").strip() if not code: - await websocket.send_json( - {"type": "error", "message": "No code to execute"} - ) + await websocket.send_json({"type": "error", "message": "No code to execute"}) return tmp_file: Path | None = None @@ -491,9 +470,7 @@ async def _handle_execute( ) try: - stdout, stderr = await asyncio.wait_for( - proc.communicate(), timeout=timeout - ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) except TimeoutError: proc.kill() await proc.communicate() @@ -518,9 +495,7 @@ async def _handle_execute( except Exception as exc: logger.error("Smith execution failed: %s", exc, exc_info=True) - await websocket.send_json( - {"type": "error", "message": f"Execution failed: {exc}"} - ) + await websocket.send_json({"type": "error", "message": f"Execution failed: {exc}"}) finally: # Clean up temp file if tmp_file is not None: @@ -543,9 +518,7 @@ async def _handle_approve_command( approved = data.get("approved", False) if not command_id: - await websocket.send_json( - {"type": "error", "message": "Missing command_id"} - ) + await websocket.send_json({"type": "error", "message": "Missing command_id"}) return command = pending_commands.pop(command_id, None) @@ -589,9 +562,7 @@ async def _handle_approve_command( ) timeout = command.get("timeout", 30) try: - stdout, stderr = await asyncio.wait_for( - proc.communicate(), timeout=timeout - ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) except TimeoutError: proc.kill() await proc.communicate() @@ -617,9 +588,7 @@ async def _handle_approve_command( ) except Exception as exc: logger.error("Approved command execution failed: %s", exc, exc_info=True) - await websocket.send_json( - {"type": "error", "message": f"Command execution failed: {exc}"} - ) + await websocket.send_json({"type": "error", "message": f"Command execution failed: {exc}"}) # --------------------------------------------------------------------------- @@ -634,13 +603,24 @@ async def _handle_approve_command( # Map language tags to file extensions _LANG_TO_EXT: dict[str, str] = { - "python": ".py", "py": ".py", - "javascript": ".js", "js": ".js", - "typescript": ".ts", "ts": ".ts", - "json": ".json", "yaml": ".yaml", "yml": ".yaml", - "bash": ".sh", "sh": ".sh", "shell": ".sh", - "sql": ".sql", "html": ".html", "css": ".css", - "xml": ".xml", "toml": ".toml", "ini": ".ini", + "python": ".py", + "py": ".py", + "javascript": ".js", + "js": ".js", + "typescript": ".ts", + "ts": ".ts", + "json": ".json", + "yaml": ".yaml", + "yml": ".yaml", + "bash": ".sh", + "sh": ".sh", + "shell": ".sh", + "sql": ".sql", + "html": ".html", + "css": ".css", + "xml": ".xml", + "toml": ".toml", + "ini": ".ini", } # Minimum total code length to be considered "substantial" (skip tiny snippets) @@ -672,11 +652,13 @@ def _extract_code_blocks(text: str) -> list[dict[str, Any]]: count = counters.get(lang, 0) + 1 counters[lang] = count name = f"smith_code_{count}{ext}" if count > 1 else f"smith_code{ext}" - files.append({ - "path": name, - "content": content.strip(), - "language": lang, - }) + files.append( + { + "path": name, + "content": content.strip(), + "language": lang, + } + ) return files @@ -782,9 +764,7 @@ def _extract_tool_calls(result: Any) -> list[dict[str, Any]]: tool_name = getattr(part, "tool_name", "") for tc in tool_calls: if tc["tool"] == tool_name and tc["result"] is None: - tc["result"] = ( - str(content)[:500] if content else "" - ) + tc["result"] = str(content)[:500] if content else "" break except Exception as exc: logger.warning("Could not extract tool calls: %s", exc) diff --git a/src/fireflyframework_genai/studio/api/tunnel.py b/src/fireflyframework_genai/studio/api/tunnel.py index 8f5207c3..66b1e46a 100644 --- a/src/fireflyframework_genai/studio/api/tunnel.py +++ b/src/fireflyframework_genai/studio/api/tunnel.py @@ -13,6 +13,7 @@ # limitations under the License. """API endpoints for Cloudflare Tunnel management.""" + from __future__ import annotations from typing import Any diff --git a/src/fireflyframework_genai/studio/assistant/agent.py b/src/fireflyframework_genai/studio/assistant/agent.py index 895a6351..4659de9d 100644 --- a/src/fireflyframework_genai/studio/assistant/agent.py +++ b/src/fireflyframework_genai/studio/assistant/agent.py @@ -39,11 +39,21 @@ # Canvas data models # --------------------------------------------------------------------------- -_VALID_NODE_TYPES = frozenset({ - "agent", "tool", "reasoning", "condition", - "memory", "validator", "custom_code", "fan_out", "fan_in", - "input", "output", -}) +_VALID_NODE_TYPES = frozenset( + { + "agent", + "tool", + "reasoning", + "condition", + "memory", + "validator", + "custom_code", + "fan_out", + "fan_in", + "input", + "output", + } +) class CanvasNode(BaseModel): @@ -119,19 +129,13 @@ async def add_node( if not canvas.nodes: x, y = start_x, start_y else: - occupied = { - (int(n.position.get("x", 0)), int(n.position.get("y", 0))) - for n in canvas.nodes - } + occupied = {(int(n.position.get("x", 0)), int(n.position.get("y", 0))) for n in canvas.nodes} rightmost = max(canvas.nodes, key=lambda n: n.position.get("x", 0)) x = rightmost.position.get("x", 0) + h_gap y = rightmost.position.get("y", start_y) # Avoid vertical collision: offset downward if position is taken - while any( - abs(ox - x) < 100 and abs(oy - y) < 80 - for ox, oy in occupied - ): + while any(abs(ox - x) < 100 and abs(oy - y) < 80 for ox, oy in occupied): y += v_gap node = CanvasNode( @@ -317,9 +321,13 @@ async def validate_pipeline() -> str: if ntype == "agent": if not cfg.get("model"): - errors.append(f"Agent '{node.id}' ({node.label or 'unnamed'}) is missing 'model'. Set it with configure_node.") + errors.append( + f"Agent '{node.id}' ({node.label or 'unnamed'}) is missing 'model'. Set it with configure_node." + ) if not cfg.get("instructions"): - errors.append(f"Agent '{node.id}' ({node.label or 'unnamed'}) is missing 'instructions'. Every agent needs a system prompt.") + errors.append( + f"Agent '{node.id}' ({node.label or 'unnamed'}) is missing 'instructions'. Every agent needs a system prompt." + ) if not cfg.get("description"): warnings.append(f"Agent '{node.id}' ({node.label or 'unnamed'}) has no 'description'.") elif ntype == "tool": @@ -423,19 +431,32 @@ async def auto_layout() -> str: if node: node.position = {"x": float(x), "y": float(layer_y + pos_idx * v_gap)} - return json.dumps({ - "status": "layout_complete", - "layers": len(layers), - "nodes_arranged": sum(len(layer) for layer in layers), - }) + return json.dumps( + { + "status": "layout_complete", + "layers": len(layers), + "nodes_arranged": sum(len(layer) for layer in layers), + } + ) - return [add_node, connect_nodes, configure_node, remove_node, list_nodes, list_edges, clear_canvas, validate_pipeline, auto_layout] + return [ + add_node, + connect_nodes, + configure_node, + remove_node, + list_nodes, + list_edges, + clear_canvas, + validate_pipeline, + auto_layout, + ] # --------------------------------------------------------------------------- # Registry query tools # --------------------------------------------------------------------------- + def create_registry_tools() -> list[BaseTool]: """Create tools that query the framework registries at runtime.""" @@ -447,6 +468,7 @@ def create_registry_tools() -> list[BaseTool]: async def list_registered_agents() -> str: """Query the agent registry for all available agents.""" from fireflyframework_genai.agents.registry import agent_registry + agents = agent_registry.list_agents() return json.dumps( [{"name": a.name, "version": a.version, "description": a.description, "tags": a.tags} for a in agents] @@ -460,9 +482,13 @@ async def list_registered_agents() -> str: async def list_registered_tools() -> str: """Query the tool registry for all available tools.""" from fireflyframework_genai.tools.registry import tool_registry + tools = tool_registry.list_tools() return json.dumps( - [{"name": t.name, "description": t.description, "tags": t.tags, "parameter_count": t.parameter_count} for t in tools] + [ + {"name": t.name, "description": t.description, "tags": t.tags, "parameter_count": t.parameter_count} + for t in tools + ] ) @firefly_tool( @@ -473,6 +499,7 @@ async def list_registered_tools() -> str: async def list_reasoning_patterns() -> str: """Query the reasoning pattern registry.""" from fireflyframework_genai.reasoning.registry import reasoning_registry + patterns = reasoning_registry.list_patterns() return json.dumps(patterns) @@ -496,6 +523,7 @@ async def get_framework_docs() -> str: # Framework version try: from fireflyframework_genai._version import __version__ + docs["version"] = __version__ except Exception: docs["version"] = "unknown" @@ -529,26 +557,25 @@ async def get_framework_docs() -> str: # Agent templates try: from fireflyframework_genai.agents import agent_registry + agents = agent_registry.list_agents() - docs["agent_templates"] = [ - {"name": a.name, "description": a.description} for a in agents - ] + docs["agent_templates"] = [{"name": a.name, "description": a.description} for a in agents] except Exception: docs["agent_templates"] = [] # Registered tools (including custom) try: from fireflyframework_genai.tools.registry import tool_registry as tr + tools = tr.list_tools() - docs["tools"] = [ - {"name": t.name, "description": t.description[:100]} for t in tools - ] + docs["tools"] = [{"name": t.name, "description": t.description[:100]} for t in tools] except Exception: docs["tools"] = [] # Reasoning patterns try: from fireflyframework_genai.reasoning.registry import reasoning_registry + docs["reasoning_patterns"] = reasoning_registry.list_patterns() except Exception: docs["reasoning_patterns"] = [] @@ -557,7 +584,8 @@ async def get_framework_docs() -> str: try: mod = importlib.import_module("fireflyframework_genai.memory") classes = [ - name for name, obj in inspect.getmembers(mod, inspect.isclass) + name + for name, obj in inspect.getmembers(mod, inspect.isclass) if "Memory" in name or "memory" in name.lower() ] docs["memory_classes"] = classes @@ -568,8 +596,7 @@ async def get_framework_docs() -> str: try: mod = importlib.import_module("fireflyframework_genai.pipeline") classes = [ - name for name, obj in inspect.getmembers(mod, inspect.isclass) - if "Node" in name or "Pipeline" in name + name for name, obj in inspect.getmembers(mod, inspect.isclass) if "Node" in name or "Pipeline" in name ] docs["pipeline_classes"] = classes except Exception: @@ -595,17 +622,35 @@ async def read_framework_doc(topic: str) -> str: docs_dir = Path(__file__).resolve().parents[4] / "docs" valid_topics = { - "agents", "architecture", "content", "experiments", "explainability", - "exposure-queues", "exposure-rest", "lab", "memory", "observability", - "pipeline", "prompts", "reasoning", "security", "studio", "templates", - "tools", "tutorial", "use-case-idp", "validation", + "agents", + "architecture", + "content", + "experiments", + "explainability", + "exposure-queues", + "exposure-rest", + "lab", + "memory", + "observability", + "pipeline", + "prompts", + "reasoning", + "security", + "studio", + "templates", + "tools", + "tutorial", + "use-case-idp", + "validation", } if topic not in valid_topics: - return json.dumps({ - "error": f"Unknown topic '{topic}'", - "available_topics": sorted(valid_topics), - }) + return json.dumps( + { + "error": f"Unknown topic '{topic}'", + "available_topics": sorted(valid_topics), + } + ) doc_path = docs_dir / f"{topic}.md" if not doc_path.exists(): @@ -644,10 +689,7 @@ async def get_tool_status() -> str: results = [] for tool_name, required_creds in _tool_credential_map.items(): - configured = [ - c for c in required_creds - if getattr(tc, c, None) - ] + configured = [c for c in required_creds if getattr(tc, c, None)] # Check if tool is registered try: tool_registry.get(tool_name) @@ -655,17 +697,26 @@ async def get_tool_status() -> str: except Exception: registered = False - results.append({ - "name": tool_name, - "registered": registered, - "has_credentials": len(configured) > 0, - "required_credentials": required_creds, - "configured_credentials": configured, - }) + results.append( + { + "name": tool_name, + "registered": registered, + "has_credentials": len(configured) > 0, + "required_credentials": required_creds, + "configured_credentials": configured, + } + ) return json.dumps(results, indent=2) - return [list_registered_agents, list_registered_tools, list_reasoning_patterns, get_framework_docs, read_framework_doc, get_tool_status] + return [ + list_registered_agents, + list_registered_tools, + list_reasoning_patterns, + get_framework_docs, + read_framework_doc, + get_tool_status, + ] # --------------------------------------------------------------------------- @@ -811,13 +862,15 @@ async def present_plan( # WebSocket layer in assistant.py. This tool's return value will be # replaced by the user's actual response before the agent continues. # For now, return a placeholder that signals the plan was presented. - return json.dumps({ - "status": "plan_presented", - "summary": summary, - "steps": steps, - "options": options, - "question": question, - }) + return json.dumps( + { + "status": "plan_presented", + "summary": summary, + "steps": steps, + "options": options, + "question": question, + } + ) return [present_plan] diff --git a/src/fireflyframework_genai/studio/assistant/oracle.py b/src/fireflyframework_genai/studio/assistant/oracle.py index 9f4688a9..0f532f84 100644 --- a/src/fireflyframework_genai/studio/assistant/oracle.py +++ b/src/fireflyframework_genai/studio/assistant/oracle.py @@ -164,12 +164,14 @@ async def analyze_pipeline() -> str: for oid in orphans: node = next((n for n in nodes if n["id"] == oid), None) if node and len(nodes) > 1: - issues.append({ - "severity": "warning", - "title": f"Disconnected node: {node.get('data', {}).get('label', oid)}", - "description": f"Node '{node.get('data', {}).get('label', oid)}' has no connections. It won't participate in the pipeline flow.", - "action": f"Connect node '{oid}' to the pipeline, or remove it if it's not needed.", - }) + issues.append( + { + "severity": "warning", + "title": f"Disconnected node: {node.get('data', {}).get('label', oid)}", + "description": f"Node '{node.get('data', {}).get('label', oid)}' has no connections. It won't participate in the pipeline flow.", + "action": f"Connect node '{oid}' to the pipeline, or remove it if it's not needed.", + } + ) # Check agent nodes for missing model/instructions for node in nodes: @@ -179,37 +181,45 @@ async def analyze_pipeline() -> str: if ntype == "agent": if not data.get("model"): - issues.append({ - "severity": "critical", - "title": f"Agent '{label}' has no model", - "description": "An agent without a model cannot execute. It needs a model like 'openai:gpt-4o'.", - "action": f"Configure node '{node['id']}' with key='model' value='openai:gpt-4o' (or your preferred model).", - }) + issues.append( + { + "severity": "critical", + "title": f"Agent '{label}' has no model", + "description": "An agent without a model cannot execute. It needs a model like 'openai:gpt-4o'.", + "action": f"Configure node '{node['id']}' with key='model' value='openai:gpt-4o' (or your preferred model).", + } + ) if not data.get("instructions"): - issues.append({ - "severity": "suggestion", - "title": f"Agent '{label}' has no instructions", - "description": "Without instructions, the agent has no guidance on how to behave. Consider adding a system prompt.", - "action": f"Configure node '{node['id']}' with key='instructions' and a clear system prompt.", - }) + issues.append( + { + "severity": "suggestion", + "title": f"Agent '{label}' has no instructions", + "description": "Without instructions, the agent has no guidance on how to behave. Consider adding a system prompt.", + "action": f"Configure node '{node['id']}' with key='instructions' and a clear system prompt.", + } + ) elif ntype == "tool": if not data.get("tool_name"): - issues.append({ - "severity": "critical", - "title": f"Tool '{label}' has no tool_name", - "description": "A tool node must specify which registered tool to use.", - "action": f"Configure node '{node['id']}' with key='tool_name' and a valid tool name.", - }) + issues.append( + { + "severity": "critical", + "title": f"Tool '{label}' has no tool_name", + "description": "A tool node must specify which registered tool to use.", + "action": f"Configure node '{node['id']}' with key='tool_name' and a valid tool name.", + } + ) elif ntype == "condition": if not data.get("branches"): - issues.append({ - "severity": "critical", - "title": f"Condition '{label}' has no branches", - "description": "A condition node needs branches to route flow.", - "action": f"Configure node '{node['id']}' with key='branches' as a JSON dict.", - }) + issues.append( + { + "severity": "critical", + "title": f"Condition '{label}' has no branches", + "description": "A condition node needs branches to route flow.", + "action": f"Configure node '{node['id']}' with key='branches' as a JSON dict.", + } + ) return json.dumps({"issues": issues, "node_count": len(nodes), "edge_count": len(edges)}) @@ -251,14 +261,16 @@ async def analyze_node_config(node_id: str) -> str: if not data.get("branches"): missing.append("branches") - return json.dumps({ - "node_id": node_id, - "type": ntype, - "label": data.get("label", node.get("id", "")), - "missing_required": missing, - "recommendations": recommendations, - "is_complete": len(missing) == 0, - }) + return json.dumps( + { + "node_id": node_id, + "type": ntype, + "label": data.get("label", node.get("id", "")), + "missing_required": missing, + "recommendations": recommendations, + "is_complete": len(missing) == 0, + } + ) @firefly_tool( "check_connectivity", @@ -285,14 +297,16 @@ async def check_connectivity() -> str: # Entry points (no incoming edges) entry_points = [nid for nid in node_ids if nid not in targets and nid in sources] - return json.dumps({ - "connected": len(orphans) == 0, - "orphans": orphans, - "dead_ends": dead_ends, - "entry_points": entry_points, - "total_nodes": len(nodes), - "total_edges": len(edges), - }) + return json.dumps( + { + "connected": len(orphans) == 0, + "orphans": orphans, + "dead_ends": dead_ends, + "entry_points": entry_points, + "total_nodes": len(nodes), + "total_edges": len(edges), + } + ) @firefly_tool( "suggest_improvement", @@ -306,13 +320,15 @@ async def suggest_improvement( action_instruction: str = "", ) -> str: """Create a structured suggestion for the user.""" - return json.dumps({ - "type": "suggestion", - "title": title, - "description": description, - "severity": severity, - "action_instruction": action_instruction, - }) + return json.dumps( + { + "type": "suggestion", + "title": title, + "description": description, + "severity": severity, + "action_instruction": action_instruction, + } + ) @firefly_tool( "get_pipeline_stats", @@ -341,13 +357,15 @@ async def get_pipeline_stats() -> str: ): configured += 1 - return json.dumps({ - "total_nodes": len(nodes), - "total_edges": len(edges), - "by_type": type_counts, - "configured_nodes": configured, - "configuration_coverage": f"{configured}/{len(nodes)}" if nodes else "0/0", - }) + return json.dumps( + { + "total_nodes": len(nodes), + "total_edges": len(edges), + "by_type": type_counts, + "configured_nodes": configured, + "configuration_coverage": f"{configured}/{len(nodes)}" if nodes else "0/0", + } + ) @firefly_tool( "review_agent_setup", @@ -379,20 +397,29 @@ async def review_agent_setup() -> str: if target_node and target_node.get("type") == "tool": connected_tools.append(target_node.get("data", {}).get("label", target_id)) - reviews.append({ - "node_id": agent["id"], - "label": label, - "has_model": bool(data.get("model")), - "model": data.get("model", ""), - "has_instructions": bool(data.get("instructions")), - "has_description": bool(data.get("description")), - "connected_tools": connected_tools, - "multimodal_enabled": bool((data.get("multimodal") or {}).get("vision_enabled")), - }) + reviews.append( + { + "node_id": agent["id"], + "label": label, + "has_model": bool(data.get("model")), + "model": data.get("model", ""), + "has_instructions": bool(data.get("instructions")), + "has_description": bool(data.get("description")), + "connected_tools": connected_tools, + "multimodal_enabled": bool((data.get("multimodal") or {}).get("vision_enabled")), + } + ) return json.dumps({"agent_count": len(agents), "reviews": reviews}) - return [analyze_pipeline, analyze_node_config, check_connectivity, suggest_improvement, get_pipeline_stats, review_agent_setup] + return [ + analyze_pipeline, + analyze_node_config, + check_connectivity, + suggest_improvement, + get_pipeline_stats, + review_agent_setup, + ] # --------------------------------------------------------------------------- @@ -404,9 +431,7 @@ def _build_oracle_instructions(user_name: str) -> str: """Build Oracle instructions personalised with the user's name.""" from fireflyframework_genai.studio.assistant.agent import _FRAMEWORK_KNOWLEDGE - personality = _THE_ORACLE_PERSONALITY.replace( - "{user_name_placeholder}", user_name or "friend" - ) + personality = _THE_ORACLE_PERSONALITY.replace("{user_name_placeholder}", user_name or "friend") return ( personality + "\n\n" diff --git a/src/fireflyframework_genai/studio/assistant/smith.py b/src/fireflyframework_genai/studio/assistant/smith.py index d0675331..ee3a43a0 100644 --- a/src/fireflyframework_genai/studio/assistant/smith.py +++ b/src/fireflyframework_genai/studio/assistant/smith.py @@ -48,6 +48,7 @@ def update_canvas_state(nodes: list, edges: list) -> None: _canvas_state["nodes"] = nodes _canvas_state["edges"] = edges + # --------------------------------------------------------------------------- # System prompt — canonical API patterns for code generation # --------------------------------------------------------------------------- @@ -329,9 +330,7 @@ def create_smith_agent(user_name: str = "") -> FireflyAgent: model = _resolve_assistant_model() - instructions = _SMITH_SYSTEM_PROMPT.replace( - "{user_name_placeholder}", user_name or "the user" - ) + instructions = _SMITH_SYSTEM_PROMPT.replace("{user_name_placeholder}", user_name or "the user") agent = FireflyAgent( "smith-codegen", @@ -346,10 +345,10 @@ def create_smith_agent(user_name: str = "") -> FireflyAgent: return agent -_BLOCKED_PATTERNS = ['sudo ', 'rm -rf /', 'chmod 777', 'mkfs ', 'dd if='] -_RISKY_PATTERNS = ['pip install', 'rm ', 'curl ', 'wget '] -_RISKY_CHARS = ['|', '>', ';', '&&', '||'] -_SAFE_PREFIXES = ['python ', 'python3 ', 'pytest', 'pip list', 'pip show', 'pip freeze'] +_BLOCKED_PATTERNS = ["sudo ", "rm -rf /", "chmod 777", "mkfs ", "dd if="] +_RISKY_PATTERNS = ["pip install", "rm ", "curl ", "wget "] +_RISKY_CHARS = ["|", ">", ";", "&&", "||"] +_SAFE_PREFIXES = ["python ", "python3 ", "pytest", "pip list", "pip show", "pip freeze"] def _classify_command(cmd: str) -> str: @@ -357,17 +356,17 @@ def _classify_command(cmd: str) -> str: cmd_stripped = cmd.strip() for pattern in _BLOCKED_PATTERNS: if pattern in cmd_stripped: - return 'blocked' + return "blocked" for pattern in _RISKY_PATTERNS: if pattern in cmd_stripped: - return 'risky' + return "risky" for char in _RISKY_CHARS: if char in cmd_stripped: - return 'risky' + return "risky" for prefix in _SAFE_PREFIXES: if cmd_stripped.startswith(prefix): - return 'safe' - return 'risky' + return "safe" + return "risky" def _create_smith_tools() -> list: @@ -384,6 +383,7 @@ async def get_framework_docs() -> str: docs: dict[str, Any] = {} try: from fireflyframework_genai._version import __version__ + docs["version"] = __version__ except Exception: docs["version"] = "unknown" @@ -405,6 +405,7 @@ async def get_framework_docs() -> str: try: from fireflyframework_genai.tools.registry import tool_registry as tr + tools = tr.list_tools() docs["tools"] = [{"name": t.name, "description": t.description[:80]} for t in tools] except Exception: @@ -412,6 +413,7 @@ async def get_framework_docs() -> str: try: from fireflyframework_genai.reasoning.registry import reasoning_registry + docs["reasoning_patterns"] = reasoning_registry.list_patterns() except Exception: docs["reasoning_patterns"] = [] @@ -433,10 +435,26 @@ async def read_framework_doc(topic: str) -> str: docs_dir = Path(__file__).resolve().parents[4] / "docs" valid_topics = { - "agents", "architecture", "content", "experiments", "explainability", - "exposure-queues", "exposure-rest", "lab", "memory", "observability", - "pipeline", "prompts", "reasoning", "security", "studio", "templates", - "tools", "tutorial", "use-case-idp", "validation", + "agents", + "architecture", + "content", + "experiments", + "explainability", + "exposure-queues", + "exposure-rest", + "lab", + "memory", + "observability", + "pipeline", + "prompts", + "reasoning", + "security", + "studio", + "templates", + "tools", + "tutorial", + "use-case-idp", + "validation", } if topic not in valid_topics: return json.dumps({"error": f"Unknown topic '{topic}'", "available_topics": sorted(valid_topics)}) @@ -468,19 +486,24 @@ async def get_tool_status() -> str: results.append({"name": tool_name, "has_credentials": len(configured) > 0}) return json.dumps(results) - @firefly_tool("validate_python", description="Validate Python code syntax without executing it", auto_register=False) + @firefly_tool( + "validate_python", description="Validate Python code syntax without executing it", auto_register=False + ) async def validate_python(code: str) -> str: import asyncio as _asyncio import os import sys import tempfile - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write(code) tmp_path = f.name try: proc = await _asyncio.create_subprocess_exec( - sys.executable, '-m', 'py_compile', tmp_path, + sys.executable, + "-m", + "py_compile", + tmp_path, stdout=_asyncio.subprocess.PIPE, stderr=_asyncio.subprocess.PIPE, ) @@ -501,12 +524,13 @@ async def run_python(code: str) -> str: import sys import tempfile - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f: + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: f.write(code) tmp_path = f.name try: proc = await _asyncio.create_subprocess_exec( - sys.executable, tmp_path, + sys.executable, + tmp_path, stdout=_asyncio.subprocess.PIPE, stderr=_asyncio.subprocess.PIPE, ) @@ -515,16 +539,20 @@ async def run_python(code: str) -> str: except TimeoutError: proc.kill() await proc.communicate() - return json.dumps({ - "returncode": -1, - "stdout": "", - "stderr": "Execution timed out after 30s", - }) - return json.dumps({ - "returncode": proc.returncode, - "stdout": stdout.decode("utf-8", errors="replace")[:5000], - "stderr": stderr.decode("utf-8", errors="replace")[:2000], - }) + return json.dumps( + { + "returncode": -1, + "stdout": "", + "stderr": "Execution timed out after 30s", + } + ) + return json.dumps( + { + "returncode": proc.returncode, + "stdout": stdout.decode("utf-8", errors="replace")[:5000], + "stderr": stderr.decode("utf-8", errors="replace")[:2000], + } + ) finally: os.unlink(tmp_path) @@ -533,9 +561,9 @@ async def run_shell(command: str) -> str: import asyncio as _asyncio level = _classify_command(command) - if level == 'blocked': + if level == "blocked": return json.dumps({"error": "Command blocked for safety", "command": command}) - if level == 'risky': + if level == "risky": # Return approval_required so the API layer can intercept this # and send a WebSocket message to the frontend. return json.dumps({"approval_required": True, "command": command, "level": level}) @@ -549,16 +577,20 @@ async def run_shell(command: str) -> str: except TimeoutError: proc.kill() await proc.communicate() - return json.dumps({ - "returncode": -1, - "stdout": "", - "stderr": "Command timed out after 30s", - }) - return json.dumps({ - "returncode": proc.returncode, - "stdout": stdout.decode("utf-8", errors="replace")[:5000], - "stderr": stderr.decode("utf-8", errors="replace")[:2000], - }) + return json.dumps( + { + "returncode": -1, + "stdout": "", + "stderr": "Command timed out after 30s", + } + ) + return json.dumps( + { + "returncode": proc.returncode, + "stdout": stdout.decode("utf-8", errors="replace")[:5000], + "stderr": stderr.decode("utf-8", errors="replace")[:2000], + } + ) @firefly_tool("get_canvas_state", description="Get the current canvas pipeline state", auto_register=False) async def get_canvas_state() -> str: @@ -567,16 +599,28 @@ async def get_canvas_state() -> str: @firefly_tool("get_project_info", description="Get current project name and user profile", auto_register=False) async def get_project_info() -> str: from fireflyframework_genai.studio.settings import load_settings + try: settings = load_settings() - return json.dumps({ - "user": settings.user_profile.name, - "model": settings.model_defaults.default_model, - }) + return json.dumps( + { + "user": settings.user_profile.name, + "model": settings.model_defaults.default_model, + } + ) except Exception: return json.dumps({"user": "Unknown", "model": "openai:gpt-4o"}) - return [get_framework_docs, read_framework_doc, get_tool_status, validate_python, run_python, run_shell, get_canvas_state, get_project_info] + return [ + get_framework_docs, + read_framework_doc, + get_tool_status, + validate_python, + run_python, + run_shell, + get_canvas_state, + get_project_info, + ] # --------------------------------------------------------------------------- @@ -588,10 +632,7 @@ def _build_smith_prompt(graph: dict, settings: dict | None = None) -> str: """Convert a graph JSON into a structured prompt for Smith.""" default_model = "openai:gpt-4o" if settings: - default_model = ( - settings.get("model_defaults", {}).get("default_model", default_model) - or default_model - ) + default_model = settings.get("model_defaults", {}).get("default_model", default_model) or default_model lines = [ "Convert this visual pipeline graph into production Python code.", @@ -617,55 +658,63 @@ def _extract_files(text: str) -> list[dict[str, str]]: # Try multi-file format first: --- FILE: path --- file_pattern = re.compile( - r'---\s*FILE:\s*(.+?)\s*---\s*\n```(?:\w+)?\s*\n(.*?)```', + r"---\s*FILE:\s*(.+?)\s*---\s*\n```(?:\w+)?\s*\n(.*?)```", re.DOTALL, ) matches = file_pattern.findall(text) if matches: for path, content in matches: path = path.strip() - if path.endswith('.md'): - lang = 'markdown' - elif path.endswith('.json'): - lang = 'json' - elif path.endswith(('.yaml', '.yml')): - lang = 'yaml' + if path.endswith(".md"): + lang = "markdown" + elif path.endswith(".json"): + lang = "json" + elif path.endswith((".yaml", ".yml")): + lang = "yaml" else: - lang = 'python' - files.append({ - 'path': path, - 'content': content.strip(), - 'language': lang, - }) + lang = "python" + files.append( + { + "path": path, + "content": content.strip(), + "language": lang, + } + ) return files # Fallback: single python code block -> main.py - match = re.search(r'```python\s*\n(.*?)```', text, re.DOTALL) + match = re.search(r"```python\s*\n(.*?)```", text, re.DOTALL) if match: - files.append({ - 'path': 'main.py', - 'content': match.group(1).strip(), - 'language': 'python', - }) + files.append( + { + "path": "main.py", + "content": match.group(1).strip(), + "language": "python", + } + ) return files # Last resort: generic code block - match = re.search(r'```\s*\n(.*?)```', text, re.DOTALL) + match = re.search(r"```\s*\n(.*?)```", text, re.DOTALL) if match: - files.append({ - 'path': 'main.py', - 'content': match.group(1).strip(), - 'language': 'python', - }) + files.append( + { + "path": "main.py", + "content": match.group(1).strip(), + "language": "python", + } + ) return files # No fences -- treat entire text as main.py if it looks like code - if text.strip().startswith(('import ', 'from ', '#', 'async ', 'def ')): - files.append({ - 'path': 'main.py', - 'content': text.strip(), - 'language': 'python', - }) + if text.strip().startswith(("import ", "from ", "#", "async ", "def ")): + files.append( + { + "path": "main.py", + "content": text.strip(), + "language": "python", + } + ) return files diff --git a/src/fireflyframework_genai/studio/codegen/generator.py b/src/fireflyframework_genai/studio/codegen/generator.py index 3e686e47..dab6c4df 100644 --- a/src/fireflyframework_genai/studio/codegen/generator.py +++ b/src/fireflyframework_genai/studio/codegen/generator.py @@ -147,9 +147,7 @@ def _emit_imports(type_set: set[NodeType], has_edges: bool) -> list[str]: step_types.append("CallableStep") if step_types: - imports.append( - f"from fireflyframework_genai.pipeline.steps import {', '.join(sorted(step_types))}" - ) + imports.append(f"from fireflyframework_genai.pipeline.steps import {', '.join(sorted(step_types))}") imports.append("from fireflyframework_genai.pipeline.context import PipelineContext") @@ -184,12 +182,12 @@ def _emit_agent_node(node: GraphNode, default_model: str) -> str: instructions = node.data.get("instructions", "") description = node.data.get("description", "") - parts = [f'{name} = FireflyAgent('] + parts = [f"{name} = FireflyAgent("] parts.append(f' name="{node.id}",') parts.append(f' model="{model}",') - parts.append(f' instructions={_format_string_literal(instructions)},') + parts.append(f" instructions={_format_string_literal(instructions)},") if description: - parts.append(f' description={_format_string_literal(description)},') + parts.append(f" description={_format_string_literal(description)},") parts.append(")") return "\n".join(parts) @@ -202,13 +200,13 @@ def _emit_tool_node(node: GraphNode, _default_model: str) -> str: return f"# TOOL node {node.id!r} is missing 'tool_name' configuration" lines = [ - f'# Tool node: {node.label or node.id}', + f"# Tool node: {node.label or node.id}", f'{name}_tool = tool_registry.get("{tool_name}")', - '', - '', - f'async def {name}_execute(context: PipelineContext, inputs: dict) -> dict:', + "", + "", + f"async def {name}_execute(context: PipelineContext, inputs: dict) -> dict:", f' """Execute tool: {tool_name}."""', - f' return await {name}_tool.execute(**inputs)', + f" return await {name}_tool.execute(**inputs)", ] return "\n".join(lines) @@ -223,15 +221,14 @@ def _emit_reasoning_node(node: GraphNode, _default_model: str) -> str: return f"# REASONING node {node.id!r} is missing 'pattern_name' configuration" lines = [ - f'# Reasoning node: {node.label or node.id}', + f"# Reasoning node: {node.label or node.id}", f'{name}_pattern = reasoning_registry.get("{pattern_name}")', ] if agent_name: lines.append(f'{name}_agent = agent_registry.get("{agent_name}")') else: - lines.append(f'# Note: REASONING node {node.id!r} has no agent_name; ' - f'you may need to assign an agent here') - lines.append(f'{name}_agent = None # TODO: assign an agent') + lines.append(f"# Note: REASONING node {node.id!r} has no agent_name; you may need to assign an agent here") + lines.append(f"{name}_agent = None # TODO: assign an agent") return "\n".join(lines) @@ -248,15 +245,15 @@ def _emit_condition_node(node: GraphNode, _default_model: str) -> str: branches_repr = repr(branches) lines = [ - f'# Condition node: {node.label or node.id}', - f'{name}_branches = {branches_repr}', - f'{name}_default = next(iter({name}_branches.values()))', - '', - '', - f'def {name}_router(inputs: dict) -> str:', + f"# Condition node: {node.label or node.id}", + f"{name}_branches = {branches_repr}", + f"{name}_default = next(iter({name}_branches.values()))", + "", + "", + f"def {name}_router(inputs: dict) -> str:", f' """Route based on key {condition_key!r}."""', f' value = str(inputs.get("{condition_key}", ""))', - f' return {name}_branches.get(value, {name}_default)', + f" return {name}_branches.get(value, {name}_default)", ] return "\n".join(lines) @@ -267,23 +264,27 @@ def _emit_fan_out_node(node: GraphNode, _default_model: str) -> str: field = node.data.get("split_expression", "") lines = [ - f'# Fan-Out node: {node.label or node.id}', - '', - '', - f'def {name}_split(value):', + f"# Fan-Out node: {node.label or node.id}", + "", + "", + f"def {name}_split(value):", ' """Split input for parallel processing."""', ] if field: - lines.extend([ - ' if isinstance(value, dict):', - f' extracted = value.get("{field}", value)', - ' return list(extracted) if isinstance(extracted, list) else [extracted]', - ' return list(value) if isinstance(value, list) else [value]', - ]) + lines.extend( + [ + " if isinstance(value, dict):", + f' extracted = value.get("{field}", value)', + " return list(extracted) if isinstance(extracted, list) else [extracted]", + " return list(value) if isinstance(value, list) else [value]", + ] + ) else: - lines.extend([ - ' return list(value) if isinstance(value, list) else [value]', - ]) + lines.extend( + [ + " return list(value) if isinstance(value, list) else [value]", + ] + ) return "\n".join(lines) @@ -293,24 +294,26 @@ def _emit_fan_in_node(node: GraphNode, _default_model: str) -> str: merge_expr = node.data.get("merge_expression", "collect") lines = [ - f'# Fan-In node: {node.label or node.id}', - '', - '', - f'def {name}_merge(items: list):', + f"# Fan-In node: {node.label or node.id}", + "", + "", + f"def {name}_merge(items: list):", f' """Merge parallel outputs ({merge_expr})."""', ] if merge_expr == "concat": - lines.extend([ - ' result = []', - ' for item in items:', - ' if isinstance(item, list):', - ' result.extend(item)', - ' else:', - ' result.append(item)', - ' return result', - ]) + lines.extend( + [ + " result = []", + " for item in items:", + " if isinstance(item, list):", + " result.extend(item)", + " else:", + " result.append(item)", + " return result", + ] + ) else: - lines.append(' return items') + lines.append(" return items") return "\n".join(lines) @@ -320,32 +323,38 @@ def _emit_memory_node(node: GraphNode, _default_model: str) -> str: action = node.data.get("memory_action", "retrieve") lines = [ - f'# Memory node: {node.label or node.id} (action: {action})', - '', - '', - f'async def {name}_execute(context: PipelineContext, inputs: dict):', + f"# Memory node: {node.label or node.id} (action: {action})", + "", + "", + f"async def {name}_execute(context: PipelineContext, inputs: dict):", f' """Memory operation: {action}."""', - ' memory = context.memory', - ' if memory is None:', + " memory = context.memory", + " if memory is None:", ' return inputs.get("input")', ' key = inputs.get("key", "default")', ] if action == "store": - lines.extend([ - ' value = inputs.get("input", inputs.get("value"))', - ' memory.set_fact(key, value)', - ' return value', - ]) + lines.extend( + [ + ' value = inputs.get("input", inputs.get("value"))', + " memory.set_fact(key, value)", + " return value", + ] + ) elif action == "clear": - lines.extend([ - ' memory.working.delete(key)', - ' return None', - ]) + lines.extend( + [ + " memory.working.delete(key)", + " return None", + ] + ) else: # retrieve - lines.extend([ - ' return memory.get_fact(key)', - ]) + lines.extend( + [ + " return memory.get_fact(key)", + ] + ) return "\n".join(lines) @@ -355,41 +364,51 @@ def _emit_validator_node(node: GraphNode, _default_model: str) -> str: rule = node.data.get("validation_rule", "not_empty") lines = [ - f'# Validator node: {node.label or node.id} (rule: {rule})', - '', - '', - f'async def {name}_validate(context: PipelineContext, inputs: dict):', + f"# Validator node: {node.label or node.id} (rule: {rule})", + "", + "", + f"async def {name}_validate(context: PipelineContext, inputs: dict):", f' """Validate input: {rule}."""', ' value = inputs.get("input", context.inputs)', ] if rule == "not_empty": - lines.extend([ - ' if not value:', - ' raise ValueError("Validation failed: value is empty")', - ]) + lines.extend( + [ + " if not value:", + ' raise ValueError("Validation failed: value is empty")', + ] + ) elif rule == "is_string": - lines.extend([ - ' if not isinstance(value, str):', - ' raise TypeError(f"Expected string, got {type(value).__name__}")', - ]) + lines.extend( + [ + " if not isinstance(value, str):", + ' raise TypeError(f"Expected string, got {type(value).__name__}")', + ] + ) elif rule == "is_list": - lines.extend([ - ' if not isinstance(value, list):', - ' raise TypeError(f"Expected list, got {type(value).__name__}")', - ]) + lines.extend( + [ + " if not isinstance(value, list):", + ' raise TypeError(f"Expected list, got {type(value).__name__}")', + ] + ) elif rule == "is_dict": - lines.extend([ - ' if not isinstance(value, dict):', - ' raise TypeError(f"Expected dict, got {type(value).__name__}")', - ]) + lines.extend( + [ + " if not isinstance(value, dict):", + ' raise TypeError(f"Expected dict, got {type(value).__name__}")', + ] + ) elif rule: - lines.extend([ - f' if isinstance(value, dict) and "{rule}" not in value:', - f' raise KeyError("Missing required key: {rule}")', - ]) - - lines.append(' return value') + lines.extend( + [ + f' if isinstance(value, dict) and "{rule}" not in value:', + f' raise KeyError("Missing required key: {rule}")', + ] + ) + + lines.append(" return value") return "\n".join(lines) @@ -402,17 +421,19 @@ def _emit_custom_code_node(node: GraphNode, _default_model: str) -> str: return f"# CUSTOM_CODE node {node.id!r} has no code defined" lines = [ - f'# Custom Code node: {node.label or node.id}', - '# The code below must define: async def execute(context, inputs) -> Any', + f"# Custom Code node: {node.label or node.id}", + "# The code below must define: async def execute(context, inputs) -> Any", ] # Indent user code inside a namespace to avoid collisions for code_line in code.splitlines(): lines.append(code_line) - lines.extend([ - '', - f'{name}_execute = execute # bind to node variable', - ]) + lines.extend( + [ + "", + f"{name}_execute = execute # bind to node variable", + ] + ) return "\n".join(lines) @@ -422,10 +443,10 @@ def _emit_input_node(node: GraphNode, _default_model: str) -> str: trigger_type = node.data.get("trigger_type", "manual") lines = [ - f'# Input node: {node.label or node.id} (trigger: {trigger_type})', - '', - '', - f'async def {name}_step(context: PipelineContext, inputs: dict):', + f"# Input node: {node.label or node.id} (trigger: {trigger_type})", + "", + "", + f"async def {name}_step(context: PipelineContext, inputs: dict):", f' """Pipeline entry point ({trigger_type} trigger)."""', ' return inputs.get("input", context.inputs)', ] @@ -436,23 +457,23 @@ def _emit_input_node(node: GraphNode, _default_model: str) -> str: if http: method = http.get("method", "POST") path = http.get("path_suffix", "") - lines.insert(1, f'# HTTP config: {method} {path}') + lines.insert(1, f"# HTTP config: {method} {path}") elif trigger_type == "queue": q = node.data.get("queue_config", {}) if q: broker = q.get("broker", "") topic = q.get("topic_or_queue", "") - lines.insert(1, f'# Queue config: {broker} / {topic}') + lines.insert(1, f"# Queue config: {broker} / {topic}") elif trigger_type == "schedule": sched = node.data.get("schedule_config", {}) if sched: cron = sched.get("cron_expression", "") - lines.insert(1, f'# Schedule config: {cron}') + lines.insert(1, f"# Schedule config: {cron}") elif trigger_type == "file_upload": fc = node.data.get("file_config", {}) if fc: types = fc.get("accepted_types", ["*/*"]) - lines.insert(1, f'# File upload config: accepts {types}') + lines.insert(1, f"# File upload config: accepts {types}") return "\n".join(lines) @@ -463,10 +484,10 @@ def _emit_output_node(node: GraphNode, _default_model: str) -> str: dest_type = node.data.get("destination_type", "response") lines = [ - f'# Output node: {node.label or node.id} (destination: {dest_type})', - '', - '', - f'async def {name}_step(context: PipelineContext, inputs: dict):', + f"# Output node: {node.label or node.id} (destination: {dest_type})", + "", + "", + f"async def {name}_step(context: PipelineContext, inputs: dict):", f' """Pipeline exit point ({dest_type} destination)."""', f' context.metadata["_output_config"] = {repr(node.data)}', ' return inputs.get("input", inputs)', @@ -476,13 +497,13 @@ def _emit_output_node(node: GraphNode, _default_model: str) -> str: wh = node.data.get("webhook_config", {}) if wh: url = wh.get("url", "") - lines.insert(1, f'# Webhook config: {url}') + lines.insert(1, f"# Webhook config: {url}") elif dest_type == "store": sc = node.data.get("store_config", {}) if sc: storage = sc.get("storage_type", "file") path = sc.get("path_or_table", "") - lines.insert(1, f'# Store config: {storage} / {path}') + lines.insert(1, f"# Store config: {storage} / {path}") return "\n".join(lines) @@ -491,9 +512,9 @@ def _emit_pipeline_step_node(node: GraphNode, _default_model: str) -> str: """Emit a generic pipeline step (pass-through).""" name = _safe_var(node.id) return ( - f'# Pipeline step: {node.label or node.id}\n' - f'\n\n' - f'async def {name}_step(context: PipelineContext, inputs: dict):\n' + f"# Pipeline step: {node.label or node.id}\n" + f"\n\n" + f"async def {name}_step(context: PipelineContext, inputs: dict):\n" f' """Pass-through step."""\n' f' return inputs.get("input", context.inputs)' ) @@ -584,28 +605,30 @@ def _emit_main_block(graph: GraphModel) -> str: lines = [ 'if __name__ == "__main__":', - '', - ' async def main():', - ' from fireflyframework_genai.pipeline.context import PipelineContext', + "", + " async def main():", + " from fireflyframework_genai.pipeline.context import PipelineContext", ] has_memory = any(n.type == NodeType.MEMORY for n in graph.nodes) if has_memory: lines.append(' memory = MemoryManager(store=FileStore(base_dir="./memory"))') - lines.append(' context = PipelineContext(memory=memory)') + lines.append(" context = PipelineContext(memory=memory)") else: - lines.append(' context = PipelineContext()') + lines.append(" context = PipelineContext()") if input_nodes: trigger = input_nodes[0].data.get("trigger_type", "manual") - lines.append(f' # Trigger type: {trigger}') - - lines.extend([ - ' result = await pipeline.run(context, inputs={"input": "Hello, pipeline!"})', - ' print("Pipeline result:", result)', - '', - ' asyncio.run(main())', - ]) + lines.append(f" # Trigger type: {trigger}") + + lines.extend( + [ + ' result = await pipeline.run(context, inputs={"input": "Hello, pipeline!"})', + ' print("Pipeline result:", result)', + "", + " asyncio.run(main())", + ] + ) return "\n".join(lines) diff --git a/src/fireflyframework_genai/studio/custom_tools.py b/src/fireflyframework_genai/studio/custom_tools.py index a9db49d1..125d4005 100644 --- a/src/fireflyframework_genai/studio/custom_tools.py +++ b/src/fireflyframework_genai/studio/custom_tools.py @@ -93,9 +93,7 @@ class CustomToolManager: """ def __init__(self, base_dir: Path | None = None) -> None: - self._base_dir = ( - base_dir or Path.home() / ".firefly-studio" / "custom_tools" - ).resolve() + self._base_dir = (base_dir or Path.home() / ".firefly-studio" / "custom_tools").resolve() self._base_dir.mkdir(parents=True, exist_ok=True) def _safe_path(self, name: str) -> Path: @@ -152,9 +150,7 @@ def delete(self, name: str) -> None: # -- Runtime tool creation ---------------------------------------------- - def create_runtime_tool( - self, definition: CustomToolDefinition - ) -> _DecoratedTool: + def create_runtime_tool(self, definition: CustomToolDefinition) -> _DecoratedTool: """Convert a definition into a live BaseTool instance.""" tool_name = f"custom:{definition.name}" @@ -178,13 +174,9 @@ def _make_python_handler(self, definition: CustomToolDefinition): """Load an async ``run`` function from a Python file on disk.""" module_path = Path(definition.module_path).resolve() if not module_path.is_file(): - raise FileNotFoundError( - f"Python module not found: {definition.module_path}" - ) + raise FileNotFoundError(f"Python module not found: {definition.module_path}") - spec = importlib.util.spec_from_file_location( - f"custom_tool_{definition.name}", module_path - ) + spec = importlib.util.spec_from_file_location(f"custom_tool_{definition.name}", module_path) if spec is None or spec.loader is None: raise ImportError(f"Cannot load module from {module_path}") @@ -193,9 +185,7 @@ def _make_python_handler(self, definition: CustomToolDefinition): run_fn = getattr(module, "run", None) if run_fn is None: - raise AttributeError( - f"Module {module_path} must define an async 'run' function" - ) + raise AttributeError(f"Module {module_path} must define an async 'run' function") return run_fn def _make_webhook_handler(self, definition: CustomToolDefinition): diff --git a/src/fireflyframework_genai/studio/execution/compiler.py b/src/fireflyframework_genai/studio/execution/compiler.py index abf1644a..5f502bb7 100644 --- a/src/fireflyframework_genai/studio/execution/compiler.py +++ b/src/fireflyframework_genai/studio/execution/compiler.py @@ -78,13 +78,9 @@ def compile_graph( if input_nodes: if len(input_nodes) > 1: - raise CompilationError( - f"Pipeline must have exactly one Input node, found {len(input_nodes)}." - ) + raise CompilationError(f"Pipeline must have exactly one Input node, found {len(input_nodes)}.") if not output_nodes: - raise CompilationError( - "Pipeline with an Input node must have at least one Output node." - ) + raise CompilationError("Pipeline with an Input node must have at least one Output node.") name = graph.metadata.get("name", "studio-pipeline") builder = PipelineBuilder(name=name) diff --git a/src/fireflyframework_genai/studio/execution/io_nodes.py b/src/fireflyframework_genai/studio/execution/io_nodes.py index 3ba9b9f2..5ef380a4 100644 --- a/src/fireflyframework_genai/studio/execution/io_nodes.py +++ b/src/fireflyframework_genai/studio/execution/io_nodes.py @@ -76,9 +76,7 @@ class InputNodeConfig(BaseModel): @classmethod def _validate_trigger_type(cls, v: str) -> str: if v not in _VALID_TRIGGER_TYPES: - raise ValueError( - f"Invalid trigger_type '{v}'. Must be one of: {', '.join(sorted(_VALID_TRIGGER_TYPES))}" - ) + raise ValueError(f"Invalid trigger_type '{v}'. Must be one of: {', '.join(sorted(_VALID_TRIGGER_TYPES))}") return v diff --git a/src/fireflyframework_genai/studio/runtime.py b/src/fireflyframework_genai/studio/runtime.py index e33441de..6f96dc49 100644 --- a/src/fireflyframework_genai/studio/runtime.py +++ b/src/fireflyframework_genai/studio/runtime.py @@ -51,9 +51,11 @@ async def start(self, graph: GraphModel) -> None: await self._start_scheduler() self.status = "running" - logger.info("ProjectRuntime '%s' started (trigger=%s)", - self.project_name, - self._input_config.trigger_type if self._input_config else "none") + logger.info( + "ProjectRuntime '%s' started (trigger=%s)", + self.project_name, + self._input_config.trigger_type if self._input_config else "none", + ) async def stop(self) -> None: """Gracefully stop all background processes.""" @@ -103,6 +105,7 @@ async def _start_queue_consumer(self) -> None: if qc.broker == "kafka": from fireflyframework_genai.exposure.queues.kafka import KafkaAgentConsumer + consumer = KafkaAgentConsumer( agent_name, topic=qc.topic_or_queue, @@ -111,6 +114,7 @@ async def _start_queue_consumer(self) -> None: ) elif qc.broker == "rabbitmq": from fireflyframework_genai.exposure.queues.rabbitmq import RabbitMQAgentConsumer + consumer = RabbitMQAgentConsumer( agent_name, queue_name=qc.topic_or_queue, @@ -118,6 +122,7 @@ async def _start_queue_consumer(self) -> None: ) elif qc.broker == "redis": from fireflyframework_genai.exposure.queues.redis import RedisAgentConsumer + consumer = RedisAgentConsumer( agent_name, channel=qc.topic_or_queue, diff --git a/src/fireflyframework_genai/studio/server.py b/src/fireflyframework_genai/studio/server.py index 79e0dfd2..cfdc03a8 100644 --- a/src/fireflyframework_genai/studio/server.py +++ b/src/fireflyframework_genai/studio/server.py @@ -133,10 +133,12 @@ async def health() -> dict[str, str]: # -- Per-project runtime & execution API ------------------------------- from fireflyframework_genai.studio.api.project_api import create_project_api_router + app.include_router(create_project_api_router(project_manager)) # -- Version history endpoints ------------------------------------------- from fireflyframework_genai.studio.api.projects import create_versioning_router + app.include_router(create_versioning_router(project_manager)) # -- Custom tools endpoints -------------------------------------------- diff --git a/src/fireflyframework_genai/studio/versioning.py b/src/fireflyframework_genai/studio/versioning.py index 93a97a19..4fd462ab 100644 --- a/src/fireflyframework_genai/studio/versioning.py +++ b/src/fireflyframework_genai/studio/versioning.py @@ -50,11 +50,14 @@ def commit(self, message: str) -> str: return sha def get_history(self, limit: int = 50) -> list[dict]: - result = self._run([ - "git", "log", - f"--max-count={limit}", - "--format=%H|%s|%aI", - ]) + result = self._run( + [ + "git", + "log", + f"--max-count={limit}", + "--format=%H|%s|%aI", + ] + ) if result.returncode != 0: return [] @@ -67,12 +70,14 @@ def get_history(self, limit: int = 50) -> list[dict]: continue parts = line.split("|", 2) if len(parts) >= 3: - history.append({ - "sha": parts[0], - "message": parts[1], - "timestamp": parts[2], - "bookmarked": parts[0] in bookmarks, - }) + history.append( + { + "sha": parts[0], + "message": parts[1], + "timestamp": parts[2], + "bookmarked": parts[0] in bookmarks, + } + ) return history def restore(self, commit_sha: str) -> None: diff --git a/tests/test_studio/test_compiler_io.py b/tests/test_studio/test_compiler_io.py index 2e63599a..fdbaefaf 100644 --- a/tests/test_studio/test_compiler_io.py +++ b/tests/test_studio/test_compiler_io.py @@ -63,8 +63,13 @@ def test_multiple_input_nodes_raises(self): graph = GraphModel( nodes=[ _input_node(), - GraphNode(id="input_2", type=NodeType.INPUT, label="Input 2", - position={"x": 0, "y": 400}, data={"trigger_type": "http"}), + GraphNode( + id="input_2", + type=NodeType.INPUT, + label="Input 2", + position={"x": 0, "y": 400}, + data={"trigger_type": "http"}, + ), _step_node(), _output_node(), ], @@ -103,11 +108,16 @@ def test_input_node_with_schema_validates(self): def test_pipeline_without_io_nodes_still_works(self): """Backward compatibility: pipelines without IO nodes should still compile.""" graph = GraphModel( - nodes=[_step_node("s1"), GraphNode( - id="s2", type=NodeType.PIPELINE_STEP, label="Step 2", - position={"x": 300, "y": 200}, - data={}, - )], + nodes=[ + _step_node("s1"), + GraphNode( + id="s2", + type=NodeType.PIPELINE_STEP, + label="Step 2", + position={"x": 300, "y": 200}, + data={}, + ), + ], edges=[GraphEdge(id="e1", source="s1", target="s2")], ) engine = compile_graph(graph) diff --git a/tests/test_studio/test_dynamic_model.py b/tests/test_studio/test_dynamic_model.py index 67f45f29..94395ce1 100644 --- a/tests/test_studio/test_dynamic_model.py +++ b/tests/test_studio/test_dynamic_model.py @@ -11,10 +11,14 @@ class TestDynamicDefaultModel: def test_get_default_model_reads_settings(self, tmp_path): settings_file = tmp_path / "settings.json" - settings_file.write_text(json.dumps({ - "model_defaults": {"default_model": "anthropic:claude-sonnet-4-20250514"}, - "setup_complete": True, - })) + settings_file.write_text( + json.dumps( + { + "model_defaults": {"default_model": "anthropic:claude-sonnet-4-20250514"}, + "setup_complete": True, + } + ) + ) model = _get_default_model(settings_path=settings_file) assert model == "anthropic:claude-sonnet-4-20250514" @@ -24,12 +28,18 @@ def test_get_default_model_fallback(self, tmp_path): def test_codegen_uses_settings_model(self, tmp_path): settings_file = tmp_path / "settings.json" - settings_file.write_text(json.dumps({ - "model_defaults": {"default_model": "google-gla:gemini-2.5-flash"}, - })) + settings_file.write_text( + json.dumps( + { + "model_defaults": {"default_model": "google-gla:gemini-2.5-flash"}, + } + ) + ) node = GraphNode( - id="agent_1", type=NodeType.AGENT, label="Agent", + id="agent_1", + type=NodeType.AGENT, + label="Agent", position={"x": 0, "y": 0}, data={"instructions": "Help the user."}, # No model specified ) @@ -41,12 +51,17 @@ def test_codegen_uses_settings_model(self, tmp_path): class TestArchitectDefaultModel: def test_architect_instructions_include_default_model(self, tmp_path): settings_file = tmp_path / "settings.json" - settings_file.write_text(json.dumps({ - "model_defaults": {"default_model": "anthropic:claude-sonnet-4-20250514"}, - "user_profile": {"name": "TestUser"}, - "setup_complete": True, - })) + settings_file.write_text( + json.dumps( + { + "model_defaults": {"default_model": "anthropic:claude-sonnet-4-20250514"}, + "user_profile": {"name": "TestUser"}, + "setup_complete": True, + } + ) + ) from fireflyframework_genai.studio.assistant.agent import _build_instructions + instructions = _build_instructions(settings_path=settings_file) assert "anthropic:claude-sonnet-4-20250514" in instructions diff --git a/tests/test_studio/test_graphql.py b/tests/test_studio/test_graphql.py index 35edc64d..60283069 100644 --- a/tests/test_studio/test_graphql.py +++ b/tests/test_studio/test_graphql.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the GraphQL API endpoint.""" + from __future__ import annotations from pathlib import Path @@ -68,9 +69,7 @@ async def test_graphql_projects_query_empty(self, client: httpx.AsyncClient): async def test_graphql_projects_query_after_create(self, client: httpx.AsyncClient): """Creating a project via REST and querying via GraphQL returns it.""" # Create a project via the REST API - create_resp = await client.post( - "/api/projects", json={"name": "gql-test", "description": "GraphQL test"} - ) + create_resp = await client.post("/api/projects", json={"name": "gql-test", "description": "GraphQL test"}) assert create_resp.status_code == 200 resp = await client.post( @@ -86,15 +85,11 @@ async def test_graphql_projects_query_after_create(self, client: httpx.AsyncClie async def test_graphql_project_by_name(self, client: httpx.AsyncClient): """Querying a single project by name returns it.""" - await client.post( - "/api/projects", json={"name": "single-test", "description": "Single"} - ) + await client.post("/api/projects", json={"name": "single-test", "description": "Single"}) resp = await client.post( "/api/graphql", - json={ - "query": '{ project(name: "single-test") { name description } }' - }, + json={"query": '{ project(name: "single-test") { name description } }'}, ) assert resp.status_code == 200 body = resp.json() @@ -104,9 +99,7 @@ async def test_graphql_project_not_found(self, client: httpx.AsyncClient): """Querying a non-existent project returns null.""" resp = await client.post( "/api/graphql", - json={ - "query": '{ project(name: "nonexistent") { name } }' - }, + json={"query": '{ project(name: "nonexistent") { name } }'}, ) assert resp.status_code == 200 body = resp.json() @@ -116,9 +109,7 @@ async def test_graphql_runtime_status_stopped(self, client: httpx.AsyncClient): """Runtime status for a project with no active runtime is 'stopped'.""" resp = await client.post( "/api/graphql", - json={ - "query": '{ runtimeStatus(project: "any-project") { project status consumers schedulerActive } }' - }, + json={"query": '{ runtimeStatus(project: "any-project") { project status consumers schedulerActive } }'}, ) assert resp.status_code == 200 body = resp.json() @@ -152,9 +143,7 @@ async def test_graphql_query_type_fields(self, client: httpx.AsyncClient): """The Query type exposes expected field names.""" resp = await client.post( "/api/graphql", - json={ - "query": '{ __type(name: "Query") { fields { name } } }' - }, + json={"query": '{ __type(name: "Query") { fields { name } } }'}, ) assert resp.status_code == 200 body = resp.json() @@ -174,9 +163,7 @@ async def test_run_pipeline_missing_project(self, client: httpx.AsyncClient): """Running a pipeline for a non-existent project returns error status.""" resp = await client.post( "/api/graphql", - json={ - "query": 'mutation { runPipeline(project: "nope", input: "hello") { executionId status result } }' - }, + json={"query": 'mutation { runPipeline(project: "nope", input: "hello") { executionId status result } }'}, ) assert resp.status_code == 200 body = resp.json() @@ -186,9 +173,7 @@ async def test_run_pipeline_missing_project(self, client: httpx.AsyncClient): async def test_run_pipeline_no_pipeline_saved(self, client: httpx.AsyncClient): """Running a pipeline when no pipeline JSON exists returns error.""" - await client.post( - "/api/projects", json={"name": "empty-proj"} - ) + await client.post("/api/projects", json={"name": "empty-proj"}) resp = await client.post( "/api/graphql", json={ diff --git a/tests/test_studio/test_pipeline_events.py b/tests/test_studio/test_pipeline_events.py index 4040ce15..8346e235 100644 --- a/tests/test_studio/test_pipeline_events.py +++ b/tests/test_studio/test_pipeline_events.py @@ -53,9 +53,7 @@ def _make_node( label: str = "test-node", data: dict | None = None, ) -> GraphNode: - return GraphNode( - id=node_id, type=node_type, label=label, position=_POS, data=data or {} - ) + return GraphNode(id=node_id, type=node_type, label=label, position=_POS, data=data or {}) def _mock_agent(name: str = "mock-agent") -> MagicMock: diff --git a/tests/test_studio/test_pipeline_integration.py b/tests/test_studio/test_pipeline_integration.py index 2d6b7440..83b82399 100644 --- a/tests/test_studio/test_pipeline_integration.py +++ b/tests/test_studio/test_pipeline_integration.py @@ -65,9 +65,7 @@ def _make_node( label: str = "test-node", data: dict | None = None, ) -> GraphNode: - return GraphNode( - id=node_id, type=node_type, label=label, position=_POS, data=data or {} - ) + return GraphNode(id=node_id, type=node_type, label=label, position=_POS, data=data or {}) def _mock_agent(name: str = "mock-agent") -> MagicMock: diff --git a/tests/test_studio/test_project_api.py b/tests/test_studio/test_project_api.py index 1396ac56..1aaad6f1 100644 --- a/tests/test_studio/test_project_api.py +++ b/tests/test_studio/test_project_api.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for per-project auto-generated API endpoints.""" + from __future__ import annotations from pathlib import Path diff --git a/tests/test_studio/test_project_runtime.py b/tests/test_studio/test_project_runtime.py index 3da985c1..34c59d1f 100644 --- a/tests/test_studio/test_project_runtime.py +++ b/tests/test_studio/test_project_runtime.py @@ -16,13 +16,27 @@ def _simple_graph() -> GraphModel: return GraphModel( nodes=[ - GraphNode(id="input_1", type=NodeType.INPUT, label="Input", - position={"x": 0, "y": 200}, data={"trigger_type": "manual"}), - GraphNode(id="agent_1", type=NodeType.AGENT, label="Agent", - position={"x": 300, "y": 200}, - data={"model": "openai:gpt-4o", "instructions": "Echo input."}), - GraphNode(id="output_1", type=NodeType.OUTPUT, label="Output", - position={"x": 600, "y": 200}, data={"destination_type": "response"}), + GraphNode( + id="input_1", + type=NodeType.INPUT, + label="Input", + position={"x": 0, "y": 200}, + data={"trigger_type": "manual"}, + ), + GraphNode( + id="agent_1", + type=NodeType.AGENT, + label="Agent", + position={"x": 300, "y": 200}, + data={"model": "openai:gpt-4o", "instructions": "Echo input."}, + ), + GraphNode( + id="output_1", + type=NodeType.OUTPUT, + label="Output", + position={"x": 600, "y": 200}, + data={"destination_type": "response"}, + ), ], edges=[ GraphEdge(id="e1", source="input_1", target="agent_1"), diff --git a/tests/test_studio/test_tunnel_api.py b/tests/test_studio/test_tunnel_api.py index 03d3c392..5a2c7b2c 100644 --- a/tests/test_studio/test_tunnel_api.py +++ b/tests/test_studio/test_tunnel_api.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for the tunnel API endpoints.""" + from __future__ import annotations import pytest