From f0a70bb664948551202458132380e3b5176b3a92 Mon Sep 17 00:00:00 2001 From: Marc Olivier Bergeron Date: Tue, 26 May 2026 10:57:36 -0400 Subject: [PATCH] Added dynamic users for post files. Fixed a few path issues. --- ctf/commands/check.py | 2 +- ctf/commands/deploy.py | 10 ++-- ctf/commands/destroy.py | 58 +++++++++++++++------- ctf/commands/flags.py | 11 ++--- ctf/commands/new.py | 45 ++++++++--------- ctf/commands/post/new.py | 95 +++++++++++++++++++++++++++++------- ctf/commands/services.py | 11 ++--- ctf/commands/stats.py | 2 +- ctf/common/utils.py | 8 +-- ctf/validate_json_schemas.py | 4 +- 10 files changed, 159 insertions(+), 87 deletions(-) diff --git a/ctf/commands/check.py b/ctf/commands/check.py index dc91e30..892d11d 100644 --- a/ctf/commands/check.py +++ b/ctf/commands/check.py @@ -52,5 +52,5 @@ def check( # Check if Git LFS is installed on the system as it will be required for deployment. if not check_git_lfs(): LOG.warning( - msg="Git LFS is missing from your system. Install it before deploying." + "Git LFS is missing from your system. Install it before deploying." ) diff --git a/ctf/commands/deploy.py b/ctf/commands/deploy.py index d252a8b..8530016 100644 --- a/ctf/commands/deploy.py +++ b/ctf/commands/deploy.py @@ -121,7 +121,7 @@ def deploy( # Check if Git LFS is installed on the system as it is required for deployment. if not check_git_lfs(): LOG.critical( - msg="Git LFS is missing from your system. Install it before deploying." + "Git LFS is missing from your system. Install it before deploying." ) exit(1) @@ -376,18 +376,14 @@ def deploy( and (track_index := int(track_index)) and 0 < track_index <= len(tracks_list) ): - LOG.info( - msg=f"Running `incus project switch {tracks_list[track_index - 1]}`" - ) + LOG.info(f"Running `incus project switch {tracks_list[track_index - 1]}`") subprocess.run( args=["incus", "project", "switch", tracks_list[track_index - 1].name], check=True, env=ENV, ) elif track_index: - LOG.warning( - msg=f"Could not switch project, unrecognized input: {track_index}." - ) + LOG.warning(f"Could not switch project, unrecognized input: {track_index}.") def terraform_apply( diff --git a/ctf/commands/destroy.py b/ctf/commands/destroy.py index 9437c05..79b3c7c 100644 --- a/ctf/commands/destroy.py +++ b/ctf/commands/destroy.py @@ -3,6 +3,7 @@ import typer from pydantic import ValidationError +from rich.prompt import Confirm from typing_extensions import Annotated from ctf import ENV @@ -109,7 +110,7 @@ def destroy( project_list = list((projects - terraform_tracks)) if len(project_list) == 0: LOG.critical( - msg="No project to switch to. This should never happen as the default should always exists." + "No project to switch to. This should never happen as the default should always exists." ) exit(1) @@ -178,13 +179,22 @@ def destroy( ) } + network_zones = { + Track(name=network_zone["name"]) + for network_zone in json.loads( + s=subprocess.run( + args=["incus", "network", "zone", "list", "--format=json"], + check=False, + capture_output=True, + env=ENV, + ).stdout.decode() + ) + } + for module in terraform_tracks: if module in projects: LOG.warning(f"The project {module.name} was not destroyed properly.") - if ( - force - or (input("Do you want to destroy it? [Y/n] ").lower() or "y") == "y" - ): + if force or Confirm.ask("Do you want to destroy it?", default=True): subprocess.run( args=["incus", "project", "delete", module.name, "--force"], check=False, @@ -194,13 +204,8 @@ def destroy( ) if (tmp_module_name := module.name[:15]) in networks: - LOG.warning( - msg=f"The network {tmp_module_name} was not destroyed properly." - ) - if ( - force - or (input("Do you want to destroy it? [Y/n] ").lower() or "y") == "y" - ): + LOG.warning(f"The network {tmp_module_name} was not destroyed properly.") + if force or Confirm.ask("Do you want to destroy it?", default=True): subprocess.run( args=["incus", "network", "delete", tmp_module_name], check=False, @@ -212,12 +217,9 @@ def destroy( tmp_module := Track(name=f"{module.name}-default") ) in network_acls: LOG.warning( - msg=f"The network ACL {tmp_module.name} was not destroyed properly." + f"The network ACL {tmp_module.name} was not destroyed properly." ) - if ( - force - or (input("Do you want to destroy it? [Y/n] ").lower() or "y") == "y" - ): + if force or Confirm.ask("Do you want to destroy it?", default=True): subprocess.run( args=["incus", "network", "acl", "delete", tmp_module.name], check=False, @@ -225,6 +227,28 @@ def destroy( env=ENV, ) + if Track(name="ctf") in network_zones: + LOG.warning('The network zone "ctf" was not destroyed properly.') + if force or Confirm.ask("Do you want to destroy it?", default=True): + subprocess.run( + args=["incus", "network", "zone", "delete", "ctf"], + check=False, + capture_output=True, + env=ENV, + ) + + if Track(name="simulated-production-acl") in network_acls: + LOG.warning( + 'The network ACL "simulated-production-acl" was not destroyed properly.' + ) + if force or Confirm.ask("Do you want to destroy it?", default=True): + subprocess.run( + args=["incus", "network", "acl", "delete", "simulated-production-acl"], + check=False, + capture_output=True, + env=ENV, + ) + remove_tracks_from_terraform_modules( tracks=terraform_tracks, remote=remote, diff --git a/ctf/commands/flags.py b/ctf/commands/flags.py index 9bc5260..f97d54e 100644 --- a/ctf/commands/flags.py +++ b/ctf/commands/flags.py @@ -1,7 +1,6 @@ import csv import io import json -import os from enum import StrEnum import rich @@ -39,16 +38,16 @@ def flags( ) -> None: distinct_tracks: set[Track] = set() - for entry in os.listdir( + for entry in ( challenges_directory := (find_ctf_root_directory() / "challenges") - ): + ).iterdir(): if (track_directory := challenges_directory / entry).is_dir() and ( track_directory / "track.yaml" ).exists(): if not tracks: - distinct_tracks.add(Track(name=entry)) - elif entry in tracks: - distinct_tracks.add(Track(name=entry)) + distinct_tracks.add(Track(name=entry.name)) + elif entry.name in tracks: + distinct_tracks.add(Track(name=entry.name)) flags = [] for track in distinct_tracks: diff --git a/ctf/commands/new.py b/ctf/commands/new.py index 3f5fd21..97763ce 100644 --- a/ctf/commands/new.py +++ b/ctf/commands/new.py @@ -70,7 +70,7 @@ def new( LOG.info(f"Creating a new track: {name}") if not re.match(pattern=r"^[a-z][a-z0-9\-]{0,61}[a-z0-9]$", string=name): LOG.critical( - msg="""The track name Valid instance names must fulfill the following requirements: + """The track name Valid instance names must fulfill the following requirements: * The name must be between 1 and 63 characters long; * The name must contain only letters, numbers and dashes from the ASCII table; * The name must not start with a digit or a dash; @@ -136,8 +136,8 @@ def new( "with_virtual_machine": with_virtual_machine, } ) - with open( - (p := new_challenge_directory / "track.yaml"), mode="w", encoding="utf-8" + with (p := new_challenge_directory / "track.yaml").open( + mode="w", encoding="utf-8" ) as f: f.write(render) @@ -145,8 +145,8 @@ def new( readme_template = env.get_template(name=os.path.join("common", "README.md.j2")) render = readme_template.render(data={"name": name}) - with open( - (p := new_challenge_directory / "README.md"), mode="w", encoding="utf-8" + with (p := new_challenge_directory / "README.md").open( + mode="w", encoding="utf-8" ) as f: f.write(render) @@ -159,8 +159,8 @@ def new( track_template = env.get_template(name=os.path.join("common", "topic.yaml.j2")) render = track_template.render(data={"name": name}) - with open( - (p := posts_directory / f"{name}.yaml"), mode="w", encoding="utf-8" + with (p := posts_directory / f"{name}.yaml").open( + mode="w", encoding="utf-8" ) as f: f.write(render) @@ -168,8 +168,7 @@ def new( track_template = env.get_template(name=os.path.join("common", "post.yaml.j2")) render = track_template.render(data={"name": name}) - with open( - (p := os.path.join(posts_directory, f"{name}_flag1.yaml")), + with (p := posts_directory / f"{name}_flag1.yaml").open( mode="w", encoding="utf-8", ) as f: @@ -205,8 +204,8 @@ def new( "is_windows": template == Template.WINDOWS_VM, } ) - with open( - (p := terraform_directory / "main.tf"), mode="w", encoding="utf-8" + with (p := terraform_directory / "main.tf").open( + mode="w", encoding="utf-8" ) as f: f.write(render) @@ -243,8 +242,8 @@ def new( "with_virtual_machine": with_virtual_machine, } ) - with open( - (p := ansible_directory / "deploy.yaml"), mode="w", encoding="utf-8" + with (p := ansible_directory / "deploy.yaml").open( + mode="w", encoding="utf-8" ) as f: f.write(render) @@ -264,8 +263,8 @@ def new( data={"name": name, "with_build": with_build_container} ) - with open( - (p := ansible_directory / "build.yaml"), mode="w", encoding="utf-8" + with (p := ansible_directory / "build.yaml").open( + mode="w", encoding="utf-8" ) as f: f.write(render) LOG.debug(f"Wrote {p}.") @@ -279,8 +278,8 @@ def new( "is_windows": template == Template.WINDOWS_VM, } ) - with open( - (p := ansible_directory / "inventory"), mode="w", encoding="utf-8" + with (p := ansible_directory / "inventory").open( + mode="w", encoding="utf-8" ) as f: f.write(render) @@ -296,8 +295,7 @@ def new( os.path.join(Template.APACHE_PHP, "index.php.j2") ) render = track_template.render(data={"name": name}) - with open( - (p := ansible_challenge_directory / "index.php"), + with (p := ansible_challenge_directory / "index.php").open( mode="w", encoding="utf-8", ) as f: @@ -310,8 +308,7 @@ def new( os.path.join(Template.PYTHON_SERVICE, "app.py.j2") ) render = track_template.render(data={"name": name}) - with open( - (p := ansible_challenge_directory / "app.py"), + with (p := ansible_challenge_directory / "app.py").open( mode="w", encoding="utf-8", ) as f: @@ -319,8 +316,7 @@ def new( LOG.debug(f"Wrote {p}.") - with open( - (p := ansible_challenge_directory / "flag-1.txt"), + with (p := ansible_challenge_directory / "flag-1.txt").open( mode="w", encoding="utf-8", ) as f: @@ -341,8 +337,7 @@ def new( os.path.join(Template.RUST_WEBSERVICE, "Cargo.toml.j2") ) render = manifest_template.render(data={"name": name}) - with open( - (p := ansible_challenge_directory / "Cargo.toml"), + with (p := ansible_challenge_directory / "Cargo.toml").open( mode="w", encoding="utf-8", ) as f: diff --git a/ctf/commands/post/new.py b/ctf/commands/post/new.py index 5e6f4e4..92a6b2e 100644 --- a/ctf/commands/post/new.py +++ b/ctf/commands/post/new.py @@ -1,10 +1,11 @@ import datetime +import json import os from enum import StrEnum from pathlib import Path import typer -from pydantic import BaseModel, Field +from pydantic import BaseModel, field_validator from rich.prompt import Confirm, IntPrompt, Prompt from typing_extensions import Annotated @@ -12,19 +13,13 @@ from ctf.common.models import Track from ctf.common.utils import ( get_all_available_tracks, + get_ctf_script_schemas_directory, parse_track_yaml, ) -# TODO: Find a way to allow this year's users. Maybe this could be done with this issue https://github.com/nsec/ctf-script/issues/42. - app = typer.Typer() -class ApiUser(StrEnum): - NSEC = "nsec" - SYSTEM = "system" - - class TriggerType(StrEnum): FLAG = "flag" NONE = "none" @@ -33,8 +28,52 @@ class TriggerType(StrEnum): class ApiPost(BaseModel): - user: ApiUser = Field(default=ApiUser.NSEC) - body: str = Field(default="") + user: str + body: str + + @field_validator("user", mode="after") + @classmethod + def is_even(cls, value: str) -> str: + if value not in _get_api_users_from_schema(): + raise ValueError( + f"{value} is not a valid user from {_get_api_users_from_schema()}" + ) + return value + + +__API_USERS: list[str] = [] + + +def _get_api_users_from_schema(lowercase: bool = False) -> list[str]: + global __API_USERS + if not __API_USERS: + __API_USERS = json.load( + (get_ctf_script_schemas_directory() / "post.json").open( + mode="r", encoding="utf-8" + ) + )["properties"]["api"]["properties"]["user"]["enum"] + + if lowercase: + return [user.lower() if lowercase else user for user in __API_USERS] + + return __API_USERS + + +def _validate_user(value: str | None) -> str | None: + if value and value not in _get_api_users_from_schema(lowercase=True): + raise typer.BadParameter( + f"{value} is not a valid user from {_get_api_users_from_schema()}" + ) + + return value + + +def _autocomplete_user(value: str) -> list[str]: + completion: list[str] = [] + for name in _get_api_users_from_schema(): + if name.lower().startswith(value.lower()): + completion.append(name) + return completion def _format_yaml_block(text: str, indent: int = 2) -> str: @@ -150,7 +189,7 @@ def _render_post_yaml( lines.extend( [ " - api:", - f" user: {api_post.user.value}", + f" user: {api_post.user}", " body: |-", _format_yaml_block(api_post.body, indent=6), ] @@ -160,7 +199,7 @@ def _render_post_yaml( lines.extend( [ "api:", - f" user: {api_posts[0].user.value}", + f" user: {api_posts[0].user}", "body: |-", _format_yaml_block(api_posts[0].body), ] @@ -199,12 +238,15 @@ def new( ), ] = None, user: Annotated[ - ApiUser, + str | None, typer.Option( "--user", help="Discourse user posting this message. If multiple users, use --multiple-users instead.", + callback=_validate_user, + autocompletion=_autocomplete_user, + case_sensitive=False, ), - ] = ApiUser.NSEC, + ] = None, body: Annotated[ str, typer.Option( @@ -267,14 +309,31 @@ def new( while True: u = Prompt.ask( "user", - choices=[ApiUser.NSEC, ApiUser.SYSTEM], + choices=_get_api_users_from_schema(), show_choices=True, + case_sensitive=False, ) b = Prompt.ask("body") - api_posts.append(ApiPost(user=ApiUser(u), body=b)) + + api_posts.append(ApiPost(user=u, body=b)) + if not Confirm.ask("Adding more?"): break else: + if not user: + user = Prompt.ask( + "user", + choices=_get_api_users_from_schema(), + show_choices=True, + case_sensitive=False, + ) + + if ( + user not in _get_api_users_from_schema() + and user in _get_api_users_from_schema(lowercase=True) + ): + user = [u for u in _get_api_users_from_schema() if user == u.lower()][0] + api_posts.append(ApiPost(user=user, body=body)) match trigger: @@ -339,14 +398,14 @@ def new( ): raise typer.Exit(0) - post_file_path = _resolve_post_file_path( + post_file_path: Path = _resolve_post_file_path( posts_directory=posts_directory, track=track_obj, name=name, force=force, ) - post_yaml = _render_post_yaml( + post_yaml: str = _render_post_yaml( track=track_obj, api_posts=api_posts, trigger=trigger, diff --git a/ctf/commands/services.py b/ctf/commands/services.py index 03de2ad..f37ba14 100644 --- a/ctf/commands/services.py +++ b/ctf/commands/services.py @@ -1,4 +1,3 @@ -import os import socket import requests @@ -27,16 +26,16 @@ def services( ] = False, ) -> None: distinct_tracks: set[str] = set() - for entry in os.listdir( + for entry in ( challenges_directory := (find_ctf_root_directory() / "challenges") - ): + ).iterdir(): if (track_directory := (challenges_directory / entry)).is_dir() and ( track_directory / "track.yaml" ).exists(): if not tracks: - distinct_tracks.add(entry) - elif entry in tracks: - distinct_tracks.add(entry) + distinct_tracks.add(entry.name) + elif entry.name in tracks: + distinct_tracks.add(entry.name) all_services = [] diff --git a/ctf/commands/stats.py b/ctf/commands/stats.py index e1fd143..d61296c 100644 --- a/ctf/commands/stats.py +++ b/ctf/commands/stats.py @@ -423,6 +423,6 @@ def stats( def write_badge(name: str, svg: str) -> None: with open( - file=os.path.join(".badges", f"badge-{name}.svg"), mode="w", encoding="utf-8" + os.path.join(".badges", f"badge-{name}.svg"), mode="w", encoding="utf-8" ) as f: f.write(svg) diff --git a/ctf/common/utils.py b/ctf/common/utils.py index 909c226..7ae42c9 100644 --- a/ctf/common/utils.py +++ b/ctf/common/utils.py @@ -15,7 +15,7 @@ from ctf.common.logger import LOG from ctf.common.models import Track, TrackYaml -__CTF_ROOT_DIRECTORY = "" +__CTF_ROOT_DIRECTORY: Path | None = None def available_incus_remotes() -> list[str]: @@ -344,11 +344,11 @@ def get_ctf_script_schemas_directory() -> Path: def remove_ctf_script_root_directory_from_path(path: Path) -> Path: - return Path(os.path.relpath(path=path, start=find_ctf_root_directory())) + return Path(os.path.relpath(path, find_ctf_root_directory())) def load_yaml_file(file: Path) -> dict[str, Any]: - return yaml.safe_load(stream=open(file, mode="r", encoding="utf-8")) + return yaml.safe_load(file.open(mode="r", encoding="utf-8")) def parse_track_yaml(track_name: str) -> dict[str, Any]: @@ -378,7 +378,7 @@ def parse_post_yamls(track_name: str) -> list[dict]: def find_ctf_root_directory() -> Path: global __CTF_ROOT_DIRECTORY if __CTF_ROOT_DIRECTORY: - return Path(__CTF_ROOT_DIRECTORY) + return __CTF_ROOT_DIRECTORY path: Path = (Path(ENV.get("CTF_ROOT_DIR", "."))).expanduser().resolve() while not is_ctf_dir(path) and path != (path := (path / "..").resolve()): diff --git a/ctf/validate_json_schemas.py b/ctf/validate_json_schemas.py index 17fd452..e5a99a6 100644 --- a/ctf/validate_json_schemas.py +++ b/ctf/validate_json_schemas.py @@ -22,7 +22,7 @@ def validate_with_json_schemas(schema: Path, files_pattern: str) -> None: LOG.debug("Starting JSON Schema validator") LOG.debug(f"Schema: {schema}") - schema = json.load(open(schema, mode="r", encoding="utf-8")) + schema = json.load(schema.open(mode="r", encoding="utf-8")) if not isinstance(schema, dict): LOG.error(msg=f"Loaded schema was not a dictionary: {schema}") @@ -40,7 +40,7 @@ def validate_with_json_schemas(schema: Path, files_pattern: str) -> None: for file in files: LOG.debug(f"Validating {file}") yaml_document = yaml.safe_load( - stream=open(file=file, mode="r", encoding="utf-8") + stream=open(file, mode="r", encoding="utf-8") ) try: jsonschema.validate(instance=yaml_document, schema=schema)