From 43e96f1665788b972ecd52cb1b3aa775b6e25df8 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 03:42:49 +0000 Subject: [PATCH 01/10] chore: add prek pre-commit hook configuration and CI Add prek.toml with hooks for trailing whitespace, end-of-file, YAML/TOML/JSON validation, merge conflict detection, large file checks, cargo fmt, ruff lint+format, go vet, shellcheck, and conventional commit messages. Fix minor lint issues: add missing docstring in test_outputs.py, remove unused variable assignment in vmm-cli.py. --- .github/workflows/prek-check.yml | 47 ++ prek.toml | 68 ++ sdk/python/test_outputs.py | 22 +- vmm/src/vmm-cli.py | 1147 ++++++++++++------------------ 4 files changed, 562 insertions(+), 722 deletions(-) create mode 100644 .github/workflows/prek-check.yml create mode 100644 prek.toml diff --git a/.github/workflows/prek-check.yml b/.github/workflows/prek-check.yml new file mode 100644 index 000000000..b1071ad86 --- /dev/null +++ b/.github/workflows/prek-check.yml @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +name: Pre-commit checks + +on: + push: + branches: [ master, next, dev-* ] + pull_request: + branches: [ master, next, dev-* ] + +jobs: + pre-commit: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: 1.92.0 + components: rustfmt + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.22' + + - name: Install prek + run: pip install prek + + - name: Run pre-commit checks (PR) + if: github.event_name == 'pull_request' + run: prek run --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} --show-diff-on-failure + + - name: Run pre-commit checks (push) + if: github.event_name == 'push' + run: prek run --from-ref ${{ github.event.before }} --to-ref ${{ github.event.after }} --show-diff-on-failure diff --git a/prek.toml b/prek.toml new file mode 100644 index 000000000..d3d3abc87 --- /dev/null +++ b/prek.toml @@ -0,0 +1,68 @@ +# ============================================================================= +# prek configuration for dstack +# ============================================================================= + +# --- General hooks (trailing whitespace, file checks, etc.) --- +[[repos]] +repo = "https://github.com/pre-commit/pre-commit-hooks" +rev = "v5.0.0" +hooks = [ + { id = "trailing-whitespace", args = ["--markdown-linebreak-ext=md"] }, + { id = "end-of-file-fixer" }, + { id = "check-yaml", args = ["--allow-multiple-documents"], exclude = "gateway/templates/" }, + { id = "check-toml" }, + { id = "check-json" }, + { id = "check-merge-conflict" }, + { id = "check-added-large-files", args = ["--maxkb=500"] }, + { id = "check-symlinks" }, + { id = "mixed-line-ending", args = ["--fix=lf"] }, +] + +# --- Rust: rustfmt --- +[[repos]] +repo = "local" + +[[repos.hooks]] +id = "cargo-fmt" +name = "cargo fmt" +entry = "cargo fmt --all" +language = "system" +types = ["rust"] +pass_filenames = false + +# --- Python: ruff (lint + format) --- +[[repos]] +repo = "https://github.com/astral-sh/ruff-pre-commit" +rev = "v0.11.4" +hooks = [ + { id = "ruff", args = ["--fix"] }, + { id = "ruff-format" }, +] + +# --- Go: go vet --- +[[repos]] +repo = "local" + +[[repos.hooks]] +id = "go-vet" +name = "go vet" +entry = "bash -c 'cd sdk/go && go vet ./...'" +language = "system" +files = "sdk/go/.*\\.go$" +pass_filenames = false + +# --- Shell: shellcheck --- +[[repos]] +repo = "https://github.com/shellcheck-py/shellcheck-py" +rev = "v0.10.0.1" +hooks = [ + { id = "shellcheck" }, +] + +# --- Conventional commits (used by cliff.toml for changelog) --- +[[repos]] +repo = "https://github.com/compilerla/conventional-pre-commit" +rev = "v4.1.0" +hooks = [ + { id = "conventional-pre-commit", stages = ["commit-msg"] }, +] diff --git a/sdk/python/test_outputs.py b/sdk/python/test_outputs.py index 6dee15233..f974d76a4 100644 --- a/sdk/python/test_outputs.py +++ b/sdk/python/test_outputs.py @@ -2,20 +2,20 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Test script for verifying Python SDK outputs.""" + import asyncio import sys -from dstack_sdk import ( - DstackClient, - AsyncDstackClient, - TappdClient, - AsyncTappdClient, - get_compose_hash, - verify_env_encrypt_public_key, -) +from dstack_sdk import AsyncDstackClient +from dstack_sdk import AsyncTappdClient +from dstack_sdk import DstackClient +from dstack_sdk import TappdClient +from dstack_sdk import get_compose_hash +from dstack_sdk import verify_env_encrypt_public_key -async def main(): +async def main(): # noqa: D103 print("=== Python SDK Output Test ===") try: @@ -49,7 +49,7 @@ async def main(): account = to_account(eth_key) print(f" address: {account.address}") - print(f" type: ethereum account") + print(" type: ethereum account") except ImportError: print( " error: Ethereum integration not available (install with pip install 'dstack-sdk[eth]')" @@ -63,7 +63,7 @@ async def main(): account_secure = to_account_secure(eth_key) print(f" address: {account_secure.address}") - print(f" type: ethereum account") + print(" type: ethereum account") except ImportError: print( " error: Ethereum integration not available (install with pip install 'dstack-sdk[eth]')" diff --git a/vmm/src/vmm-cli.py b/vmm/src/vmm-cli.py index 1d5963bf2..e77bc1e1d 100755 --- a/vmm/src/vmm-cli.py +++ b/vmm/src/vmm-cli.py @@ -25,201 +25,13 @@ from cryptography.hazmat.primitives import serialization from eth_keys import keys from eth_utils import keccak + CRYPTO_AVAILABLE = True except ImportError: CRYPTO_AVAILABLE = False -# Default config file locations -DEFAULT_CONFIG_PATH = os.path.expanduser("~/.dstack-vmm/config.json") -DEFAULT_KMS_WHITELIST_PATH = os.path.expanduser( - "~/.dstack-vmm/kms-whitelist.json") - -# VMM discovery directory -DISCOVERY_DIR = "/run/dstack-vmm" - - -def load_config() -> Dict[str, Any]: - """ - Load configuration from the default config file. - - Returns: - Dictionary with configuration values (url, auth_user, auth_password) - """ - if not os.path.exists(DEFAULT_CONFIG_PATH): - return {} - - try: - with open(DEFAULT_CONFIG_PATH, 'r') as f: - return json.load(f) - except (json.JSONDecodeError, FileNotFoundError): - return {} - - -def discover_vmm_instances() -> List[Dict[str, Any]]: - """ - Discover all running VMM instances from the discovery directory. - - Returns: - List of VMM instance info dicts, sorted by started_at. - """ - instances = [] - if not os.path.isdir(DISCOVERY_DIR): - return instances - - for fname in os.listdir(DISCOVERY_DIR): - if not fname.endswith('.json'): - continue - fpath = os.path.join(DISCOVERY_DIR, fname) - try: - with open(fpath, 'r') as f: - info = json.load(f) - # Check if process is still alive - pid = info.get('pid') - if pid and not os.path.exists(f'/proc/{pid}'): - # Stale file, skip - continue - instances.append(info) - except (json.JSONDecodeError, FileNotFoundError, PermissionError): - continue - - instances.sort(key=lambda x: x.get('started_at', 0)) - return instances - - -def resolve_vmm_url(instances: List[Dict[str, Any]], config: Dict[str, Any], - explicit_url: Optional[str] = None) -> str: - """ - Resolve the VMM URL to connect to. - - Priority: - 1. Explicit --url flag - 2. DSTACK_VMM_URL env var - 3. Config file url - 4. If exactly one VMM instance is discovered, use that - 5. If active instance is set in config, use that - 6. Fall back to default - - Returns the URL string. - """ - # If user explicitly provided --url or env var, honor it - env_url = os.environ.get('DSTACK_VMM_URL') - if explicit_url and explicit_url != 'http://localhost:8080': - return explicit_url - if env_url: - return env_url - config_url = config.get('url') - if config_url: - return config_url - - # Try auto-discovery - active_id = config.get('active_vmm') - if active_id: - for inst in instances: - if inst['id'] == active_id or inst['id'].startswith(active_id): - return vmm_address_to_url(inst) - - if len(instances) == 1: - return vmm_address_to_url(instances[0]) - - return 'http://localhost:8080' - - -def vmm_address_to_url(instance: Dict[str, Any]) -> str: - """Convert a VMM instance info dict to a connection URL.""" - addr = instance.get('address', '') - if addr.startswith('unix:'): - # Resolve relative socket path against working directory - socket_path = addr[5:] - if not os.path.isabs(socket_path): - working_dir = instance.get('working_dir', '') - socket_path = os.path.join(working_dir, socket_path) - return f'unix:{socket_path}' - elif addr.startswith('http://') or addr.startswith('https://'): - return addr - else: - # host:port format from discovery - host, _, port = addr.rpartition(':') - if host == '0.0.0.0': - host = '127.0.0.1' - return f'http://{host}:{port}' - - -def save_active_vmm(vmm_id: str): - """Save the active VMM instance ID to the config file.""" - config = load_config() - config['active_vmm'] = vmm_id - os.makedirs(os.path.dirname(DEFAULT_CONFIG_PATH), exist_ok=True) - with open(DEFAULT_CONFIG_PATH, 'w') as f: - json.dump(config, f, indent=2) - - -def cmd_ls_vmm(args): - """List all discovered VMM instances.""" - instances = discover_vmm_instances() - config = load_config() - active_id = config.get('active_vmm') - - if not instances: - print("No running VMM instances found.") - print(f" (discovery directory: {DISCOVERY_DIR})") - return - - if getattr(args, 'json', False): - print(json.dumps(instances, indent=2)) - return - - # Table output - from datetime import datetime - - fmt = " {active} {id:<12s} {pid:<8s} {node:<12s} {address:<24s} {workdir}" - print(fmt.format( - active='', id='ID', pid='PID', node='NAME', - address='ADDRESS', workdir='WORKING DIR')) - print(" " + "-" * 90) - - for inst in instances: - short_id = inst['id'][:8] - is_active = '*' if active_id and inst['id'].startswith(active_id) else ' ' - node_name = inst.get('node_name', '') or '-' - address = inst.get('address', '?') - - print(fmt.format( - active=is_active, - id=short_id, - pid=str(inst.get('pid', '?')), - node=node_name[:12], - address=address[:24], - workdir=inst.get('working_dir', '?'), - )) - - -def cmd_switch_vmm(args): - """Switch the active VMM instance.""" - target = args.vmm_id - instances = discover_vmm_instances() - - if not instances: - print("No running VMM instances found.") - return - - # Find matching instance by prefix - matches = [i for i in instances if i['id'].startswith(target)] - if len(matches) == 0: - print(f"No VMM instance matching '{target}'.") - print("Available instances:") - for inst in instances: - print(f" {inst['id'][:8]} {inst.get('address', '?')} {inst.get('working_dir', '?')}") - return - if len(matches) > 1: - print(f"Ambiguous ID '{target}', matches multiple instances:") - for inst in matches: - print(f" {inst['id'][:8]} {inst.get('address', '?')}") - return - - selected = matches[0] - save_active_vmm(selected['id']) - url = vmm_address_to_url(selected) - print(f"Switched to VMM {selected['id'][:8]} ({url})") +# Default whitelist file location +DEFAULT_KMS_WHITELIST_PATH = os.path.expanduser("~/.dstack-vmm/kms-whitelist.json") def encrypt_env(envs, hex_public_key: str) -> str: @@ -264,8 +76,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: ephemeral_public_key = ephemeral_private_key.public_key() # Compute the shared secret using X25519. - peer_public_key = x25519.X25519PublicKey.from_public_bytes( - remote_pubkey_bytes) + peer_public_key = x25519.X25519PublicKey.from_public_bytes(remote_pubkey_bytes) shared = ephemeral_private_key.exchange(peer_public_key) # Use the shared secret as a key for AES-GCM encryption (AES-256 needs 32 bytes). @@ -275,8 +86,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: # Serialize the ephemeral public key to raw bytes. ephemeral_public_bytes = ephemeral_public_key.public_bytes( - encoding=serialization.Encoding.Raw, - format=serialization.PublicFormat.Raw + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw ) # Combine ephemeral public key, IV, and ciphertext. @@ -288,34 +98,35 @@ def encrypt_env(envs, hex_public_key: str) -> str: def parse_port_mapping(port_str: str) -> Dict: """Parse a port mapping string into a dictionary""" - parts = port_str.split(':') + parts = port_str.split(":") if len(parts) == 3: return { "protocol": parts[0], "host_address": "127.0.0.1", "host_port": int(parts[1]), - "vm_port": int(parts[2]) + "vm_port": int(parts[2]), } elif len(parts) == 4: return { "protocol": parts[0], "host_address": parts[1], "host_port": int(parts[2]), - "vm_port": int(parts[3]) + "vm_port": int(parts[3]), } else: - raise argparse.ArgumentTypeError( - f"Invalid port mapping format: {port_str}") + raise argparse.ArgumentTypeError(f"Invalid port mapping format: {port_str}") + def read_utf8(filepath: str) -> str: - with open(filepath, 'rb') as f: - return f.read().decode('utf-8') + with open(filepath, "rb") as f: + return f.read().decode("utf-8") + class UnixSocketHTTPConnection(http.client.HTTPConnection): """HTTPConnection that connects to a Unix domain socket.""" def __init__(self, socket_path, timeout=None): - super().__init__('localhost', timeout=timeout) + super().__init__("localhost", timeout=timeout) self.socket_path = socket_path def connect(self): @@ -329,9 +140,14 @@ def connect(self): class VmmClient: """A unified HTTP client that supports both regular HTTP and Unix Domain Sockets.""" - def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password: Optional[str] = None): - self.base_url = base_url.rstrip('/') - self.use_uds = self.base_url.startswith('unix:') + def __init__( + self, + base_url: str, + auth_user: Optional[str] = None, + auth_password: Optional[str] = None, + ): + self.base_url = base_url.rstrip("/") + self.use_uds = self.base_url.startswith("unix:") self.auth_user = auth_user self.auth_password = auth_password @@ -341,10 +157,16 @@ def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password # Parse the base URL for regular HTTP connections self.parsed_url = urllib.parse.urlparse(self.base_url) self.host = self.parsed_url.netloc - self.is_https = self.parsed_url.scheme == 'https' + self.is_https = self.parsed_url.scheme == "https" - def request(self, method: str, path: str, headers: Dict[str, str] = None, - body: Any = None, stream: bool = False) -> Tuple[int, Union[Dict, str, BinaryIO]]: + def request( + self, + method: str, + path: str, + headers: Dict[str, str] = None, + body: Any = None, + stream: bool = False, + ) -> Tuple[int, Union[Dict, str, BinaryIO]]: """ Make an HTTP request to the server. @@ -364,15 +186,16 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, # Add Basic Authentication header if credentials are provided if self.auth_user and self.auth_password: credentials = f"{self.auth_user}:{self.auth_password}" - encoded_credentials = base64.b64encode( - credentials.encode('utf-8')).decode('ascii') - headers['Authorization'] = f'Basic {encoded_credentials}' + encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode( + "ascii" + ) + headers["Authorization"] = f"Basic {encoded_credentials}" # Prepare the body if isinstance(body, dict): - body = json.dumps(body).encode('utf-8') - if 'Content-Type' not in headers: - headers['Content-Type'] = 'application/json' + body = json.dumps(body).encode("utf-8") + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json" # Create the appropriate connection if self.use_uds: @@ -401,15 +224,19 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, data = response.read() # Try to parse as JSON if it looks like JSON - content_type = response.getheader('Content-Type', '') - if 'application/json' in content_type or data.startswith(b'{') or data.startswith(b'['): + content_type = response.getheader("Content-Type", "") + if ( + "application/json" in content_type + or data.startswith(b"{") + or data.startswith(b"[") + ): try: - return status, json.loads(data.decode('utf-8')) + return status, json.loads(data.decode("utf-8")) except json.JSONDecodeError: pass # Return as string if not JSON - return status, data.decode('utf-8') + return status, data.decode("utf-8") except Exception as e: if not stream: conn.close() @@ -419,18 +246,22 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, class VmmCLI: - def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password: Optional[str] = None): - self.base_url = base_url.rstrip('/') - self.headers = { - 'Content-Type': 'application/json' - } + def __init__( + self, + base_url: str, + auth_user: Optional[str] = None, + auth_password: Optional[str] = None, + ): + self.base_url = base_url.rstrip("/") + self.headers = {"Content-Type": "application/json"} self.client = VmmClient(base_url, auth_user, auth_password) def rpc_call(self, method: str, params: Optional[Dict] = None) -> Dict: """Make an RPC call to the dstack-vmm API""" path = f"/prpc/{method}?json" status, response = self.client.request( - 'POST', path, headers=self.headers, body=params or {}) + "POST", path, headers=self.headers, body=params or {} + ) if status != 200: if isinstance(response, str): @@ -443,8 +274,8 @@ def rpc_call(self, method: str, params: Optional[Dict] = None) -> Dict: def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: """List all VMs and their status""" - response = self.rpc_call('Status') - vms = response['vms'] + response = self.rpc_call("Status") + vms = response["vms"] if json_output: # Return raw JSON data for automation/testing @@ -455,31 +286,32 @@ def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: print("No VMs found") return - headers = ['VM ID', 'App ID', 'Name', 'Status', 'Uptime'] + headers = ["VM ID", "App ID", "Name", "Status", "Uptime"] if verbose: - headers.extend(['Instance ID', 'vCPU', 'Memory', 'Disk', 'Image', 'GPUs']) + headers.extend(["vCPU", "Memory", "Disk", "Image", "GPUs"]) rows = [] for vm in vms: row = [ - vm['id'], - vm['app_id'], - vm['name'], - vm['status'], - vm.get('uptime', '-') + vm["id"], + vm["app_id"], + vm["name"], + vm["status"], + vm.get("uptime", "-"), ] if verbose: - config = vm.get('configuration', {}) - gpu_info = self._format_gpu_info(config.get('gpus')) - row.extend([ - vm.get('instance_id', '-') or '-', - config.get('vcpu', '-'), - f"{config.get('memory', '-')}MB", - f"{config.get('disk_size', '-')}GB", - config.get('image', '-'), - gpu_info - ]) + config = vm.get("configuration", {}) + gpu_info = self._format_gpu_info(config.get("gpus")) + row.extend( + [ + config.get("vcpu", "-"), + f"{config.get('memory', '-')}MB", + f"{config.get('disk_size', '-')}GB", + config.get("image", "-"), + gpu_info, + ] + ) rows.append(row) @@ -488,36 +320,36 @@ def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: def _format_gpu_info(self, gpu_config): """Format GPU configuration for display""" if not gpu_config: - return '-' + return "-" - attach_mode = gpu_config.get('attach_mode', '') - gpus = gpu_config.get('gpus', []) + attach_mode = gpu_config.get("attach_mode", "") + gpus = gpu_config.get("gpus", []) - if attach_mode == 'all': - return 'All GPUs' - elif attach_mode == 'listed' and gpus: - gpu_slots = [gpu.get('slot', 'Unknown') for gpu in gpus] - return ', '.join(gpu_slots) + if attach_mode == "all": + return "All GPUs" + elif attach_mode == "listed" and gpus: + gpu_slots = [gpu.get("slot", "Unknown") for gpu in gpus] + return ", ".join(gpu_slots) else: - return '-' + return "-" def start_vm(self, vm_id: str) -> None: """Start a VM""" - self.rpc_call('StartVm', {'id': vm_id}) + self.rpc_call("StartVm", {"id": vm_id}) print(f"Started VM {vm_id}") def stop_vm(self, vm_id: str, force: bool = False) -> None: """Stop a VM""" if force: - self.rpc_call('StopVm', {'id': vm_id}) + self.rpc_call("StopVm", {"id": vm_id}) print(f"Forcefully stopped VM {vm_id}") else: - self.rpc_call('ShutdownVm', {'id': vm_id}) + self.rpc_call("ShutdownVm", {"id": vm_id}) print(f"Gracefully shutting down VM {vm_id}") def remove_vm(self, vm_id: str) -> None: """Remove a VM""" - self.rpc_call('RemoveVm', {'id': vm_id}) + self.rpc_call("RemoveVm", {"id": vm_id}) print(f"Removed VM {vm_id}") def resize_vm( @@ -552,7 +384,8 @@ def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: path = f"/logs?id={vm_id}&follow={str(follow).lower()}&ansi=false&lines={lines}" status, response = self.client.request( - 'GET', path, headers=self.headers, stream=follow) + "GET", path, headers=self.headers, stream=follow + ) if status != 200: if isinstance(response, str): @@ -569,7 +402,7 @@ def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: line = response.readline() if not line: break - print(line.decode('utf-8').rstrip()) + print(line.decode("utf-8").rstrip()) except KeyboardInterrupt: # Allow clean exit with Ctrl+C return @@ -582,8 +415,8 @@ def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: def list_images(self, json_output: bool = False) -> None: """Get list of available images""" - response = self.rpc_call('ListImages') - images = response['images'] + response = self.rpc_call("ListImages") + images = response["images"] if json_output: # Return raw JSON data for automation/testing @@ -594,74 +427,64 @@ def list_images(self, json_output: bool = False) -> None: print("No images found") return - headers = ['Name', 'Version'] - rows = [[img['name'], img.get('version', '-')] for img in images] + headers = ["Name", "Version"] + rows = [[img["name"], img.get("version", "-")] for img in images] print(format_table(rows, headers)) - def get_app_env_encrypt_pub_key(self, app_id: str, kms_url: Optional[str] = None) -> Dict: + def get_app_env_encrypt_pub_key( + self, app_id: str, kms_url: Optional[str] = None + ) -> Dict: """Get the encryption public key for the specified application ID""" if kms_url: client = VmmClient(kms_url) - path = f"/prpc/GetAppEnvEncryptPubKey?json" + path = "/prpc/GetAppEnvEncryptPubKey?json" status, response = client.request( - 'POST', path, headers={ - 'Content-Type': 'application/json' - }, body={'app_id': app_id}) + "POST", + path, + headers={"Content-Type": "application/json"}, + body={"app_id": app_id}, + ) print(f"Getting encryption public key for {app_id} from {kms_url}") else: - response = self.rpc_call( - 'GetAppEnvEncryptPubKey', {'app_id': app_id}) + response = self.rpc_call("GetAppEnvEncryptPubKey", {"app_id": app_id}) # Verify the signature if available - if 'signature' not in response and 'signature_v1' not in response: + if "signature" not in response: if not self.confirm_untrusted_signer("none"): raise Exception("Aborted due to invalid signature") - return response['public_key'] - - public_key = bytes.fromhex(response['public_key']) - - # Prefer signature_v1 (with timestamp) if available - signer_pubkey = None - if 'signature_v1' in response and 'timestamp' in response: - signature_v1 = bytes.fromhex(response['signature_v1']) - timestamp = response['timestamp'] - signer_pubkey = verify_signature_v1(public_key, signature_v1, app_id, timestamp) - if signer_pubkey: - print(f"Verified signature_v1 (with timestamp) from: {signer_pubkey}") - - # Fall back to legacy signature if signature_v1 verification failed or not available - if not signer_pubkey and 'signature' in response: - print("WARNING: Using legacy signature without timestamp protection. " - "Consider upgrading your KMS to support signature_v1.", file=sys.stderr) - signature = bytes.fromhex(response['signature']) - signer_pubkey = verify_signature(public_key, signature, app_id) - if signer_pubkey: - print(f"Verified legacy signature from: {signer_pubkey}") + return response["public_key"] + + public_key = bytes.fromhex(response["public_key"]) + signature = bytes.fromhex(response["signature"]) + signer_pubkey = verify_signature(public_key, signature, app_id) if signer_pubkey: whitelist = load_whitelist() if whitelist and signer_pubkey not in whitelist: print( - f"WARNING: Signer {signer_pubkey} is not in the trusted whitelist!") + f"WARNING: Signer {signer_pubkey} is not in the trusted whitelist!" + ) if not self.confirm_untrusted_signer(signer_pubkey): raise Exception("Aborted due to untrusted signer") + else: + print(f"Verified signature from: {signer_pubkey}") else: print("WARNING: Could not verify signature!") if not self.confirm_untrusted_signer("unknown"): raise Exception("Aborted due to invalid signature") - return response['public_key'] + return response["public_key"] def confirm_untrusted_signer(self, signer: str) -> bool: """Ask user to confirm using an untrusted signer""" response = input(f"Continue with untrusted signer {signer}? (y/N): ") - return response.lower() in ('y', 'yes') + return response.lower() in ("y", "yes") def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: """Manage the whitelist of trusted signers""" whitelist = load_whitelist() - if action == 'list': + if action == "list": if not whitelist: print("Whitelist is empty") else: @@ -671,7 +494,7 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: return # Normalize pubkey format - trim 0x prefix if present - if pubkey and pubkey.startswith('0x'): + if pubkey and pubkey.startswith("0x"): pubkey = pubkey[2:] # Convert to bytes for validation try: @@ -682,7 +505,7 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: raise Exception(f"Invalid public key length: {len(pubkey)}") pubkey = pubkey.hex() - if action == 'add': + if action == "add": if pubkey in whitelist: print(f"Public key {pubkey} is already in the whitelist") else: @@ -690,7 +513,7 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: save_whitelist(whitelist) print(f"Added {pubkey} to the whitelist") - elif action == 'remove': + elif action == "remove": if pubkey not in whitelist: print(f"Public key {pubkey} is not in the whitelist") else: @@ -709,16 +532,13 @@ def calc_app_id(self, compose_file: str) -> str: def create_app_compose(self, args) -> None: """Create a new app compose file""" envs = parse_env_file(args.env_file) or {} - - # Validate: --env-file requires --kms - if envs and not args.kms: - raise Exception("--env-file requires --kms to enable KMS for environment variable decryption") - app_compose = { "manifest_version": 2, "name": args.name, "runner": "docker-compose", - "docker_compose_file": open(args.docker_compose, 'rb').read().decode('utf-8'), + "docker_compose_file": open(args.docker_compose, "rb") + .read() + .decode("utf-8"), "kms_enabled": args.kms, "gateway_enabled": args.gateway, "local_key_provider_enabled": args.local_key_provider, @@ -729,11 +549,10 @@ def create_app_compose(self, args) -> None: "no_instance_id": args.no_instance_id, "secure_time": args.secure_time, } - if args.key_provider: - app_compose["key_provider"] = args.key_provider if args.prelaunch_script: - app_compose["pre_launch_script"] = open( - args.prelaunch_script, 'rb').read().decode('utf-8') + app_compose["pre_launch_script"] = ( + open(args.prelaunch_script, "rb").read().decode("utf-8") + ) if args.swap is not None: swap_bytes = max(0, int(round(args.swap)) * 1024 * 1024) if swap_bytes > 0: @@ -741,10 +560,11 @@ def create_app_compose(self, args) -> None: else: app_compose.pop("swap_size", None) - compose_file = json.dumps( - app_compose, indent=4, ensure_ascii=False).encode('utf-8') + compose_file = json.dumps(app_compose, indent=4, ensure_ascii=False).encode( + "utf-8" + ) compose_hash = hashlib.sha256(compose_file).hexdigest() - with open(args.output, 'wb') as f: + with open(args.output, "wb") as f: f.write(compose_file) print(f"App compose file created at: {args.output}") print(f"Compose hash: {compose_hash}") @@ -759,17 +579,6 @@ def create_vm(self, args) -> None: envs = parse_env_file(args.env_file) - # Validate: --env-file requires --kms-url and kms_enabled in compose - if envs: - if not args.kms_url: - raise Exception("--env-file requires --kms-url to encrypt environment variables") - try: - compose_json = json.loads(compose_content) - if not compose_json.get('kms_enabled', False): - raise Exception("--env-file requires kms_enabled=true in the compose file (use --kms when creating compose)") - except json.JSONDecodeError: - pass # Let the server handle invalid JSON - # Read user config file if provided user_config = "" if args.user_config: @@ -796,62 +605,57 @@ def create_vm(self, args) -> None: if swap_bytes > 0: params["swap_size"] = swap_bytes - if args.ppcie or (args.gpu and "all" in args.gpu): - params["gpus"] = { - "attach_mode": "all" - } + if args.ppcie: + params["gpus"] = {"attach_mode": "all"} elif args.gpu: params["gpus"] = { "attach_mode": "listed", - "gpus": [{"slot": gpu} for gpu in args.gpu or []] + "gpus": [{"slot": gpu} for gpu in args.gpu or []], } if args.kms_url: params["kms_urls"] = args.kms_url if args.gateway_url: params["gateway_urls"] = args.gateway_url - if args.net: - params["networking"] = {"mode": args.net} app_id = args.app_id or self.calc_app_id(compose_content) print(f"App ID: {app_id}") if envs: encrypt_pubkey = self.get_app_env_encrypt_pub_key( - app_id, args.kms_url[0] if args.kms_url else None) - print( - f"Encrypting environment variables with key: {encrypt_pubkey}") + app_id, args.kms_url[0] if args.kms_url else None + ) + print(f"Encrypting environment variables with key: {encrypt_pubkey}") envs_list = [{"key": k, "value": v} for k, v in envs.items()] params["encrypted_env"] = encrypt_env(envs_list, encrypt_pubkey) - response = self.rpc_call('CreateVm', params) + response = self.rpc_call("CreateVm", params) print(f"Created VM with ID: {response.get('id')}") - return response.get('id') + return response.get("id") - def update_vm_env(self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[List[str]] = None) -> None: + def update_vm_env( + self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[List[str]] = None + ) -> None: """Update environment variables for a VM""" - # Validate: requires --kms-url - if not kms_urls: - raise Exception("--kms-url is required to encrypt environment variables") - envs = envs or {} # First get the VM info to retrieve the app_id - vm_info_response = self.rpc_call('GetInfo', {'id': vm_id}) + vm_info_response = self.rpc_call("GetInfo", {"id": vm_id}) - if not vm_info_response.get('found', False) or 'info' not in vm_info_response: + if not vm_info_response.get("found", False) or "info" not in vm_info_response: raise Exception(f"VM with ID {vm_id} not found") - app_id = vm_info_response['info']['app_id'] + app_id = vm_info_response["info"]["app_id"] print(f"Retrieved app ID: {app_id}") - vm_configuration = vm_info_response['info'].get('configuration') or {} - compose_file = vm_configuration.get('compose_file') + vm_configuration = vm_info_response["info"].get("configuration") or {} + compose_file = vm_configuration.get("compose_file") # Now get the encryption key for the app encrypt_pubkey = self.get_app_env_encrypt_pub_key( - app_id, kms_urls[0] if kms_urls else None) + app_id, kms_urls[0] if kms_urls else None + ) print(f"Encrypting environment variables with key: {encrypt_pubkey}") envs_list = [{"key": k, "value": v} for k, v in envs.items()] encrypted_env = encrypt_env(envs_list, encrypt_pubkey) # Use UpdateApp with the VM ID - payload = {'id': vm_id, 'encrypted_env': encrypted_env} + payload = {"id": vm_id, "encrypted_env": encrypted_env} if compose_file: try: @@ -860,42 +664,40 @@ def update_vm_env(self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[Lis app_compose = {} compose_changed = False allowed_envs = list(envs.keys()) - if app_compose.get('allowed_envs') != allowed_envs: - app_compose['allowed_envs'] = allowed_envs + if app_compose.get("allowed_envs") != allowed_envs: + app_compose["allowed_envs"] = allowed_envs compose_changed = True - launch_token_value = envs.get('APP_LAUNCH_TOKEN') + launch_token_value = envs.get("APP_LAUNCH_TOKEN") if launch_token_value is not None: launch_token_hash = hashlib.sha256( - launch_token_value.encode('utf-8') + launch_token_value.encode("utf-8") ).hexdigest() - if app_compose.get('launch_token_hash') != launch_token_hash: - app_compose['launch_token_hash'] = launch_token_hash + if app_compose.get("launch_token_hash") != launch_token_hash: + app_compose["launch_token_hash"] = launch_token_hash compose_changed = True if compose_changed: - payload['compose_file'] = json.dumps( - app_compose, indent=4, ensure_ascii=False) + payload["compose_file"] = json.dumps( + app_compose, indent=4, ensure_ascii=False + ) - self.rpc_call('UpgradeApp', payload) + self.rpc_call("UpgradeApp", payload) print(f"Environment variables updated for VM {vm_id}") def update_vm_user_config(self, vm_id: str, user_config: str) -> None: """Update user config for a VM""" - self.rpc_call('UpgradeApp', {'id': vm_id, - 'user_config': user_config}) + self.rpc_call("UpgradeApp", {"id": vm_id, "user_config": user_config}) print(f"User config updated for VM {vm_id}") def update_vm_app_compose(self, vm_id: str, app_compose: str) -> None: """Update app compose for a VM""" - self.rpc_call('UpgradeApp', {'id': vm_id, - 'compose_file': app_compose}) + self.rpc_call("UpgradeApp", {"id": vm_id, "compose_file": app_compose}) print(f"App compose updated for VM {vm_id}") - + def update_vm_ports(self, vm_id: str, ports: List[str]) -> None: """Update port mapping for a VM""" port_mappings = [parse_port_mapping(port) for port in ports] self.rpc_call( - "UpgradeApp", {"id": vm_id, - "update_ports": True, "ports": port_mappings} + "UpgradeApp", {"id": vm_id, "update_ports": True, "ports": port_mappings} ) print(f"Port mapping updated for VM {vm_id}") @@ -920,10 +722,6 @@ def update_vm( no_tee: Optional[bool] = None, ) -> None: """Update multiple aspects of a VM in one command""" - # Validate: --env-file requires --kms-url - if env_file and not kms_urls: - raise Exception("--env-file requires --kms-url to encrypt environment variables") - updates = [] # handle resize operations (vcpu, memory, disk, image) @@ -949,56 +747,68 @@ def update_vm( upgrade_params = {"id": vm_id} # handle compose file updates (docker-compose, prelaunch script, swap) - needs_compose_update = docker_compose_content or prelaunch_script is not None or swap_size is not None + needs_compose_update = ( + docker_compose_content + or prelaunch_script is not None + or swap_size is not None + ) vm_info_response = None if needs_compose_update or env_file: - vm_info_response = self.rpc_call('GetInfo', {'id': vm_id}) - if not vm_info_response.get('found', False) or 'info' not in vm_info_response: + vm_info_response = self.rpc_call("GetInfo", {"id": vm_id}) + if ( + not vm_info_response.get("found", False) + or "info" not in vm_info_response + ): raise Exception(f"VM with ID {vm_id} not found") if needs_compose_update: - vm_configuration = vm_info_response['info'].get('configuration') or {} - compose_file_content = vm_configuration.get('compose_file') + vm_configuration = vm_info_response["info"].get("configuration") or {} + compose_file_content = vm_configuration.get("compose_file") try: - app_compose = json.loads(compose_file_content) if compose_file_content else {} + app_compose = ( + json.loads(compose_file_content) if compose_file_content else {} + ) except json.JSONDecodeError: app_compose = {} if docker_compose_content: - app_compose['docker_compose_file'] = docker_compose_content + app_compose["docker_compose_file"] = docker_compose_content updates.append("docker compose") if prelaunch_script is not None: script_stripped = prelaunch_script.strip() if script_stripped: - app_compose['pre_launch_script'] = script_stripped + app_compose["pre_launch_script"] = script_stripped updates.append("prelaunch script") - elif 'pre_launch_script' in app_compose: - del app_compose['pre_launch_script'] + elif "pre_launch_script" in app_compose: + del app_compose["pre_launch_script"] updates.append("prelaunch script (removed)") if swap_size is not None: swap_bytes = max(0, int(round(swap_size)) * 1024 * 1024) if swap_bytes > 0: - app_compose['swap_size'] = swap_bytes + app_compose["swap_size"] = swap_bytes updates.append(f"swap: {swap_size}MB") - elif 'swap_size' in app_compose: - del app_compose['swap_size'] + elif "swap_size" in app_compose: + del app_compose["swap_size"] updates.append("swap (disabled)") - upgrade_params['compose_file'] = json.dumps(app_compose, indent=4, ensure_ascii=False) + upgrade_params["compose_file"] = json.dumps( + app_compose, indent=4, ensure_ascii=False + ) if env_file: envs = parse_env_file(env_file) if envs: - app_id = vm_info_response['info']['app_id'] - vm_configuration = vm_info_response['info'].get('configuration') or {} - compose_file_content = vm_configuration.get('compose_file') + app_id = vm_info_response["info"]["app_id"] + vm_configuration = vm_info_response["info"].get("configuration") or {} + compose_file_content = vm_configuration.get("compose_file") encrypt_pubkey = self.get_app_env_encrypt_pub_key( - app_id, kms_urls[0] if kms_urls else None) + app_id, kms_urls[0] if kms_urls else None + ) envs_list = [{"key": k, "value": v} for k, v in envs.items()] upgrade_params["encrypted_env"] = encrypt_env(envs_list, encrypt_pubkey) updates.append("environment variables") @@ -1011,20 +821,21 @@ def update_vm( app_compose = {} compose_changed = False allowed_envs = list(envs.keys()) - if app_compose.get('allowed_envs') != allowed_envs: - app_compose['allowed_envs'] = allowed_envs + if app_compose.get("allowed_envs") != allowed_envs: + app_compose["allowed_envs"] = allowed_envs compose_changed = True - launch_token_value = envs.get('APP_LAUNCH_TOKEN') + launch_token_value = envs.get("APP_LAUNCH_TOKEN") if launch_token_value is not None: launch_token_hash = hashlib.sha256( - launch_token_value.encode('utf-8') + launch_token_value.encode("utf-8") ).hexdigest() - if app_compose.get('launch_token_hash') != launch_token_hash: - app_compose['launch_token_hash'] = launch_token_hash + if app_compose.get("launch_token_hash") != launch_token_hash: + app_compose["launch_token_hash"] = launch_token_hash compose_changed = True if compose_changed: - upgrade_params['compose_file'] = json.dumps( - app_compose, indent=4, ensure_ascii=False) + upgrade_params["compose_file"] = json.dumps( + app_compose, indent=4, ensure_ascii=False + ) if user_config: upgrade_params["user_config"] = user_config @@ -1046,29 +857,22 @@ def update_vm( upgrade_params["ports"] = port_mappings # handle GPU updates - only update if one of the GPU flags is set - gpu_all = gpu_slots and "all" in gpu_slots - if attach_all or gpu_all or no_gpus or gpu_slots is not None: - if attach_all or gpu_all: + if attach_all or no_gpus or gpu_slots is not None: + if attach_all: gpu_config = {"attach_mode": "all"} updates.append("GPUs (all)") elif no_gpus: - gpu_config = { - "attach_mode": "listed", - "gpus": [] - } + gpu_config = {"attach_mode": "listed", "gpus": []} updates.append("GPUs (detached)") elif gpu_slots: gpu_config = { "attach_mode": "listed", - "gpus": [{"slot": gpu} for gpu in gpu_slots] + "gpus": [{"slot": gpu} for gpu in gpu_slots], } updates.append(f"GPUs ({len(gpu_slots)} devices)") else: # gpu_slots is an empty list ([] not None) - shouldn't happen with mutually exclusive group - gpu_config = { - "attach_mode": "listed", - "gpus": [] - } + gpu_config = {"attach_mode": "listed", "gpus": []} updates.append("GPUs (none)") upgrade_params["gpus"] = gpu_config @@ -1084,54 +888,10 @@ def update_vm( else: print(f"No updates specified for VM {vm_id}") - def show_info(self, vm_id: str, json_output: bool = False) -> None: - """Show detailed information about a VM""" - response = self.rpc_call('GetInfo', {'id': vm_id}) - - if not response.get('found', False) or 'info' not in response: - print(f"VM with ID {vm_id} not found") - return - - info = response['info'] - - if json_output: - print(json.dumps(info, indent=2)) - return - - config = info.get('configuration', {}) - - print(f"VM ID: {info.get('id', '-')}") - print(f"Name: {info.get('name', '-')}") - print(f"Status: {info.get('status', '-')}") - print(f"Uptime: {info.get('uptime', '-')}") - print(f"App ID: {info.get('app_id', '-')}") - print(f"Instance ID: {info.get('instance_id', '-') or '-'}") - print(f"App URL: {info.get('app_url', '-') or '-'}") - print(f"Image: {config.get('image', '-')}") - print(f"Image Version: {info.get('image_version', '-')}") - print(f"vCPU: {config.get('vcpu', '-')}") - print(f"Memory: {config.get('memory', '-')}MB") - print(f"Disk: {config.get('disk_size', '-')}GB") - print(f"GPUs: {self._format_gpu_info(config.get('gpus'))}") - print(f"Boot Progress: {info.get('boot_progress', '-')}") - if info.get('boot_error'): - print(f"Boot Error: {info['boot_error']}") - if info.get('exited_at'): - print(f"Exited At: {info['exited_at']}") - if info.get('shutdown_progress'): - print(f"Shutdown: {info['shutdown_progress']}") - - events = info.get('events', []) - if events: - print(f"\nRecent Events:") - for event in events[-10:]: - ts = event.get('timestamp', 0) - print(f" [{event.get('event', '')}] {event.get('body', '')} (ts: {ts})") - def list_gpus(self, json_output: bool = False) -> None: """List all available GPUs""" - response = self.rpc_call('ListGpus') - gpus = response.get('gpus', []) + response = self.rpc_call("ListGpus") + gpus = response.get("gpus", []) if json_output: # Return raw JSON data for automation/testing @@ -1142,14 +902,14 @@ def list_gpus(self, json_output: bool = False) -> None: print("No GPUs found") return - headers = ['Slot', 'Product ID', 'Description', 'Available'] + headers = ["Slot", "Product ID", "Description", "Available"] rows = [] for gpu in gpus: row = [ - gpu.get('slot', '-'), - gpu.get('product_id', '-'), - gpu.get('description', '-'), - 'Yes' if gpu.get('is_free', False) else 'No' + gpu.get("slot", "-"), + gpu.get("product_id", "-"), + gpu.get("description", "-"), + "Yes" if gpu.get("is_free", False) else "No", ] rows.append(row) @@ -1174,11 +934,7 @@ def format_table(rows, headers): bottom_border = "└─" + "─┴─".join("─" * w for w in widths) + "─┘" # Build table - table = [ - top_border, - row_format.format(*headers), - separator - ] + table = [top_border, row_format.format(*headers), separator] for row in rows: table.append(row_format.format(*[str(cell) for cell in row])) table.append(bottom_border) @@ -1197,14 +953,14 @@ def parse_env_file(file_path: str) -> Dict[str, str]: return {} envs = {} - with open(file_path, 'r') as f: + with open(file_path, "r") as f: for line in f: line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue - if '=' not in line: + if "=" not in line: continue - key, value = line.split('=', 1) + key, value = line.split("=", 1) envs[key.strip()] = value.strip() return envs @@ -1226,7 +982,7 @@ def parse_size(s: str, target_unit: str) -> int: out to be fractional. """ s = s.strip() - m = re.fullmatch(r'(\d+(?:\.\d+)?)([a-zA-Z]{1,2})?', s) + m = re.fullmatch(r"(\d+(?:\.\d+)?)([a-zA-Z]{1,2})?", s) if not m: raise argparse.ArgumentTypeError(f"Invalid size format: '{s}'") number_str, unit = m.groups() @@ -1251,7 +1007,8 @@ def parse_size(s: str, target_unit: str) -> int: factor = 1024 * 1024 else: raise argparse.ArgumentTypeError( - f"Invalid size unit '{unit}' for memory. Use M, G, or T.") + f"Invalid size unit '{unit}' for memory. Use M, G, or T." + ) elif target_unit == "GB": # For disk, if no suffix is provided we assume GB. if unit in ["G", "GB"]: @@ -1260,14 +1017,16 @@ def parse_size(s: str, target_unit: str) -> int: factor = 1024 else: raise argparse.ArgumentTypeError( - f"Invalid size unit '{unit}' for disk. Use G, T.") + f"Invalid size unit '{unit}' for disk. Use G, T." + ) else: raise ValueError("Unsupported target unit") value = number * factor if not value.is_integer(): raise argparse.ArgumentTypeError( - f"Size must be an integer number of {target_unit}. Got {value}.") + f"Size must be an integer number of {target_unit}. Got {value}." + ) return int(value) @@ -1281,52 +1040,9 @@ def parse_disk_size(s: str) -> int: return parse_size(s, "GB") -def verify_signature_v1(public_key: bytes, signature: bytes, app_id: str, timestamp: int) -> Optional[str]: - """ - Verify the v1 signature (with timestamp) of a public key. - - Args: - public_key: The public key bytes to verify - signature: The signature bytes (65 bytes) - app_id: The application ID - timestamp: Unix timestamp in seconds when the response was generated - - Returns: - The compressed public key if valid, None otherwise - """ - if not CRYPTO_AVAILABLE: - raise ImportError( - "Cryptography libraries not available. Please install them with:\n" - "pip install cryptography eth-keys eth-utils" - ) - - if len(signature) != 65: - return None - - # Create the message to verify - # Signs: Keccak256("dstack-env-encrypt-pubkey" + ":" + app_id + timestamp_be_bytes + public_key) - prefix = b"dstack-env-encrypt-pubkey" - if app_id.startswith("0x"): - app_id = app_id[2:] - timestamp_bytes = timestamp.to_bytes(8, byteorder='big') - message = prefix + b":" + bytes.fromhex(app_id) + timestamp_bytes + public_key - - # Hash the message with Keccak-256 - message_hash = keccak(message) - - # Recover the public key from the signature - try: - sig = keys.Signature(signature_bytes=signature) - recovered_key = sig.recover_public_key_from_msg_hash(message_hash) - return '0x' + recovered_key.to_compressed_bytes().hex() - except Exception as e: - print(f"Signature v1 verification failed: {e}", file=sys.stderr) - return None - - def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Optional[str]: """ - Verify the legacy signature (without timestamp) of a public key. + Verify the signature of a public key. Args: public_key: The public key bytes to verify @@ -1354,7 +1070,6 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option return None # Create the message to verify - # Signs: Keccak256("dstack-env-encrypt-pubkey" + ":" + app_id + public_key) prefix = b"dstack-env-encrypt-pubkey" if app_id.startswith("0x"): app_id = app_id[2:] @@ -1369,7 +1084,7 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option sig = keys.Signature(signature_bytes=signature) recovered_key = sig.recover_public_key_from_msg_hash(message_hash) - return '0x' + recovered_key.to_compressed_bytes().hex() + return "0x" + recovered_key.to_compressed_bytes().hex() except Exception as e: print(f"Signature verification failed: {e}", file=sys.stderr) return None @@ -1387,9 +1102,9 @@ def load_whitelist() -> List[str]: return [] try: - with open(DEFAULT_KMS_WHITELIST_PATH, 'r') as f: + with open(DEFAULT_KMS_WHITELIST_PATH, "r") as f: data = json.load(f) - return data.get('trusted_signers', []) + return data.get("trusted_signers", []) except (json.JSONDecodeError, FileNotFoundError): return [] @@ -1402,81 +1117,59 @@ def save_whitelist(whitelist: List[str]) -> None: whitelist: List of trusted Ethereum addresses """ os.makedirs(os.path.dirname(DEFAULT_KMS_WHITELIST_PATH), exist_ok=True) - with open(DEFAULT_KMS_WHITELIST_PATH, 'w') as f: - json.dump({'trusted_signers': whitelist}, f, indent=2) + with open(DEFAULT_KMS_WHITELIST_PATH, "w") as f: + json.dump({"trusted_signers": whitelist}, f, indent=2) def main(): - parser = argparse.ArgumentParser(description='dstack-vmm CLI - Manage VMs') - - # Load config file defaults - config = load_config() - - # Discover running VMM instances - instances = discover_vmm_instances() + parser = argparse.ArgumentParser(description="dstack-vmm CLI - Manage VMs") - # Priority: command line > environment variable > config file > auto-discovery > default - default_url = os.environ.get( - 'DSTACK_VMM_URL', - config.get('url', 'http://localhost:8080')) - default_auth_user = os.environ.get( - 'DSTACK_VMM_AUTH_USER', - config.get('auth_user')) - default_auth_password = os.environ.get( - 'DSTACK_VMM_AUTH_PASSWORD', - config.get('auth_password')) + # Get default URL from environment variable or use localhost + default_url = os.environ.get("DSTACK_VMM_URL", "http://localhost:8080") parser.add_argument( - '--url', default=default_url, - help='dstack-vmm API URL (can also be set via DSTACK_VMM_URL env var or config file)') + "--url", + default=default_url, + help="dstack-vmm API URL (can also be set via DSTACK_VMM_URL env var)", + ) # Basic authentication arguments parser.add_argument( - '--auth-user', default=default_auth_user, - help='Basic auth username (can also be set via DSTACK_VMM_AUTH_USER env var or config file)') + "--auth-user", + default=os.environ.get("DSTACK_VMM_AUTH_USER"), + help="Basic auth username (can also be set via DSTACK_VMM_AUTH_USER env var)", + ) parser.add_argument( - '--auth-password', default=default_auth_password, - help='Basic auth password (can also be set via DSTACK_VMM_AUTH_PASSWORD env var or config file)') - - subparsers = parser.add_subparsers(dest='command', help='Commands') - - # VMM discovery commands - ls_vmm_parser = subparsers.add_parser( - 'ls-vmm', help='List all running VMM instances on this host') - ls_vmm_parser.add_argument( - '--json', action='store_true', help='Output in JSON format') + "--auth-password", + default=os.environ.get("DSTACK_VMM_AUTH_PASSWORD"), + help="Basic auth password (can also be set via DSTACK_VMM_AUTH_PASSWORD env var)", + ) - switch_vmm_parser = subparsers.add_parser( - 'switch-vmm', help='Switch active VMM instance') - switch_vmm_parser.add_argument( - 'vmm_id', help='VMM instance ID (prefix match supported)') + subparsers = parser.add_subparsers(dest="command", help="Commands") # List command - lsvm_parser = subparsers.add_parser('lsvm', help='List VMs') + lsvm_parser = subparsers.add_parser("lsvm", help="List VMs") lsvm_parser.add_argument( - '-v', '--verbose', action='store_true', help='Show detailed information') + "-v", "--verbose", action="store_true", help="Show detailed information" + ) lsvm_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') - - # Info command - info_parser = subparsers.add_parser('info', help='Show detailed VM information') - info_parser.add_argument('vm_id', help='VM ID to show info for') - info_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') + "--json", action="store_true", help="Output in JSON format for automation" + ) # Start command - start_parser = subparsers.add_parser('start', help='Start a VM') - start_parser.add_argument('vm_id', help='VM ID to start') + start_parser = subparsers.add_parser("start", help="Start a VM") + start_parser.add_argument("vm_id", help="VM ID to start") # Stop command - stop_parser = subparsers.add_parser('stop', help='Stop a VM') - stop_parser.add_argument('vm_id', help='VM ID to stop') + stop_parser = subparsers.add_parser("stop", help="Stop a VM") + stop_parser.add_argument("vm_id", help="VM ID to stop") stop_parser.add_argument( - '-f', '--force', action='store_true', help='Force stop the VM') + "-f", "--force", action="store_true", help="Force stop the VM" + ) # Remove command - remove_parser = subparsers.add_parser('remove', help='Remove a VM') - remove_parser.add_argument('vm_id', help='VM ID to remove') + remove_parser = subparsers.add_parser("remove", help="Remove a VM") + remove_parser.add_argument("vm_id", help="VM ID to remove") # Resize command resize_parser = subparsers.add_parser("resize", help="Resize a VM") @@ -1491,147 +1184,199 @@ def main(): resize_parser.add_argument("--image", type=str, help="Image name") # Logs command - logs_parser = subparsers.add_parser('logs', help='Show VM logs') - logs_parser.add_argument('vm_id', help='VM ID to show logs for') - logs_parser.add_argument('-n', '--lines', type=int, - default=20, help='Number of lines to show') + logs_parser = subparsers.add_parser("logs", help="Show VM logs") + logs_parser.add_argument("vm_id", help="VM ID to show logs for") logs_parser.add_argument( - '-f', '--follow', action='store_true', help='Follow log output') + "-n", "--lines", type=int, default=20, help="Number of lines to show" + ) + logs_parser.add_argument( + "-f", "--follow", action="store_true", help="Follow log output" + ) # Compose command compose_parser = subparsers.add_parser( - 'compose', help='Create a new app-compose.json file') - compose_parser.add_argument('--name', required=True, help='VM image name') - compose_parser.add_argument( - '--docker-compose', required=True, help='Path to docker-compose.yml file') - compose_parser.add_argument( - '--prelaunch-script', default=None, help='Path to prelaunch script') + "compose", help="Create a new app-compose.json file" + ) + compose_parser.add_argument("--name", required=True, help="VM image name") compose_parser.add_argument( - '--kms', action='store_true', help='Enable KMS') + "--docker-compose", required=True, help="Path to docker-compose.yml file" + ) compose_parser.add_argument( - '--gateway', action='store_true', help='Enable dstack-gateway') + "--prelaunch-script", default=None, help="Path to prelaunch script" + ) + compose_parser.add_argument("--kms", action="store_true", help="Enable KMS") compose_parser.add_argument( - '--local-key-provider', action='store_true', help='Enable local key provider') + "--gateway", action="store_true", help="Enable dstack-gateway" + ) compose_parser.add_argument( - '--key-provider', choices=['none', 'kms', 'local'], default=None, - help='Override key provider type (none, kms, local)') + "--local-key-provider", action="store_true", help="Enable local key provider" + ) compose_parser.add_argument( - '--key-provider-id', default=None, help='Key provider ID if you want to bind to a specific key provider') + "--key-provider-id", + default=None, + help="Key provider ID if you want to bind to a specific key provider", + ) compose_parser.add_argument( - '--public-logs', action='store_true', help='Enable public logs') + "--public-logs", action="store_true", help="Enable public logs" + ) compose_parser.add_argument( - '--public-sysinfo', action='store_true', help='Enable public sysinfo') + "--public-sysinfo", action="store_true", help="Enable public sysinfo" + ) compose_parser.add_argument( - '--env-file', help='File with environment variables to encrypt', default=None) + "--env-file", help="File with environment variables to encrypt", default=None + ) compose_parser.add_argument( - '--no-instance-id', action='store_true', help='Disable instance ID') + "--no-instance-id", action="store_true", help="Disable instance ID" + ) compose_parser.add_argument( - '--secure-time', action='store_true', help='Enable secure time') + "--secure-time", action="store_true", help="Enable secure time" + ) compose_parser.add_argument( - '--swap', type=parse_memory_size, default=None, - help='Swap size (e.g. 4G). Set to 0 to disable') + "--swap", + type=parse_memory_size, + default=None, + help="Swap size (e.g. 4G). Set to 0 to disable", + ) compose_parser.add_argument( - '--output', required=True, help='Path to output app-compose.json file') + "--output", required=True, help="Path to output app-compose.json file" + ) # Deploy command - deploy_parser = subparsers.add_parser('deploy', help='Deploy a new VM') - deploy_parser.add_argument('--name', required=True, help='VM name') - deploy_parser.add_argument('--image', required=True, help='VM image') + deploy_parser = subparsers.add_parser("deploy", help="Deploy a new VM") + deploy_parser.add_argument("--name", required=True, help="VM name") + deploy_parser.add_argument("--image", required=True, help="VM image") + deploy_parser.add_argument( + "--compose", required=True, help="Path to app-compose.json file" + ) + deploy_parser.add_argument("--vcpu", type=int, default=1, help="Number of vCPUs") deploy_parser.add_argument( - '--compose', required=True, help='Path to app-compose.json file') + "--memory", + type=parse_memory_size, + default=1024, + help="Memory size (e.g. 1G, 100M)", + ) deploy_parser.add_argument( - '--vcpu', type=int, default=1, help='Number of vCPUs') + "--disk", type=parse_disk_size, default=20, help="Disk size (e.g. 1G, 100M)" + ) deploy_parser.add_argument( - '--memory', type=parse_memory_size, default=1024, help='Memory size (e.g. 1G, 100M)') + "--swap", + type=parse_memory_size, + default=None, + help="Swap size (e.g. 4G). Set to 0 to disable", + ) deploy_parser.add_argument( - '--disk', type=parse_disk_size, default=20, help='Disk size (e.g. 1G, 100M)') + "--env-file", help="File with environment variables to encrypt", default=None + ) deploy_parser.add_argument( - '--swap', type=parse_memory_size, default=None, - help='Swap size (e.g. 4G). Set to 0 to disable') + "--user-config", help="Path to user config file", default=None + ) + deploy_parser.add_argument("--app-id", help="Application ID", default=None) deploy_parser.add_argument( - '--env-file', help='File with environment variables to encrypt', default=None) + "--port", + action="append", + type=str, + help="Port mapping in format: protocol[:address]:from:to", + ) deploy_parser.add_argument( - '--user-config', help='Path to user config file', default=None) - deploy_parser.add_argument('--app-id', help='Application ID', default=None) - deploy_parser.add_argument('--port', action='append', type=str, - help='Port mapping in format: protocol[:address]:from:to') - deploy_parser.add_argument('--gpu', action='append', type=str, - help='GPU slot to attach (can be used multiple times), or "all" to attach all GPUs') - deploy_parser.add_argument('--ppcie', action='store_true', - help='Enable PPCIE (Protected PCIe) mode - attach all available GPUs') - deploy_parser.add_argument('--pin-numa', action='store_true', - help='Pin VM to specific NUMA node') - deploy_parser.add_argument('--hugepages', action='store_true', - help='Enable hugepages for the VM') - deploy_parser.add_argument('--kms-url', action='append', type=str, - help='KMS URL') - deploy_parser.add_argument('--gateway-url', action='append', type=str, - help='Gateway URL') - deploy_parser.add_argument('--stopped', action='store_true', - help='Create VM in stopped state (requires dstack-vmm >= 0.5.4)') - deploy_parser.add_argument('--no-tee', dest='no_tee', action='store_true', - help='Disable Intel TDX / run without TEE') - deploy_parser.add_argument('--tee', dest='no_tee', action='store_false', - help='Force-enable Intel TDX (default)') + "--gpu", + action="append", + type=str, + help="GPU slot to attach (can be used multiple times)", + ) + deploy_parser.add_argument( + "--ppcie", + action="store_true", + help="Enable PPCIE (Protected PCIe) mode - attach all available GPUs", + ) + deploy_parser.add_argument( + "--pin-numa", action="store_true", help="Pin VM to specific NUMA node" + ) + deploy_parser.add_argument( + "--hugepages", action="store_true", help="Enable hugepages for the VM" + ) + deploy_parser.add_argument("--kms-url", action="append", type=str, help="KMS URL") + deploy_parser.add_argument( + "--gateway-url", action="append", type=str, help="Gateway URL" + ) + deploy_parser.add_argument( + "--stopped", + action="store_true", + help="Create VM in stopped state (requires dstack-vmm >= 0.5.4)", + ) + deploy_parser.add_argument( + "--no-tee", + dest="no_tee", + action="store_true", + help="Disable Intel TDX / run without TEE", + ) + deploy_parser.add_argument( + "--tee", + dest="no_tee", + action="store_false", + help="Force-enable Intel TDX (default)", + ) deploy_parser.set_defaults(no_tee=False) - deploy_parser.add_argument('--net', choices=['bridge', 'user'], - help='Networking mode (default: use global config)') - # Images command - lsimage_parser = subparsers.add_parser( - 'lsimage', help='List available images') + lsimage_parser = subparsers.add_parser("lsimage", help="List available images") lsimage_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') + "--json", action="store_true", help="Output in JSON format for automation" + ) # GPU command - lsgpu_parser = subparsers.add_parser('lsgpu', help='List available GPUs') + lsgpu_parser = subparsers.add_parser("lsgpu", help="List available GPUs") lsgpu_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') + "--json", action="store_true", help="Output in JSON format for automation" + ) # Update environment variables command update_env_parser = subparsers.add_parser( - 'update-env', help='Update environment variables for a VM') - update_env_parser.add_argument('vm_id', help='VM ID to update') + "update-env", help="Update environment variables for a VM" + ) + update_env_parser.add_argument("vm_id", help="VM ID to update") update_env_parser.add_argument( - '--env-file', required=True, help='File with environment variables to encrypt') + "--env-file", required=True, help="File with environment variables to encrypt" + ) update_env_parser.add_argument( - '--kms-url', action='append', type=str, - help='KMS URL') + "--kms-url", action="append", type=str, help="KMS URL" + ) # Whitelist command - kms_parser = subparsers.add_parser( - 'kms', help='Manage trusted KMS whitelist') - kms_subparsers = kms_parser.add_subparsers( - dest='kms_action', help='KMS actions') + kms_parser = subparsers.add_parser("kms", help="Manage trusted KMS whitelist") + kms_subparsers = kms_parser.add_subparsers(dest="kms_action", help="KMS actions") # List whitelist - list_kms_parser = kms_subparsers.add_parser( - 'list', help='List trusted signers') + kms_subparsers.add_parser("list", help="List trusted signers") # Add to whitelist add_kms_parser = kms_subparsers.add_parser( - 'add', help='Add public key to trusted signers') - add_kms_parser.add_argument('pubkey', help='Public key to add') + "add", help="Add public key to trusted signers" + ) + add_kms_parser.add_argument("pubkey", help="Public key to add") # Remove from whitelist remove_kms_parser = kms_subparsers.add_parser( - 'remove', help='Remove public key from trusted signers') - remove_kms_parser.add_argument('pubkey', help='Public key to remove') + "remove", help="Remove public key from trusted signers" + ) + remove_kms_parser.add_argument("pubkey", help="Public key to remove") # Update app compose update_app_compose_parser = subparsers.add_parser( - 'update-app-compose', help='Update app compose for a VM') - update_app_compose_parser.add_argument('vm_id', help='VM ID to update') + "update-app-compose", help="Update app compose for a VM" + ) + update_app_compose_parser.add_argument("vm_id", help="VM ID to update") update_app_compose_parser.add_argument( - 'compose', help='Path to app-compose.json file') + "compose", help="Path to app-compose.json file" + ) # Update user config update_user_config_parser = subparsers.add_parser( - 'update-user-config', help='Update user config for a VM') - update_user_config_parser.add_argument('vm_id', help='VM ID to update') + "update-user-config", help="Update user config for a VM" + ) + update_user_config_parser.add_argument("vm_id", help="VM ID to update") update_user_config_parser.add_argument( - 'user_config', help='Path to user config file') + "user_config", help="Path to user config file" + ) # Update port mapping update_ports_parser = subparsers.add_parser( @@ -1653,32 +1398,24 @@ def main(): update_parser.add_argument("vm_id", help="VM ID to update") # Resource options (requires VM to be stopped) - update_parser.add_argument( - "--vcpu", type=int, help="Number of vCPUs" - ) + update_parser.add_argument("--vcpu", type=int, help="Number of vCPUs") update_parser.add_argument( "--memory", type=parse_memory_size, help="Memory size (e.g. 1G, 100M)" ) update_parser.add_argument( "--disk", type=parse_disk_size, help="Disk size (e.g. 20G, 1T)" ) - update_parser.add_argument( - "--image", type=str, help="Image name" - ) + update_parser.add_argument("--image", type=str, help="Image name") # Application options - update_parser.add_argument( - "--compose", help="Path to app-compose.json file" - ) + update_parser.add_argument("--compose", help="Path to app-compose.json file") update_parser.add_argument( "--prelaunch-script", help="Path to pre-launch script file" ) update_parser.add_argument( "--env-file", help="File with environment variables to encrypt" ) - update_parser.add_argument( - "--user-config", help="Path to user config file" - ) + update_parser.add_argument("--user-config", help="Path to user config file") # Port mapping options (mutually exclusive with --no-ports) port_group = update_parser.add_mutually_exclusive_group() port_group.add_argument( @@ -1693,7 +1430,9 @@ def main(): help="Remove all port mappings from the VM", ) update_parser.add_argument( - "--swap", type=parse_memory_size, help="Swap size (e.g. 4G). Set to 0 to disable" + "--swap", + type=parse_memory_size, + help="Swap size (e.g. 4G). Set to 0 to disable", ) # GPU options (mutually exclusive) @@ -1702,7 +1441,7 @@ def main(): "--gpu", action="append", type=str, - help="GPU slot to attach (can be used multiple times), or \"all\" to attach all GPUs", + help="GPU slot to attach (can be used multiple times)", ) gpu_group.add_argument( "--ppcie", @@ -1732,35 +1471,21 @@ def main(): update_parser.set_defaults(no_tee=None) # KMS URL for environment encryption - update_parser.add_argument( - "--kms-url", action="append", type=str, help="KMS URL" - ) + update_parser.add_argument("--kms-url", action="append", type=str, help="KMS URL") args = parser.parse_args() - # Handle discovery commands before creating CLI (they don't need a connection) - if args.command == 'ls-vmm': - cmd_ls_vmm(args) - return - elif args.command == 'switch-vmm': - cmd_switch_vmm(args) - return - - # Resolve the URL with auto-discovery - url = resolve_vmm_url(instances, config, args.url) - cli = VmmCLI(url, args.auth_user, args.auth_password) + cli = VmmCLI(args.url, args.auth_user, args.auth_password) - if args.command == 'lsvm': + if args.command == "lsvm": cli.list_vms(args.verbose, args.json) - elif args.command == 'info': - cli.show_info(args.vm_id, args.json) - elif args.command == 'start': + elif args.command == "start": cli.start_vm(args.vm_id) - elif args.command == 'stop': + elif args.command == "stop": cli.stop_vm(args.vm_id, args.force) - elif args.command == 'remove': + elif args.command == "remove": cli.remove_vm(args.vm_id) - elif args.command == 'resize': + elif args.command == "resize": cli.resize_vm( args.vm_id, vcpu=args.vcpu, @@ -1768,24 +1493,24 @@ def main(): disk_size=args.disk, image=args.image, ) - elif args.command == 'logs': + elif args.command == "logs": cli.show_logs(args.vm_id, args.lines, args.follow) - elif args.command == 'compose': + elif args.command == "compose": cli.create_app_compose(args) - elif args.command == 'deploy': + elif args.command == "deploy": cli.create_vm(args) - elif args.command == 'lsimage': + elif args.command == "lsimage": cli.list_images(args.json) - elif args.command == 'lsgpu': + elif args.command == "lsgpu": cli.list_gpus(args.json) - elif args.command == 'update-env': - cli.update_vm_env(args.vm_id, parse_env_file( - args.env_file), kms_urls=args.kms_url) - elif args.command == 'update-user-config': - cli.update_vm_user_config( - args.vm_id, open(args.user_config, 'r').read()) - elif args.command == 'update-app-compose': - cli.update_vm_app_compose(args.vm_id, open(args.compose, 'r').read()) + elif args.command == "update-env": + cli.update_vm_env( + args.vm_id, parse_env_file(args.env_file), kms_urls=args.kms_url + ) + elif args.command == "update-user-config": + cli.update_vm_user_config(args.vm_id, open(args.user_config, "r").read()) + elif args.command == "update-app-compose": + cli.update_vm_app_compose(args.vm_id, open(args.compose, "r").read()) elif args.command == "update-ports": cli.update_vm_ports(args.vm_id, args.port) elif args.command == "update": @@ -1793,7 +1518,7 @@ def main(): if args.compose: compose_content = read_utf8(args.compose) prelaunch_content = None - if hasattr(args, 'prelaunch_script') and args.prelaunch_script: + if hasattr(args, "prelaunch_script") and args.prelaunch_script: prelaunch_content = read_utf8(args.prelaunch_script) user_config_content = None if args.user_config: @@ -1806,28 +1531,28 @@ def main(): image=args.image, docker_compose_content=compose_content, prelaunch_script=prelaunch_content, - swap_size=args.swap if hasattr(args, 'swap') else None, + swap_size=args.swap if hasattr(args, "swap") else None, env_file=args.env_file, user_config=user_config_content, ports=args.port, - no_ports=args.no_ports if hasattr(args, 'no_ports') else False, + no_ports=args.no_ports if hasattr(args, "no_ports") else False, gpu_slots=args.gpu, attach_all=args.ppcie, - no_gpus=args.no_gpus if hasattr(args, 'no_gpus') else False, + no_gpus=args.no_gpus if hasattr(args, "no_gpus") else False, kms_urls=args.kms_url, no_tee=args.no_tee, ) - elif args.command == 'kms': + elif args.command == "kms": if not args.kms_action: kms_parser.print_help() else: cli.manage_kms_whitelist( action=args.kms_action, - pubkey=getattr(args, 'pubkey', None), + pubkey=getattr(args, "pubkey", None), ) else: parser.print_help() -if __name__ == '__main__': +if __name__ == "__main__": main() From 1e0d995ac875b461ae7f09f1f444e60139ad8006 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 03:46:22 +0000 Subject: [PATCH 02/10] fix: add SPDX header to prek.toml and permissions to CI workflow --- .github/workflows/prek-check.yml | 3 +++ prek.toml | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/.github/workflows/prek-check.yml b/.github/workflows/prek-check.yml index b1071ad86..56af547a3 100644 --- a/.github/workflows/prek-check.yml +++ b/.github/workflows/prek-check.yml @@ -10,6 +10,9 @@ on: pull_request: branches: [ master, next, dev-* ] +permissions: + contents: read + jobs: pre-commit: runs-on: ubuntu-latest diff --git a/prek.toml b/prek.toml index d3d3abc87..30e141ad3 100644 --- a/prek.toml +++ b/prek.toml @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + # ============================================================================= # prek configuration for dstack # ============================================================================= From c8d2cac445c5dcc4657de550d47a8d66a52cc660 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 03:51:07 +0000 Subject: [PATCH 03/10] chore: rename pre-commit to prek in CI workflow --- .github/workflows/prek-check.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/prek-check.yml b/.github/workflows/prek-check.yml index 56af547a3..a1a60394c 100644 --- a/.github/workflows/prek-check.yml +++ b/.github/workflows/prek-check.yml @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -name: Pre-commit checks +name: Prek checks on: push: @@ -14,7 +14,7 @@ permissions: contents: read jobs: - pre-commit: + prek: runs-on: ubuntu-latest steps: @@ -41,10 +41,10 @@ jobs: - name: Install prek run: pip install prek - - name: Run pre-commit checks (PR) + - name: Run prek checks (PR) if: github.event_name == 'pull_request' run: prek run --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} --show-diff-on-failure - - name: Run pre-commit checks (push) + - name: Run prek checks (push) if: github.event_name == 'push' run: prek run --from-ref ${{ github.event.before }} --to-ref ${{ github.event.after }} --show-diff-on-failure From d71ef88942dc329d77457115f77ee86fdb3a8bee Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 03:54:35 +0000 Subject: [PATCH 04/10] fix: explicitly pass ruff config path for correct resolution --- prek.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prek.toml b/prek.toml index 30e141ad3..8e01f5b04 100644 --- a/prek.toml +++ b/prek.toml @@ -39,8 +39,8 @@ pass_filenames = false repo = "https://github.com/astral-sh/ruff-pre-commit" rev = "v0.11.4" hooks = [ - { id = "ruff", args = ["--fix"] }, - { id = "ruff-format" }, + { id = "ruff", args = ["--fix", "--config", "sdk/python/pyproject.toml"] }, + { id = "ruff-format", args = ["--config", "sdk/python/pyproject.toml"] }, ] # --- Go: go vet --- From ec0df8afba36c7fdc83adbf034f3466f0b93cd27 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 03:57:56 +0000 Subject: [PATCH 05/10] fix: use root-level ruff.toml instead of sdk/python config The sdk/python config has strict docstring rules (D*) that don't apply to scripts outside sdk/python/. Add a minimal root-level ruff.toml with basic E/F/I rules only. --- prek.toml | 4 ++-- ruff.toml | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) create mode 100644 ruff.toml diff --git a/prek.toml b/prek.toml index 8e01f5b04..30e141ad3 100644 --- a/prek.toml +++ b/prek.toml @@ -39,8 +39,8 @@ pass_filenames = false repo = "https://github.com/astral-sh/ruff-pre-commit" rev = "v0.11.4" hooks = [ - { id = "ruff", args = ["--fix", "--config", "sdk/python/pyproject.toml"] }, - { id = "ruff-format", args = ["--config", "sdk/python/pyproject.toml"] }, + { id = "ruff", args = ["--fix"] }, + { id = "ruff-format" }, ] # --- Go: go vet --- diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..3a15a4083 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# Root-level ruff config for pre-commit hooks. +# sdk/python/ has its own stricter config in pyproject.toml. + +line-length = 88 + +[lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes + "I", # isort +] +ignore = [ + "E203", # whitespace before ':' + "E501", # line too long (handled by formatter) +] From b6f9d85ee5121b9459b3060d93ca8305f37de2cc Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 04:10:10 +0000 Subject: [PATCH 06/10] fix: sort imports in vmm-cli.py to pass ruff isort check --- vmm/src/vmm-cli.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vmm/src/vmm-cli.py b/vmm/src/vmm-cli.py index e77bc1e1d..dbbf0ea39 100755 --- a/vmm/src/vmm-cli.py +++ b/vmm/src/vmm-cli.py @@ -4,25 +4,24 @@ # # SPDX-License-Identifier: Apache-2.0 -import os -import sys -import json import argparse +import base64 import hashlib +import http.client +import json +import os import re import socket -import http.client -import urllib.parse import ssl -import base64 - -from typing import Optional, Dict, List, Tuple, Union, BinaryIO, Any +import sys +import urllib.parse +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union # Optional cryptography imports - only needed for encrypted environment variables try: + from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x25519 from cryptography.hazmat.primitives.ciphers.aead import AESGCM - from cryptography.hazmat.primitives import serialization from eth_keys import keys from eth_utils import keccak From e77c64705d1b94e54ee0a44d4efa395fdb662e59 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 04:12:12 +0000 Subject: [PATCH 07/10] fix: resolve all E501 line-too-long issues in vmm-cli.py --- vmm/src/vmm-cli.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/vmm/src/vmm-cli.py b/vmm/src/vmm-cli.py index dbbf0ea39..a325457c4 100755 --- a/vmm/src/vmm-cli.py +++ b/vmm/src/vmm-cli.py @@ -39,16 +39,20 @@ def encrypt_env(envs, hex_public_key: str) -> str: This function does the following: 1. Converts the given environment variables to JSON bytes. - 2. Removes a leading "0x" from the provided public key (if present) and converts it to bytes. + 2. Removes a leading "0x" from the provided public key + (if present) and converts it to bytes. 3. Generates an ephemeral X25519 key pair. - 4. Computes a shared secret using this ephemeral private key and the remote public key. + 4. Computes a shared secret using this ephemeral private + key and the remote public key. 5. Uses the shared key directly as the 32-byte key for AES-GCM. - 6. Encrypts the JSON string with AES-GCM using a randomly generated IV. - 7. Concatenates the ephemeral public key, IV, and ciphertext and returns it as a hex string. + 6. Encrypts the JSON string with AES-GCM using a random IV. + 7. Concatenates the ephemeral public key, IV, and ciphertext + and returns it as a hex string. Args: - envs: The environment variables to encrypt. This can be any JSON-serializable data structure. - hex_public_key: The remote encryption public key in hexadecimal format. + envs: The environment variables to encrypt. This can be + any JSON-serializable data structure. + hex_public_key: The remote encryption public key in hex. Returns: A hexadecimal string that is the concatenation of: @@ -372,7 +376,8 @@ def resize_vm( if len(params) == 1: raise Exception( - "at least one parameter must be specified for resize: --vcpu, --memory, --disk, or --image" + "at least one parameter must be specified for resize:" + " --vcpu, --memory, --disk, or --image" ) self.rpc_call("ResizeVm", params) @@ -849,7 +854,7 @@ def update_vm( port_mappings = [parse_port_mapping(port) for port in ports] updates.append("port mappings") else: - # ports is an empty list - shouldn't happen with mutually exclusive group + # ports is empty - shouldn't happen with exclusive group port_mappings = [] updates.append("port mappings (none)") upgrade_params["update_ports"] = True @@ -870,7 +875,7 @@ def update_vm( } updates.append(f"GPUs ({len(gpu_slots)} devices)") else: - # gpu_slots is an empty list ([] not None) - shouldn't happen with mutually exclusive group + # gpu_slots is empty - shouldn't happen with exclusive group gpu_config = {"attach_mode": "listed", "gpus": []} updates.append("GPUs (none)") upgrade_params["gpus"] = gpu_config @@ -1052,8 +1057,15 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option The compressed public key if valid, None otherwise Examples: - >>> public_key = bytes.fromhex('e33a1832c6562067ff8f844a61e51ad051f1180b66ec2551fb0251735f3ee90a') - >>> signature = bytes.fromhex('8542c49081fbf4e03f62034f13fbf70630bdf256a53032e38465a27c36fd6bed7a5e7111652004aef37f7fd92fbfc1285212c4ae6a6154203a48f5e16cad2cef00') + >>> pk_hex = 'e33a1832c6562067ff8f844a61e51ad051f1180b66ec2551fb0251735f3ee90a' + >>> public_key = bytes.fromhex(pk_hex) + >>> sig_hex = ( + ... '8542c49081fbf4e03f62034f13fbf70630bdf256' + ... 'a53032e38465a27c36fd6bed7a5e7111652004ae' + ... 'f37f7fd92fbfc1285212c4ae6a6154203a48f5e1' + ... '6cad2cef00' + ... ) + >>> signature = bytes.fromhex(sig_hex) >>> app_id = '00' * 20 >>> compressed_pubkey = verify_signature(public_key, signature, app_id) >>> print(compressed_pubkey) @@ -1141,7 +1153,7 @@ def main(): parser.add_argument( "--auth-password", default=os.environ.get("DSTACK_VMM_AUTH_PASSWORD"), - help="Basic auth password (can also be set via DSTACK_VMM_AUTH_PASSWORD env var)", + help="Basic auth password (env: DSTACK_VMM_AUTH_PASSWORD)", ) subparsers = parser.add_subparsers(dest="command", help="Commands") @@ -1387,7 +1399,7 @@ def main(): action="append", type=str, required=True, - help="Port mapping in format: protocol[:address]:from:to (can be used multiple times)", + help="Port mapping: protocol[:address]:from:to (repeatable)", ) # Update (all-in-one) command @@ -1421,7 +1433,7 @@ def main(): "--port", action="append", type=str, - help="Port mapping in format: protocol[:address]:from:to (can be used multiple times)", + help="Port mapping: protocol[:address]:from:to (repeatable)", ) port_group.add_argument( "--no-ports", From 194b6d9d67be9af6476b2b8c7d50a3750db0c3ba Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 04:18:16 +0000 Subject: [PATCH 08/10] fix: sort imports and remove f-string prefix in Python files --- python/ct_monitor/ct_monitor.py | 26 ++- scripts/add-spdx-attribution.py | 403 ++++++++++++++++++-------------- tools/mock-cf-dns-api/app.py | 143 +++++++----- 3 files changed, 338 insertions(+), 234 deletions(-) diff --git a/python/ct_monitor/ct_monitor.py b/python/ct_monitor/ct_monitor.py index 756636354..3f44db7d2 100644 --- a/python/ct_monitor/ct_monitor.py +++ b/python/ct_monitor/ct_monitor.py @@ -2,10 +2,11 @@ # # SPDX-License-Identifier: Apache-2.0 +import argparse import sys import time + import requests -import argparse from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -28,24 +29,29 @@ def get_logs(self, count: int = 100): url = f"{BASE_URL}/?q={self.domain}&output=json&limit={count}" response = requests.get(url) return response.json() - + def check_one_log(self, log: object): log_id = log["id"] cert_url = f"{BASE_URL}/?d={log_id}" cert_data = requests.get(cert_url).text # Extract PEM-encoded certificate import re - pem_match = re.search(r'-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----', cert_data, re.DOTALL) + + pem_match = re.search( + r"-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----", + cert_data, + re.DOTALL, + ) if pem_match: pem_cert = pem_match.group(0) - + # Parse PEM certificate cert = x509.load_pem_x509_certificate(pem_cert.encode(), default_backend()) # Extract the public key public_key = cert.public_key() pem_public_key = public_key.public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) print("Public Key:") print(pem_public_key.hex()) @@ -65,11 +71,11 @@ def check_new_logs(self): logs = self.get_logs(count=10000) print("num logs", len(logs)) for log in logs: - print(f"log id={log["id"]}") + print(f"log id={log['id']}") if log["id"] <= (self.last_checked or 0): break self.check_one_log(log) - print('next') + print("next") if len(logs) > 0: self.last_checked = logs[0]["id"] @@ -92,7 +98,7 @@ def validate_domain(domain: str): # Regular expression for validating domain names domain_regex = re.compile( - r'^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$' + r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$" ) if not domain_regex.match(domain): @@ -102,7 +108,9 @@ def validate_domain(domain: str): def main(): - parser = argparse.ArgumentParser(description="Monitor certificate transparency logs") + parser = argparse.ArgumentParser( + description="Monitor certificate transparency logs" + ) parser.add_argument("-d", "--domain", help="The domain to monitor") args = parser.parse_args() monitor = Monitor(args.domain) diff --git a/scripts/add-spdx-attribution.py b/scripts/add-spdx-attribution.py index 67de938ba..926110253 100755 --- a/scripts/add-spdx-attribution.py +++ b/scripts/add-spdx-attribution.py @@ -18,14 +18,12 @@ import argparse import fnmatch -import os import re import subprocess import sys from collections import defaultdict -from datetime import datetime from pathlib import Path -from typing import Dict, List, Set, Tuple, Optional +from typing import Dict, List, Optional, Set, Tuple class SPDXAttributor: @@ -35,243 +33,279 @@ def __init__(self, repo_root: str, dry_run: bool = False): self.exclude_patterns = self._load_exclude_patterns() self.mailmap = self._parse_mailmap() self.company_domains = self._build_company_domain_map() - + def _load_exclude_patterns(self) -> List[str]: """Load exclusion patterns from .spdx-exclude file.""" - exclude_file = self.repo_root / '.spdx-exclude' + exclude_file = self.repo_root / ".spdx-exclude" patterns = [] - + if exclude_file.exists(): - with open(exclude_file, 'r') as f: + with open(exclude_file, "r") as f: for line in f: line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): patterns.append(line) - + return patterns - + def _parse_mailmap(self) -> Dict[str, Tuple[str, str]]: """Parse .mailmap file to get canonical name/email mappings.""" - mailmap_file = self.repo_root / '.mailmap' + mailmap_file = self.repo_root / ".mailmap" mailmap = {} - + if not mailmap_file.exists(): print("Warning: .mailmap file not found") return mailmap - - with open(mailmap_file, 'r') as f: + + with open(mailmap_file, "r") as f: for line in f: line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue - + # Parse mailmap format: "Proper Name " # or "Proper Name Commit Name " - if '>' in line: - parts = line.split('>') + if ">" in line: + parts = line.split(">") if len(parts) >= 2: proper_part = parts[0].strip() commit_part = parts[1].strip() - + # Extract proper name and email - if '<' in proper_part: - proper_name = proper_part.split('<')[0].strip() - proper_email = proper_part.split('<')[1].strip() + if "<" in proper_part: + proper_name = proper_part.split("<")[0].strip() + proper_email = proper_part.split("<")[1].strip() else: continue - + # Extract commit email (and possibly name) - if '<' in commit_part: - commit_email = commit_part.split('<')[1].split('>')[0].strip() + if "<" in commit_part: + commit_email = ( + commit_part.split("<")[1].split(">")[0].strip() + ) else: commit_email = commit_part.strip() - + mailmap[commit_email] = (proper_name, proper_email) - + return mailmap - + def _build_company_domain_map(self) -> Dict[str, str]: """Build mapping from email domains to company names.""" return { - 'phala.network': 'Phala Network', - 'near.ai': 'Near Foundation', - 'nethermind.io': 'Nethermind', - 'rizelabs.io': 'Rize Labs', - 'testinprod.io': 'Test in Prod', + "phala.network": "Phala Network", + "near.ai": "Near Foundation", + "nethermind.io": "Nethermind", + "rizelabs.io": "Rize Labs", + "testinprod.io": "Test in Prod", } - + def _is_excluded(self, file_path: Path) -> bool: """Check if a file should be excluded based on patterns.""" rel_path = file_path.relative_to(self.repo_root) rel_path_str = str(rel_path) - + for pattern in self.exclude_patterns: # Handle directory patterns - if pattern.endswith('/'): - if rel_path_str.startswith(pattern) or f"/{pattern}" in f"/{rel_path_str}/": + if pattern.endswith("/"): + if ( + rel_path_str.startswith(pattern) + or f"/{pattern}" in f"/{rel_path_str}/" + ): return True # Handle glob patterns elif fnmatch.fnmatch(rel_path_str, pattern): return True # Handle wildcard patterns - elif '*' in pattern and fnmatch.fnmatch(rel_path_str, pattern): + elif "*" in pattern and fnmatch.fnmatch(rel_path_str, pattern): return True - + return False - + def _get_canonical_contributor(self, email: str, name: str) -> Tuple[str, str]: """Get canonical name and email for a contributor using mailmap.""" if email in self.mailmap: return self.mailmap[email] return name.strip(), email.strip() - + def _get_company_name(self, email: str) -> Optional[str]: """Get company name for an email domain, or None for individual contributors.""" - domain = email.split('@')[-1].lower() + domain = email.split("@")[-1].lower() return self.company_domains.get(domain) - + def _get_company_contact_email(self, company: str) -> str: """Get the standardized contact email for a company.""" # Company-specific contact emails company_contacts = { - 'Phala Network': 'dstack@phala.network', - 'Near Foundation': 'contact@near.ai', - 'Nethermind': 'contact@nethermind.io', - 'Rize Labs': 'contact@rizelabs.io', - 'Test in Prod': 'contact@testinprod.io', + "Phala Network": "dstack@phala.network", + "Near Foundation": "contact@near.ai", + "Nethermind": "contact@nethermind.io", + "Rize Labs": "contact@rizelabs.io", + "Test in Prod": "contact@testinprod.io", } - - return company_contacts.get(company, f'contact@{company.lower().replace(" ", "")}.com') - + + return company_contacts.get( + company, f"contact@{company.lower().replace(' ', '')}.com" + ) + def _analyze_git_blame(self, file_path: Path) -> Dict[str, Set[int]]: """Analyze git blame to get contributors and their contribution years.""" try: # Run git blame to get author info and dates - result = subprocess.run([ - 'git', 'blame', '--porcelain', str(file_path) - ], capture_output=True, text=True, cwd=self.repo_root) - + result = subprocess.run( + ["git", "blame", "--porcelain", str(file_path)], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if result.returncode != 0: print(f"Warning: Could not run git blame on {file_path}") return {} - + contributors = defaultdict(set) current_commit = None - - for line in result.stdout.split('\n'): + + for line in result.stdout.split("\n"): line = line.strip() if not line: continue - + # Parse commit hash line - if re.match(r'^[0-9a-f]{40}', line): + if re.match(r"^[0-9a-f]{40}", line): current_commit = line.split()[0] # Parse author email - elif line.startswith('author-mail '): - email = line[12:].strip('<>') - + elif line.startswith("author-mail "): + email = line[12:].strip("<>") + # Skip SPDX-only commits to avoid counting license header changes as contributions if self._is_spdx_only_commit(current_commit): continue - + # Get the year for this commit year = self._get_commit_year(current_commit) if year and email: contributors[email].add(year) - + return contributors - + except subprocess.SubprocessError as e: print(f"Error running git blame on {file_path}: {e}") return {} - + def _is_spdx_only_commit(self, commit_hash: str) -> bool: """Check if a commit only contains SPDX header changes.""" try: # Get commit message - result = subprocess.run([ - 'git', 'show', '-s', '--format=%s%n%b', commit_hash - ], capture_output=True, text=True, cwd=self.repo_root) - + result = subprocess.run( + ["git", "show", "-s", "--format=%s%n%b", commit_hash], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if result.returncode != 0: return False - + commit_message = result.stdout.lower() - + # Check for SPDX-related keywords in commit message spdx_keywords = [ - 'spdx', 'license header', 'copyright header', 'add license', - 'update license', 'license annotation', 'reuse annotate', - 'add spdx', 'update spdx', 'copyright attribution' + "spdx", + "license header", + "copyright header", + "add license", + "update license", + "license annotation", + "reuse annotate", + "add spdx", + "update spdx", + "copyright attribution", ] - + if any(keyword in commit_message for keyword in spdx_keywords): # Get the diff to see if it's only header changes - diff_result = subprocess.run([ - 'git', 'show', '--format=', commit_hash - ], capture_output=True, text=True, cwd=self.repo_root) - + diff_result = subprocess.run( + ["git", "show", "--format=", commit_hash], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if diff_result.returncode == 0: diff_content = diff_result.stdout - + # Check if the diff only contains SPDX/copyright/license changes # Look for lines that are only adding/removing headers - diff_lines = diff_content.split('\n') + diff_lines = diff_content.split("\n") substantial_changes = 0 - + for line in diff_lines: - if line.startswith(('+', '-')) and not line.startswith(('+++', '---')): + if line.startswith(("+", "-")) and not line.startswith( + ("+++", "---") + ): # Skip lines that are just SPDX/copyright/license related line_content = line[1:].strip() - if line_content and not any(marker in line_content.lower() for marker in [ - 'spdx-', 'copyright', 'license-identifier', - 'filepyrighttext', '©', '(c)', 'all rights reserved' - ]): + if line_content and not any( + marker in line_content.lower() + for marker in [ + "spdx-", + "copyright", + "license-identifier", + "filepyrighttext", + "©", + "(c)", + "all rights reserved", + ] + ): substantial_changes += 1 - + # If we have very few substantial changes, likely an SPDX-only commit return substantial_changes <= 2 - + return False - + except subprocess.SubprocessError: return False - + def _get_commit_year(self, commit_hash: str) -> Optional[int]: """Get the year of a commit.""" try: - result = subprocess.run([ - 'git', 'show', '-s', '--format=%ad', '--date=format:%Y', commit_hash - ], capture_output=True, text=True, cwd=self.repo_root) - + result = subprocess.run( + ["git", "show", "-s", "--format=%ad", "--date=format:%Y", commit_hash], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if result.returncode == 0: return int(result.stdout.strip()) except (subprocess.SubprocessError, ValueError): pass - + return None - + def _generate_spdx_headers(self, contributors: Dict[str, Set[int]]) -> List[str]: """Generate SPDX-FileCopyrightText headers for contributors.""" headers = [] company_years = defaultdict(set) # Track years by company individual_contributors = {} # Track individual contributors - + # First pass: group contributors by company vs individual for email, years in contributors.items(): # Get canonical name and email name, canonical_email = self._get_canonical_contributor(email, "Unknown") - + # Check if this is a company contributor company = self._get_company_name(canonical_email) - + if company: # Accumulate years for this company company_years[company].update(years) else: # Individual contributor individual_contributors[canonical_email] = (name, years) - + # Generate company headers (one per company) for company, years in company_years.items(): year_list = sorted(years) @@ -279,11 +313,11 @@ def _generate_spdx_headers(self, contributors: Dict[str, Set[int]]) -> List[str] year_str = str(year_list[0]) else: year_str = f"{year_list[0]}-{year_list[-1]}" - + contact_email = self._get_company_contact_email(company) header = f"SPDX-FileCopyrightText: © {year_str} {company} <{contact_email}>" headers.append(header) - + # Generate individual headers for canonical_email, (name, years) in individual_contributors.items(): year_list = sorted(years) @@ -291,189 +325,216 @@ def _generate_spdx_headers(self, contributors: Dict[str, Set[int]]) -> List[str] year_str = str(year_list[0]) else: year_str = f"{year_list[0]}-{year_list[-1]}" - + header = f"SPDX-FileCopyrightText: © {year_str} {name} <{canonical_email}>" headers.append(header) - + return sorted(headers) - + def _remove_existing_spdx_headers(self, file_path: Path) -> bool: """Remove existing SPDX headers from a file.""" if not file_path.exists(): return False - + try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() - + # Find and remove existing SPDX lines new_lines = [] for line in lines: - if not (line.strip().startswith('// SPDX-') or - line.strip().startswith('# SPDX-') or - line.strip().startswith('/* SPDX-') or - line.strip().startswith(' * SPDX-')): + if not ( + line.strip().startswith("// SPDX-") + or line.strip().startswith("# SPDX-") + or line.strip().startswith("/* SPDX-") + or line.strip().startswith(" * SPDX-") + ): new_lines.append(line) - + # Write back if changes were made if len(new_lines) != len(lines): if not self.dry_run: - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.writelines(new_lines) return True - + except (IOError, UnicodeDecodeError) as e: print(f"Warning: Could not process {file_path}: {e}") - + return False - + def _get_existing_license(self, file_path: Path) -> str: """Get the existing SPDX license identifier from a file.""" try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # Look for existing SPDX-License-Identifier - lines = content.split('\n')[:20] # Check first 20 lines + lines = content.split("\n")[:20] # Check first 20 lines for line in lines: - if 'SPDX-License-Identifier:' in line: + if "SPDX-License-Identifier:" in line: # Extract the license identifier - parts = line.split('SPDX-License-Identifier:') + parts = line.split("SPDX-License-Identifier:") if len(parts) > 1: - return parts[1].strip().rstrip('*/') - + return parts[1].strip().rstrip("*/") + # Default to Apache-2.0 if no existing license found - return 'Apache-2.0' - + return "Apache-2.0" + except Exception: - return 'Apache-2.0' - + return "Apache-2.0" + def _apply_reuse_annotation(self, file_path: Path, headers: List[str]) -> bool: """Apply SPDX headers using the REUSE tool.""" if not headers: return False - + try: # Prepare REUSE command - pass full SPDX-FileCopyrightText headers directly - cmd = ['reuse', 'annotate'] - + cmd = ["reuse", "annotate"] + # Add all copyright lines with full SPDX-FileCopyrightText format for header in headers: - cmd.extend(['--copyright', header]) # Keep the full SPDX-FileCopyrightText: format - + cmd.extend( + ["--copyright", header] + ) # Keep the full SPDX-FileCopyrightText: format + # Preserve existing license or use Apache-2.0 as default existing_license = self._get_existing_license(file_path) - cmd.extend(['--license', existing_license]) - + cmd.extend(["--license", existing_license]) + # Handle Solidity files which need explicit style specification - if file_path.suffix == '.sol': - cmd.extend(['--style', 'c']) - + if file_path.suffix == ".sol": + cmd.extend(["--style", "c"]) + # Add the file cmd.append(str(file_path)) - + if self.dry_run: print(f"Would run: {' '.join(cmd)}") return True else: - result = subprocess.run(cmd, capture_output=True, text=True, cwd=self.repo_root) + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=self.repo_root + ) if result.returncode != 0: print(f"REUSE command failed for {file_path}: {result.stderr}") return False return True - + except subprocess.SubprocessError as e: print(f"Error running REUSE on {file_path}: {e}") return False - + def process_file(self, file_path: Path) -> bool: """Process a single file to add SPDX attribution.""" # Convert to absolute path if it's relative if not file_path.is_absolute(): file_path = self.repo_root / file_path - + if self._is_excluded(file_path): if self.dry_run: print(f"Excluded: {file_path}") return False - + # Analyze git blame contributors = self._analyze_git_blame(file_path) if not contributors: print(f"No contributors found for {file_path}") return False - + # Generate SPDX headers headers = self._generate_spdx_headers(contributors) - + if self.dry_run: print(f"\nFile: {file_path}") print(f"Contributors: {len(contributors)}") for header in headers: print(f" {header}") return True - + # Remove existing SPDX headers self._remove_existing_spdx_headers(file_path) - + # Apply new headers with REUSE success = self._apply_reuse_annotation(file_path, headers) - + if success: print(f"✓ Updated: {file_path}") else: print(f"✗ Failed: {file_path}") - + return success - + def find_source_files(self) -> List[Path]: """Find all source files that should be processed.""" - extensions = {'.rs', '.py', '.go', '.ts', '.js', '.c', '.h', '.cpp', '.hpp', '.sol'} + extensions = { + ".rs", + ".py", + ".go", + ".ts", + ".js", + ".c", + ".h", + ".cpp", + ".hpp", + ".sol", + } source_files = [] - + for ext in extensions: - for file_path in self.repo_root.rglob(f'*{ext}'): + for file_path in self.repo_root.rglob(f"*{ext}"): if file_path.is_file() and not self._is_excluded(file_path): source_files.append(file_path) - + return sorted(source_files) def main(): - parser = argparse.ArgumentParser(description='Add SPDX attribution headers to source files') - parser.add_argument('--dry-run', action='store_true', help='Show what would be done without making changes') - parser.add_argument('--file', type=str, help='Process a specific file instead of all source files') - parser.add_argument('--repo-root', type=str, default='.', help='Repository root directory') - + parser = argparse.ArgumentParser( + description="Add SPDX attribution headers to source files" + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be done without making changes", + ) + parser.add_argument( + "--file", type=str, help="Process a specific file instead of all source files" + ) + parser.add_argument( + "--repo-root", type=str, default=".", help="Repository root directory" + ) + args = parser.parse_args() - + # Initialize the attributor attributor = SPDXAttributor(args.repo_root, dry_run=args.dry_run) - + if args.file: # Process single file file_path = Path(args.file) if not file_path.exists(): print(f"Error: File {file_path} does not exist") sys.exit(1) - + attributor.process_file(file_path) else: # Process all source files source_files = attributor.find_source_files() - + if args.dry_run: print(f"Found {len(source_files)} source files to process") print("\nDry run - showing what would be done:\n") - + success_count = 0 for file_path in source_files: if attributor.process_file(file_path): success_count += 1 - + if not args.dry_run: print(f"\nProcessed {success_count}/{len(source_files)} files successfully") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/tools/mock-cf-dns-api/app.py b/tools/mock-cf-dns-api/app.py index 5703db797..53ca20233 100644 --- a/tools/mock-cf-dns-api/app.py +++ b/tools/mock-cf-dns-api/app.py @@ -13,14 +13,14 @@ - DELETE /client/v4/zones/{zone_id}/dns_records/{record_id} - Delete DNS record """ +import json import os import uuid -import time -import json from datetime import datetime -from flask import Flask, request, jsonify, render_template_string from functools import wraps +from flask import Flask, jsonify, render_template_string, request + app = Flask(__name__) # In-memory storage for DNS records @@ -32,7 +32,11 @@ MAX_LOGS = 100 # Valid API tokens (for testing, accept any non-empty token or use env var) -VALID_TOKENS = os.environ.get("CF_API_TOKENS", "").split(",") if os.environ.get("CF_API_TOKENS") else None +VALID_TOKENS = ( + os.environ.get("CF_API_TOKENS", "").split(",") + if os.environ.get("CF_API_TOKENS") + else None +) def log_request(zone_id, method, path, req_data, resp_data, status_code): @@ -63,29 +67,35 @@ def get_current_time(): def verify_auth(f): """Decorator to verify Bearer token authentication.""" + @wraps(f) def decorated(*args, **kwargs): auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): - return jsonify({ - "success": False, - "errors": [{"code": 10000, "message": "Authentication error"}], - "messages": [], - "result": None - }), 401 + return jsonify( + { + "success": False, + "errors": [{"code": 10000, "message": "Authentication error"}], + "messages": [], + "result": None, + } + ), 401 token = auth_header[7:] # Remove "Bearer " prefix # If VALID_TOKENS is set, validate against it; otherwise accept any token if VALID_TOKENS and token not in VALID_TOKENS: - return jsonify({ - "success": False, - "errors": [{"code": 10000, "message": "Invalid API token"}], - "messages": [], - "result": None - }), 403 + return jsonify( + { + "success": False, + "errors": [{"code": 10000, "message": "Invalid API token"}], + "messages": [], + "result": None, + } + ), 403 return f(*args, **kwargs) + return decorated @@ -95,7 +105,7 @@ def cf_response(result, success=True, errors=None, messages=None): "success": success, "errors": errors or [], "messages": messages or [], - "result": result + "result": result, } @@ -106,6 +116,7 @@ def cf_error(message, code=1000): # ==================== DNS Record Endpoints ==================== + @app.route("/client/v4/zones//dns_records", methods=["POST"]) @verify_auth def create_dns_record(zone_id): @@ -148,8 +159,8 @@ def create_dns_record(zone_id): "meta": { "auto_added": False, "managed_by_apps": False, - "managed_by_argo_tunnel": False - } + "managed_by_argo_tunnel": False, + }, } # Handle different record types @@ -177,7 +188,9 @@ def create_dns_record(zone_id): resp = cf_response(record) log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", data, resp, 200) - print(f"[CREATE] Zone: {zone_id}, Record: {record_id}, Type: {record_type}, Name: {name}") + print( + f"[CREATE] Zone: {zone_id}, Record: {record_id}, Type: {record_type}, Name: {name}" + ) return jsonify(resp), 200 @@ -220,10 +233,12 @@ def list_dns_records(zone_id): "per_page": per_page, "count": len(page_records), "total_count": total_count, - "total_pages": total_pages - } + "total_pages": total_pages, + }, } - log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records", dict(request.args), resp, 200) + log_request( + zone_id, "GET", f"/zones/{zone_id}/dns_records", dict(request.args), resp, 200 + ) return jsonify(resp), 200 @@ -237,11 +252,15 @@ def get_dns_record(zone_id, record_id): if not record: resp = cf_error("Record not found", 81044) - log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + log_request( + zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404 + ) return jsonify(resp), 404 resp = cf_response(record) - log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + log_request( + zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200 + ) return jsonify(resp), 200 @@ -255,13 +274,17 @@ def update_dns_record(zone_id, record_id): if not record: resp = cf_error("Record not found", 81044) - log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + log_request( + zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404 + ) return jsonify(resp), 404 data = request.get_json() if not data: resp = cf_error("Invalid request body") - log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 400) + log_request( + zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 400 + ) return jsonify(resp), 400 # Update allowed fields @@ -272,7 +295,9 @@ def update_dns_record(zone_id, record_id): record["modified_on"] = get_current_time() resp = cf_response(record) - log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", data, resp, 200) + log_request( + zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", data, resp, 200 + ) print(f"[UPDATE] Zone: {zone_id}, Record: {record_id}") @@ -287,13 +312,22 @@ def delete_dns_record(zone_id, record_id): if record_id not in zone_records: resp = cf_error("Record not found", 81044) - log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + log_request( + zone_id, + "DELETE", + f"/zones/{zone_id}/dns_records/{record_id}", + None, + resp, + 404, + ) return jsonify(resp), 404 del zone_records[record_id] resp = cf_response({"id": record_id}) - log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + log_request( + zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200 + ) print(f"[DELETE] Zone: {zone_id}, Record: {record_id}") @@ -321,7 +355,7 @@ def get_configured_zones(): try: return json.loads(zones_json) except json.JSONDecodeError: - print(f"Warning: Invalid MOCK_ZONES JSON, using defaults") + print("Warning: Invalid MOCK_ZONES JSON, using defaults") return DEFAULT_ZONES @@ -342,20 +376,19 @@ def list_zones(): # Build full zone objects full_zones = [] for z in zones: - full_zones.append({ - "id": z["id"], - "name": z["name"], - "status": "active", - "paused": False, - "type": "full", - "development_mode": 0, - "name_servers": [ - "ns1.mock-cloudflare.com", - "ns2.mock-cloudflare.com" - ], - "created_on": "2024-01-01T00:00:00.000000Z", - "modified_on": get_current_time(), - }) + full_zones.append( + { + "id": z["id"], + "name": z["name"], + "status": "active", + "paused": False, + "type": "full", + "development_mode": 0, + "name_servers": ["ns1.mock-cloudflare.com", "ns2.mock-cloudflare.com"], + "created_on": "2024-01-01T00:00:00.000000Z", + "modified_on": get_current_time(), + } + ) # Pagination total_count = len(full_zones) @@ -374,12 +407,14 @@ def list_zones(): "per_page": per_page, "count": len(page_zones), "total_count": total_count, - "total_pages": total_pages - } + "total_pages": total_pages, + }, } log_request("*", "GET", "/zones", dict(request.args), result, 200) - print(f"[LIST ZONES] page={page}, per_page={per_page}, count={len(page_zones)}, total={total_count}") + print( + f"[LIST ZONES] page={page}, per_page={per_page}, count={len(page_zones)}, total={total_count}" + ) return jsonify(result), 200 @@ -405,10 +440,7 @@ def get_zone(zone_id): "paused": False, "type": "full", "development_mode": 0, - "name_servers": [ - "ns1.mock-cloudflare.com", - "ns2.mock-cloudflare.com" - ], + "name_servers": ["ns1.mock-cloudflare.com", "ns2.mock-cloudflare.com"], "created_on": "2024-01-01T00:00:00.000000Z", "modified_on": get_current_time(), } @@ -789,12 +821,13 @@ def management_ui(): request_count=len(request_logs), records=all_records, logs=request_logs[:20], - port=os.environ.get("PORT", 8080) + port=os.environ.get("PORT", 8080), ) # ==================== Management API ==================== + @app.route("/api/records", methods=["DELETE"]) def clear_all_records(): """Clear all DNS records.""" @@ -829,7 +862,9 @@ def get_all_records(): @app.route("/health") def health(): """Health check endpoint.""" - return jsonify({"status": "healthy", "records": sum(len(r) for r in dns_records.values())}) + return jsonify( + {"status": "healthy", "records": sum(len(r) for r in dns_records.values())} + ) if __name__ == "__main__": From 6a3c130e26487b07db4f69e9b76ed4133a4c6eac Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 04:24:53 +0000 Subject: [PATCH 09/10] fix: add docstrings to all Python files and remove ruff.toml Remove root-level ruff.toml; pass lint rules via prek.toml args instead. Add proper docstrings to vmm-cli.py, ct_monitor.py, add-spdx-attribution.py, and app.py to comply with pydocstyle rules. --- prek.toml | 2 +- python/ct_monitor/ct_monitor.py | 13 ++++- ruff.toml | 19 -------- scripts/add-spdx-attribution.py | 11 +++-- tools/mock-cf-dns-api/app.py | 5 +- vmm/src/vmm-cli.py | 85 ++++++++++++++++++--------------- 6 files changed, 68 insertions(+), 67 deletions(-) delete mode 100644 ruff.toml diff --git a/prek.toml b/prek.toml index 30e141ad3..1e24d457a 100644 --- a/prek.toml +++ b/prek.toml @@ -39,7 +39,7 @@ pass_filenames = false repo = "https://github.com/astral-sh/ruff-pre-commit" rev = "v0.11.4" hooks = [ - { id = "ruff", args = ["--fix"] }, + { id = "ruff", args = ["--fix", "--select", "E,F,I,D", "--ignore", "D203,D213,E501"] }, { id = "ruff-format" }, ] diff --git a/python/ct_monitor/ct_monitor.py b/python/ct_monitor/ct_monitor.py index 3f44db7d2..7cc37913c 100644 --- a/python/ct_monitor/ct_monitor.py +++ b/python/ct_monitor/ct_monitor.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: © 2024 Phala Network # # SPDX-License-Identifier: Apache-2.0 +"""Monitor certificate transparency logs for a given domain.""" import argparse import sys @@ -15,22 +16,29 @@ class PoisonedLog(Exception): + """Indicate a poisoned certificate transparency log entry.""" + pass class Monitor: + """Monitor certificate transparency logs for a domain.""" + def __init__(self, domain: str): + """Initialize the monitor with a validated domain.""" if not self.validate_domain(domain): raise ValueError("Invalid domain name") self.domain = domain self.last_checked = None def get_logs(self, count: int = 100): + """Fetch recent certificate transparency log entries.""" url = f"{BASE_URL}/?q={self.domain}&output=json&limit={count}" response = requests.get(url) return response.json() def check_one_log(self, log: object): + """Fetch and inspect a single certificate log entry.""" log_id = log["id"] cert_url = f"{BASE_URL}/?d={log_id}" cert_data = requests.get(cert_url).text @@ -68,6 +76,7 @@ def check_one_log(self, log: object): print("No valid certificate found in the response.") def check_new_logs(self): + """Check for new log entries since the last check.""" logs = self.get_logs(count=10000) print("num logs", len(logs)) for log in logs: @@ -80,6 +89,7 @@ def check_new_logs(self): self.last_checked = logs[0]["id"] def run(self): + """Run the monitor loop indefinitely.""" print(f"Monitoring {self.domain}...") while True: try: @@ -93,7 +103,7 @@ def run(self): @staticmethod def validate_domain(domain: str): - # ensure domain is a valid DNS domain + """Validate that the given string is a well-formed DNS domain name.""" import re # Regular expression for validating domain names @@ -108,6 +118,7 @@ def validate_domain(domain: str): def main(): + """Parse arguments and start the certificate transparency monitor.""" parser = argparse.ArgumentParser( description="Monitor certificate transparency logs" ) diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 3a15a4083..000000000 --- a/ruff.toml +++ /dev/null @@ -1,19 +0,0 @@ -# SPDX-FileCopyrightText: © 2025 Phala Network -# -# SPDX-License-Identifier: Apache-2.0 - -# Root-level ruff config for pre-commit hooks. -# sdk/python/ has its own stricter config in pyproject.toml. - -line-length = 88 - -[lint] -select = [ - "E", # pycodestyle errors - "F", # pyflakes - "I", # isort -] -ignore = [ - "E203", # whitespace before ':' - "E501", # line too long (handled by formatter) -] diff --git a/scripts/add-spdx-attribution.py b/scripts/add-spdx-attribution.py index 926110253..d98e78ae8 100755 --- a/scripts/add-spdx-attribution.py +++ b/scripts/add-spdx-attribution.py @@ -2,11 +2,10 @@ # SPDX-FileCopyrightText: © 2025 Phala Network # # SPDX-License-Identifier: Apache-2.0 -""" -SPDX Header Attribution Script +"""SPDX header attribution script. -This script automatically analyzes git blame data to determine contributors -and adds appropriate SPDX-FileCopyrightText headers using the REUSE tool. +Analyze git blame data to determine contributors and add appropriate +SPDX-FileCopyrightText headers using the REUSE tool. Features: - Excludes third-party code based on .spdx-exclude patterns @@ -27,7 +26,10 @@ class SPDXAttributor: + """Add SPDX attribution headers to source files based on git blame.""" + def __init__(self, repo_root: str, dry_run: bool = False): + """Initialize the attributor with a repository root and options.""" self.repo_root = Path(repo_root).resolve() self.dry_run = dry_run self.exclude_patterns = self._load_exclude_patterns() @@ -491,6 +493,7 @@ def find_source_files(self) -> List[Path]: def main(): + """Parse arguments and run SPDX attribution on source files.""" parser = argparse.ArgumentParser( description="Add SPDX attribution headers to source files" ) diff --git a/tools/mock-cf-dns-api/app.py b/tools/mock-cf-dns-api/app.py index 53ca20233..daf97894f 100644 --- a/tools/mock-cf-dns-api/app.py +++ b/tools/mock-cf-dns-api/app.py @@ -3,8 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 -""" -Mock Cloudflare DNS API Server +"""Mock Cloudflare DNS API server. A mock server that simulates Cloudflare's DNS API for testing purposes. Supports the following endpoints used by certbot: @@ -66,7 +65,7 @@ def get_current_time(): def verify_auth(f): - """Decorator to verify Bearer token authentication.""" + """Verify Bearer token authentication.""" @wraps(f) def decorated(*args, **kwargs): diff --git a/vmm/src/vmm-cli.py b/vmm/src/vmm-cli.py index a325457c4..49380b530 100755 --- a/vmm/src/vmm-cli.py +++ b/vmm/src/vmm-cli.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +"""CLI tool for managing dstack-vmm virtual machines.""" # SPDX-FileCopyrightText: © 2025 Phala Network # @@ -34,8 +35,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: - """ - Encrypts environment variables using a one-time X25519 key exchange and AES-GCM. + """Encrypts environment variables using a one-time X25519 key exchange and AES-GCM. This function does the following: 1. Converts the given environment variables to JSON bytes. @@ -57,6 +57,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: Returns: A hexadecimal string that is the concatenation of: (ephemeral public key || IV || ciphertext). + """ if not CRYPTO_AVAILABLE: raise ImportError( @@ -100,7 +101,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: def parse_port_mapping(port_str: str) -> Dict: - """Parse a port mapping string into a dictionary""" + """Parse a port mapping string into a dictionary.""" parts = port_str.split(":") if len(parts) == 3: return { @@ -121,6 +122,7 @@ def parse_port_mapping(port_str: str) -> Dict: def read_utf8(filepath: str) -> str: + """Read a file and return its contents as a UTF-8 string.""" with open(filepath, "rb") as f: return f.read().decode("utf-8") @@ -129,10 +131,12 @@ class UnixSocketHTTPConnection(http.client.HTTPConnection): """HTTPConnection that connects to a Unix domain socket.""" def __init__(self, socket_path, timeout=None): + """Initialize with a Unix socket path and optional timeout.""" super().__init__("localhost", timeout=timeout) self.socket_path = socket_path def connect(self): + """Connect to the Unix domain socket.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) if self.timeout: sock.settimeout(self.timeout) @@ -149,6 +153,7 @@ def __init__( auth_user: Optional[str] = None, auth_password: Optional[str] = None, ): + """Initialize the client with a base URL and optional authentication.""" self.base_url = base_url.rstrip("/") self.use_uds = self.base_url.startswith("unix:") self.auth_user = auth_user @@ -170,8 +175,7 @@ def request( body: Any = None, stream: bool = False, ) -> Tuple[int, Union[Dict, str, BinaryIO]]: - """ - Make an HTTP request to the server. + """Make an HTTP request to the server. Args: method: HTTP method (GET, POST, etc.) @@ -182,6 +186,7 @@ def request( Returns: Tuple of (status_code, response_data) + """ if headers is None: headers = {} @@ -249,18 +254,21 @@ def request( class VmmCLI: + """Command-line interface for managing dstack-vmm virtual machines.""" + def __init__( self, base_url: str, auth_user: Optional[str] = None, auth_password: Optional[str] = None, ): + """Initialize the CLI with a base URL and optional authentication.""" self.base_url = base_url.rstrip("/") self.headers = {"Content-Type": "application/json"} self.client = VmmClient(base_url, auth_user, auth_password) def rpc_call(self, method: str, params: Optional[Dict] = None) -> Dict: - """Make an RPC call to the dstack-vmm API""" + """Make an RPC call to the dstack-vmm API.""" path = f"/prpc/{method}?json" status, response = self.client.request( "POST", path, headers=self.headers, body=params or {} @@ -276,7 +284,7 @@ def rpc_call(self, method: str, params: Optional[Dict] = None) -> Dict: return response def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: - """List all VMs and their status""" + """List all VMs and their status.""" response = self.rpc_call("Status") vms = response["vms"] @@ -321,7 +329,7 @@ def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: print(format_table(rows, headers)) def _format_gpu_info(self, gpu_config): - """Format GPU configuration for display""" + """Format GPU configuration for display.""" if not gpu_config: return "-" @@ -337,12 +345,12 @@ def _format_gpu_info(self, gpu_config): return "-" def start_vm(self, vm_id: str) -> None: - """Start a VM""" + """Start a VM.""" self.rpc_call("StartVm", {"id": vm_id}) print(f"Started VM {vm_id}") def stop_vm(self, vm_id: str, force: bool = False) -> None: - """Stop a VM""" + """Stop a VM.""" if force: self.rpc_call("StopVm", {"id": vm_id}) print(f"Forcefully stopped VM {vm_id}") @@ -351,7 +359,7 @@ def stop_vm(self, vm_id: str, force: bool = False) -> None: print(f"Gracefully shutting down VM {vm_id}") def remove_vm(self, vm_id: str) -> None: - """Remove a VM""" + """Remove a VM.""" self.rpc_call("RemoveVm", {"id": vm_id}) print(f"Removed VM {vm_id}") @@ -363,7 +371,7 @@ def resize_vm( disk_size: Optional[int] = None, image: Optional[str] = None, ) -> None: - """Resize a VM""" + """Resize a VM.""" params = {"id": vm_id} if vcpu is not None: params["vcpu"] = vcpu @@ -384,7 +392,7 @@ def resize_vm( print(f"Resized VM {vm_id}") def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: - """Show VM logs""" + """Show VM logs.""" path = f"/logs?id={vm_id}&follow={str(follow).lower()}&ansi=false&lines={lines}" status, response = self.client.request( @@ -418,7 +426,7 @@ def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: print(response) def list_images(self, json_output: bool = False) -> None: - """Get list of available images""" + """Get list of available images.""" response = self.rpc_call("ListImages") images = response["images"] @@ -438,7 +446,7 @@ def list_images(self, json_output: bool = False) -> None: def get_app_env_encrypt_pub_key( self, app_id: str, kms_url: Optional[str] = None ) -> Dict: - """Get the encryption public key for the specified application ID""" + """Get the encryption public key for the specified application ID.""" if kms_url: client = VmmClient(kms_url) path = "/prpc/GetAppEnvEncryptPubKey?json" @@ -480,12 +488,12 @@ def get_app_env_encrypt_pub_key( return response["public_key"] def confirm_untrusted_signer(self, signer: str) -> bool: - """Ask user to confirm using an untrusted signer""" + """Ask user to confirm using an untrusted signer.""" response = input(f"Continue with untrusted signer {signer}? (y/N): ") return response.lower() in ("y", "yes") def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: - """Manage the whitelist of trusted signers""" + """Manage the whitelist of trusted signers.""" whitelist = load_whitelist() if action == "list": @@ -529,12 +537,12 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: raise Exception(f"Unknown action: {action}") def calc_app_id(self, compose_file: str) -> str: - """Calculate the application ID from the compose file""" + """Calculate the application ID from the compose file.""" compose_hash = hashlib.sha256(compose_file.encode()).hexdigest() return compose_hash[:40] def create_app_compose(self, args) -> None: - """Create a new app compose file""" + """Create a new app compose file.""" envs = parse_env_file(args.env_file) or {} app_compose = { "manifest_version": 2, @@ -574,7 +582,7 @@ def create_app_compose(self, args) -> None: print(f"Compose hash: {compose_hash}") def create_vm(self, args) -> None: - """Create a new VM""" + """Create a new VM.""" # Read and validate compose file if not os.path.exists(args.compose): raise Exception(f"Compose file not found: {args.compose}") @@ -637,7 +645,7 @@ def create_vm(self, args) -> None: def update_vm_env( self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[List[str]] = None ) -> None: - """Update environment variables for a VM""" + """Update environment variables for a VM.""" envs = envs or {} # First get the VM info to retrieve the app_id vm_info_response = self.rpc_call("GetInfo", {"id": vm_id}) @@ -688,17 +696,17 @@ def update_vm_env( print(f"Environment variables updated for VM {vm_id}") def update_vm_user_config(self, vm_id: str, user_config: str) -> None: - """Update user config for a VM""" + """Update user config for a VM.""" self.rpc_call("UpgradeApp", {"id": vm_id, "user_config": user_config}) print(f"User config updated for VM {vm_id}") def update_vm_app_compose(self, vm_id: str, app_compose: str) -> None: - """Update app compose for a VM""" + """Update app compose for a VM.""" self.rpc_call("UpgradeApp", {"id": vm_id, "compose_file": app_compose}) print(f"App compose updated for VM {vm_id}") def update_vm_ports(self, vm_id: str, ports: List[str]) -> None: - """Update port mapping for a VM""" + """Update port mapping for a VM.""" port_mappings = [parse_port_mapping(port) for port in ports] self.rpc_call( "UpgradeApp", {"id": vm_id, "update_ports": True, "ports": port_mappings} @@ -725,7 +733,7 @@ def update_vm( kms_urls: Optional[List[str]] = None, no_tee: Optional[bool] = None, ) -> None: - """Update multiple aspects of a VM in one command""" + """Update multiple aspects of a VM in one command.""" updates = [] # handle resize operations (vcpu, memory, disk, image) @@ -893,7 +901,7 @@ def update_vm( print(f"No updates specified for VM {vm_id}") def list_gpus(self, json_output: bool = False) -> None: - """List all available GPUs""" + """List all available GPUs.""" response = self.rpc_call("ListGpus") gpus = response.get("gpus", []) @@ -921,7 +929,7 @@ def list_gpus(self, json_output: bool = False) -> None: def format_table(rows, headers): - """Simple table formatter""" + """Format rows and headers into an aligned table.""" if not rows: return "" @@ -947,10 +955,9 @@ def format_table(rows, headers): def parse_env_file(file_path: str) -> Dict[str, str]: - """ - Parse an environment file where each line is formatted as: - KEY=Value + """Parse an environment file into a dictionary of key-value pairs. + Each line should be formatted as KEY=Value. Lines that are empty or start with '#' are ignored. """ if not file_path: @@ -970,9 +977,7 @@ def parse_env_file(file_path: str) -> Dict[str, str]: def parse_size(s: str, target_unit: str) -> int: - """ - Parse a human-readable size string (e.g. "1G", "100M") and return the size - in the specified target unit. + """Parse a human-readable size string and return the size in the target unit. Args: s: The size string provided. @@ -984,6 +989,7 @@ def parse_size(s: str, target_unit: str) -> int: Raises: argparse.ArgumentTypeError: if the format is invalid or if conversion turns out to be fractional. + """ s = s.strip() m = re.fullmatch(r"(\d+(?:\.\d+)?)([a-zA-Z]{1,2})?", s) @@ -1045,8 +1051,7 @@ def parse_disk_size(s: str) -> int: def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Optional[str]: - """ - Verify the signature of a public key. + """Verify the signature of a public key. Args: public_key: The public key bytes to verify @@ -1070,6 +1075,7 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option >>> compressed_pubkey = verify_signature(public_key, signature, app_id) >>> print(compressed_pubkey) 0x0217610d74cbd39b6143842c6d8bc310d79da1d82cc9d17f8876376221eda0c38f + """ if not CRYPTO_AVAILABLE: raise ImportError( @@ -1102,11 +1108,11 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option def load_whitelist() -> List[str]: - """ - Load the whitelist of trusted signers from a file. + """Load the whitelist of trusted signers from a file. Returns: List of trusted Ethereum addresses + """ if not os.path.exists(DEFAULT_KMS_WHITELIST_PATH): os.makedirs(os.path.dirname(DEFAULT_KMS_WHITELIST_PATH), exist_ok=True) @@ -1121,11 +1127,11 @@ def load_whitelist() -> List[str]: def save_whitelist(whitelist: List[str]) -> None: - """ - Save the whitelist of trusted signers to a file. + """Save the whitelist of trusted signers to a file. Args: whitelist: List of trusted Ethereum addresses + """ os.makedirs(os.path.dirname(DEFAULT_KMS_WHITELIST_PATH), exist_ok=True) with open(DEFAULT_KMS_WHITELIST_PATH, "w") as f: @@ -1133,6 +1139,7 @@ def save_whitelist(whitelist: List[str]) -> None: def main(): + """Run the dstack-vmm CLI.""" parser = argparse.ArgumentParser(description="dstack-vmm CLI - Manage VMs") # Get default URL from environment variable or use localhost From a5cc52ee405d7975854dd1321a715a886bbe5aac Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 20 Mar 2026 05:22:33 +0000 Subject: [PATCH 10/10] fix: redo vmm-cli.py lint fixes on latest master version Previous rebase lost newly added functions (discover_vmm_instances, load_config, etc). Restore master version and reapply: import sorting, docstrings, line length fixes, and unused variable removal. --- vmm/src/vmm-cli.py | 503 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 450 insertions(+), 53 deletions(-) diff --git a/vmm/src/vmm-cli.py b/vmm/src/vmm-cli.py index 49380b530..6b2254e03 100755 --- a/vmm/src/vmm-cli.py +++ b/vmm/src/vmm-cli.py @@ -30,29 +30,225 @@ except ImportError: CRYPTO_AVAILABLE = False -# Default whitelist file location +# Default config file locations +DEFAULT_CONFIG_PATH = os.path.expanduser("~/.dstack-vmm/config.json") DEFAULT_KMS_WHITELIST_PATH = os.path.expanduser("~/.dstack-vmm/kms-whitelist.json") +# VMM discovery directory +DISCOVERY_DIR = "/run/dstack-vmm" + + +def load_config() -> Dict[str, Any]: + """Load configuration from the default config file. + + Returns: + Dictionary with configuration values (url, auth_user, auth_password) + + """ + if not os.path.exists(DEFAULT_CONFIG_PATH): + return {} + + try: + with open(DEFAULT_CONFIG_PATH, "r") as f: + return json.load(f) + except (json.JSONDecodeError, FileNotFoundError): + return {} + + +def discover_vmm_instances() -> List[Dict[str, Any]]: + """Discover all running VMM instances from the discovery directory. + + Returns: + List of VMM instance info dicts, sorted by started_at. + + """ + instances = [] + if not os.path.isdir(DISCOVERY_DIR): + return instances + + for fname in os.listdir(DISCOVERY_DIR): + if not fname.endswith(".json"): + continue + fpath = os.path.join(DISCOVERY_DIR, fname) + try: + with open(fpath, "r") as f: + info = json.load(f) + # Check if process is still alive + pid = info.get("pid") + if pid and not os.path.exists(f"/proc/{pid}"): + # Stale file, skip + continue + instances.append(info) + except (json.JSONDecodeError, FileNotFoundError, PermissionError): + continue + + instances.sort(key=lambda x: x.get("started_at", 0)) + return instances + + +def resolve_vmm_url( + instances: List[Dict[str, Any]], + config: Dict[str, Any], + explicit_url: Optional[str] = None, +) -> str: + """Resolve the VMM URL to connect to. + + Priority: + 1. Explicit --url flag + 2. DSTACK_VMM_URL env var + 3. Config file url + 4. If exactly one VMM instance is discovered, use that + 5. If active instance is set in config, use that + 6. Fall back to default + + Returns the URL string. + """ + # If user explicitly provided --url or env var, honor it + env_url = os.environ.get("DSTACK_VMM_URL") + if explicit_url and explicit_url != "http://localhost:8080": + return explicit_url + if env_url: + return env_url + config_url = config.get("url") + if config_url: + return config_url + + # Try auto-discovery + active_id = config.get("active_vmm") + if active_id: + for inst in instances: + if inst["id"] == active_id or inst["id"].startswith(active_id): + return vmm_address_to_url(inst) + + if len(instances) == 1: + return vmm_address_to_url(instances[0]) + + return "http://localhost:8080" + + +def vmm_address_to_url(instance: Dict[str, Any]) -> str: + """Convert a VMM instance info dict to a connection URL.""" + addr = instance.get("address", "") + if addr.startswith("unix:"): + # Resolve relative socket path against working directory + socket_path = addr[5:] + if not os.path.isabs(socket_path): + working_dir = instance.get("working_dir", "") + socket_path = os.path.join(working_dir, socket_path) + return f"unix:{socket_path}" + elif addr.startswith("http://") or addr.startswith("https://"): + return addr + else: + # host:port format from discovery + host, _, port = addr.rpartition(":") + if host == "0.0.0.0": + host = "127.0.0.1" + return f"http://{host}:{port}" + + +def save_active_vmm(vmm_id: str): + """Save the active VMM instance ID to the config file.""" + config = load_config() + config["active_vmm"] = vmm_id + os.makedirs(os.path.dirname(DEFAULT_CONFIG_PATH), exist_ok=True) + with open(DEFAULT_CONFIG_PATH, "w") as f: + json.dump(config, f, indent=2) + + +def cmd_ls_vmm(args): + """List all discovered VMM instances.""" + instances = discover_vmm_instances() + config = load_config() + active_id = config.get("active_vmm") + + if not instances: + print("No running VMM instances found.") + print(f" (discovery directory: {DISCOVERY_DIR})") + return + + if getattr(args, "json", False): + print(json.dumps(instances, indent=2)) + return + + # Table output + + fmt = " {active} {id:<12s} {pid:<8s} {node:<12s} {address:<24s} {workdir}" + print( + fmt.format( + active="", + id="ID", + pid="PID", + node="NAME", + address="ADDRESS", + workdir="WORKING DIR", + ) + ) + print(" " + "-" * 90) + + for inst in instances: + short_id = inst["id"][:8] + is_active = "*" if active_id and inst["id"].startswith(active_id) else " " + node_name = inst.get("node_name", "") or "-" + address = inst.get("address", "?") + + print( + fmt.format( + active=is_active, + id=short_id, + pid=str(inst.get("pid", "?")), + node=node_name[:12], + address=address[:24], + workdir=inst.get("working_dir", "?"), + ) + ) + + +def cmd_switch_vmm(args): + """Switch the active VMM instance.""" + target = args.vmm_id + instances = discover_vmm_instances() + + if not instances: + print("No running VMM instances found.") + return + + # Find matching instance by prefix + matches = [i for i in instances if i["id"].startswith(target)] + if len(matches) == 0: + print(f"No VMM instance matching '{target}'.") + print("Available instances:") + for inst in instances: + print( + f" {inst['id'][:8]} {inst.get('address', '?')} {inst.get('working_dir', '?')}" + ) + return + if len(matches) > 1: + print(f"Ambiguous ID '{target}', matches multiple instances:") + for inst in matches: + print(f" {inst['id'][:8]} {inst.get('address', '?')}") + return + + selected = matches[0] + save_active_vmm(selected["id"]) + url = vmm_address_to_url(selected) + print(f"Switched to VMM {selected['id'][:8]} ({url})") + def encrypt_env(envs, hex_public_key: str) -> str: """Encrypts environment variables using a one-time X25519 key exchange and AES-GCM. This function does the following: 1. Converts the given environment variables to JSON bytes. - 2. Removes a leading "0x" from the provided public key - (if present) and converts it to bytes. + 2. Removes a leading "0x" from the provided public key (if present) and converts it to bytes. 3. Generates an ephemeral X25519 key pair. - 4. Computes a shared secret using this ephemeral private - key and the remote public key. + 4. Computes a shared secret using this ephemeral private key and the remote public key. 5. Uses the shared key directly as the 32-byte key for AES-GCM. - 6. Encrypts the JSON string with AES-GCM using a random IV. - 7. Concatenates the ephemeral public key, IV, and ciphertext - and returns it as a hex string. + 6. Encrypts the JSON string with AES-GCM using a randomly generated IV. + 7. Concatenates the ephemeral public key, IV, and ciphertext and returns it as a hex string. Args: - envs: The environment variables to encrypt. This can be - any JSON-serializable data structure. - hex_public_key: The remote encryption public key in hex. + envs: The environment variables to encrypt. This can be any JSON-serializable data structure. + hex_public_key: The remote encryption public key in hexadecimal format. Returns: A hexadecimal string that is the concatenation of: @@ -131,7 +327,7 @@ class UnixSocketHTTPConnection(http.client.HTTPConnection): """HTTPConnection that connects to a Unix domain socket.""" def __init__(self, socket_path, timeout=None): - """Initialize with a Unix socket path and optional timeout.""" + """Initialize with the given Unix socket path.""" super().__init__("localhost", timeout=timeout) self.socket_path = socket_path @@ -153,7 +349,7 @@ def __init__( auth_user: Optional[str] = None, auth_password: Optional[str] = None, ): - """Initialize the client with a base URL and optional authentication.""" + """Initialize the client with a base URL and optional auth credentials.""" self.base_url = base_url.rstrip("/") self.use_uds = self.base_url.startswith("unix:") self.auth_user = auth_user @@ -254,7 +450,7 @@ def request( class VmmCLI: - """Command-line interface for managing dstack-vmm virtual machines.""" + """Command-line interface for the dstack-vmm API.""" def __init__( self, @@ -262,7 +458,7 @@ def __init__( auth_user: Optional[str] = None, auth_password: Optional[str] = None, ): - """Initialize the CLI with a base URL and optional authentication.""" + """Initialize the CLI with a base URL and optional auth credentials.""" self.base_url = base_url.rstrip("/") self.headers = {"Content-Type": "application/json"} self.client = VmmClient(base_url, auth_user, auth_password) @@ -299,7 +495,7 @@ def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: headers = ["VM ID", "App ID", "Name", "Status", "Uptime"] if verbose: - headers.extend(["vCPU", "Memory", "Disk", "Image", "GPUs"]) + headers.extend(["Instance ID", "vCPU", "Memory", "Disk", "Image", "GPUs"]) rows = [] for vm in vms: @@ -316,6 +512,7 @@ def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: gpu_info = self._format_gpu_info(config.get("gpus")) row.extend( [ + vm.get("instance_id", "-") or "-", config.get("vcpu", "-"), f"{config.get('memory', '-')}MB", f"{config.get('disk_size', '-')}GB", @@ -384,8 +581,7 @@ def resize_vm( if len(params) == 1: raise Exception( - "at least one parameter must be specified for resize:" - " --vcpu, --memory, --disk, or --image" + "at least one parameter must be specified for resize: --vcpu, --memory, --disk, or --image" ) self.rpc_call("ResizeVm", params) @@ -426,7 +622,7 @@ def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: print(response) def list_images(self, json_output: bool = False) -> None: - """Get list of available images.""" + """List available images.""" response = self.rpc_call("ListImages") images = response["images"] @@ -461,15 +657,36 @@ def get_app_env_encrypt_pub_key( response = self.rpc_call("GetAppEnvEncryptPubKey", {"app_id": app_id}) # Verify the signature if available - if "signature" not in response: + if "signature" not in response and "signature_v1" not in response: if not self.confirm_untrusted_signer("none"): raise Exception("Aborted due to invalid signature") return response["public_key"] public_key = bytes.fromhex(response["public_key"]) - signature = bytes.fromhex(response["signature"]) - signer_pubkey = verify_signature(public_key, signature, app_id) + # Prefer signature_v1 (with timestamp) if available + signer_pubkey = None + if "signature_v1" in response and "timestamp" in response: + signature_v1 = bytes.fromhex(response["signature_v1"]) + timestamp = response["timestamp"] + signer_pubkey = verify_signature_v1( + public_key, signature_v1, app_id, timestamp + ) + if signer_pubkey: + print(f"Verified signature_v1 (with timestamp) from: {signer_pubkey}") + + # Fall back to legacy signature if signature_v1 verification failed or not available + if not signer_pubkey and "signature" in response: + print( + "WARNING: Using legacy signature without timestamp protection. " + "Consider upgrading your KMS to support signature_v1.", + file=sys.stderr, + ) + signature = bytes.fromhex(response["signature"]) + signer_pubkey = verify_signature(public_key, signature, app_id) + if signer_pubkey: + print(f"Verified legacy signature from: {signer_pubkey}") + if signer_pubkey: whitelist = load_whitelist() if whitelist and signer_pubkey not in whitelist: @@ -478,8 +695,6 @@ def get_app_env_encrypt_pub_key( ) if not self.confirm_untrusted_signer(signer_pubkey): raise Exception("Aborted due to untrusted signer") - else: - print(f"Verified signature from: {signer_pubkey}") else: print("WARNING: Could not verify signature!") if not self.confirm_untrusted_signer("unknown"): @@ -544,6 +759,13 @@ def calc_app_id(self, compose_file: str) -> str: def create_app_compose(self, args) -> None: """Create a new app compose file.""" envs = parse_env_file(args.env_file) or {} + + # Validate: --env-file requires --kms + if envs and not args.kms: + raise Exception( + "--env-file requires --kms to enable KMS for environment variable decryption" + ) + app_compose = { "manifest_version": 2, "name": args.name, @@ -561,6 +783,8 @@ def create_app_compose(self, args) -> None: "no_instance_id": args.no_instance_id, "secure_time": args.secure_time, } + if args.key_provider: + app_compose["key_provider"] = args.key_provider if args.prelaunch_script: app_compose["pre_launch_script"] = ( open(args.prelaunch_script, "rb").read().decode("utf-8") @@ -591,6 +815,21 @@ def create_vm(self, args) -> None: envs = parse_env_file(args.env_file) + # Validate: --env-file requires --kms-url and kms_enabled in compose + if envs: + if not args.kms_url: + raise Exception( + "--env-file requires --kms-url to encrypt environment variables" + ) + try: + compose_json = json.loads(compose_content) + if not compose_json.get("kms_enabled", False): + raise Exception( + "--env-file requires kms_enabled=true in the compose file (use --kms when creating compose)" + ) + except json.JSONDecodeError: + pass # Let the server handle invalid JSON + # Read user config file if provided user_config = "" if args.user_config: @@ -617,7 +856,7 @@ def create_vm(self, args) -> None: if swap_bytes > 0: params["swap_size"] = swap_bytes - if args.ppcie: + if args.ppcie or (args.gpu and "all" in args.gpu): params["gpus"] = {"attach_mode": "all"} elif args.gpu: params["gpus"] = { @@ -628,6 +867,8 @@ def create_vm(self, args) -> None: params["kms_urls"] = args.kms_url if args.gateway_url: params["gateway_urls"] = args.gateway_url + if args.net: + params["networking"] = {"mode": args.net} app_id = args.app_id or self.calc_app_id(compose_content) print(f"App ID: {app_id}") @@ -646,6 +887,10 @@ def update_vm_env( self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[List[str]] = None ) -> None: """Update environment variables for a VM.""" + # Validate: requires --kms-url + if not kms_urls: + raise Exception("--kms-url is required to encrypt environment variables") + envs = envs or {} # First get the VM info to retrieve the app_id vm_info_response = self.rpc_call("GetInfo", {"id": vm_id}) @@ -734,6 +979,12 @@ def update_vm( no_tee: Optional[bool] = None, ) -> None: """Update multiple aspects of a VM in one command.""" + # Validate: --env-file requires --kms-url + if env_file and not kms_urls: + raise Exception( + "--env-file requires --kms-url to encrypt environment variables" + ) + updates = [] # handle resize operations (vcpu, memory, disk, image) @@ -862,15 +1113,16 @@ def update_vm( port_mappings = [parse_port_mapping(port) for port in ports] updates.append("port mappings") else: - # ports is empty - shouldn't happen with exclusive group + # ports is an empty list - shouldn't happen with mutually exclusive group port_mappings = [] updates.append("port mappings (none)") upgrade_params["update_ports"] = True upgrade_params["ports"] = port_mappings # handle GPU updates - only update if one of the GPU flags is set - if attach_all or no_gpus or gpu_slots is not None: - if attach_all: + gpu_all = gpu_slots and "all" in gpu_slots + if attach_all or gpu_all or no_gpus or gpu_slots is not None: + if attach_all or gpu_all: gpu_config = {"attach_mode": "all"} updates.append("GPUs (all)") elif no_gpus: @@ -883,7 +1135,7 @@ def update_vm( } updates.append(f"GPUs ({len(gpu_slots)} devices)") else: - # gpu_slots is empty - shouldn't happen with exclusive group + # gpu_slots is an empty list ([] not None) - shouldn't happen with mutually exclusive group gpu_config = {"attach_mode": "listed", "gpus": []} updates.append("GPUs (none)") upgrade_params["gpus"] = gpu_config @@ -900,6 +1152,52 @@ def update_vm( else: print(f"No updates specified for VM {vm_id}") + def show_info(self, vm_id: str, json_output: bool = False) -> None: + """Show detailed information about a VM.""" + response = self.rpc_call("GetInfo", {"id": vm_id}) + + if not response.get("found", False) or "info" not in response: + print(f"VM with ID {vm_id} not found") + return + + info = response["info"] + + if json_output: + print(json.dumps(info, indent=2)) + return + + config = info.get("configuration", {}) + + print(f"VM ID: {info.get('id', '-')}") + print(f"Name: {info.get('name', '-')}") + print(f"Status: {info.get('status', '-')}") + print(f"Uptime: {info.get('uptime', '-')}") + print(f"App ID: {info.get('app_id', '-')}") + print(f"Instance ID: {info.get('instance_id', '-') or '-'}") + print(f"App URL: {info.get('app_url', '-') or '-'}") + print(f"Image: {config.get('image', '-')}") + print(f"Image Version: {info.get('image_version', '-')}") + print(f"vCPU: {config.get('vcpu', '-')}") + print(f"Memory: {config.get('memory', '-')}MB") + print(f"Disk: {config.get('disk_size', '-')}GB") + print(f"GPUs: {self._format_gpu_info(config.get('gpus'))}") + print(f"Boot Progress: {info.get('boot_progress', '-')}") + if info.get("boot_error"): + print(f"Boot Error: {info['boot_error']}") + if info.get("exited_at"): + print(f"Exited At: {info['exited_at']}") + if info.get("shutdown_progress"): + print(f"Shutdown: {info['shutdown_progress']}") + + events = info.get("events", []) + if events: + print("\nRecent Events:") + for event in events[-10:]: + ts = event.get("timestamp", 0) + print( + f" [{event.get('event', '')}] {event.get('body', '')} (ts: {ts})" + ) + def list_gpus(self, json_output: bool = False) -> None: """List all available GPUs.""" response = self.rpc_call("ListGpus") @@ -929,7 +1227,7 @@ def list_gpus(self, json_output: bool = False) -> None: def format_table(rows, headers): - """Format rows and headers into an aligned table.""" + """Format rows and headers into a table string.""" if not rows: return "" @@ -959,6 +1257,7 @@ def parse_env_file(file_path: str) -> Dict[str, str]: Each line should be formatted as KEY=Value. Lines that are empty or start with '#' are ignored. + """ if not file_path: return {} @@ -979,6 +1278,8 @@ def parse_env_file(file_path: str) -> Dict[str, str]: def parse_size(s: str, target_unit: str) -> int: """Parse a human-readable size string and return the size in the target unit. + Accept strings like "1G" or "100M". + Args: s: The size string provided. target_unit: Either "MB" (for memory) or "GB" (for disk). @@ -1050,8 +1351,53 @@ def parse_disk_size(s: str) -> int: return parse_size(s, "GB") +def verify_signature_v1( + public_key: bytes, signature: bytes, app_id: str, timestamp: int +) -> Optional[str]: + """Verify the v1 signature (with timestamp) of a public key. + + Args: + public_key: The public key bytes to verify + signature: The signature bytes (65 bytes) + app_id: The application ID + timestamp: Unix timestamp in seconds when the response was generated + + Returns: + The compressed public key if valid, None otherwise + + """ + if not CRYPTO_AVAILABLE: + raise ImportError( + "Cryptography libraries not available. Please install them with:\n" + "pip install cryptography eth-keys eth-utils" + ) + + if len(signature) != 65: + return None + + # Create the message to verify + # Signs: Keccak256("dstack-env-encrypt-pubkey" + ":" + app_id + timestamp_be_bytes + public_key) + prefix = b"dstack-env-encrypt-pubkey" + if app_id.startswith("0x"): + app_id = app_id[2:] + timestamp_bytes = timestamp.to_bytes(8, byteorder="big") + message = prefix + b":" + bytes.fromhex(app_id) + timestamp_bytes + public_key + + # Hash the message with Keccak-256 + message_hash = keccak(message) + + # Recover the public key from the signature + try: + sig = keys.Signature(signature_bytes=signature) + recovered_key = sig.recover_public_key_from_msg_hash(message_hash) + return "0x" + recovered_key.to_compressed_bytes().hex() + except Exception as e: + print(f"Signature v1 verification failed: {e}", file=sys.stderr) + return None + + def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Optional[str]: - """Verify the signature of a public key. + """Verify the legacy signature (without timestamp) of a public key. Args: public_key: The public key bytes to verify @@ -1062,15 +1408,8 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option The compressed public key if valid, None otherwise Examples: - >>> pk_hex = 'e33a1832c6562067ff8f844a61e51ad051f1180b66ec2551fb0251735f3ee90a' - >>> public_key = bytes.fromhex(pk_hex) - >>> sig_hex = ( - ... '8542c49081fbf4e03f62034f13fbf70630bdf256' - ... 'a53032e38465a27c36fd6bed7a5e7111652004ae' - ... 'f37f7fd92fbfc1285212c4ae6a6154203a48f5e1' - ... '6cad2cef00' - ... ) - >>> signature = bytes.fromhex(sig_hex) + >>> public_key = bytes.fromhex('e33a1832c6562067ff8f844a61e51ad051f1180b66ec2551fb0251735f3ee90a') + >>> signature = bytes.fromhex('8542c49081fbf4e03f62034f13fbf70630bdf256a53032e38465a27c36fd6bed7a5e7111652004aef37f7fd92fbfc1285212c4ae6a6154203a48f5e16cad2cef00') >>> app_id = '00' * 20 >>> compressed_pubkey = verify_signature(public_key, signature, app_id) >>> print(compressed_pubkey) @@ -1087,6 +1426,7 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option return None # Create the message to verify + # Signs: Keccak256("dstack-env-encrypt-pubkey" + ":" + app_id + public_key) prefix = b"dstack-env-encrypt-pubkey" if app_id.startswith("0x"): app_id = app_id[2:] @@ -1139,32 +1479,59 @@ def save_whitelist(whitelist: List[str]) -> None: def main(): - """Run the dstack-vmm CLI.""" + """Parse arguments and dispatch to the appropriate command handler.""" parser = argparse.ArgumentParser(description="dstack-vmm CLI - Manage VMs") - # Get default URL from environment variable or use localhost - default_url = os.environ.get("DSTACK_VMM_URL", "http://localhost:8080") + # Load config file defaults + config = load_config() + + # Discover running VMM instances + instances = discover_vmm_instances() + + # Priority: command line > environment variable > config file > auto-discovery > default + default_url = os.environ.get( + "DSTACK_VMM_URL", config.get("url", "http://localhost:8080") + ) + default_auth_user = os.environ.get("DSTACK_VMM_AUTH_USER", config.get("auth_user")) + default_auth_password = os.environ.get( + "DSTACK_VMM_AUTH_PASSWORD", config.get("auth_password") + ) parser.add_argument( "--url", default=default_url, - help="dstack-vmm API URL (can also be set via DSTACK_VMM_URL env var)", + help="dstack-vmm API URL (can also be set via DSTACK_VMM_URL env var or config file)", ) # Basic authentication arguments parser.add_argument( "--auth-user", - default=os.environ.get("DSTACK_VMM_AUTH_USER"), - help="Basic auth username (can also be set via DSTACK_VMM_AUTH_USER env var)", + default=default_auth_user, + help="Basic auth username (can also be set via DSTACK_VMM_AUTH_USER env var or config file)", ) parser.add_argument( "--auth-password", - default=os.environ.get("DSTACK_VMM_AUTH_PASSWORD"), - help="Basic auth password (env: DSTACK_VMM_AUTH_PASSWORD)", + default=default_auth_password, + help="Basic auth password (can also be set via DSTACK_VMM_AUTH_PASSWORD env var or config file)", ) subparsers = parser.add_subparsers(dest="command", help="Commands") + # VMM discovery commands + ls_vmm_parser = subparsers.add_parser( + "ls-vmm", help="List all running VMM instances on this host" + ) + ls_vmm_parser.add_argument( + "--json", action="store_true", help="Output in JSON format" + ) + + switch_vmm_parser = subparsers.add_parser( + "switch-vmm", help="Switch active VMM instance" + ) + switch_vmm_parser.add_argument( + "vmm_id", help="VMM instance ID (prefix match supported)" + ) + # List command lsvm_parser = subparsers.add_parser("lsvm", help="List VMs") lsvm_parser.add_argument( @@ -1174,6 +1541,13 @@ def main(): "--json", action="store_true", help="Output in JSON format for automation" ) + # Info command + info_parser = subparsers.add_parser("info", help="Show detailed VM information") + info_parser.add_argument("vm_id", help="VM ID to show info for") + info_parser.add_argument( + "--json", action="store_true", help="Output in JSON format for automation" + ) + # Start command start_parser = subparsers.add_parser("start", help="Start a VM") start_parser.add_argument("vm_id", help="VM ID to start") @@ -1229,6 +1603,12 @@ def main(): compose_parser.add_argument( "--local-key-provider", action="store_true", help="Enable local key provider" ) + compose_parser.add_argument( + "--key-provider", + choices=["none", "kms", "local"], + default=None, + help="Override key provider type (none, kms, local)", + ) compose_parser.add_argument( "--key-provider-id", default=None, @@ -1299,7 +1679,7 @@ def main(): "--gpu", action="append", type=str, - help="GPU slot to attach (can be used multiple times)", + help='GPU slot to attach (can be used multiple times), or "all" to attach all GPUs', ) deploy_parser.add_argument( "--ppcie", @@ -1334,6 +1714,11 @@ def main(): help="Force-enable Intel TDX (default)", ) deploy_parser.set_defaults(no_tee=False) + deploy_parser.add_argument( + "--net", + choices=["bridge", "user"], + help="Networking mode (default: use global config)", + ) # Images command lsimage_parser = subparsers.add_parser("lsimage", help="List available images") @@ -1406,7 +1791,7 @@ def main(): action="append", type=str, required=True, - help="Port mapping: protocol[:address]:from:to (repeatable)", + help="Port mapping in format: protocol[:address]:from:to (can be used multiple times)", ) # Update (all-in-one) command @@ -1440,7 +1825,7 @@ def main(): "--port", action="append", type=str, - help="Port mapping: protocol[:address]:from:to (repeatable)", + help="Port mapping in format: protocol[:address]:from:to (can be used multiple times)", ) port_group.add_argument( "--no-ports", @@ -1459,7 +1844,7 @@ def main(): "--gpu", action="append", type=str, - help="GPU slot to attach (can be used multiple times)", + help='GPU slot to attach (can be used multiple times), or "all" to attach all GPUs', ) gpu_group.add_argument( "--ppcie", @@ -1493,10 +1878,22 @@ def main(): args = parser.parse_args() - cli = VmmCLI(args.url, args.auth_user, args.auth_password) + # Handle discovery commands before creating CLI (they don't need a connection) + if args.command == "ls-vmm": + cmd_ls_vmm(args) + return + elif args.command == "switch-vmm": + cmd_switch_vmm(args) + return + + # Resolve the URL with auto-discovery + url = resolve_vmm_url(instances, config, args.url) + cli = VmmCLI(url, args.auth_user, args.auth_password) if args.command == "lsvm": cli.list_vms(args.verbose, args.json) + elif args.command == "info": + cli.show_info(args.vm_id, args.json) elif args.command == "start": cli.start_vm(args.vm_id) elif args.command == "stop":