From bc7dea2c41ad96528f199f6b500c33880ef7c05f Mon Sep 17 00:00:00 2001 From: Marc Olivier Bergeron Date: Mon, 25 May 2026 23:42:22 -0400 Subject: [PATCH] Added more post types. --- ctf/commands/post/new.py | 294 +++++++++++++++++++++++++++++---------- 1 file changed, 217 insertions(+), 77 deletions(-) diff --git a/ctf/commands/post/new.py b/ctf/commands/post/new.py index dd37690..5e6f4e4 100644 --- a/ctf/commands/post/new.py +++ b/ctf/commands/post/new.py @@ -1,9 +1,11 @@ +import datetime import os -import re from enum import StrEnum from pathlib import Path import typer +from pydantic import BaseModel, Field +from rich.prompt import Confirm, IntPrompt, Prompt from typing_extensions import Annotated from ctf.common.logger import LOG @@ -13,6 +15,8 @@ parse_track_yaml, ) +# TODO: Find a way to allow this year's users. Maybe this could be done with this issue https://github.com/nsec/ctf-script/issues/42. + app = typer.Typer() @@ -23,22 +27,19 @@ class ApiUser(StrEnum): class TriggerType(StrEnum): FLAG = "flag" + NONE = "none" + SCORE = "score" + TIMER = "timer" -def _format_yaml_block(text: str) -> str: - lines = text.splitlines() or [""] - return "\n".join(f" {line}" for line in lines) +class ApiPost(BaseModel): + user: ApiUser = Field(default=ApiUser.NSEC) + body: str = Field(default="") -def _default_post_filename(track: Track, tag: str) -> str: - normalized_track = track.name.replace("-", "_") - suffix = tag - if tag.startswith(normalized_track + "_"): - suffix = tag[len(normalized_track) + 1 :] - suffix = re.sub(r"[^a-zA-Z0-9_-]+", "-", suffix).strip("-_") - if not suffix: - suffix = "post" - return f"{track}-{suffix.replace('_', '-')}.yaml" +def _format_yaml_block(text: str, indent: int = 2) -> str: + lines = text.splitlines() or [""] + return "\n".join(f"{' ' * indent}{line}" for line in lines) def _get_available_discourse_tags(track: Track) -> list[str]: @@ -70,18 +71,9 @@ def _resolve_post_file_path( posts_directory: Path, track: Track, name: str | None, - tag: str | None, force: bool, ) -> Path: - filename = ( - f"{track}-{name}.yaml" - if name - else ( - _default_post_filename(track=track, tag=tag) - if tag - else f"{track}-post.yaml" - ) - ) + filename = f"{track}-{name}.yaml" if name else (f"{track}-post.yaml") if not force: filename = _add_counter_to_filename(posts_directory, filename) @@ -91,34 +83,89 @@ def _resolve_post_file_path( def _render_post_yaml( track: Track, - user: ApiUser, - body: str, - trigger: TriggerType | None = None, - tag: str | None = None, + api_posts: list[ApiPost], + trigger: TriggerType, + tags: list[str] = [], + score_value: int | None = None, + threshold: int | None = None, + timer_after: datetime.datetime | None = None, ) -> str: - lines = [ - "type: post", + lines: list[str] = [ + f"type: post{'s' if len(api_posts) > 1 else ''}", f"topic: {track}", + "", ] - if trigger == TriggerType.FLAG: + match trigger: + case TriggerType.FLAG: + lines.extend( + [ + "trigger:", + f" type: {trigger}", + ] + ) + if len(tags) > 1: + lines.append(" tags:") + for tag in tags: + lines.append(f" - {tag}") + else: + lines.append(f"{tags[0]}") + + if threshold: + lines.append(f" threshold: {threshold}") + case TriggerType.SCORE: + if not score_value: + LOG.critical( + "--value parameter is required when using the score trigger." + ) + raise typer.Exit(1) + + lines.extend( + [ + "trigger:", + f" type: {trigger}", + f" value: {score_value}", + ] + ) + case TriggerType.TIMER: + if not timer_after: + LOG.critical( + "--after parameter is required when using the timer trigger." + ) + raise typer.Exit(1) + + lines.extend( + [ + "trigger:", + f" type: {trigger}", + f" after: {timer_after.strftime('%Y/%m/%d %H:%M')}", + ] + ) + case TriggerType.NONE: + ... + + if len(api_posts) > 1: + lines.append("posts:") + for api_post in api_posts: + lines.extend( + [ + " - api:", + f" user: {api_post.user.value}", + " body: |-", + _format_yaml_block(api_post.body, indent=6), + ] + ) + + else: lines.extend( [ - "trigger:", - " type: flag", - f" tag: {tag}", + "api:", + f" user: {api_posts[0].user.value}", + "body: |-", + _format_yaml_block(api_posts[0].body), ] ) - lines.extend( - [ - "api:", - f" user: {user.value}", - "body: |-", - _format_yaml_block(body), - ] - ) - return "\n".join(lines) + "\n" @@ -131,46 +178,84 @@ def new( track: Annotated[ str, typer.Option( - "--track", "-t", + "--track", help="Track name (challenge directory name).", ), ], - tag: Annotated[ - str | None, - typer.Option( - "--tag", - help="Discourse trigger tag, usually from track.yaml flag tags.discourse. Required when --trigger flag is set.", - ), - ] = None, trigger: Annotated[ - TriggerType | None, + TriggerType, typer.Option( "--trigger", help="Trigger type for this post. If omitted, no trigger block is added.", ), - ] = None, + ] = TriggerType.NONE, name: Annotated[ str | None, typer.Option( - "--name", "-n", + "--name", help="Post file name. Defaults to a name derived from the track and tag.", ), ] = None, user: Annotated[ ApiUser, - typer.Option("--user", help="Discourse user posting this message."), + typer.Option( + "--user", + help="Discourse user posting this message. If multiple users, use --multiple-users instead.", + ), ] = ApiUser.NSEC, body: Annotated[ str, - typer.Option("--body", help="Post body. Markdown is supported."), + typer.Option( + "--body", + help="Post body. Markdown is supported. Do not use when using --multiple-users.", + ), ] = "CHANGE_ME", + tags: Annotated[ + list[str], + typer.Option( + "-T", + "--tags", + help="Discourse trigger tag, usually from track.yaml flag tags.discourse. Required when --trigger=flag is set.", + ), + ] = [], + threshold: Annotated[ + int | None, + typer.Option( + "--threshold", + help="Amount of flags (tags) required to trigger. Required when --trigger=flag is set. Must be lower than the amount of tags provided.", + ), + ] = None, + score_value: Annotated[ + int | None, + typer.Option( + "--value", + help="Score value. When the team has reached that score, the post will trigger. Required when --trigger=score is set.", + ), + ] = None, + timer_after: Annotated[ + datetime.datetime | None, + typer.Option( + "--after", + help="After a specific date. Required when --trigger=timer is set.", + formats=["%Y/%m/%d %H:%M"], + ), + ] = None, + multiple_users: Annotated[ + bool, + typer.Option( + "-M", + "--multiple-users", + help="Multiple users for the post file. This results in multiple posts in one post file.", + ), + ] = False, force: Annotated[ bool, typer.Option("--force", help="Overwrite the post file if it already exists."), ] = False, ) -> None: + api_posts: list[ApiPost] = [] if (track_obj := Track(name=track)) not in get_all_available_tracks(): LOG.critical(f"Track directory not found: {track_obj.name}. Verify --track.") raise typer.Exit(1) @@ -178,45 +263,100 @@ def new( posts_directory: Path = track_obj.location / "posts" os.makedirs(posts_directory, exist_ok=True) - # TODO: add support for other triggers - if trigger == TriggerType.FLAG and not tag: - LOG.critical("--tag is required when --trigger flag is provided.") - raise typer.Exit(1) - - if trigger != TriggerType.FLAG and tag: - LOG.critical("--tag can only be used with --trigger flag.") - raise typer.Exit(1) - - if trigger == TriggerType.FLAG and tag: - valid_tags = _get_available_discourse_tags(track=track_obj) - if tag not in valid_tags: - if valid_tags: + if multiple_users: + while True: + u = Prompt.ask( + "user", + choices=[ApiUser.NSEC, ApiUser.SYSTEM], + show_choices=True, + ) + b = Prompt.ask("body") + api_posts.append(ApiPost(user=ApiUser(u), body=b)) + if not Confirm.ask("Adding more?"): + break + else: + api_posts.append(ApiPost(user=user, body=body)) + + match trigger: + case TriggerType.FLAG: + if not tags: + LOG.critical("--tags is required when --trigger=flag is provided.") + raise typer.Exit(1) + + if not (valid_tags := _get_available_discourse_tags(track=track_obj)): LOG.critical( - f'Invalid --tag "{tag}" for track "{track_obj.name}". Valid tags: {", ".join(valid_tags)}' + f"No discourse tags were found in track.yaml flags[].tags.discourse for {track_obj.name}" ) - else: + raise typer.Exit(1) + + for tag in tags: + if tag not in valid_tags: + LOG.critical( + f'Invalid --tag "{tag}" for track "{track_obj.name}". Valid tags: {", ".join(valid_tags)}' + ) + raise typer.Exit(1) + + if threshold and (threshold <= 0 or threshold > len(tags)): LOG.critical( - f'Invalid --tag "{tag}" for track "{track_obj.name}". No discourse tags were found in track.yaml flags[].tags.discourse.' + "Threshold must be higher than 0 and lower than the amount of tags provided." ) - raise typer.Exit(1) + raise typer.Exit(1) + case TriggerType.SCORE: + if not score_value: + while ( + score_value := IntPrompt.ask( + "Please enter the score at which this post will trigger for teams [bold magenta]\\[x>0][/bold magenta]" + ) + ) <= 0: + LOG.warning("The score must be positive and above 0.") + case TriggerType.TIMER: + while True: + if timer_after: + if timer_after >= datetime.datetime.now(): + break + + LOG.warning("The date must be in the future.") + try: + timer_after = datetime.datetime.strptime( + Prompt.ask( + "Enter a datetime in the futur [bold magenta]\\[YYYY/MM/DD HH:MM][/bold magenta]" + ), + "%Y/%m/%d %H:%M", + ) + + if timer_after < datetime.datetime.now(): + LOG.warning("The date must be in the future.") + continue + + break + except ValueError: + LOG.warning("The provided string was not a valid date.") + + case TriggerType.NONE: + if not Confirm.ask( + "Without a trigger, the post will [bold red]automatically be submitted to all teams[/bold red]. This is usually used for [bold cyan]hints[/bold cyan]. Is this what you want?", + default=False, + ): + raise typer.Exit(0) post_file_path = _resolve_post_file_path( posts_directory=posts_directory, track=track_obj, name=name, - tag=tag, force=force, ) post_yaml = _render_post_yaml( track=track_obj, - user=user, - body=body, + api_posts=api_posts, trigger=trigger, - tag=tag, + tags=tags, + threshold=threshold, + score_value=score_value, + timer_after=timer_after, ) - with open(post_file_path, "w", encoding="utf-8") as f: + with post_file_path.open(mode="w", encoding="utf-8") as f: f.write(post_yaml) LOG.info(f"Created post file: {post_file_path}")