diff --git a/scripts/compare_loss.py b/scripts/compare_loss.py index 8b581266..31b2a009 100755 --- a/scripts/compare_loss.py +++ b/scripts/compare_loss.py @@ -9,6 +9,7 @@ import sys from pathlib import Path from argparse import ArgumentParser +from compare_utils import collect_log_files, exit_if_duplicate_logs def get_dtype_from_filename(filename): """Determine dtype from filename. Returns 'bfloat16' or 'fp32'.""" @@ -62,8 +63,10 @@ def main(): args.threshold_fp32 = args.threshold args.threshold_bf16 = args.threshold - files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')} - files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')} + files1, duplicates1 = collect_log_files(args.dir1) + files2, duplicates2 = collect_log_files(args.dir2) + exit_if_duplicate_logs(args.dir1, duplicates1) + exit_if_duplicate_logs(args.dir2, duplicates2) only_in_1 = set(files1.keys()) - set(files2.keys()) only_in_2 = set(files2.keys()) - set(files1.keys()) diff --git a/scripts/compare_tps.py b/scripts/compare_tps.py index 270b1ddd..de6327de 100755 --- a/scripts/compare_tps.py +++ b/scripts/compare_tps.py @@ -9,6 +9,7 @@ import sys from pathlib import Path from argparse import ArgumentParser +from compare_utils import collect_log_files, exit_if_duplicate_logs def parse_log(file_path): """Extract step -> tok/s mapping from log file.""" @@ -55,8 +56,10 @@ def main(): parser.add_argument('--verbose', action='store_true', help='Print detailed output for all files, including passed ones') args = parser.parse_args() - files1 = {f.name: f for f in args.dir1.glob('*.log') if not f.name.startswith('build')} - files2 = {f.name: f for f in args.dir2.glob('*.log') if not f.name.startswith('build')} + files1, duplicates1 = collect_log_files(args.dir1) + files2, duplicates2 = collect_log_files(args.dir2) + exit_if_duplicate_logs(args.dir1, duplicates1) + exit_if_duplicate_logs(args.dir2, duplicates2) only_in_1 = set(files1.keys()) - set(files2.keys()) only_in_2 = set(files2.keys()) - set(files1.keys()) diff --git a/scripts/compare_utils.py b/scripts/compare_utils.py new file mode 100644 index 00000000..0831f7be --- /dev/null +++ b/scripts/compare_utils.py @@ -0,0 +1,31 @@ +from pathlib import Path +import sys + + +def collect_log_files(base_dir: Path): + """Collect comparable training logs keyed by basename.""" + files = {} + duplicates = {} + + for path in base_dir.rglob("*.log"): + if path.name.startswith("build") or path.name.endswith("_profile.log"): + continue + + key = path.name + if key in files: + duplicates.setdefault(key, [files[key]]).append(path) + continue + files[key] = path + + return files, duplicates + + +def exit_if_duplicate_logs(base_dir: Path, duplicates): + """Abort when duplicate basenames make comparison ambiguous.""" + if not duplicates: + return + + print(f"Found duplicate log basenames in {base_dir.resolve()}, cannot compare safely:") + for name, paths in sorted(duplicates.items()): + print(f" {name}: {', '.join(str(p.relative_to(base_dir)) for p in paths)}") + sys.exit(1) diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 1cf27935..b183a936 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -3,7 +3,56 @@ set -e set -o pipefail -CONFIG_FILE="${1:-test_config.json}" +usage() { + cat <<'EOF' +Usage: run_models_and_profile.bash [--test-config path] [--only-run tag1,tag2] + +Options: + --test-config PATH Path to test config JSON. Default: test_config.json. + --only-run TAGS Only run the specified tag groups, separated by commas. + -h, --help Show this help message. +EOF +} + +CONFIG_FILE="test_config.json" +ONLY_RUN_TAGS="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --test-config) + [[ $# -lt 2 ]] && { echo "Error: --test-config requires a file path."; exit 1; } + CONFIG_FILE="$2" + shift 2 + ;; + --test-config=*) + CONFIG_FILE="${1#*=}" + shift + ;; + --only-run) + [[ $# -lt 2 ]] && { echo "Error: --only-run requires a comma-separated tag list."; exit 1; } + ONLY_RUN_TAGS="$2" + shift 2 + ;; + --only-run=*) + ONLY_RUN_TAGS="${1#*=}" + shift + ;; + -h|--help) + usage + exit 0 + ;; + -*) + echo "Error: Unknown option: $1" + usage + exit 1 + ;; + *) + echo "Error: Unknown positional argument: $1" + usage + exit 1 + ;; + esac +done # Dependencies check if ! command -v jq >/dev/null 2>&1; then @@ -33,6 +82,28 @@ done < <(jq -r '.variables | to_entries[] | "\(.key)=\(.value)"' "$CONFIG_FILE") # Global variable to save the last cmake command LAST_CMAKE_CMD="" +declare -A SELECTED_TAGS=() + +normalize_tag() { + local raw="$1" + raw="${raw#"${raw%%[![:space:]]*}"}" + raw="${raw%"${raw##*[![:space:]]}"}" + printf '%s' "$raw" +} + +if [[ -n "$ONLY_RUN_TAGS" ]]; then + IFS=',' read -r -a requested_tags <<< "$ONLY_RUN_TAGS" + for raw_tag in "${requested_tags[@]}"; do + tag="$(normalize_tag "$raw_tag")" + [[ -z "$tag" ]] && continue + SELECTED_TAGS["$tag"]=1 + done + + if [[ ${#SELECTED_TAGS[@]} -eq 0 ]]; then + echo "Error: --only-run did not contain any valid tags." + exit 1 + fi +fi # Clean the build directory clean_build_dir() { @@ -46,9 +117,12 @@ run_and_log() { local cmd="$1" local log_name="$2" local is_profile="$3" + local tag="${4:-basic}" local timestamp timestamp=$(date '+%Y-%m-%d %H:%M:%S') - local log_path="$(realpath "${LOG_DIR}/${log_name}.log")" + local tag_log_dir="${LOG_DIR}/${tag}" + mkdir -p "$tag_log_dir" + local log_path="$(realpath "${tag_log_dir}/${log_name}.log")" echo -e "\033[1;32m============================================================\033[0m" echo -e "\033[1;36m[$timestamp] [Running] ${log_name}\033[0m" @@ -99,7 +173,7 @@ run_and_log() { # If profiling is enabled, move profiling files to the target directory if [[ "$is_profile" == "yes" ]]; then - move_profile_logs "$log_name" + move_profile_logs "$log_name" "$tag" fi } @@ -107,14 +181,17 @@ run_and_log() { # Move profiling output logs move_profile_logs() { local prefix="$1" + local tag="${2:-basic}" + local tag_profile_dir="${PROFILE_LOG_DIR}/${tag}" + mkdir -p "$tag_profile_dir" # Move *.report.rankN files for report_file in "${BUILD_DIR}"/*.report.rank*; do if [[ -f "$report_file" ]]; then local base_name base_name=$(basename "$report_file") - mv "$report_file" "${PROFILE_LOG_DIR}/${prefix}_${base_name}" - echo "Moved $base_name to ${PROFILE_LOG_DIR}/${prefix}_${base_name}" + mv "$report_file" "${tag_profile_dir}/${prefix}_${base_name}" + echo "Moved $base_name to ${tag_profile_dir}/${prefix}_${base_name}" fi done @@ -123,17 +200,18 @@ move_profile_logs() { if [[ -f "$record_file" ]]; then local base_name base_name=$(basename "$record_file") - mv "$record_file" "${PROFILE_LOG_DIR}/${prefix}_${base_name}" - echo "Moved $base_name to ${PROFILE_LOG_DIR}/${prefix}_${base_name}" + mv "$record_file" "${tag_profile_dir}/${prefix}_${base_name}" + echo "Moved $base_name to ${tag_profile_dir}/${prefix}_${base_name}" fi done } -# Build "--key value" arg string from tests[i].args (shell-escaped) +# Build "--key value" arg string from test_groups[gi].tests[ti].args (shell-escaped) args_string_for_test() { - local idx="$1" - jq -r --argjson i "$idx" ' - .tests[$i].args + local group_idx="$1" + local test_idx="$2" + jq -r --argjson g "$group_idx" --argjson t "$test_idx" ' + .test_groups[$g].tests[$t].args | to_entries[] | "--\(.key) \(.value|tostring)" ' "$CONFIG_FILE" | paste -sd' ' - @@ -141,7 +219,20 @@ args_string_for_test() { # Run tests num_builds=$(jq '.builds | length' "$CONFIG_FILE") -num_tests=$(jq '.tests | length' "$CONFIG_FILE") +num_groups=$(jq '.test_groups | length' "$CONFIG_FILE") + +selected_group_count=0 +for ((gi=0; gi sheet_id={sheet_id}") + + cmd_args, sheet_data = get_model_data(model_name=model_name, sheet_title=testcase, tag=tag) + + if not sheet_data: + print("No valid data generated, skipping") continue - remote_by_title[testcase] = sheet_id - sort_sheets = True - write_cmd = True - print(f"Created sheet '{testcase}' with id={sheet_id}") - - print(f"Processing testcase '{testcase}' -> sheet_id={sheet_id}") - - cmd_args, sheet_data = get_model_data(model_name=model_name, sheet_title=testcase) - - if not sheet_data: - print("No valid data generated, skipping") - continue - if write_cmd and cmd_args: - handler.write_cmd_args_to_header(spreadsheet_token, cmd_args, sheet_id) + if write_cmd and cmd_args: + handler.write_cmd_args_to_header(spreadsheet_token, cmd_args, sheet_id) - if handler.prepend_data(spreadsheet_token, sheet_id, sheet_data): - handler.post_process(spreadsheet_token, sheet_id) + if handler.prepend_data(spreadsheet_token, sheet_id, sheet_data): + handler.post_process(spreadsheet_token, sheet_id) - if sort_sheets: - handler.sort_sheets_by_title(spreadsheet_token, "模板") + if sort_sheets: + handler.sort_sheets_by_title(spreadsheet_token, "模板") print("\n=== All models and sheets processed ===")