diff --git a/run_inference.py b/run_inference.py index f3ab727b6..5e88f8256 100644 --- a/run_inference.py +++ b/run_inference.py @@ -4,13 +4,36 @@ import platform import argparse import subprocess +import logging +def setup_logging(): + logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler() + ] + ) + +def check_cuda(): + """Check if CUDA is available.""" + try: + import torch + if torch.cuda.is_available(): + logging.info(f"CUDA is available: {torch.cuda.get_device_name(0)}") + return True + else: + logging.warning("CUDA is not available. Falling back to CPU.") + except ImportError: + logging.warning("PyTorch not installed. CUDA check skipped.") + return False def run_command(command, shell=False): """Run a system command and ensure it succeeds.""" try: + logging.info(f"Executing command: {' '.join(command)}") subprocess.run(command, shell=shell, check=True) except subprocess.CalledProcessError as e: - print(f"Error occurred while running command: {e}") + logging.error(f"Error occurred while running command: {e}") sys.exit(1) def run_inference(): @@ -21,28 +44,38 @@ def run_inference(): main_path = os.path.join(build_dir, "bin", "llama-cli") else: main_path = os.path.join(build_dir, "bin", "llama-cli") + + if not os.path.exists(main_path): + logging.error(f"The executable {main_path} does not exist. Please ensure the build directory is correct.") + sys.exit(1) + command = [ f'{main_path}', '-m', args.model, '-n', str(args.n_predict), '-t', str(args.threads), '-p', args.prompt, - '-ngl', '0', + '-ngl', '1' if check_cuda() else '0', '-c', str(args.ctx_size), '--temp', str(args.temperature), "-b", "1", ] + if args.conversation: command.append("-cnv") + + logging.info("Starting inference process...") run_command(command) def signal_handler(sig, frame): - print("Ctrl+C pressed, exiting...") + logging.info("Ctrl+C pressed, exiting...") sys.exit(0) if __name__ == "__main__": + setup_logging() + logging.info("Initializing inference script.") signal.signal(signal.SIGINT, signal_handler) - # Usage: python run_inference.py -p "Microsoft Corporation is an American multinational corporation and technology company headquartered in Redmond, Washington." + parser = argparse.ArgumentParser(description='Run inference') parser.add_argument("-m", "--model", type=str, help="Path to model file", required=False, default="models/bitnet_b1_58-3B/ggml-model-i2_s.gguf") parser.add_argument("-n", "--n-predict", type=int, help="Number of tokens to predict when generating text", required=False, default=128) @@ -53,4 +86,5 @@ def signal_handler(sig, frame): parser.add_argument("-cnv", "--conversation", action='store_true', help="Whether to enable chat mode or not (for instruct models.)") args = parser.parse_args() + logging.info("Parsed arguments successfully.") run_inference() \ No newline at end of file