diff --git a/src/modelinfo/cli.py b/src/modelinfo/cli.py index 474bfd3..18aa340 100644 --- a/src/modelinfo/cli.py +++ b/src/modelinfo/cli.py @@ -3,7 +3,6 @@ import os import sys from typing import Sequence - from modelinfo.architecture import identify_architecture_name from modelinfo.calculator import calculate_footprint from modelinfo.parsers.gguf import parse_gguf_header @@ -12,6 +11,29 @@ from modelinfo.ui import console, print_model_info, print_compare_info +class VersionAction(argparse.Action): + def __init__(self, option_strings, dest=argparse.SUPPRESS, default=argparse.SUPPRESS, help="show program's version number and exit"): + super().__init__( + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help, + ) + + def __call__(self, parser, namespace, values, option_string=None): + from importlib.metadata import PackageNotFoundError, version + from modelinfo import __version__ + + try: + ver = version("modelinfo-cli") + except PackageNotFoundError: + ver = __version__ + + print(f"{parser.prog} {ver}") + parser.exit() + + def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( prog="modelinfo", @@ -72,6 +94,11 @@ def parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: default=0.9, help="vLLM gpu_memory_utilization ratio (default 0.9). Reserves 10 percent for PyTorch context.", ) + parser.add_argument( + "-v", + "--version", + action=VersionAction, + ) return parser.parse_args(argv) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..1a29100 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,12 @@ +import pytest + +from modelinfo import __version__ +from modelinfo.cli import parse_args + + +def test_version_flag_prints_installed_version(capsys): + with pytest.raises(SystemExit) as exc_info: + parse_args(["--version"]) + + assert exc_info.value.code == 0 + assert f"modelinfo {__version__}" in capsys.readouterr().out