diff --git a/.github/workflows/prek-check.yml b/.github/workflows/prek-check.yml new file mode 100644 index 000000000..a1a60394c --- /dev/null +++ b/.github/workflows/prek-check.yml @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +name: Prek checks + +on: + push: + branches: [ master, next, dev-* ] + pull_request: + branches: [ master, next, dev-* ] + +permissions: + contents: read + +jobs: + prek: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: 1.92.0 + components: rustfmt + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.22' + + - name: Install prek + run: pip install prek + + - name: Run prek checks (PR) + if: github.event_name == 'pull_request' + run: prek run --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} --show-diff-on-failure + + - name: Run prek checks (push) + if: github.event_name == 'push' + run: prek run --from-ref ${{ github.event.before }} --to-ref ${{ github.event.after }} --show-diff-on-failure diff --git a/prek.toml b/prek.toml new file mode 100644 index 000000000..1e24d457a --- /dev/null +++ b/prek.toml @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: © 2025 Phala Network +# +# SPDX-License-Identifier: Apache-2.0 + +# ============================================================================= +# prek configuration for dstack +# ============================================================================= + +# --- General hooks (trailing whitespace, file checks, etc.) --- +[[repos]] +repo = "https://github.com/pre-commit/pre-commit-hooks" +rev = "v5.0.0" +hooks = [ + { id = "trailing-whitespace", args = ["--markdown-linebreak-ext=md"] }, + { id = "end-of-file-fixer" }, + { id = "check-yaml", args = ["--allow-multiple-documents"], exclude = "gateway/templates/" }, + { id = "check-toml" }, + { id = "check-json" }, + { id = "check-merge-conflict" }, + { id = "check-added-large-files", args = ["--maxkb=500"] }, + { id = "check-symlinks" }, + { id = "mixed-line-ending", args = ["--fix=lf"] }, +] + +# --- Rust: rustfmt --- +[[repos]] +repo = "local" + +[[repos.hooks]] +id = "cargo-fmt" +name = "cargo fmt" +entry = "cargo fmt --all" +language = "system" +types = ["rust"] +pass_filenames = false + +# --- Python: ruff (lint + format) --- +[[repos]] +repo = "https://github.com/astral-sh/ruff-pre-commit" +rev = "v0.11.4" +hooks = [ + { id = "ruff", args = ["--fix", "--select", "E,F,I,D", "--ignore", "D203,D213,E501"] }, + { id = "ruff-format" }, +] + +# --- Go: go vet --- +[[repos]] +repo = "local" + +[[repos.hooks]] +id = "go-vet" +name = "go vet" +entry = "bash -c 'cd sdk/go && go vet ./...'" +language = "system" +files = "sdk/go/.*\\.go$" +pass_filenames = false + +# --- Shell: shellcheck --- +[[repos]] +repo = "https://github.com/shellcheck-py/shellcheck-py" +rev = "v0.10.0.1" +hooks = [ + { id = "shellcheck" }, +] + +# --- Conventional commits (used by cliff.toml for changelog) --- +[[repos]] +repo = "https://github.com/compilerla/conventional-pre-commit" +rev = "v4.1.0" +hooks = [ + { id = "conventional-pre-commit", stages = ["commit-msg"] }, +] diff --git a/python/ct_monitor/ct_monitor.py b/python/ct_monitor/ct_monitor.py index 756636354..7cc37913c 100644 --- a/python/ct_monitor/ct_monitor.py +++ b/python/ct_monitor/ct_monitor.py @@ -1,11 +1,13 @@ # SPDX-FileCopyrightText: © 2024 Phala Network # # SPDX-License-Identifier: Apache-2.0 +"""Monitor certificate transparency logs for a given domain.""" +import argparse import sys import time + import requests -import argparse from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization @@ -14,38 +16,50 @@ class PoisonedLog(Exception): + """Indicate a poisoned certificate transparency log entry.""" + pass class Monitor: + """Monitor certificate transparency logs for a domain.""" + def __init__(self, domain: str): + """Initialize the monitor with a validated domain.""" if not self.validate_domain(domain): raise ValueError("Invalid domain name") self.domain = domain self.last_checked = None def get_logs(self, count: int = 100): + """Fetch recent certificate transparency log entries.""" url = f"{BASE_URL}/?q={self.domain}&output=json&limit={count}" response = requests.get(url) return response.json() - + def check_one_log(self, log: object): + """Fetch and inspect a single certificate log entry.""" log_id = log["id"] cert_url = f"{BASE_URL}/?d={log_id}" cert_data = requests.get(cert_url).text # Extract PEM-encoded certificate import re - pem_match = re.search(r'-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----', cert_data, re.DOTALL) + + pem_match = re.search( + r"-----BEGIN CERTIFICATE-----.*?-----END CERTIFICATE-----", + cert_data, + re.DOTALL, + ) if pem_match: pem_cert = pem_match.group(0) - + # Parse PEM certificate cert = x509.load_pem_x509_certificate(pem_cert.encode(), default_backend()) # Extract the public key public_key = cert.public_key() pem_public_key = public_key.public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) print("Public Key:") print(pem_public_key.hex()) @@ -62,18 +76,20 @@ def check_one_log(self, log: object): print("No valid certificate found in the response.") def check_new_logs(self): + """Check for new log entries since the last check.""" logs = self.get_logs(count=10000) print("num logs", len(logs)) for log in logs: - print(f"log id={log["id"]}") + print(f"log id={log['id']}") if log["id"] <= (self.last_checked or 0): break self.check_one_log(log) - print('next') + print("next") if len(logs) > 0: self.last_checked = logs[0]["id"] def run(self): + """Run the monitor loop indefinitely.""" print(f"Monitoring {self.domain}...") while True: try: @@ -87,12 +103,12 @@ def run(self): @staticmethod def validate_domain(domain: str): - # ensure domain is a valid DNS domain + """Validate that the given string is a well-formed DNS domain name.""" import re # Regular expression for validating domain names domain_regex = re.compile( - r'^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$' + r"^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$" ) if not domain_regex.match(domain): @@ -102,7 +118,10 @@ def validate_domain(domain: str): def main(): - parser = argparse.ArgumentParser(description="Monitor certificate transparency logs") + """Parse arguments and start the certificate transparency monitor.""" + parser = argparse.ArgumentParser( + description="Monitor certificate transparency logs" + ) parser.add_argument("-d", "--domain", help="The domain to monitor") args = parser.parse_args() monitor = Monitor(args.domain) diff --git a/scripts/add-spdx-attribution.py b/scripts/add-spdx-attribution.py index 67de938ba..d98e78ae8 100755 --- a/scripts/add-spdx-attribution.py +++ b/scripts/add-spdx-attribution.py @@ -2,11 +2,10 @@ # SPDX-FileCopyrightText: © 2025 Phala Network # # SPDX-License-Identifier: Apache-2.0 -""" -SPDX Header Attribution Script +"""SPDX header attribution script. -This script automatically analyzes git blame data to determine contributors -and adds appropriate SPDX-FileCopyrightText headers using the REUSE tool. +Analyze git blame data to determine contributors and add appropriate +SPDX-FileCopyrightText headers using the REUSE tool. Features: - Excludes third-party code based on .spdx-exclude patterns @@ -18,260 +17,297 @@ import argparse import fnmatch -import os import re import subprocess import sys from collections import defaultdict -from datetime import datetime from pathlib import Path -from typing import Dict, List, Set, Tuple, Optional +from typing import Dict, List, Optional, Set, Tuple class SPDXAttributor: + """Add SPDX attribution headers to source files based on git blame.""" + def __init__(self, repo_root: str, dry_run: bool = False): + """Initialize the attributor with a repository root and options.""" self.repo_root = Path(repo_root).resolve() self.dry_run = dry_run self.exclude_patterns = self._load_exclude_patterns() self.mailmap = self._parse_mailmap() self.company_domains = self._build_company_domain_map() - + def _load_exclude_patterns(self) -> List[str]: """Load exclusion patterns from .spdx-exclude file.""" - exclude_file = self.repo_root / '.spdx-exclude' + exclude_file = self.repo_root / ".spdx-exclude" patterns = [] - + if exclude_file.exists(): - with open(exclude_file, 'r') as f: + with open(exclude_file, "r") as f: for line in f: line = line.strip() - if line and not line.startswith('#'): + if line and not line.startswith("#"): patterns.append(line) - + return patterns - + def _parse_mailmap(self) -> Dict[str, Tuple[str, str]]: """Parse .mailmap file to get canonical name/email mappings.""" - mailmap_file = self.repo_root / '.mailmap' + mailmap_file = self.repo_root / ".mailmap" mailmap = {} - + if not mailmap_file.exists(): print("Warning: .mailmap file not found") return mailmap - - with open(mailmap_file, 'r') as f: + + with open(mailmap_file, "r") as f: for line in f: line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue - + # Parse mailmap format: "Proper Name " # or "Proper Name Commit Name " - if '>' in line: - parts = line.split('>') + if ">" in line: + parts = line.split(">") if len(parts) >= 2: proper_part = parts[0].strip() commit_part = parts[1].strip() - + # Extract proper name and email - if '<' in proper_part: - proper_name = proper_part.split('<')[0].strip() - proper_email = proper_part.split('<')[1].strip() + if "<" in proper_part: + proper_name = proper_part.split("<")[0].strip() + proper_email = proper_part.split("<")[1].strip() else: continue - + # Extract commit email (and possibly name) - if '<' in commit_part: - commit_email = commit_part.split('<')[1].split('>')[0].strip() + if "<" in commit_part: + commit_email = ( + commit_part.split("<")[1].split(">")[0].strip() + ) else: commit_email = commit_part.strip() - + mailmap[commit_email] = (proper_name, proper_email) - + return mailmap - + def _build_company_domain_map(self) -> Dict[str, str]: """Build mapping from email domains to company names.""" return { - 'phala.network': 'Phala Network', - 'near.ai': 'Near Foundation', - 'nethermind.io': 'Nethermind', - 'rizelabs.io': 'Rize Labs', - 'testinprod.io': 'Test in Prod', + "phala.network": "Phala Network", + "near.ai": "Near Foundation", + "nethermind.io": "Nethermind", + "rizelabs.io": "Rize Labs", + "testinprod.io": "Test in Prod", } - + def _is_excluded(self, file_path: Path) -> bool: """Check if a file should be excluded based on patterns.""" rel_path = file_path.relative_to(self.repo_root) rel_path_str = str(rel_path) - + for pattern in self.exclude_patterns: # Handle directory patterns - if pattern.endswith('/'): - if rel_path_str.startswith(pattern) or f"/{pattern}" in f"/{rel_path_str}/": + if pattern.endswith("/"): + if ( + rel_path_str.startswith(pattern) + or f"/{pattern}" in f"/{rel_path_str}/" + ): return True # Handle glob patterns elif fnmatch.fnmatch(rel_path_str, pattern): return True # Handle wildcard patterns - elif '*' in pattern and fnmatch.fnmatch(rel_path_str, pattern): + elif "*" in pattern and fnmatch.fnmatch(rel_path_str, pattern): return True - + return False - + def _get_canonical_contributor(self, email: str, name: str) -> Tuple[str, str]: """Get canonical name and email for a contributor using mailmap.""" if email in self.mailmap: return self.mailmap[email] return name.strip(), email.strip() - + def _get_company_name(self, email: str) -> Optional[str]: """Get company name for an email domain, or None for individual contributors.""" - domain = email.split('@')[-1].lower() + domain = email.split("@")[-1].lower() return self.company_domains.get(domain) - + def _get_company_contact_email(self, company: str) -> str: """Get the standardized contact email for a company.""" # Company-specific contact emails company_contacts = { - 'Phala Network': 'dstack@phala.network', - 'Near Foundation': 'contact@near.ai', - 'Nethermind': 'contact@nethermind.io', - 'Rize Labs': 'contact@rizelabs.io', - 'Test in Prod': 'contact@testinprod.io', + "Phala Network": "dstack@phala.network", + "Near Foundation": "contact@near.ai", + "Nethermind": "contact@nethermind.io", + "Rize Labs": "contact@rizelabs.io", + "Test in Prod": "contact@testinprod.io", } - - return company_contacts.get(company, f'contact@{company.lower().replace(" ", "")}.com') - + + return company_contacts.get( + company, f"contact@{company.lower().replace(' ', '')}.com" + ) + def _analyze_git_blame(self, file_path: Path) -> Dict[str, Set[int]]: """Analyze git blame to get contributors and their contribution years.""" try: # Run git blame to get author info and dates - result = subprocess.run([ - 'git', 'blame', '--porcelain', str(file_path) - ], capture_output=True, text=True, cwd=self.repo_root) - + result = subprocess.run( + ["git", "blame", "--porcelain", str(file_path)], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if result.returncode != 0: print(f"Warning: Could not run git blame on {file_path}") return {} - + contributors = defaultdict(set) current_commit = None - - for line in result.stdout.split('\n'): + + for line in result.stdout.split("\n"): line = line.strip() if not line: continue - + # Parse commit hash line - if re.match(r'^[0-9a-f]{40}', line): + if re.match(r"^[0-9a-f]{40}", line): current_commit = line.split()[0] # Parse author email - elif line.startswith('author-mail '): - email = line[12:].strip('<>') - + elif line.startswith("author-mail "): + email = line[12:].strip("<>") + # Skip SPDX-only commits to avoid counting license header changes as contributions if self._is_spdx_only_commit(current_commit): continue - + # Get the year for this commit year = self._get_commit_year(current_commit) if year and email: contributors[email].add(year) - + return contributors - + except subprocess.SubprocessError as e: print(f"Error running git blame on {file_path}: {e}") return {} - + def _is_spdx_only_commit(self, commit_hash: str) -> bool: """Check if a commit only contains SPDX header changes.""" try: # Get commit message - result = subprocess.run([ - 'git', 'show', '-s', '--format=%s%n%b', commit_hash - ], capture_output=True, text=True, cwd=self.repo_root) - + result = subprocess.run( + ["git", "show", "-s", "--format=%s%n%b", commit_hash], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if result.returncode != 0: return False - + commit_message = result.stdout.lower() - + # Check for SPDX-related keywords in commit message spdx_keywords = [ - 'spdx', 'license header', 'copyright header', 'add license', - 'update license', 'license annotation', 'reuse annotate', - 'add spdx', 'update spdx', 'copyright attribution' + "spdx", + "license header", + "copyright header", + "add license", + "update license", + "license annotation", + "reuse annotate", + "add spdx", + "update spdx", + "copyright attribution", ] - + if any(keyword in commit_message for keyword in spdx_keywords): # Get the diff to see if it's only header changes - diff_result = subprocess.run([ - 'git', 'show', '--format=', commit_hash - ], capture_output=True, text=True, cwd=self.repo_root) - + diff_result = subprocess.run( + ["git", "show", "--format=", commit_hash], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if diff_result.returncode == 0: diff_content = diff_result.stdout - + # Check if the diff only contains SPDX/copyright/license changes # Look for lines that are only adding/removing headers - diff_lines = diff_content.split('\n') + diff_lines = diff_content.split("\n") substantial_changes = 0 - + for line in diff_lines: - if line.startswith(('+', '-')) and not line.startswith(('+++', '---')): + if line.startswith(("+", "-")) and not line.startswith( + ("+++", "---") + ): # Skip lines that are just SPDX/copyright/license related line_content = line[1:].strip() - if line_content and not any(marker in line_content.lower() for marker in [ - 'spdx-', 'copyright', 'license-identifier', - 'filepyrighttext', '©', '(c)', 'all rights reserved' - ]): + if line_content and not any( + marker in line_content.lower() + for marker in [ + "spdx-", + "copyright", + "license-identifier", + "filepyrighttext", + "©", + "(c)", + "all rights reserved", + ] + ): substantial_changes += 1 - + # If we have very few substantial changes, likely an SPDX-only commit return substantial_changes <= 2 - + return False - + except subprocess.SubprocessError: return False - + def _get_commit_year(self, commit_hash: str) -> Optional[int]: """Get the year of a commit.""" try: - result = subprocess.run([ - 'git', 'show', '-s', '--format=%ad', '--date=format:%Y', commit_hash - ], capture_output=True, text=True, cwd=self.repo_root) - + result = subprocess.run( + ["git", "show", "-s", "--format=%ad", "--date=format:%Y", commit_hash], + capture_output=True, + text=True, + cwd=self.repo_root, + ) + if result.returncode == 0: return int(result.stdout.strip()) except (subprocess.SubprocessError, ValueError): pass - + return None - + def _generate_spdx_headers(self, contributors: Dict[str, Set[int]]) -> List[str]: """Generate SPDX-FileCopyrightText headers for contributors.""" headers = [] company_years = defaultdict(set) # Track years by company individual_contributors = {} # Track individual contributors - + # First pass: group contributors by company vs individual for email, years in contributors.items(): # Get canonical name and email name, canonical_email = self._get_canonical_contributor(email, "Unknown") - + # Check if this is a company contributor company = self._get_company_name(canonical_email) - + if company: # Accumulate years for this company company_years[company].update(years) else: # Individual contributor individual_contributors[canonical_email] = (name, years) - + # Generate company headers (one per company) for company, years in company_years.items(): year_list = sorted(years) @@ -279,11 +315,11 @@ def _generate_spdx_headers(self, contributors: Dict[str, Set[int]]) -> List[str] year_str = str(year_list[0]) else: year_str = f"{year_list[0]}-{year_list[-1]}" - + contact_email = self._get_company_contact_email(company) header = f"SPDX-FileCopyrightText: © {year_str} {company} <{contact_email}>" headers.append(header) - + # Generate individual headers for canonical_email, (name, years) in individual_contributors.items(): year_list = sorted(years) @@ -291,189 +327,217 @@ def _generate_spdx_headers(self, contributors: Dict[str, Set[int]]) -> List[str] year_str = str(year_list[0]) else: year_str = f"{year_list[0]}-{year_list[-1]}" - + header = f"SPDX-FileCopyrightText: © {year_str} {name} <{canonical_email}>" headers.append(header) - + return sorted(headers) - + def _remove_existing_spdx_headers(self, file_path: Path) -> bool: """Remove existing SPDX headers from a file.""" if not file_path.exists(): return False - + try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: lines = f.readlines() - + # Find and remove existing SPDX lines new_lines = [] for line in lines: - if not (line.strip().startswith('// SPDX-') or - line.strip().startswith('# SPDX-') or - line.strip().startswith('/* SPDX-') or - line.strip().startswith(' * SPDX-')): + if not ( + line.strip().startswith("// SPDX-") + or line.strip().startswith("# SPDX-") + or line.strip().startswith("/* SPDX-") + or line.strip().startswith(" * SPDX-") + ): new_lines.append(line) - + # Write back if changes were made if len(new_lines) != len(lines): if not self.dry_run: - with open(file_path, 'w', encoding='utf-8') as f: + with open(file_path, "w", encoding="utf-8") as f: f.writelines(new_lines) return True - + except (IOError, UnicodeDecodeError) as e: print(f"Warning: Could not process {file_path}: {e}") - + return False - + def _get_existing_license(self, file_path: Path) -> str: """Get the existing SPDX license identifier from a file.""" try: - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: content = f.read() - + # Look for existing SPDX-License-Identifier - lines = content.split('\n')[:20] # Check first 20 lines + lines = content.split("\n")[:20] # Check first 20 lines for line in lines: - if 'SPDX-License-Identifier:' in line: + if "SPDX-License-Identifier:" in line: # Extract the license identifier - parts = line.split('SPDX-License-Identifier:') + parts = line.split("SPDX-License-Identifier:") if len(parts) > 1: - return parts[1].strip().rstrip('*/') - + return parts[1].strip().rstrip("*/") + # Default to Apache-2.0 if no existing license found - return 'Apache-2.0' - + return "Apache-2.0" + except Exception: - return 'Apache-2.0' - + return "Apache-2.0" + def _apply_reuse_annotation(self, file_path: Path, headers: List[str]) -> bool: """Apply SPDX headers using the REUSE tool.""" if not headers: return False - + try: # Prepare REUSE command - pass full SPDX-FileCopyrightText headers directly - cmd = ['reuse', 'annotate'] - + cmd = ["reuse", "annotate"] + # Add all copyright lines with full SPDX-FileCopyrightText format for header in headers: - cmd.extend(['--copyright', header]) # Keep the full SPDX-FileCopyrightText: format - + cmd.extend( + ["--copyright", header] + ) # Keep the full SPDX-FileCopyrightText: format + # Preserve existing license or use Apache-2.0 as default existing_license = self._get_existing_license(file_path) - cmd.extend(['--license', existing_license]) - + cmd.extend(["--license", existing_license]) + # Handle Solidity files which need explicit style specification - if file_path.suffix == '.sol': - cmd.extend(['--style', 'c']) - + if file_path.suffix == ".sol": + cmd.extend(["--style", "c"]) + # Add the file cmd.append(str(file_path)) - + if self.dry_run: print(f"Would run: {' '.join(cmd)}") return True else: - result = subprocess.run(cmd, capture_output=True, text=True, cwd=self.repo_root) + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=self.repo_root + ) if result.returncode != 0: print(f"REUSE command failed for {file_path}: {result.stderr}") return False return True - + except subprocess.SubprocessError as e: print(f"Error running REUSE on {file_path}: {e}") return False - + def process_file(self, file_path: Path) -> bool: """Process a single file to add SPDX attribution.""" # Convert to absolute path if it's relative if not file_path.is_absolute(): file_path = self.repo_root / file_path - + if self._is_excluded(file_path): if self.dry_run: print(f"Excluded: {file_path}") return False - + # Analyze git blame contributors = self._analyze_git_blame(file_path) if not contributors: print(f"No contributors found for {file_path}") return False - + # Generate SPDX headers headers = self._generate_spdx_headers(contributors) - + if self.dry_run: print(f"\nFile: {file_path}") print(f"Contributors: {len(contributors)}") for header in headers: print(f" {header}") return True - + # Remove existing SPDX headers self._remove_existing_spdx_headers(file_path) - + # Apply new headers with REUSE success = self._apply_reuse_annotation(file_path, headers) - + if success: print(f"✓ Updated: {file_path}") else: print(f"✗ Failed: {file_path}") - + return success - + def find_source_files(self) -> List[Path]: """Find all source files that should be processed.""" - extensions = {'.rs', '.py', '.go', '.ts', '.js', '.c', '.h', '.cpp', '.hpp', '.sol'} + extensions = { + ".rs", + ".py", + ".go", + ".ts", + ".js", + ".c", + ".h", + ".cpp", + ".hpp", + ".sol", + } source_files = [] - + for ext in extensions: - for file_path in self.repo_root.rglob(f'*{ext}'): + for file_path in self.repo_root.rglob(f"*{ext}"): if file_path.is_file() and not self._is_excluded(file_path): source_files.append(file_path) - + return sorted(source_files) def main(): - parser = argparse.ArgumentParser(description='Add SPDX attribution headers to source files') - parser.add_argument('--dry-run', action='store_true', help='Show what would be done without making changes') - parser.add_argument('--file', type=str, help='Process a specific file instead of all source files') - parser.add_argument('--repo-root', type=str, default='.', help='Repository root directory') - + """Parse arguments and run SPDX attribution on source files.""" + parser = argparse.ArgumentParser( + description="Add SPDX attribution headers to source files" + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be done without making changes", + ) + parser.add_argument( + "--file", type=str, help="Process a specific file instead of all source files" + ) + parser.add_argument( + "--repo-root", type=str, default=".", help="Repository root directory" + ) + args = parser.parse_args() - + # Initialize the attributor attributor = SPDXAttributor(args.repo_root, dry_run=args.dry_run) - + if args.file: # Process single file file_path = Path(args.file) if not file_path.exists(): print(f"Error: File {file_path} does not exist") sys.exit(1) - + attributor.process_file(file_path) else: # Process all source files source_files = attributor.find_source_files() - + if args.dry_run: print(f"Found {len(source_files)} source files to process") print("\nDry run - showing what would be done:\n") - + success_count = 0 for file_path in source_files: if attributor.process_file(file_path): success_count += 1 - + if not args.dry_run: print(f"\nProcessed {success_count}/{len(source_files)} files successfully") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/sdk/python/test_outputs.py b/sdk/python/test_outputs.py index 6dee15233..f974d76a4 100644 --- a/sdk/python/test_outputs.py +++ b/sdk/python/test_outputs.py @@ -2,20 +2,20 @@ # # SPDX-License-Identifier: Apache-2.0 +"""Test script for verifying Python SDK outputs.""" + import asyncio import sys -from dstack_sdk import ( - DstackClient, - AsyncDstackClient, - TappdClient, - AsyncTappdClient, - get_compose_hash, - verify_env_encrypt_public_key, -) +from dstack_sdk import AsyncDstackClient +from dstack_sdk import AsyncTappdClient +from dstack_sdk import DstackClient +from dstack_sdk import TappdClient +from dstack_sdk import get_compose_hash +from dstack_sdk import verify_env_encrypt_public_key -async def main(): +async def main(): # noqa: D103 print("=== Python SDK Output Test ===") try: @@ -49,7 +49,7 @@ async def main(): account = to_account(eth_key) print(f" address: {account.address}") - print(f" type: ethereum account") + print(" type: ethereum account") except ImportError: print( " error: Ethereum integration not available (install with pip install 'dstack-sdk[eth]')" @@ -63,7 +63,7 @@ async def main(): account_secure = to_account_secure(eth_key) print(f" address: {account_secure.address}") - print(f" type: ethereum account") + print(" type: ethereum account") except ImportError: print( " error: Ethereum integration not available (install with pip install 'dstack-sdk[eth]')" diff --git a/tools/mock-cf-dns-api/app.py b/tools/mock-cf-dns-api/app.py index 5703db797..daf97894f 100644 --- a/tools/mock-cf-dns-api/app.py +++ b/tools/mock-cf-dns-api/app.py @@ -3,8 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 -""" -Mock Cloudflare DNS API Server +"""Mock Cloudflare DNS API server. A mock server that simulates Cloudflare's DNS API for testing purposes. Supports the following endpoints used by certbot: @@ -13,14 +12,14 @@ - DELETE /client/v4/zones/{zone_id}/dns_records/{record_id} - Delete DNS record """ +import json import os import uuid -import time -import json from datetime import datetime -from flask import Flask, request, jsonify, render_template_string from functools import wraps +from flask import Flask, jsonify, render_template_string, request + app = Flask(__name__) # In-memory storage for DNS records @@ -32,7 +31,11 @@ MAX_LOGS = 100 # Valid API tokens (for testing, accept any non-empty token or use env var) -VALID_TOKENS = os.environ.get("CF_API_TOKENS", "").split(",") if os.environ.get("CF_API_TOKENS") else None +VALID_TOKENS = ( + os.environ.get("CF_API_TOKENS", "").split(",") + if os.environ.get("CF_API_TOKENS") + else None +) def log_request(zone_id, method, path, req_data, resp_data, status_code): @@ -62,30 +65,36 @@ def get_current_time(): def verify_auth(f): - """Decorator to verify Bearer token authentication.""" + """Verify Bearer token authentication.""" + @wraps(f) def decorated(*args, **kwargs): auth_header = request.headers.get("Authorization", "") if not auth_header.startswith("Bearer "): - return jsonify({ - "success": False, - "errors": [{"code": 10000, "message": "Authentication error"}], - "messages": [], - "result": None - }), 401 + return jsonify( + { + "success": False, + "errors": [{"code": 10000, "message": "Authentication error"}], + "messages": [], + "result": None, + } + ), 401 token = auth_header[7:] # Remove "Bearer " prefix # If VALID_TOKENS is set, validate against it; otherwise accept any token if VALID_TOKENS and token not in VALID_TOKENS: - return jsonify({ - "success": False, - "errors": [{"code": 10000, "message": "Invalid API token"}], - "messages": [], - "result": None - }), 403 + return jsonify( + { + "success": False, + "errors": [{"code": 10000, "message": "Invalid API token"}], + "messages": [], + "result": None, + } + ), 403 return f(*args, **kwargs) + return decorated @@ -95,7 +104,7 @@ def cf_response(result, success=True, errors=None, messages=None): "success": success, "errors": errors or [], "messages": messages or [], - "result": result + "result": result, } @@ -106,6 +115,7 @@ def cf_error(message, code=1000): # ==================== DNS Record Endpoints ==================== + @app.route("/client/v4/zones//dns_records", methods=["POST"]) @verify_auth def create_dns_record(zone_id): @@ -148,8 +158,8 @@ def create_dns_record(zone_id): "meta": { "auto_added": False, "managed_by_apps": False, - "managed_by_argo_tunnel": False - } + "managed_by_argo_tunnel": False, + }, } # Handle different record types @@ -177,7 +187,9 @@ def create_dns_record(zone_id): resp = cf_response(record) log_request(zone_id, "POST", f"/zones/{zone_id}/dns_records", data, resp, 200) - print(f"[CREATE] Zone: {zone_id}, Record: {record_id}, Type: {record_type}, Name: {name}") + print( + f"[CREATE] Zone: {zone_id}, Record: {record_id}, Type: {record_type}, Name: {name}" + ) return jsonify(resp), 200 @@ -220,10 +232,12 @@ def list_dns_records(zone_id): "per_page": per_page, "count": len(page_records), "total_count": total_count, - "total_pages": total_pages - } + "total_pages": total_pages, + }, } - log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records", dict(request.args), resp, 200) + log_request( + zone_id, "GET", f"/zones/{zone_id}/dns_records", dict(request.args), resp, 200 + ) return jsonify(resp), 200 @@ -237,11 +251,15 @@ def get_dns_record(zone_id, record_id): if not record: resp = cf_error("Record not found", 81044) - log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + log_request( + zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404 + ) return jsonify(resp), 404 resp = cf_response(record) - log_request(zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + log_request( + zone_id, "GET", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200 + ) return jsonify(resp), 200 @@ -255,13 +273,17 @@ def update_dns_record(zone_id, record_id): if not record: resp = cf_error("Record not found", 81044) - log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + log_request( + zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404 + ) return jsonify(resp), 404 data = request.get_json() if not data: resp = cf_error("Invalid request body") - log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 400) + log_request( + zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 400 + ) return jsonify(resp), 400 # Update allowed fields @@ -272,7 +294,9 @@ def update_dns_record(zone_id, record_id): record["modified_on"] = get_current_time() resp = cf_response(record) - log_request(zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", data, resp, 200) + log_request( + zone_id, "PUT", f"/zones/{zone_id}/dns_records/{record_id}", data, resp, 200 + ) print(f"[UPDATE] Zone: {zone_id}, Record: {record_id}") @@ -287,13 +311,22 @@ def delete_dns_record(zone_id, record_id): if record_id not in zone_records: resp = cf_error("Record not found", 81044) - log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 404) + log_request( + zone_id, + "DELETE", + f"/zones/{zone_id}/dns_records/{record_id}", + None, + resp, + 404, + ) return jsonify(resp), 404 del zone_records[record_id] resp = cf_response({"id": record_id}) - log_request(zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200) + log_request( + zone_id, "DELETE", f"/zones/{zone_id}/dns_records/{record_id}", None, resp, 200 + ) print(f"[DELETE] Zone: {zone_id}, Record: {record_id}") @@ -321,7 +354,7 @@ def get_configured_zones(): try: return json.loads(zones_json) except json.JSONDecodeError: - print(f"Warning: Invalid MOCK_ZONES JSON, using defaults") + print("Warning: Invalid MOCK_ZONES JSON, using defaults") return DEFAULT_ZONES @@ -342,20 +375,19 @@ def list_zones(): # Build full zone objects full_zones = [] for z in zones: - full_zones.append({ - "id": z["id"], - "name": z["name"], - "status": "active", - "paused": False, - "type": "full", - "development_mode": 0, - "name_servers": [ - "ns1.mock-cloudflare.com", - "ns2.mock-cloudflare.com" - ], - "created_on": "2024-01-01T00:00:00.000000Z", - "modified_on": get_current_time(), - }) + full_zones.append( + { + "id": z["id"], + "name": z["name"], + "status": "active", + "paused": False, + "type": "full", + "development_mode": 0, + "name_servers": ["ns1.mock-cloudflare.com", "ns2.mock-cloudflare.com"], + "created_on": "2024-01-01T00:00:00.000000Z", + "modified_on": get_current_time(), + } + ) # Pagination total_count = len(full_zones) @@ -374,12 +406,14 @@ def list_zones(): "per_page": per_page, "count": len(page_zones), "total_count": total_count, - "total_pages": total_pages - } + "total_pages": total_pages, + }, } log_request("*", "GET", "/zones", dict(request.args), result, 200) - print(f"[LIST ZONES] page={page}, per_page={per_page}, count={len(page_zones)}, total={total_count}") + print( + f"[LIST ZONES] page={page}, per_page={per_page}, count={len(page_zones)}, total={total_count}" + ) return jsonify(result), 200 @@ -405,10 +439,7 @@ def get_zone(zone_id): "paused": False, "type": "full", "development_mode": 0, - "name_servers": [ - "ns1.mock-cloudflare.com", - "ns2.mock-cloudflare.com" - ], + "name_servers": ["ns1.mock-cloudflare.com", "ns2.mock-cloudflare.com"], "created_on": "2024-01-01T00:00:00.000000Z", "modified_on": get_current_time(), } @@ -789,12 +820,13 @@ def management_ui(): request_count=len(request_logs), records=all_records, logs=request_logs[:20], - port=os.environ.get("PORT", 8080) + port=os.environ.get("PORT", 8080), ) # ==================== Management API ==================== + @app.route("/api/records", methods=["DELETE"]) def clear_all_records(): """Clear all DNS records.""" @@ -829,7 +861,9 @@ def get_all_records(): @app.route("/health") def health(): """Health check endpoint.""" - return jsonify({"status": "healthy", "records": sum(len(r) for r in dns_records.values())}) + return jsonify( + {"status": "healthy", "records": sum(len(r) for r in dns_records.values())} + ) if __name__ == "__main__": diff --git a/vmm/src/vmm-cli.py b/vmm/src/vmm-cli.py index 1d5963bf2..6b2254e03 100755 --- a/vmm/src/vmm-cli.py +++ b/vmm/src/vmm-cli.py @@ -1,95 +1,97 @@ #!/usr/bin/env python3 +"""CLI tool for managing dstack-vmm virtual machines.""" # SPDX-FileCopyrightText: © 2025 Phala Network # # SPDX-License-Identifier: Apache-2.0 -import os -import sys -import json import argparse +import base64 import hashlib +import http.client +import json +import os import re import socket -import http.client -import urllib.parse import ssl -import base64 - -from typing import Optional, Dict, List, Tuple, Union, BinaryIO, Any +import sys +import urllib.parse +from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union # Optional cryptography imports - only needed for encrypted environment variables try: + from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import x25519 from cryptography.hazmat.primitives.ciphers.aead import AESGCM - from cryptography.hazmat.primitives import serialization from eth_keys import keys from eth_utils import keccak + CRYPTO_AVAILABLE = True except ImportError: CRYPTO_AVAILABLE = False # Default config file locations DEFAULT_CONFIG_PATH = os.path.expanduser("~/.dstack-vmm/config.json") -DEFAULT_KMS_WHITELIST_PATH = os.path.expanduser( - "~/.dstack-vmm/kms-whitelist.json") +DEFAULT_KMS_WHITELIST_PATH = os.path.expanduser("~/.dstack-vmm/kms-whitelist.json") # VMM discovery directory DISCOVERY_DIR = "/run/dstack-vmm" def load_config() -> Dict[str, Any]: - """ - Load configuration from the default config file. + """Load configuration from the default config file. Returns: Dictionary with configuration values (url, auth_user, auth_password) + """ if not os.path.exists(DEFAULT_CONFIG_PATH): return {} try: - with open(DEFAULT_CONFIG_PATH, 'r') as f: + with open(DEFAULT_CONFIG_PATH, "r") as f: return json.load(f) except (json.JSONDecodeError, FileNotFoundError): return {} def discover_vmm_instances() -> List[Dict[str, Any]]: - """ - Discover all running VMM instances from the discovery directory. + """Discover all running VMM instances from the discovery directory. Returns: List of VMM instance info dicts, sorted by started_at. + """ instances = [] if not os.path.isdir(DISCOVERY_DIR): return instances for fname in os.listdir(DISCOVERY_DIR): - if not fname.endswith('.json'): + if not fname.endswith(".json"): continue fpath = os.path.join(DISCOVERY_DIR, fname) try: - with open(fpath, 'r') as f: + with open(fpath, "r") as f: info = json.load(f) # Check if process is still alive - pid = info.get('pid') - if pid and not os.path.exists(f'/proc/{pid}'): + pid = info.get("pid") + if pid and not os.path.exists(f"/proc/{pid}"): # Stale file, skip continue instances.append(info) except (json.JSONDecodeError, FileNotFoundError, PermissionError): continue - instances.sort(key=lambda x: x.get('started_at', 0)) + instances.sort(key=lambda x: x.get("started_at", 0)) return instances -def resolve_vmm_url(instances: List[Dict[str, Any]], config: Dict[str, Any], - explicit_url: Optional[str] = None) -> str: - """ - Resolve the VMM URL to connect to. +def resolve_vmm_url( + instances: List[Dict[str, Any]], + config: Dict[str, Any], + explicit_url: Optional[str] = None, +) -> str: + """Resolve the VMM URL to connect to. Priority: 1. Explicit --url flag @@ -102,54 +104,54 @@ def resolve_vmm_url(instances: List[Dict[str, Any]], config: Dict[str, Any], Returns the URL string. """ # If user explicitly provided --url or env var, honor it - env_url = os.environ.get('DSTACK_VMM_URL') - if explicit_url and explicit_url != 'http://localhost:8080': + env_url = os.environ.get("DSTACK_VMM_URL") + if explicit_url and explicit_url != "http://localhost:8080": return explicit_url if env_url: return env_url - config_url = config.get('url') + config_url = config.get("url") if config_url: return config_url # Try auto-discovery - active_id = config.get('active_vmm') + active_id = config.get("active_vmm") if active_id: for inst in instances: - if inst['id'] == active_id or inst['id'].startswith(active_id): + if inst["id"] == active_id or inst["id"].startswith(active_id): return vmm_address_to_url(inst) if len(instances) == 1: return vmm_address_to_url(instances[0]) - return 'http://localhost:8080' + return "http://localhost:8080" def vmm_address_to_url(instance: Dict[str, Any]) -> str: """Convert a VMM instance info dict to a connection URL.""" - addr = instance.get('address', '') - if addr.startswith('unix:'): + addr = instance.get("address", "") + if addr.startswith("unix:"): # Resolve relative socket path against working directory socket_path = addr[5:] if not os.path.isabs(socket_path): - working_dir = instance.get('working_dir', '') + working_dir = instance.get("working_dir", "") socket_path = os.path.join(working_dir, socket_path) - return f'unix:{socket_path}' - elif addr.startswith('http://') or addr.startswith('https://'): + return f"unix:{socket_path}" + elif addr.startswith("http://") or addr.startswith("https://"): return addr else: # host:port format from discovery - host, _, port = addr.rpartition(':') - if host == '0.0.0.0': - host = '127.0.0.1' - return f'http://{host}:{port}' + host, _, port = addr.rpartition(":") + if host == "0.0.0.0": + host = "127.0.0.1" + return f"http://{host}:{port}" def save_active_vmm(vmm_id: str): """Save the active VMM instance ID to the config file.""" config = load_config() - config['active_vmm'] = vmm_id + config["active_vmm"] = vmm_id os.makedirs(os.path.dirname(DEFAULT_CONFIG_PATH), exist_ok=True) - with open(DEFAULT_CONFIG_PATH, 'w') as f: + with open(DEFAULT_CONFIG_PATH, "w") as f: json.dump(config, f, indent=2) @@ -157,40 +159,48 @@ def cmd_ls_vmm(args): """List all discovered VMM instances.""" instances = discover_vmm_instances() config = load_config() - active_id = config.get('active_vmm') + active_id = config.get("active_vmm") if not instances: print("No running VMM instances found.") print(f" (discovery directory: {DISCOVERY_DIR})") return - if getattr(args, 'json', False): + if getattr(args, "json", False): print(json.dumps(instances, indent=2)) return # Table output - from datetime import datetime fmt = " {active} {id:<12s} {pid:<8s} {node:<12s} {address:<24s} {workdir}" - print(fmt.format( - active='', id='ID', pid='PID', node='NAME', - address='ADDRESS', workdir='WORKING DIR')) + print( + fmt.format( + active="", + id="ID", + pid="PID", + node="NAME", + address="ADDRESS", + workdir="WORKING DIR", + ) + ) print(" " + "-" * 90) for inst in instances: - short_id = inst['id'][:8] - is_active = '*' if active_id and inst['id'].startswith(active_id) else ' ' - node_name = inst.get('node_name', '') or '-' - address = inst.get('address', '?') - - print(fmt.format( - active=is_active, - id=short_id, - pid=str(inst.get('pid', '?')), - node=node_name[:12], - address=address[:24], - workdir=inst.get('working_dir', '?'), - )) + short_id = inst["id"][:8] + is_active = "*" if active_id and inst["id"].startswith(active_id) else " " + node_name = inst.get("node_name", "") or "-" + address = inst.get("address", "?") + + print( + fmt.format( + active=is_active, + id=short_id, + pid=str(inst.get("pid", "?")), + node=node_name[:12], + address=address[:24], + workdir=inst.get("working_dir", "?"), + ) + ) def cmd_switch_vmm(args): @@ -203,12 +213,14 @@ def cmd_switch_vmm(args): return # Find matching instance by prefix - matches = [i for i in instances if i['id'].startswith(target)] + matches = [i for i in instances if i["id"].startswith(target)] if len(matches) == 0: print(f"No VMM instance matching '{target}'.") print("Available instances:") for inst in instances: - print(f" {inst['id'][:8]} {inst.get('address', '?')} {inst.get('working_dir', '?')}") + print( + f" {inst['id'][:8]} {inst.get('address', '?')} {inst.get('working_dir', '?')}" + ) return if len(matches) > 1: print(f"Ambiguous ID '{target}', matches multiple instances:") @@ -217,14 +229,13 @@ def cmd_switch_vmm(args): return selected = matches[0] - save_active_vmm(selected['id']) + save_active_vmm(selected["id"]) url = vmm_address_to_url(selected) print(f"Switched to VMM {selected['id'][:8]} ({url})") def encrypt_env(envs, hex_public_key: str) -> str: - """ - Encrypts environment variables using a one-time X25519 key exchange and AES-GCM. + """Encrypts environment variables using a one-time X25519 key exchange and AES-GCM. This function does the following: 1. Converts the given environment variables to JSON bytes. @@ -242,6 +253,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: Returns: A hexadecimal string that is the concatenation of: (ephemeral public key || IV || ciphertext). + """ if not CRYPTO_AVAILABLE: raise ImportError( @@ -264,8 +276,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: ephemeral_public_key = ephemeral_private_key.public_key() # Compute the shared secret using X25519. - peer_public_key = x25519.X25519PublicKey.from_public_bytes( - remote_pubkey_bytes) + peer_public_key = x25519.X25519PublicKey.from_public_bytes(remote_pubkey_bytes) shared = ephemeral_private_key.exchange(peer_public_key) # Use the shared secret as a key for AES-GCM encryption (AES-256 needs 32 bytes). @@ -275,8 +286,7 @@ def encrypt_env(envs, hex_public_key: str) -> str: # Serialize the ephemeral public key to raw bytes. ephemeral_public_bytes = ephemeral_public_key.public_bytes( - encoding=serialization.Encoding.Raw, - format=serialization.PublicFormat.Raw + encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw ) # Combine ephemeral public key, IV, and ciphertext. @@ -287,38 +297,42 @@ def encrypt_env(envs, hex_public_key: str) -> str: def parse_port_mapping(port_str: str) -> Dict: - """Parse a port mapping string into a dictionary""" - parts = port_str.split(':') + """Parse a port mapping string into a dictionary.""" + parts = port_str.split(":") if len(parts) == 3: return { "protocol": parts[0], "host_address": "127.0.0.1", "host_port": int(parts[1]), - "vm_port": int(parts[2]) + "vm_port": int(parts[2]), } elif len(parts) == 4: return { "protocol": parts[0], "host_address": parts[1], "host_port": int(parts[2]), - "vm_port": int(parts[3]) + "vm_port": int(parts[3]), } else: - raise argparse.ArgumentTypeError( - f"Invalid port mapping format: {port_str}") + raise argparse.ArgumentTypeError(f"Invalid port mapping format: {port_str}") + def read_utf8(filepath: str) -> str: - with open(filepath, 'rb') as f: - return f.read().decode('utf-8') + """Read a file and return its contents as a UTF-8 string.""" + with open(filepath, "rb") as f: + return f.read().decode("utf-8") + class UnixSocketHTTPConnection(http.client.HTTPConnection): """HTTPConnection that connects to a Unix domain socket.""" def __init__(self, socket_path, timeout=None): - super().__init__('localhost', timeout=timeout) + """Initialize with the given Unix socket path.""" + super().__init__("localhost", timeout=timeout) self.socket_path = socket_path def connect(self): + """Connect to the Unix domain socket.""" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) if self.timeout: sock.settimeout(self.timeout) @@ -329,9 +343,15 @@ def connect(self): class VmmClient: """A unified HTTP client that supports both regular HTTP and Unix Domain Sockets.""" - def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password: Optional[str] = None): - self.base_url = base_url.rstrip('/') - self.use_uds = self.base_url.startswith('unix:') + def __init__( + self, + base_url: str, + auth_user: Optional[str] = None, + auth_password: Optional[str] = None, + ): + """Initialize the client with a base URL and optional auth credentials.""" + self.base_url = base_url.rstrip("/") + self.use_uds = self.base_url.startswith("unix:") self.auth_user = auth_user self.auth_password = auth_password @@ -341,12 +361,17 @@ def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password # Parse the base URL for regular HTTP connections self.parsed_url = urllib.parse.urlparse(self.base_url) self.host = self.parsed_url.netloc - self.is_https = self.parsed_url.scheme == 'https' + self.is_https = self.parsed_url.scheme == "https" - def request(self, method: str, path: str, headers: Dict[str, str] = None, - body: Any = None, stream: bool = False) -> Tuple[int, Union[Dict, str, BinaryIO]]: - """ - Make an HTTP request to the server. + def request( + self, + method: str, + path: str, + headers: Dict[str, str] = None, + body: Any = None, + stream: bool = False, + ) -> Tuple[int, Union[Dict, str, BinaryIO]]: + """Make an HTTP request to the server. Args: method: HTTP method (GET, POST, etc.) @@ -357,6 +382,7 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, Returns: Tuple of (status_code, response_data) + """ if headers is None: headers = {} @@ -364,15 +390,16 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, # Add Basic Authentication header if credentials are provided if self.auth_user and self.auth_password: credentials = f"{self.auth_user}:{self.auth_password}" - encoded_credentials = base64.b64encode( - credentials.encode('utf-8')).decode('ascii') - headers['Authorization'] = f'Basic {encoded_credentials}' + encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode( + "ascii" + ) + headers["Authorization"] = f"Basic {encoded_credentials}" # Prepare the body if isinstance(body, dict): - body = json.dumps(body).encode('utf-8') - if 'Content-Type' not in headers: - headers['Content-Type'] = 'application/json' + body = json.dumps(body).encode("utf-8") + if "Content-Type" not in headers: + headers["Content-Type"] = "application/json" # Create the appropriate connection if self.use_uds: @@ -401,15 +428,19 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, data = response.read() # Try to parse as JSON if it looks like JSON - content_type = response.getheader('Content-Type', '') - if 'application/json' in content_type or data.startswith(b'{') or data.startswith(b'['): + content_type = response.getheader("Content-Type", "") + if ( + "application/json" in content_type + or data.startswith(b"{") + or data.startswith(b"[") + ): try: - return status, json.loads(data.decode('utf-8')) + return status, json.loads(data.decode("utf-8")) except json.JSONDecodeError: pass # Return as string if not JSON - return status, data.decode('utf-8') + return status, data.decode("utf-8") except Exception as e: if not stream: conn.close() @@ -419,18 +450,25 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None, class VmmCLI: - def __init__(self, base_url: str, auth_user: Optional[str] = None, auth_password: Optional[str] = None): - self.base_url = base_url.rstrip('/') - self.headers = { - 'Content-Type': 'application/json' - } + """Command-line interface for the dstack-vmm API.""" + + def __init__( + self, + base_url: str, + auth_user: Optional[str] = None, + auth_password: Optional[str] = None, + ): + """Initialize the CLI with a base URL and optional auth credentials.""" + self.base_url = base_url.rstrip("/") + self.headers = {"Content-Type": "application/json"} self.client = VmmClient(base_url, auth_user, auth_password) def rpc_call(self, method: str, params: Optional[Dict] = None) -> Dict: - """Make an RPC call to the dstack-vmm API""" + """Make an RPC call to the dstack-vmm API.""" path = f"/prpc/{method}?json" status, response = self.client.request( - 'POST', path, headers=self.headers, body=params or {}) + "POST", path, headers=self.headers, body=params or {} + ) if status != 200: if isinstance(response, str): @@ -442,9 +480,9 @@ def rpc_call(self, method: str, params: Optional[Dict] = None) -> Dict: return response def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: - """List all VMs and their status""" - response = self.rpc_call('Status') - vms = response['vms'] + """List all VMs and their status.""" + response = self.rpc_call("Status") + vms = response["vms"] if json_output: # Return raw JSON data for automation/testing @@ -455,69 +493,71 @@ def list_vms(self, verbose: bool = False, json_output: bool = False) -> None: print("No VMs found") return - headers = ['VM ID', 'App ID', 'Name', 'Status', 'Uptime'] + headers = ["VM ID", "App ID", "Name", "Status", "Uptime"] if verbose: - headers.extend(['Instance ID', 'vCPU', 'Memory', 'Disk', 'Image', 'GPUs']) + headers.extend(["Instance ID", "vCPU", "Memory", "Disk", "Image", "GPUs"]) rows = [] for vm in vms: row = [ - vm['id'], - vm['app_id'], - vm['name'], - vm['status'], - vm.get('uptime', '-') + vm["id"], + vm["app_id"], + vm["name"], + vm["status"], + vm.get("uptime", "-"), ] if verbose: - config = vm.get('configuration', {}) - gpu_info = self._format_gpu_info(config.get('gpus')) - row.extend([ - vm.get('instance_id', '-') or '-', - config.get('vcpu', '-'), - f"{config.get('memory', '-')}MB", - f"{config.get('disk_size', '-')}GB", - config.get('image', '-'), - gpu_info - ]) + config = vm.get("configuration", {}) + gpu_info = self._format_gpu_info(config.get("gpus")) + row.extend( + [ + vm.get("instance_id", "-") or "-", + config.get("vcpu", "-"), + f"{config.get('memory', '-')}MB", + f"{config.get('disk_size', '-')}GB", + config.get("image", "-"), + gpu_info, + ] + ) rows.append(row) print(format_table(rows, headers)) def _format_gpu_info(self, gpu_config): - """Format GPU configuration for display""" + """Format GPU configuration for display.""" if not gpu_config: - return '-' + return "-" - attach_mode = gpu_config.get('attach_mode', '') - gpus = gpu_config.get('gpus', []) + attach_mode = gpu_config.get("attach_mode", "") + gpus = gpu_config.get("gpus", []) - if attach_mode == 'all': - return 'All GPUs' - elif attach_mode == 'listed' and gpus: - gpu_slots = [gpu.get('slot', 'Unknown') for gpu in gpus] - return ', '.join(gpu_slots) + if attach_mode == "all": + return "All GPUs" + elif attach_mode == "listed" and gpus: + gpu_slots = [gpu.get("slot", "Unknown") for gpu in gpus] + return ", ".join(gpu_slots) else: - return '-' + return "-" def start_vm(self, vm_id: str) -> None: - """Start a VM""" - self.rpc_call('StartVm', {'id': vm_id}) + """Start a VM.""" + self.rpc_call("StartVm", {"id": vm_id}) print(f"Started VM {vm_id}") def stop_vm(self, vm_id: str, force: bool = False) -> None: - """Stop a VM""" + """Stop a VM.""" if force: - self.rpc_call('StopVm', {'id': vm_id}) + self.rpc_call("StopVm", {"id": vm_id}) print(f"Forcefully stopped VM {vm_id}") else: - self.rpc_call('ShutdownVm', {'id': vm_id}) + self.rpc_call("ShutdownVm", {"id": vm_id}) print(f"Gracefully shutting down VM {vm_id}") def remove_vm(self, vm_id: str) -> None: - """Remove a VM""" - self.rpc_call('RemoveVm', {'id': vm_id}) + """Remove a VM.""" + self.rpc_call("RemoveVm", {"id": vm_id}) print(f"Removed VM {vm_id}") def resize_vm( @@ -528,7 +568,7 @@ def resize_vm( disk_size: Optional[int] = None, image: Optional[str] = None, ) -> None: - """Resize a VM""" + """Resize a VM.""" params = {"id": vm_id} if vcpu is not None: params["vcpu"] = vcpu @@ -548,11 +588,12 @@ def resize_vm( print(f"Resized VM {vm_id}") def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: - """Show VM logs""" + """Show VM logs.""" path = f"/logs?id={vm_id}&follow={str(follow).lower()}&ansi=false&lines={lines}" status, response = self.client.request( - 'GET', path, headers=self.headers, stream=follow) + "GET", path, headers=self.headers, stream=follow + ) if status != 200: if isinstance(response, str): @@ -569,7 +610,7 @@ def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: line = response.readline() if not line: break - print(line.decode('utf-8').rstrip()) + print(line.decode("utf-8").rstrip()) except KeyboardInterrupt: # Allow clean exit with Ctrl+C return @@ -581,9 +622,9 @@ def show_logs(self, vm_id: str, lines: int = 20, follow: bool = False) -> None: print(response) def list_images(self, json_output: bool = False) -> None: - """Get list of available images""" - response = self.rpc_call('ListImages') - images = response['images'] + """List available images.""" + response = self.rpc_call("ListImages") + images = response["images"] if json_output: # Return raw JSON data for automation/testing @@ -594,46 +635,54 @@ def list_images(self, json_output: bool = False) -> None: print("No images found") return - headers = ['Name', 'Version'] - rows = [[img['name'], img.get('version', '-')] for img in images] + headers = ["Name", "Version"] + rows = [[img["name"], img.get("version", "-")] for img in images] print(format_table(rows, headers)) - def get_app_env_encrypt_pub_key(self, app_id: str, kms_url: Optional[str] = None) -> Dict: - """Get the encryption public key for the specified application ID""" + def get_app_env_encrypt_pub_key( + self, app_id: str, kms_url: Optional[str] = None + ) -> Dict: + """Get the encryption public key for the specified application ID.""" if kms_url: client = VmmClient(kms_url) - path = f"/prpc/GetAppEnvEncryptPubKey?json" + path = "/prpc/GetAppEnvEncryptPubKey?json" status, response = client.request( - 'POST', path, headers={ - 'Content-Type': 'application/json' - }, body={'app_id': app_id}) + "POST", + path, + headers={"Content-Type": "application/json"}, + body={"app_id": app_id}, + ) print(f"Getting encryption public key for {app_id} from {kms_url}") else: - response = self.rpc_call( - 'GetAppEnvEncryptPubKey', {'app_id': app_id}) + response = self.rpc_call("GetAppEnvEncryptPubKey", {"app_id": app_id}) # Verify the signature if available - if 'signature' not in response and 'signature_v1' not in response: + if "signature" not in response and "signature_v1" not in response: if not self.confirm_untrusted_signer("none"): raise Exception("Aborted due to invalid signature") - return response['public_key'] + return response["public_key"] - public_key = bytes.fromhex(response['public_key']) + public_key = bytes.fromhex(response["public_key"]) # Prefer signature_v1 (with timestamp) if available signer_pubkey = None - if 'signature_v1' in response and 'timestamp' in response: - signature_v1 = bytes.fromhex(response['signature_v1']) - timestamp = response['timestamp'] - signer_pubkey = verify_signature_v1(public_key, signature_v1, app_id, timestamp) + if "signature_v1" in response and "timestamp" in response: + signature_v1 = bytes.fromhex(response["signature_v1"]) + timestamp = response["timestamp"] + signer_pubkey = verify_signature_v1( + public_key, signature_v1, app_id, timestamp + ) if signer_pubkey: print(f"Verified signature_v1 (with timestamp) from: {signer_pubkey}") # Fall back to legacy signature if signature_v1 verification failed or not available - if not signer_pubkey and 'signature' in response: - print("WARNING: Using legacy signature without timestamp protection. " - "Consider upgrading your KMS to support signature_v1.", file=sys.stderr) - signature = bytes.fromhex(response['signature']) + if not signer_pubkey and "signature" in response: + print( + "WARNING: Using legacy signature without timestamp protection. " + "Consider upgrading your KMS to support signature_v1.", + file=sys.stderr, + ) + signature = bytes.fromhex(response["signature"]) signer_pubkey = verify_signature(public_key, signature, app_id) if signer_pubkey: print(f"Verified legacy signature from: {signer_pubkey}") @@ -642,7 +691,8 @@ def get_app_env_encrypt_pub_key(self, app_id: str, kms_url: Optional[str] = None whitelist = load_whitelist() if whitelist and signer_pubkey not in whitelist: print( - f"WARNING: Signer {signer_pubkey} is not in the trusted whitelist!") + f"WARNING: Signer {signer_pubkey} is not in the trusted whitelist!" + ) if not self.confirm_untrusted_signer(signer_pubkey): raise Exception("Aborted due to untrusted signer") else: @@ -650,18 +700,18 @@ def get_app_env_encrypt_pub_key(self, app_id: str, kms_url: Optional[str] = None if not self.confirm_untrusted_signer("unknown"): raise Exception("Aborted due to invalid signature") - return response['public_key'] + return response["public_key"] def confirm_untrusted_signer(self, signer: str) -> bool: - """Ask user to confirm using an untrusted signer""" + """Ask user to confirm using an untrusted signer.""" response = input(f"Continue with untrusted signer {signer}? (y/N): ") - return response.lower() in ('y', 'yes') + return response.lower() in ("y", "yes") def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: - """Manage the whitelist of trusted signers""" + """Manage the whitelist of trusted signers.""" whitelist = load_whitelist() - if action == 'list': + if action == "list": if not whitelist: print("Whitelist is empty") else: @@ -671,7 +721,7 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: return # Normalize pubkey format - trim 0x prefix if present - if pubkey and pubkey.startswith('0x'): + if pubkey and pubkey.startswith("0x"): pubkey = pubkey[2:] # Convert to bytes for validation try: @@ -682,7 +732,7 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: raise Exception(f"Invalid public key length: {len(pubkey)}") pubkey = pubkey.hex() - if action == 'add': + if action == "add": if pubkey in whitelist: print(f"Public key {pubkey} is already in the whitelist") else: @@ -690,7 +740,7 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: save_whitelist(whitelist) print(f"Added {pubkey} to the whitelist") - elif action == 'remove': + elif action == "remove": if pubkey not in whitelist: print(f"Public key {pubkey} is not in the whitelist") else: @@ -702,23 +752,27 @@ def manage_kms_whitelist(self, action: str, pubkey: str = None) -> None: raise Exception(f"Unknown action: {action}") def calc_app_id(self, compose_file: str) -> str: - """Calculate the application ID from the compose file""" + """Calculate the application ID from the compose file.""" compose_hash = hashlib.sha256(compose_file.encode()).hexdigest() return compose_hash[:40] def create_app_compose(self, args) -> None: - """Create a new app compose file""" + """Create a new app compose file.""" envs = parse_env_file(args.env_file) or {} # Validate: --env-file requires --kms if envs and not args.kms: - raise Exception("--env-file requires --kms to enable KMS for environment variable decryption") + raise Exception( + "--env-file requires --kms to enable KMS for environment variable decryption" + ) app_compose = { "manifest_version": 2, "name": args.name, "runner": "docker-compose", - "docker_compose_file": open(args.docker_compose, 'rb').read().decode('utf-8'), + "docker_compose_file": open(args.docker_compose, "rb") + .read() + .decode("utf-8"), "kms_enabled": args.kms, "gateway_enabled": args.gateway, "local_key_provider_enabled": args.local_key_provider, @@ -732,8 +786,9 @@ def create_app_compose(self, args) -> None: if args.key_provider: app_compose["key_provider"] = args.key_provider if args.prelaunch_script: - app_compose["pre_launch_script"] = open( - args.prelaunch_script, 'rb').read().decode('utf-8') + app_compose["pre_launch_script"] = ( + open(args.prelaunch_script, "rb").read().decode("utf-8") + ) if args.swap is not None: swap_bytes = max(0, int(round(args.swap)) * 1024 * 1024) if swap_bytes > 0: @@ -741,16 +796,17 @@ def create_app_compose(self, args) -> None: else: app_compose.pop("swap_size", None) - compose_file = json.dumps( - app_compose, indent=4, ensure_ascii=False).encode('utf-8') + compose_file = json.dumps(app_compose, indent=4, ensure_ascii=False).encode( + "utf-8" + ) compose_hash = hashlib.sha256(compose_file).hexdigest() - with open(args.output, 'wb') as f: + with open(args.output, "wb") as f: f.write(compose_file) print(f"App compose file created at: {args.output}") print(f"Compose hash: {compose_hash}") def create_vm(self, args) -> None: - """Create a new VM""" + """Create a new VM.""" # Read and validate compose file if not os.path.exists(args.compose): raise Exception(f"Compose file not found: {args.compose}") @@ -762,11 +818,15 @@ def create_vm(self, args) -> None: # Validate: --env-file requires --kms-url and kms_enabled in compose if envs: if not args.kms_url: - raise Exception("--env-file requires --kms-url to encrypt environment variables") + raise Exception( + "--env-file requires --kms-url to encrypt environment variables" + ) try: compose_json = json.loads(compose_content) - if not compose_json.get('kms_enabled', False): - raise Exception("--env-file requires kms_enabled=true in the compose file (use --kms when creating compose)") + if not compose_json.get("kms_enabled", False): + raise Exception( + "--env-file requires kms_enabled=true in the compose file (use --kms when creating compose)" + ) except json.JSONDecodeError: pass # Let the server handle invalid JSON @@ -797,13 +857,11 @@ def create_vm(self, args) -> None: params["swap_size"] = swap_bytes if args.ppcie or (args.gpu and "all" in args.gpu): - params["gpus"] = { - "attach_mode": "all" - } + params["gpus"] = {"attach_mode": "all"} elif args.gpu: params["gpus"] = { "attach_mode": "listed", - "gpus": [{"slot": gpu} for gpu in args.gpu or []] + "gpus": [{"slot": gpu} for gpu in args.gpu or []], } if args.kms_url: params["kms_urls"] = args.kms_url @@ -816,42 +874,45 @@ def create_vm(self, args) -> None: print(f"App ID: {app_id}") if envs: encrypt_pubkey = self.get_app_env_encrypt_pub_key( - app_id, args.kms_url[0] if args.kms_url else None) - print( - f"Encrypting environment variables with key: {encrypt_pubkey}") + app_id, args.kms_url[0] if args.kms_url else None + ) + print(f"Encrypting environment variables with key: {encrypt_pubkey}") envs_list = [{"key": k, "value": v} for k, v in envs.items()] params["encrypted_env"] = encrypt_env(envs_list, encrypt_pubkey) - response = self.rpc_call('CreateVm', params) + response = self.rpc_call("CreateVm", params) print(f"Created VM with ID: {response.get('id')}") - return response.get('id') + return response.get("id") - def update_vm_env(self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[List[str]] = None) -> None: - """Update environment variables for a VM""" + def update_vm_env( + self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[List[str]] = None + ) -> None: + """Update environment variables for a VM.""" # Validate: requires --kms-url if not kms_urls: raise Exception("--kms-url is required to encrypt environment variables") envs = envs or {} # First get the VM info to retrieve the app_id - vm_info_response = self.rpc_call('GetInfo', {'id': vm_id}) + vm_info_response = self.rpc_call("GetInfo", {"id": vm_id}) - if not vm_info_response.get('found', False) or 'info' not in vm_info_response: + if not vm_info_response.get("found", False) or "info" not in vm_info_response: raise Exception(f"VM with ID {vm_id} not found") - app_id = vm_info_response['info']['app_id'] + app_id = vm_info_response["info"]["app_id"] print(f"Retrieved app ID: {app_id}") - vm_configuration = vm_info_response['info'].get('configuration') or {} - compose_file = vm_configuration.get('compose_file') + vm_configuration = vm_info_response["info"].get("configuration") or {} + compose_file = vm_configuration.get("compose_file") # Now get the encryption key for the app encrypt_pubkey = self.get_app_env_encrypt_pub_key( - app_id, kms_urls[0] if kms_urls else None) + app_id, kms_urls[0] if kms_urls else None + ) print(f"Encrypting environment variables with key: {encrypt_pubkey}") envs_list = [{"key": k, "value": v} for k, v in envs.items()] encrypted_env = encrypt_env(envs_list, encrypt_pubkey) # Use UpdateApp with the VM ID - payload = {'id': vm_id, 'encrypted_env': encrypted_env} + payload = {"id": vm_id, "encrypted_env": encrypted_env} if compose_file: try: @@ -860,42 +921,40 @@ def update_vm_env(self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[Lis app_compose = {} compose_changed = False allowed_envs = list(envs.keys()) - if app_compose.get('allowed_envs') != allowed_envs: - app_compose['allowed_envs'] = allowed_envs + if app_compose.get("allowed_envs") != allowed_envs: + app_compose["allowed_envs"] = allowed_envs compose_changed = True - launch_token_value = envs.get('APP_LAUNCH_TOKEN') + launch_token_value = envs.get("APP_LAUNCH_TOKEN") if launch_token_value is not None: launch_token_hash = hashlib.sha256( - launch_token_value.encode('utf-8') + launch_token_value.encode("utf-8") ).hexdigest() - if app_compose.get('launch_token_hash') != launch_token_hash: - app_compose['launch_token_hash'] = launch_token_hash + if app_compose.get("launch_token_hash") != launch_token_hash: + app_compose["launch_token_hash"] = launch_token_hash compose_changed = True if compose_changed: - payload['compose_file'] = json.dumps( - app_compose, indent=4, ensure_ascii=False) + payload["compose_file"] = json.dumps( + app_compose, indent=4, ensure_ascii=False + ) - self.rpc_call('UpgradeApp', payload) + self.rpc_call("UpgradeApp", payload) print(f"Environment variables updated for VM {vm_id}") def update_vm_user_config(self, vm_id: str, user_config: str) -> None: - """Update user config for a VM""" - self.rpc_call('UpgradeApp', {'id': vm_id, - 'user_config': user_config}) + """Update user config for a VM.""" + self.rpc_call("UpgradeApp", {"id": vm_id, "user_config": user_config}) print(f"User config updated for VM {vm_id}") def update_vm_app_compose(self, vm_id: str, app_compose: str) -> None: - """Update app compose for a VM""" - self.rpc_call('UpgradeApp', {'id': vm_id, - 'compose_file': app_compose}) + """Update app compose for a VM.""" + self.rpc_call("UpgradeApp", {"id": vm_id, "compose_file": app_compose}) print(f"App compose updated for VM {vm_id}") - + def update_vm_ports(self, vm_id: str, ports: List[str]) -> None: - """Update port mapping for a VM""" + """Update port mapping for a VM.""" port_mappings = [parse_port_mapping(port) for port in ports] self.rpc_call( - "UpgradeApp", {"id": vm_id, - "update_ports": True, "ports": port_mappings} + "UpgradeApp", {"id": vm_id, "update_ports": True, "ports": port_mappings} ) print(f"Port mapping updated for VM {vm_id}") @@ -919,10 +978,12 @@ def update_vm( kms_urls: Optional[List[str]] = None, no_tee: Optional[bool] = None, ) -> None: - """Update multiple aspects of a VM in one command""" + """Update multiple aspects of a VM in one command.""" # Validate: --env-file requires --kms-url if env_file and not kms_urls: - raise Exception("--env-file requires --kms-url to encrypt environment variables") + raise Exception( + "--env-file requires --kms-url to encrypt environment variables" + ) updates = [] @@ -949,56 +1010,68 @@ def update_vm( upgrade_params = {"id": vm_id} # handle compose file updates (docker-compose, prelaunch script, swap) - needs_compose_update = docker_compose_content or prelaunch_script is not None or swap_size is not None + needs_compose_update = ( + docker_compose_content + or prelaunch_script is not None + or swap_size is not None + ) vm_info_response = None if needs_compose_update or env_file: - vm_info_response = self.rpc_call('GetInfo', {'id': vm_id}) - if not vm_info_response.get('found', False) or 'info' not in vm_info_response: + vm_info_response = self.rpc_call("GetInfo", {"id": vm_id}) + if ( + not vm_info_response.get("found", False) + or "info" not in vm_info_response + ): raise Exception(f"VM with ID {vm_id} not found") if needs_compose_update: - vm_configuration = vm_info_response['info'].get('configuration') or {} - compose_file_content = vm_configuration.get('compose_file') + vm_configuration = vm_info_response["info"].get("configuration") or {} + compose_file_content = vm_configuration.get("compose_file") try: - app_compose = json.loads(compose_file_content) if compose_file_content else {} + app_compose = ( + json.loads(compose_file_content) if compose_file_content else {} + ) except json.JSONDecodeError: app_compose = {} if docker_compose_content: - app_compose['docker_compose_file'] = docker_compose_content + app_compose["docker_compose_file"] = docker_compose_content updates.append("docker compose") if prelaunch_script is not None: script_stripped = prelaunch_script.strip() if script_stripped: - app_compose['pre_launch_script'] = script_stripped + app_compose["pre_launch_script"] = script_stripped updates.append("prelaunch script") - elif 'pre_launch_script' in app_compose: - del app_compose['pre_launch_script'] + elif "pre_launch_script" in app_compose: + del app_compose["pre_launch_script"] updates.append("prelaunch script (removed)") if swap_size is not None: swap_bytes = max(0, int(round(swap_size)) * 1024 * 1024) if swap_bytes > 0: - app_compose['swap_size'] = swap_bytes + app_compose["swap_size"] = swap_bytes updates.append(f"swap: {swap_size}MB") - elif 'swap_size' in app_compose: - del app_compose['swap_size'] + elif "swap_size" in app_compose: + del app_compose["swap_size"] updates.append("swap (disabled)") - upgrade_params['compose_file'] = json.dumps(app_compose, indent=4, ensure_ascii=False) + upgrade_params["compose_file"] = json.dumps( + app_compose, indent=4, ensure_ascii=False + ) if env_file: envs = parse_env_file(env_file) if envs: - app_id = vm_info_response['info']['app_id'] - vm_configuration = vm_info_response['info'].get('configuration') or {} - compose_file_content = vm_configuration.get('compose_file') + app_id = vm_info_response["info"]["app_id"] + vm_configuration = vm_info_response["info"].get("configuration") or {} + compose_file_content = vm_configuration.get("compose_file") encrypt_pubkey = self.get_app_env_encrypt_pub_key( - app_id, kms_urls[0] if kms_urls else None) + app_id, kms_urls[0] if kms_urls else None + ) envs_list = [{"key": k, "value": v} for k, v in envs.items()] upgrade_params["encrypted_env"] = encrypt_env(envs_list, encrypt_pubkey) updates.append("environment variables") @@ -1011,20 +1084,21 @@ def update_vm( app_compose = {} compose_changed = False allowed_envs = list(envs.keys()) - if app_compose.get('allowed_envs') != allowed_envs: - app_compose['allowed_envs'] = allowed_envs + if app_compose.get("allowed_envs") != allowed_envs: + app_compose["allowed_envs"] = allowed_envs compose_changed = True - launch_token_value = envs.get('APP_LAUNCH_TOKEN') + launch_token_value = envs.get("APP_LAUNCH_TOKEN") if launch_token_value is not None: launch_token_hash = hashlib.sha256( - launch_token_value.encode('utf-8') + launch_token_value.encode("utf-8") ).hexdigest() - if app_compose.get('launch_token_hash') != launch_token_hash: - app_compose['launch_token_hash'] = launch_token_hash + if app_compose.get("launch_token_hash") != launch_token_hash: + app_compose["launch_token_hash"] = launch_token_hash compose_changed = True if compose_changed: - upgrade_params['compose_file'] = json.dumps( - app_compose, indent=4, ensure_ascii=False) + upgrade_params["compose_file"] = json.dumps( + app_compose, indent=4, ensure_ascii=False + ) if user_config: upgrade_params["user_config"] = user_config @@ -1052,23 +1126,17 @@ def update_vm( gpu_config = {"attach_mode": "all"} updates.append("GPUs (all)") elif no_gpus: - gpu_config = { - "attach_mode": "listed", - "gpus": [] - } + gpu_config = {"attach_mode": "listed", "gpus": []} updates.append("GPUs (detached)") elif gpu_slots: gpu_config = { "attach_mode": "listed", - "gpus": [{"slot": gpu} for gpu in gpu_slots] + "gpus": [{"slot": gpu} for gpu in gpu_slots], } updates.append(f"GPUs ({len(gpu_slots)} devices)") else: # gpu_slots is an empty list ([] not None) - shouldn't happen with mutually exclusive group - gpu_config = { - "attach_mode": "listed", - "gpus": [] - } + gpu_config = {"attach_mode": "listed", "gpus": []} updates.append("GPUs (none)") upgrade_params["gpus"] = gpu_config @@ -1085,20 +1153,20 @@ def update_vm( print(f"No updates specified for VM {vm_id}") def show_info(self, vm_id: str, json_output: bool = False) -> None: - """Show detailed information about a VM""" - response = self.rpc_call('GetInfo', {'id': vm_id}) + """Show detailed information about a VM.""" + response = self.rpc_call("GetInfo", {"id": vm_id}) - if not response.get('found', False) or 'info' not in response: + if not response.get("found", False) or "info" not in response: print(f"VM with ID {vm_id} not found") return - info = response['info'] + info = response["info"] if json_output: print(json.dumps(info, indent=2)) return - config = info.get('configuration', {}) + config = info.get("configuration", {}) print(f"VM ID: {info.get('id', '-')}") print(f"Name: {info.get('name', '-')}") @@ -1114,24 +1182,26 @@ def show_info(self, vm_id: str, json_output: bool = False) -> None: print(f"Disk: {config.get('disk_size', '-')}GB") print(f"GPUs: {self._format_gpu_info(config.get('gpus'))}") print(f"Boot Progress: {info.get('boot_progress', '-')}") - if info.get('boot_error'): + if info.get("boot_error"): print(f"Boot Error: {info['boot_error']}") - if info.get('exited_at'): + if info.get("exited_at"): print(f"Exited At: {info['exited_at']}") - if info.get('shutdown_progress'): + if info.get("shutdown_progress"): print(f"Shutdown: {info['shutdown_progress']}") - events = info.get('events', []) + events = info.get("events", []) if events: - print(f"\nRecent Events:") + print("\nRecent Events:") for event in events[-10:]: - ts = event.get('timestamp', 0) - print(f" [{event.get('event', '')}] {event.get('body', '')} (ts: {ts})") + ts = event.get("timestamp", 0) + print( + f" [{event.get('event', '')}] {event.get('body', '')} (ts: {ts})" + ) def list_gpus(self, json_output: bool = False) -> None: - """List all available GPUs""" - response = self.rpc_call('ListGpus') - gpus = response.get('gpus', []) + """List all available GPUs.""" + response = self.rpc_call("ListGpus") + gpus = response.get("gpus", []) if json_output: # Return raw JSON data for automation/testing @@ -1142,14 +1212,14 @@ def list_gpus(self, json_output: bool = False) -> None: print("No GPUs found") return - headers = ['Slot', 'Product ID', 'Description', 'Available'] + headers = ["Slot", "Product ID", "Description", "Available"] rows = [] for gpu in gpus: row = [ - gpu.get('slot', '-'), - gpu.get('product_id', '-'), - gpu.get('description', '-'), - 'Yes' if gpu.get('is_free', False) else 'No' + gpu.get("slot", "-"), + gpu.get("product_id", "-"), + gpu.get("description", "-"), + "Yes" if gpu.get("is_free", False) else "No", ] rows.append(row) @@ -1157,7 +1227,7 @@ def list_gpus(self, json_output: bool = False) -> None: def format_table(rows, headers): - """Simple table formatter""" + """Format rows and headers into a table string.""" if not rows: return "" @@ -1174,11 +1244,7 @@ def format_table(rows, headers): bottom_border = "└─" + "─┴─".join("─" * w for w in widths) + "─┘" # Build table - table = [ - top_border, - row_format.format(*headers), - separator - ] + table = [top_border, row_format.format(*headers), separator] for row in rows: table.append(row_format.format(*[str(cell) for cell in row])) table.append(bottom_border) @@ -1187,32 +1253,32 @@ def format_table(rows, headers): def parse_env_file(file_path: str) -> Dict[str, str]: - """ - Parse an environment file where each line is formatted as: - KEY=Value + """Parse an environment file into a dictionary of key-value pairs. + Each line should be formatted as KEY=Value. Lines that are empty or start with '#' are ignored. + """ if not file_path: return {} envs = {} - with open(file_path, 'r') as f: + with open(file_path, "r") as f: for line in f: line = line.strip() - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue - if '=' not in line: + if "=" not in line: continue - key, value = line.split('=', 1) + key, value = line.split("=", 1) envs[key.strip()] = value.strip() return envs def parse_size(s: str, target_unit: str) -> int: - """ - Parse a human-readable size string (e.g. "1G", "100M") and return the size - in the specified target unit. + """Parse a human-readable size string and return the size in the target unit. + + Accept strings like "1G" or "100M". Args: s: The size string provided. @@ -1224,9 +1290,10 @@ def parse_size(s: str, target_unit: str) -> int: Raises: argparse.ArgumentTypeError: if the format is invalid or if conversion turns out to be fractional. + """ s = s.strip() - m = re.fullmatch(r'(\d+(?:\.\d+)?)([a-zA-Z]{1,2})?', s) + m = re.fullmatch(r"(\d+(?:\.\d+)?)([a-zA-Z]{1,2})?", s) if not m: raise argparse.ArgumentTypeError(f"Invalid size format: '{s}'") number_str, unit = m.groups() @@ -1251,7 +1318,8 @@ def parse_size(s: str, target_unit: str) -> int: factor = 1024 * 1024 else: raise argparse.ArgumentTypeError( - f"Invalid size unit '{unit}' for memory. Use M, G, or T.") + f"Invalid size unit '{unit}' for memory. Use M, G, or T." + ) elif target_unit == "GB": # For disk, if no suffix is provided we assume GB. if unit in ["G", "GB"]: @@ -1260,14 +1328,16 @@ def parse_size(s: str, target_unit: str) -> int: factor = 1024 else: raise argparse.ArgumentTypeError( - f"Invalid size unit '{unit}' for disk. Use G, T.") + f"Invalid size unit '{unit}' for disk. Use G, T." + ) else: raise ValueError("Unsupported target unit") value = number * factor if not value.is_integer(): raise argparse.ArgumentTypeError( - f"Size must be an integer number of {target_unit}. Got {value}.") + f"Size must be an integer number of {target_unit}. Got {value}." + ) return int(value) @@ -1281,9 +1351,10 @@ def parse_disk_size(s: str) -> int: return parse_size(s, "GB") -def verify_signature_v1(public_key: bytes, signature: bytes, app_id: str, timestamp: int) -> Optional[str]: - """ - Verify the v1 signature (with timestamp) of a public key. +def verify_signature_v1( + public_key: bytes, signature: bytes, app_id: str, timestamp: int +) -> Optional[str]: + """Verify the v1 signature (with timestamp) of a public key. Args: public_key: The public key bytes to verify @@ -1293,6 +1364,7 @@ def verify_signature_v1(public_key: bytes, signature: bytes, app_id: str, timest Returns: The compressed public key if valid, None otherwise + """ if not CRYPTO_AVAILABLE: raise ImportError( @@ -1308,7 +1380,7 @@ def verify_signature_v1(public_key: bytes, signature: bytes, app_id: str, timest prefix = b"dstack-env-encrypt-pubkey" if app_id.startswith("0x"): app_id = app_id[2:] - timestamp_bytes = timestamp.to_bytes(8, byteorder='big') + timestamp_bytes = timestamp.to_bytes(8, byteorder="big") message = prefix + b":" + bytes.fromhex(app_id) + timestamp_bytes + public_key # Hash the message with Keccak-256 @@ -1318,15 +1390,14 @@ def verify_signature_v1(public_key: bytes, signature: bytes, app_id: str, timest try: sig = keys.Signature(signature_bytes=signature) recovered_key = sig.recover_public_key_from_msg_hash(message_hash) - return '0x' + recovered_key.to_compressed_bytes().hex() + return "0x" + recovered_key.to_compressed_bytes().hex() except Exception as e: print(f"Signature v1 verification failed: {e}", file=sys.stderr) return None def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Optional[str]: - """ - Verify the legacy signature (without timestamp) of a public key. + """Verify the legacy signature (without timestamp) of a public key. Args: public_key: The public key bytes to verify @@ -1343,6 +1414,7 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option >>> compressed_pubkey = verify_signature(public_key, signature, app_id) >>> print(compressed_pubkey) 0x0217610d74cbd39b6143842c6d8bc310d79da1d82cc9d17f8876376221eda0c38f + """ if not CRYPTO_AVAILABLE: raise ImportError( @@ -1369,45 +1441,46 @@ def verify_signature(public_key: bytes, signature: bytes, app_id: str) -> Option sig = keys.Signature(signature_bytes=signature) recovered_key = sig.recover_public_key_from_msg_hash(message_hash) - return '0x' + recovered_key.to_compressed_bytes().hex() + return "0x" + recovered_key.to_compressed_bytes().hex() except Exception as e: print(f"Signature verification failed: {e}", file=sys.stderr) return None def load_whitelist() -> List[str]: - """ - Load the whitelist of trusted signers from a file. + """Load the whitelist of trusted signers from a file. Returns: List of trusted Ethereum addresses + """ if not os.path.exists(DEFAULT_KMS_WHITELIST_PATH): os.makedirs(os.path.dirname(DEFAULT_KMS_WHITELIST_PATH), exist_ok=True) return [] try: - with open(DEFAULT_KMS_WHITELIST_PATH, 'r') as f: + with open(DEFAULT_KMS_WHITELIST_PATH, "r") as f: data = json.load(f) - return data.get('trusted_signers', []) + return data.get("trusted_signers", []) except (json.JSONDecodeError, FileNotFoundError): return [] def save_whitelist(whitelist: List[str]) -> None: - """ - Save the whitelist of trusted signers to a file. + """Save the whitelist of trusted signers to a file. Args: whitelist: List of trusted Ethereum addresses + """ os.makedirs(os.path.dirname(DEFAULT_KMS_WHITELIST_PATH), exist_ok=True) - with open(DEFAULT_KMS_WHITELIST_PATH, 'w') as f: - json.dump({'trusted_signers': whitelist}, f, indent=2) + with open(DEFAULT_KMS_WHITELIST_PATH, "w") as f: + json.dump({"trusted_signers": whitelist}, f, indent=2) def main(): - parser = argparse.ArgumentParser(description='dstack-vmm CLI - Manage VMs') + """Parse arguments and dispatch to the appropriate command handler.""" + parser = argparse.ArgumentParser(description="dstack-vmm CLI - Manage VMs") # Load config file defaults config = load_config() @@ -1417,66 +1490,78 @@ def main(): # Priority: command line > environment variable > config file > auto-discovery > default default_url = os.environ.get( - 'DSTACK_VMM_URL', - config.get('url', 'http://localhost:8080')) - default_auth_user = os.environ.get( - 'DSTACK_VMM_AUTH_USER', - config.get('auth_user')) + "DSTACK_VMM_URL", config.get("url", "http://localhost:8080") + ) + default_auth_user = os.environ.get("DSTACK_VMM_AUTH_USER", config.get("auth_user")) default_auth_password = os.environ.get( - 'DSTACK_VMM_AUTH_PASSWORD', - config.get('auth_password')) + "DSTACK_VMM_AUTH_PASSWORD", config.get("auth_password") + ) parser.add_argument( - '--url', default=default_url, - help='dstack-vmm API URL (can also be set via DSTACK_VMM_URL env var or config file)') + "--url", + default=default_url, + help="dstack-vmm API URL (can also be set via DSTACK_VMM_URL env var or config file)", + ) # Basic authentication arguments parser.add_argument( - '--auth-user', default=default_auth_user, - help='Basic auth username (can also be set via DSTACK_VMM_AUTH_USER env var or config file)') + "--auth-user", + default=default_auth_user, + help="Basic auth username (can also be set via DSTACK_VMM_AUTH_USER env var or config file)", + ) parser.add_argument( - '--auth-password', default=default_auth_password, - help='Basic auth password (can also be set via DSTACK_VMM_AUTH_PASSWORD env var or config file)') + "--auth-password", + default=default_auth_password, + help="Basic auth password (can also be set via DSTACK_VMM_AUTH_PASSWORD env var or config file)", + ) - subparsers = parser.add_subparsers(dest='command', help='Commands') + subparsers = parser.add_subparsers(dest="command", help="Commands") # VMM discovery commands ls_vmm_parser = subparsers.add_parser( - 'ls-vmm', help='List all running VMM instances on this host') + "ls-vmm", help="List all running VMM instances on this host" + ) ls_vmm_parser.add_argument( - '--json', action='store_true', help='Output in JSON format') + "--json", action="store_true", help="Output in JSON format" + ) switch_vmm_parser = subparsers.add_parser( - 'switch-vmm', help='Switch active VMM instance') + "switch-vmm", help="Switch active VMM instance" + ) switch_vmm_parser.add_argument( - 'vmm_id', help='VMM instance ID (prefix match supported)') + "vmm_id", help="VMM instance ID (prefix match supported)" + ) # List command - lsvm_parser = subparsers.add_parser('lsvm', help='List VMs') + lsvm_parser = subparsers.add_parser("lsvm", help="List VMs") lsvm_parser.add_argument( - '-v', '--verbose', action='store_true', help='Show detailed information') + "-v", "--verbose", action="store_true", help="Show detailed information" + ) lsvm_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') + "--json", action="store_true", help="Output in JSON format for automation" + ) # Info command - info_parser = subparsers.add_parser('info', help='Show detailed VM information') - info_parser.add_argument('vm_id', help='VM ID to show info for') + info_parser = subparsers.add_parser("info", help="Show detailed VM information") + info_parser.add_argument("vm_id", help="VM ID to show info for") info_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') + "--json", action="store_true", help="Output in JSON format for automation" + ) # Start command - start_parser = subparsers.add_parser('start', help='Start a VM') - start_parser.add_argument('vm_id', help='VM ID to start') + start_parser = subparsers.add_parser("start", help="Start a VM") + start_parser.add_argument("vm_id", help="VM ID to start") # Stop command - stop_parser = subparsers.add_parser('stop', help='Stop a VM') - stop_parser.add_argument('vm_id', help='VM ID to stop') + stop_parser = subparsers.add_parser("stop", help="Stop a VM") + stop_parser.add_argument("vm_id", help="VM ID to stop") stop_parser.add_argument( - '-f', '--force', action='store_true', help='Force stop the VM') + "-f", "--force", action="store_true", help="Force stop the VM" + ) # Remove command - remove_parser = subparsers.add_parser('remove', help='Remove a VM') - remove_parser.add_argument('vm_id', help='VM ID to remove') + remove_parser = subparsers.add_parser("remove", help="Remove a VM") + remove_parser.add_argument("vm_id", help="VM ID to remove") # Resize command resize_parser = subparsers.add_parser("resize", help="Resize a VM") @@ -1491,147 +1576,210 @@ def main(): resize_parser.add_argument("--image", type=str, help="Image name") # Logs command - logs_parser = subparsers.add_parser('logs', help='Show VM logs') - logs_parser.add_argument('vm_id', help='VM ID to show logs for') - logs_parser.add_argument('-n', '--lines', type=int, - default=20, help='Number of lines to show') + logs_parser = subparsers.add_parser("logs", help="Show VM logs") + logs_parser.add_argument("vm_id", help="VM ID to show logs for") logs_parser.add_argument( - '-f', '--follow', action='store_true', help='Follow log output') + "-n", "--lines", type=int, default=20, help="Number of lines to show" + ) + logs_parser.add_argument( + "-f", "--follow", action="store_true", help="Follow log output" + ) # Compose command compose_parser = subparsers.add_parser( - 'compose', help='Create a new app-compose.json file') - compose_parser.add_argument('--name', required=True, help='VM image name') - compose_parser.add_argument( - '--docker-compose', required=True, help='Path to docker-compose.yml file') + "compose", help="Create a new app-compose.json file" + ) + compose_parser.add_argument("--name", required=True, help="VM image name") compose_parser.add_argument( - '--prelaunch-script', default=None, help='Path to prelaunch script') + "--docker-compose", required=True, help="Path to docker-compose.yml file" + ) compose_parser.add_argument( - '--kms', action='store_true', help='Enable KMS') + "--prelaunch-script", default=None, help="Path to prelaunch script" + ) + compose_parser.add_argument("--kms", action="store_true", help="Enable KMS") compose_parser.add_argument( - '--gateway', action='store_true', help='Enable dstack-gateway') + "--gateway", action="store_true", help="Enable dstack-gateway" + ) compose_parser.add_argument( - '--local-key-provider', action='store_true', help='Enable local key provider') + "--local-key-provider", action="store_true", help="Enable local key provider" + ) compose_parser.add_argument( - '--key-provider', choices=['none', 'kms', 'local'], default=None, - help='Override key provider type (none, kms, local)') + "--key-provider", + choices=["none", "kms", "local"], + default=None, + help="Override key provider type (none, kms, local)", + ) compose_parser.add_argument( - '--key-provider-id', default=None, help='Key provider ID if you want to bind to a specific key provider') + "--key-provider-id", + default=None, + help="Key provider ID if you want to bind to a specific key provider", + ) compose_parser.add_argument( - '--public-logs', action='store_true', help='Enable public logs') + "--public-logs", action="store_true", help="Enable public logs" + ) compose_parser.add_argument( - '--public-sysinfo', action='store_true', help='Enable public sysinfo') + "--public-sysinfo", action="store_true", help="Enable public sysinfo" + ) compose_parser.add_argument( - '--env-file', help='File with environment variables to encrypt', default=None) + "--env-file", help="File with environment variables to encrypt", default=None + ) compose_parser.add_argument( - '--no-instance-id', action='store_true', help='Disable instance ID') + "--no-instance-id", action="store_true", help="Disable instance ID" + ) compose_parser.add_argument( - '--secure-time', action='store_true', help='Enable secure time') + "--secure-time", action="store_true", help="Enable secure time" + ) compose_parser.add_argument( - '--swap', type=parse_memory_size, default=None, - help='Swap size (e.g. 4G). Set to 0 to disable') + "--swap", + type=parse_memory_size, + default=None, + help="Swap size (e.g. 4G). Set to 0 to disable", + ) compose_parser.add_argument( - '--output', required=True, help='Path to output app-compose.json file') + "--output", required=True, help="Path to output app-compose.json file" + ) # Deploy command - deploy_parser = subparsers.add_parser('deploy', help='Deploy a new VM') - deploy_parser.add_argument('--name', required=True, help='VM name') - deploy_parser.add_argument('--image', required=True, help='VM image') + deploy_parser = subparsers.add_parser("deploy", help="Deploy a new VM") + deploy_parser.add_argument("--name", required=True, help="VM name") + deploy_parser.add_argument("--image", required=True, help="VM image") + deploy_parser.add_argument( + "--compose", required=True, help="Path to app-compose.json file" + ) + deploy_parser.add_argument("--vcpu", type=int, default=1, help="Number of vCPUs") deploy_parser.add_argument( - '--compose', required=True, help='Path to app-compose.json file') + "--memory", + type=parse_memory_size, + default=1024, + help="Memory size (e.g. 1G, 100M)", + ) deploy_parser.add_argument( - '--vcpu', type=int, default=1, help='Number of vCPUs') + "--disk", type=parse_disk_size, default=20, help="Disk size (e.g. 1G, 100M)" + ) deploy_parser.add_argument( - '--memory', type=parse_memory_size, default=1024, help='Memory size (e.g. 1G, 100M)') + "--swap", + type=parse_memory_size, + default=None, + help="Swap size (e.g. 4G). Set to 0 to disable", + ) deploy_parser.add_argument( - '--disk', type=parse_disk_size, default=20, help='Disk size (e.g. 1G, 100M)') + "--env-file", help="File with environment variables to encrypt", default=None + ) deploy_parser.add_argument( - '--swap', type=parse_memory_size, default=None, - help='Swap size (e.g. 4G). Set to 0 to disable') + "--user-config", help="Path to user config file", default=None + ) + deploy_parser.add_argument("--app-id", help="Application ID", default=None) deploy_parser.add_argument( - '--env-file', help='File with environment variables to encrypt', default=None) + "--port", + action="append", + type=str, + help="Port mapping in format: protocol[:address]:from:to", + ) deploy_parser.add_argument( - '--user-config', help='Path to user config file', default=None) - deploy_parser.add_argument('--app-id', help='Application ID', default=None) - deploy_parser.add_argument('--port', action='append', type=str, - help='Port mapping in format: protocol[:address]:from:to') - deploy_parser.add_argument('--gpu', action='append', type=str, - help='GPU slot to attach (can be used multiple times), or "all" to attach all GPUs') - deploy_parser.add_argument('--ppcie', action='store_true', - help='Enable PPCIE (Protected PCIe) mode - attach all available GPUs') - deploy_parser.add_argument('--pin-numa', action='store_true', - help='Pin VM to specific NUMA node') - deploy_parser.add_argument('--hugepages', action='store_true', - help='Enable hugepages for the VM') - deploy_parser.add_argument('--kms-url', action='append', type=str, - help='KMS URL') - deploy_parser.add_argument('--gateway-url', action='append', type=str, - help='Gateway URL') - deploy_parser.add_argument('--stopped', action='store_true', - help='Create VM in stopped state (requires dstack-vmm >= 0.5.4)') - deploy_parser.add_argument('--no-tee', dest='no_tee', action='store_true', - help='Disable Intel TDX / run without TEE') - deploy_parser.add_argument('--tee', dest='no_tee', action='store_false', - help='Force-enable Intel TDX (default)') + "--gpu", + action="append", + type=str, + help='GPU slot to attach (can be used multiple times), or "all" to attach all GPUs', + ) + deploy_parser.add_argument( + "--ppcie", + action="store_true", + help="Enable PPCIE (Protected PCIe) mode - attach all available GPUs", + ) + deploy_parser.add_argument( + "--pin-numa", action="store_true", help="Pin VM to specific NUMA node" + ) + deploy_parser.add_argument( + "--hugepages", action="store_true", help="Enable hugepages for the VM" + ) + deploy_parser.add_argument("--kms-url", action="append", type=str, help="KMS URL") + deploy_parser.add_argument( + "--gateway-url", action="append", type=str, help="Gateway URL" + ) + deploy_parser.add_argument( + "--stopped", + action="store_true", + help="Create VM in stopped state (requires dstack-vmm >= 0.5.4)", + ) + deploy_parser.add_argument( + "--no-tee", + dest="no_tee", + action="store_true", + help="Disable Intel TDX / run without TEE", + ) + deploy_parser.add_argument( + "--tee", + dest="no_tee", + action="store_false", + help="Force-enable Intel TDX (default)", + ) deploy_parser.set_defaults(no_tee=False) - deploy_parser.add_argument('--net', choices=['bridge', 'user'], - help='Networking mode (default: use global config)') - + deploy_parser.add_argument( + "--net", + choices=["bridge", "user"], + help="Networking mode (default: use global config)", + ) # Images command - lsimage_parser = subparsers.add_parser( - 'lsimage', help='List available images') + lsimage_parser = subparsers.add_parser("lsimage", help="List available images") lsimage_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') + "--json", action="store_true", help="Output in JSON format for automation" + ) # GPU command - lsgpu_parser = subparsers.add_parser('lsgpu', help='List available GPUs') + lsgpu_parser = subparsers.add_parser("lsgpu", help="List available GPUs") lsgpu_parser.add_argument( - '--json', action='store_true', help='Output in JSON format for automation') + "--json", action="store_true", help="Output in JSON format for automation" + ) # Update environment variables command update_env_parser = subparsers.add_parser( - 'update-env', help='Update environment variables for a VM') - update_env_parser.add_argument('vm_id', help='VM ID to update') + "update-env", help="Update environment variables for a VM" + ) + update_env_parser.add_argument("vm_id", help="VM ID to update") update_env_parser.add_argument( - '--env-file', required=True, help='File with environment variables to encrypt') + "--env-file", required=True, help="File with environment variables to encrypt" + ) update_env_parser.add_argument( - '--kms-url', action='append', type=str, - help='KMS URL') + "--kms-url", action="append", type=str, help="KMS URL" + ) # Whitelist command - kms_parser = subparsers.add_parser( - 'kms', help='Manage trusted KMS whitelist') - kms_subparsers = kms_parser.add_subparsers( - dest='kms_action', help='KMS actions') + kms_parser = subparsers.add_parser("kms", help="Manage trusted KMS whitelist") + kms_subparsers = kms_parser.add_subparsers(dest="kms_action", help="KMS actions") # List whitelist - list_kms_parser = kms_subparsers.add_parser( - 'list', help='List trusted signers') + kms_subparsers.add_parser("list", help="List trusted signers") # Add to whitelist add_kms_parser = kms_subparsers.add_parser( - 'add', help='Add public key to trusted signers') - add_kms_parser.add_argument('pubkey', help='Public key to add') + "add", help="Add public key to trusted signers" + ) + add_kms_parser.add_argument("pubkey", help="Public key to add") # Remove from whitelist remove_kms_parser = kms_subparsers.add_parser( - 'remove', help='Remove public key from trusted signers') - remove_kms_parser.add_argument('pubkey', help='Public key to remove') + "remove", help="Remove public key from trusted signers" + ) + remove_kms_parser.add_argument("pubkey", help="Public key to remove") # Update app compose update_app_compose_parser = subparsers.add_parser( - 'update-app-compose', help='Update app compose for a VM') - update_app_compose_parser.add_argument('vm_id', help='VM ID to update') + "update-app-compose", help="Update app compose for a VM" + ) + update_app_compose_parser.add_argument("vm_id", help="VM ID to update") update_app_compose_parser.add_argument( - 'compose', help='Path to app-compose.json file') + "compose", help="Path to app-compose.json file" + ) # Update user config update_user_config_parser = subparsers.add_parser( - 'update-user-config', help='Update user config for a VM') - update_user_config_parser.add_argument('vm_id', help='VM ID to update') + "update-user-config", help="Update user config for a VM" + ) + update_user_config_parser.add_argument("vm_id", help="VM ID to update") update_user_config_parser.add_argument( - 'user_config', help='Path to user config file') + "user_config", help="Path to user config file" + ) # Update port mapping update_ports_parser = subparsers.add_parser( @@ -1653,32 +1801,24 @@ def main(): update_parser.add_argument("vm_id", help="VM ID to update") # Resource options (requires VM to be stopped) - update_parser.add_argument( - "--vcpu", type=int, help="Number of vCPUs" - ) + update_parser.add_argument("--vcpu", type=int, help="Number of vCPUs") update_parser.add_argument( "--memory", type=parse_memory_size, help="Memory size (e.g. 1G, 100M)" ) update_parser.add_argument( "--disk", type=parse_disk_size, help="Disk size (e.g. 20G, 1T)" ) - update_parser.add_argument( - "--image", type=str, help="Image name" - ) + update_parser.add_argument("--image", type=str, help="Image name") # Application options - update_parser.add_argument( - "--compose", help="Path to app-compose.json file" - ) + update_parser.add_argument("--compose", help="Path to app-compose.json file") update_parser.add_argument( "--prelaunch-script", help="Path to pre-launch script file" ) update_parser.add_argument( "--env-file", help="File with environment variables to encrypt" ) - update_parser.add_argument( - "--user-config", help="Path to user config file" - ) + update_parser.add_argument("--user-config", help="Path to user config file") # Port mapping options (mutually exclusive with --no-ports) port_group = update_parser.add_mutually_exclusive_group() port_group.add_argument( @@ -1693,7 +1833,9 @@ def main(): help="Remove all port mappings from the VM", ) update_parser.add_argument( - "--swap", type=parse_memory_size, help="Swap size (e.g. 4G). Set to 0 to disable" + "--swap", + type=parse_memory_size, + help="Swap size (e.g. 4G). Set to 0 to disable", ) # GPU options (mutually exclusive) @@ -1702,7 +1844,7 @@ def main(): "--gpu", action="append", type=str, - help="GPU slot to attach (can be used multiple times), or \"all\" to attach all GPUs", + help='GPU slot to attach (can be used multiple times), or "all" to attach all GPUs', ) gpu_group.add_argument( "--ppcie", @@ -1732,17 +1874,15 @@ def main(): update_parser.set_defaults(no_tee=None) # KMS URL for environment encryption - update_parser.add_argument( - "--kms-url", action="append", type=str, help="KMS URL" - ) + update_parser.add_argument("--kms-url", action="append", type=str, help="KMS URL") args = parser.parse_args() # Handle discovery commands before creating CLI (they don't need a connection) - if args.command == 'ls-vmm': + if args.command == "ls-vmm": cmd_ls_vmm(args) return - elif args.command == 'switch-vmm': + elif args.command == "switch-vmm": cmd_switch_vmm(args) return @@ -1750,17 +1890,17 @@ def main(): url = resolve_vmm_url(instances, config, args.url) cli = VmmCLI(url, args.auth_user, args.auth_password) - if args.command == 'lsvm': + if args.command == "lsvm": cli.list_vms(args.verbose, args.json) - elif args.command == 'info': + elif args.command == "info": cli.show_info(args.vm_id, args.json) - elif args.command == 'start': + elif args.command == "start": cli.start_vm(args.vm_id) - elif args.command == 'stop': + elif args.command == "stop": cli.stop_vm(args.vm_id, args.force) - elif args.command == 'remove': + elif args.command == "remove": cli.remove_vm(args.vm_id) - elif args.command == 'resize': + elif args.command == "resize": cli.resize_vm( args.vm_id, vcpu=args.vcpu, @@ -1768,24 +1908,24 @@ def main(): disk_size=args.disk, image=args.image, ) - elif args.command == 'logs': + elif args.command == "logs": cli.show_logs(args.vm_id, args.lines, args.follow) - elif args.command == 'compose': + elif args.command == "compose": cli.create_app_compose(args) - elif args.command == 'deploy': + elif args.command == "deploy": cli.create_vm(args) - elif args.command == 'lsimage': + elif args.command == "lsimage": cli.list_images(args.json) - elif args.command == 'lsgpu': + elif args.command == "lsgpu": cli.list_gpus(args.json) - elif args.command == 'update-env': - cli.update_vm_env(args.vm_id, parse_env_file( - args.env_file), kms_urls=args.kms_url) - elif args.command == 'update-user-config': - cli.update_vm_user_config( - args.vm_id, open(args.user_config, 'r').read()) - elif args.command == 'update-app-compose': - cli.update_vm_app_compose(args.vm_id, open(args.compose, 'r').read()) + elif args.command == "update-env": + cli.update_vm_env( + args.vm_id, parse_env_file(args.env_file), kms_urls=args.kms_url + ) + elif args.command == "update-user-config": + cli.update_vm_user_config(args.vm_id, open(args.user_config, "r").read()) + elif args.command == "update-app-compose": + cli.update_vm_app_compose(args.vm_id, open(args.compose, "r").read()) elif args.command == "update-ports": cli.update_vm_ports(args.vm_id, args.port) elif args.command == "update": @@ -1793,7 +1933,7 @@ def main(): if args.compose: compose_content = read_utf8(args.compose) prelaunch_content = None - if hasattr(args, 'prelaunch_script') and args.prelaunch_script: + if hasattr(args, "prelaunch_script") and args.prelaunch_script: prelaunch_content = read_utf8(args.prelaunch_script) user_config_content = None if args.user_config: @@ -1806,28 +1946,28 @@ def main(): image=args.image, docker_compose_content=compose_content, prelaunch_script=prelaunch_content, - swap_size=args.swap if hasattr(args, 'swap') else None, + swap_size=args.swap if hasattr(args, "swap") else None, env_file=args.env_file, user_config=user_config_content, ports=args.port, - no_ports=args.no_ports if hasattr(args, 'no_ports') else False, + no_ports=args.no_ports if hasattr(args, "no_ports") else False, gpu_slots=args.gpu, attach_all=args.ppcie, - no_gpus=args.no_gpus if hasattr(args, 'no_gpus') else False, + no_gpus=args.no_gpus if hasattr(args, "no_gpus") else False, kms_urls=args.kms_url, no_tee=args.no_tee, ) - elif args.command == 'kms': + elif args.command == "kms": if not args.kms_action: kms_parser.print_help() else: cli.manage_kms_whitelist( action=args.kms_action, - pubkey=getattr(args, 'pubkey', None), + pubkey=getattr(args, "pubkey", None), ) else: parser.print_help() -if __name__ == '__main__': +if __name__ == "__main__": main()