Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: '3.8'

services:
controller:
build:
Expand All @@ -8,7 +6,8 @@ services:
ports:
- "8000:8000"
environment:
- WORKER_URLS=http://worker1:8001,http://worker2:8002
- WORKER_URLS=http://worker1:8000,http://worker2:8000
command: python -m http.server 8000
depends_on:
- worker1
- worker2
Expand All @@ -22,6 +21,7 @@ services:
environment:
- PORT=8000
- HOST=0.0.0.0
command: python prototype/worker.py --host 0.0.0.0 --port 8000

worker2:
build:
Expand All @@ -32,6 +32,7 @@ services:
environment:
- PORT=8000
- HOST=0.0.0.0
command: python prototype/worker.py --host 0.0.0.0 --port 8000

# Secure worker for PQC testing
worker_secure:
Expand All @@ -43,3 +44,4 @@ services:
environment:
- PORT=8000
- HOST=0.0.0.0
command: python prototype/worker.py --host 0.0.0.0 --port 8000
13 changes: 5 additions & 8 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# - mohawk-gui service: API/health endpoint on port 8003
# - mohawk-worker service: Inference worker on port 8004

version: '3.8'

services:
mohawk-gui:
build:
Expand All @@ -32,10 +30,10 @@ services:
- PYTHONUNBUFFERED=1
- PYTHONDONTWRITEBYTECODE=1
- QT_QPA_PLATFORM=offscreen
# Health check service running in container
command: python -c "import time; print('GUI service ready'); [print(f'.', end='', flush=True) or time.sleep(1) for _ in range(300)]"
# Lightweight API backend used by the desktop GUI for local verification.
command: python -m uvicorn mohawk_gui.mock_backend:app --host 0.0.0.0 --port 8003
healthcheck:
test: ["CMD", "python", "-c", "print('healthy')"]
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8003/health', timeout=3)"]
interval: 30s
timeout: 10s
retries: 3
Expand All @@ -54,10 +52,9 @@ services:
environment:
- PYTHONUNBUFFERED=1
- QT_QPA_PLATFORM=offscreen
# Health check service running in container
command: python -c "import time; print('Worker service ready'); [print(f'.', end='', flush=True) or time.sleep(1) for _ in range(300)]"
command: python prototype/worker.py --host 0.0.0.0 --port 8003
healthcheck:
test: ["CMD", "python", "-c", "print('healthy')"]
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8003/health', timeout=3)"]
interval: 30s
timeout: 10s
retries: 3
Expand Down
76 changes: 68 additions & 8 deletions mohawk/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import os
import logging
import json
from pathlib import Path
from typing import Optional, Dict, Any
from enum import Enum
Expand Down Expand Up @@ -44,7 +45,51 @@ def __init__(self, cache_dir: Optional[str] = None):
"""
self.cache_dir = Path(cache_dir) if cache_dir else Path.home() / ".mohawk" / "models"
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.library_file = self.cache_dir / "library.json"
self._library = self._load_library_index()
logger.info(f"ModelLoader initialized with cache_dir={self.cache_dir}")

def _load_library_index(self) -> Dict[str, Dict[str, Any]]:
"""Load persisted model library metadata."""
if not self.library_file.exists():
return {}

try:
with self.library_file.open("r", encoding="utf-8") as f:
data = json.load(f)
return data if isinstance(data, dict) else {}
except Exception:
logger.warning("Failed to read model library index; starting fresh")
return {}

def _save_library_index(self) -> None:
"""Persist model library metadata to disk."""
with self.library_file.open("w", encoding="utf-8") as f:
json.dump(self._library, f, indent=2, sort_keys=True)

def add_to_library(self, model_id: str, local_path: str, source: str) -> Dict[str, Any]:
"""Register a model in the local model library index."""
entry = {
"model_id": model_id,
"local_path": str(local_path),
"source": source,
}
self._library[model_id] = entry
self._save_library_index()
return entry

def add_local_model(self, model_path: str, alias: Optional[str] = None) -> Dict[str, Any]:
"""Add an existing local model directory or file to the model library."""
path = Path(model_path)
if not path.exists():
raise FileNotFoundError(f"Local model path not found: {model_path}")

model_id = alias or path.name
return self.add_to_library(model_id=model_id, local_path=str(path), source="local")

def list_library(self) -> list:
"""List registered models in the local model library."""
return list(self._library.values())

def detect_format(self, model_path: str) -> ModelFormat:
"""
Expand Down Expand Up @@ -134,13 +179,19 @@ def _load_huggingface(self, model_path: str, **kwargs) -> Dict[str, Any]:
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
raise ImportError("transformers required for HuggingFace models. Install with: pip install transformers")

tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)

tokenizer_kwargs = dict(kwargs.pop("tokenizer_kwargs", {}))
tokenizer_kwargs.setdefault("local_files_only", kwargs.get("local_files_only", False))

model_kwargs = dict(kwargs.pop("model_kwargs", {}))
model_kwargs.setdefault("torch_dtype", kwargs.pop("torch_dtype", "auto"))
model_kwargs.setdefault("device_map", kwargs.pop("device_map", "auto"))
model_kwargs.update(kwargs)

tokenizer = AutoTokenizer.from_pretrained(model_path, **tokenizer_kwargs)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=kwargs.get("torch_dtype", "auto"),
device_map=kwargs.get("device_map", "auto"),
**kwargs,
**model_kwargs,
)

return {"model": model, "tokenizer": tokenizer, "format": "huggingface"}
Expand Down Expand Up @@ -172,17 +223,26 @@ def download(self, model_id: str, **kwargs) -> str:
Returns:
Local path to downloaded model
"""
if not model_id or not model_id.strip():
raise ValueError("model_id must be a non-empty HuggingFace repository ID")

from huggingface_hub import snapshot_download

cache_path = self.cache_dir / model_id.replace("/", "_")
model_id = model_id.strip()
cache_path = self.cache_dir / model_id.replace("/", "--")

logger.info(f"Downloading {model_id} to {cache_path}")

download_kwargs = dict(kwargs)
download_kwargs.setdefault("local_dir", str(cache_path))
download_kwargs.setdefault("local_dir_use_symlinks", False)

local_path = snapshot_download(
repo_id=model_id,
local_dir=str(cache_path),
**kwargs,
**download_kwargs,
)

self.add_to_library(model_id=model_id, local_path=local_path, source="huggingface")

return local_path

Expand Down
3 changes: 2 additions & 1 deletion mohawk_gui/metrics_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def __init__(self):
def get_or_create_buffer(self, session_id: str) -> MetricsBuffer:
"""Get existing buffer or create new one for session."""
if session_id not in self.session_buffers:
self.session_buffers[session_id] = MetricsBuffer()
# Aggregation should be deterministic across sessions; avoid sampling loss.
self.session_buffers[session_id] = MetricsBuffer(sample_rate=1.0)
return self.session_buffers[session_id]

async def add_metrics(self, session_id: str, metrics: Dict[str, Any]):
Expand Down
77 changes: 77 additions & 0 deletions mohawk_gui/mock_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Lightweight backend API used by docker-compose for local GUI verification."""

from datetime import datetime
from fastapi import FastAPI

app = FastAPI(title="Mohawk GUI Mock Backend", version="1.0.0")


@app.get("/health")
async def health() -> dict:
return {"status": "healthy", "service": "mohawk-gui-backend"}


@app.get("/api/workers")
async def list_workers() -> dict:
return {
"workers": [
{
"id": "worker_0",
"host": "localhost",
"port": 8003,
"status": "Connected",
"model": "Llama-3-8B",
"threads": 8,
"load": 25,
},
{
"id": "worker_1",
"host": "localhost",
"port": 8004,
"status": "Connected",
"model": "Mistral-7B",
"threads": 8,
"load": 18,
},
]
}


@app.post("/api/workers/connect")
async def connect_workers() -> dict:
return {"status": "ok", "connected": 2}


@app.get("/api/metrics")
async def metrics() -> dict:
# Keep values in GUI progress bar ranges.
return {
"metrics": {
"throughput": 420,
"cpu": 31,
"memory": 44,
"gpu": 27,
"timestamp": datetime.utcnow().isoformat() + "Z",
}
}


@app.get("/api/sessions")
async def sessions() -> dict:
return {
"sessions": [
{
"id": "sess_001",
"model": "Llama-3-8B",
"status": "Running",
"throughput": 420,
"latency": 23,
"tokens": 1980,
}
]
}


@app.post("/api/sessions/{session_id}/cancel")
async def cancel_session(session_id: str) -> dict:
return {"status": "ok", "cancelled": session_id}
Loading
Loading