diff --git a/py/src/braintrust/context.py b/py/src/braintrust/context.py index 051864c1..ed596148 100644 --- a/py/src/braintrust/context.py +++ b/py/src/braintrust/context.py @@ -1,12 +1,13 @@ """Abstract context management interface for Braintrust.""" import logging -import os from abc import ABC, abstractmethod from contextvars import ContextVar from dataclasses import dataclass from typing import Any +from .env import BraintrustEnv + @dataclass class SpanInfo: @@ -120,7 +121,7 @@ def get_context_manager() -> ContextManager: """ # Check if OTEL should be explicitly enabled via environment variable - if os.environ.get("BRAINTRUST_OTEL_COMPAT", "").lower() in ("1", "true", "yes"): + if BraintrustEnv.OTEL_COMPAT.get(False): try: from braintrust.otel.context import ContextManager as OtelContextManager diff --git a/py/src/braintrust/env.py b/py/src/braintrust/env.py new file mode 100644 index 00000000..a4933c3e --- /dev/null +++ b/py/src/braintrust/env.py @@ -0,0 +1,93 @@ +import math +import os +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +from typing import TypeVar, cast + + +T = TypeVar("T") +EnvValue = bool | float | int | str +_Parser = Callable[[str], EnvValue | None] + + +def parse_float(value: str) -> float | None: + """Parse a finite float from a string.""" + try: + result = float(value) + except (ValueError, TypeError): + return None + if math.isnan(result) or math.isinf(result): + return None + return result + + +def parse_int(value: str) -> int | None: + """Parse an integer from a string.""" + try: + return int(value) + except (ValueError, TypeError): + return None + + +def parse_bool(value: str) -> bool | None: + """Parse common boolean environment variable values. + + Accepted true values: true, 1, yes, y, on. + Accepted false values: false, 0, no, n, off. + Empty or unrecognized values are invalid and fall back to the EnvVar default. + """ + normalized = value.strip().lower() + if normalized in ("true", "1", "yes", "y", "on"): + return True + if normalized in ("false", "0", "no", "n", "off"): + return False + return None + + +def parse_string(value: str) -> str | None: + """Parse a string environment variable. + + Empty strings are treated as unset so callers fall back to their default. + """ + return value or None + + +class EnvParser(Enum): + FLOAT = (parse_float,) + INT = (parse_int,) + BOOL = (parse_bool,) + STRING = (parse_string,) + + def __init__(self, parser: _Parser): + self.parser = parser + + +@dataclass(frozen=True) +class EnvVar: + name: str + parser: EnvParser + + def get(self, default: T) -> T: + value = os.environ.get(self.name) + if value is None: + return default + + parsed = self.parser.parser(value) + if parsed is None: + return default + return cast(T, parsed) + + +class BraintrustEnv: + HTTP_TIMEOUT = EnvVar("BRAINTRUST_HTTP_TIMEOUT", EnvParser.FLOAT) + SYNC_FLUSH = EnvVar("BRAINTRUST_SYNC_FLUSH", EnvParser.BOOL) + MAX_REQUEST_SIZE = EnvVar("BRAINTRUST_MAX_REQUEST_SIZE", EnvParser.INT) + DEFAULT_BATCH_SIZE = EnvVar("BRAINTRUST_DEFAULT_BATCH_SIZE", EnvParser.INT) + NUM_RETRIES = EnvVar("BRAINTRUST_NUM_RETRIES", EnvParser.INT) + QUEUE_SIZE = EnvVar("BRAINTRUST_QUEUE_SIZE", EnvParser.INT) + QUEUE_DROP_LOGGING_PERIOD = EnvVar("BRAINTRUST_QUEUE_DROP_LOGGING_PERIOD", EnvParser.FLOAT) + FAILED_PUBLISH_PAYLOADS_DIR = EnvVar("BRAINTRUST_FAILED_PUBLISH_PAYLOADS_DIR", EnvParser.STRING) + ALL_PUBLISH_PAYLOADS_DIR = EnvVar("BRAINTRUST_ALL_PUBLISH_PAYLOADS_DIR", EnvParser.STRING) + DISABLE_ATEXIT_FLUSH = EnvVar("BRAINTRUST_DISABLE_ATEXIT_FLUSH", EnvParser.BOOL) + OTEL_COMPAT = EnvVar("BRAINTRUST_OTEL_COMPAT", EnvParser.BOOL) diff --git a/py/src/braintrust/id_gen.py b/py/src/braintrust/id_gen.py index 9aab1517..b9006f6e 100644 --- a/py/src/braintrust/id_gen.py +++ b/py/src/braintrust/id_gen.py @@ -1,8 +1,9 @@ -import os import secrets import uuid from abc import ABC, abstractmethod +from .env import BraintrustEnv + def get_id_generator(): """Factory function that creates a new ID generator instance each time. @@ -10,7 +11,7 @@ def get_id_generator(): This eliminates global state and makes tests parallelizable. Each caller gets their own generator instance. """ - use_otel = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true" + use_otel = BraintrustEnv.OTEL_COMPAT.get(False) return OTELIDGenerator() if use_otel else UUIDGenerator() diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 9a65dee3..a908050d 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -52,6 +52,7 @@ TRANSACTION_ID_FIELD, VALID_SOURCES, ) +from .env import BraintrustEnv from .generated_types import ( AttachmentReference, AttachmentStatus, @@ -91,7 +92,6 @@ get_signature, mask_api_key, merge_dicts, - parse_env_var_float, response_raise_for_status, ) from .xact_ids import prettify_xact @@ -147,7 +147,7 @@ class ParametersRef(TypedDict, total=False): def _get_exporter(): """Return the active exporter (e.g. the version of SpanComponentsv*)""" - use_v4 = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true" + use_v4 = BraintrustEnv.OTEL_COMPAT.get(False) return SpanComponentsV4 if use_v4 else SpanComponentsV3 @@ -746,7 +746,7 @@ def ping(self) -> bool: def make_long_lived(self) -> None: if not self.adapter: - timeout_secs = parse_env_var_float("BRAINTRUST_HTTP_TIMEOUT", 60.0) + timeout_secs = BraintrustEnv.HTTP_TIMEOUT.get(60.0) self.adapter = RetryRequestExceptionsAdapter( base_num_retries=10, backoff_factor=0.5, default_timeout_secs=timeout_secs ) @@ -1013,52 +1013,19 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): self._max_request_size_result: dict[str, Any] | None = None self._max_request_size_lock = threading.Lock() - try: - self.sync_flush = bool(int(os.environ["BRAINTRUST_SYNC_FLUSH"])) - except: - self.sync_flush = False - - try: - self._max_request_size_override = int(os.environ["BRAINTRUST_MAX_REQUEST_SIZE"]) - except: - pass - - try: - self.default_batch_size = int(os.environ["BRAINTRUST_DEFAULT_BATCH_SIZE"]) - except: - self.default_batch_size = 100 - - try: - self.num_tries = int(os.environ["BRAINTRUST_NUM_RETRIES"]) + 1 - except: - self.num_tries = 3 - - try: - self.queue_maxsize = int(os.environ["BRAINTRUST_QUEUE_SIZE"]) - except: - self.queue_maxsize = DEFAULT_QUEUE_SIZE - - try: - self.queue_drop_logging_period = float(os.environ["BRAINTRUST_QUEUE_DROP_LOGGING_PERIOD"]) - except: - self.queue_drop_logging_period = 60 + self.sync_flush = BraintrustEnv.SYNC_FLUSH.get(False) + self._max_request_size_override = BraintrustEnv.MAX_REQUEST_SIZE.get(None) + self.default_batch_size = BraintrustEnv.DEFAULT_BATCH_SIZE.get(100) + self.num_tries = BraintrustEnv.NUM_RETRIES.get(2) + 1 + queue_maxsize = BraintrustEnv.QUEUE_SIZE.get(None) + self.queue_maxsize = DEFAULT_QUEUE_SIZE if queue_maxsize is None else queue_maxsize + self.queue_drop_logging_period = BraintrustEnv.QUEUE_DROP_LOGGING_PERIOD.get(60.0) self._queue_drop_logging_state = dict(lock=threading.Lock(), num_dropped=0, last_logged_timestamp=0) - try: - self.failed_publish_payloads_dir = os.environ["BRAINTRUST_FAILED_PUBLISH_PAYLOADS_DIR"] - except: - self.failed_publish_payloads_dir = None - - try: - self.all_publish_payloads_dir = os.environ["BRAINTRUST_ALL_PUBLISH_PAYLOADS_DIR"] - except: - self.all_publish_payloads_dir = None - - try: - disable_atexit_flush = os.environ["BRAINTRUST_DISABLE_ATEXIT_FLUSH"].lower() in ("true", "1", "yes") - except: - disable_atexit_flush = False + self.failed_publish_payloads_dir = BraintrustEnv.FAILED_PUBLISH_PAYLOADS_DIR.get(None) + self.all_publish_payloads_dir = BraintrustEnv.ALL_PUBLISH_PAYLOADS_DIR.get(None) + disable_atexit_flush = BraintrustEnv.DISABLE_ATEXIT_FLUSH.get(False) self.start_thread_lock = threading.RLock() self.thread = threading.Thread(target=self._publisher, daemon=True) @@ -4406,7 +4373,7 @@ def export(self) -> str: compute_object_metadata_args = None # Choose SpanComponents version based on BRAINTRUST_OTEL_COMPAT env var - use_v4 = os.getenv("BRAINTRUST_OTEL_COMPAT", "false").lower() == "true" + use_v4 = BraintrustEnv.OTEL_COMPAT.get(False) span_components_class = SpanComponentsV4 if use_v4 else SpanComponentsV3 # Disable span cache since remote function spans won't be in the local cache diff --git a/py/src/braintrust/test_env.py b/py/src/braintrust/test_env.py new file mode 100644 index 00000000..7d983902 --- /dev/null +++ b/py/src/braintrust/test_env.py @@ -0,0 +1,68 @@ +from .env import BraintrustEnv, EnvParser, EnvVar, parse_bool, parse_float, parse_int, parse_string + + +class TestEnvParsers: + def test_parse_float(self): + assert parse_float("123.45") == 123.45 + assert parse_float("nan") is None + assert parse_float("inf") is None + assert parse_float("") is None + assert parse_float("not_a_number") is None + + def test_parse_int(self): + assert parse_int("123") == 123 + assert parse_int("-5") == -5 + assert parse_int("") is None + assert parse_int("1.2") is None + assert parse_int("not_an_int") is None + + def test_parse_bool(self): + for value in ("true", "True", "1", "yes", "y", "on"): + assert parse_bool(value) is True + for value in ("false", "False", "0", "no", "n", "off"): + assert parse_bool(value) is False + assert parse_bool("") is None + assert parse_bool("maybe") is None + + def test_parse_string(self): + assert parse_string("value") == "value" + assert parse_string("") is None + + +class TestEnvVar: + def test_returns_default_when_env_not_set(self, monkeypatch): + monkeypatch.delenv("TEST_ENV_VAR", raising=False) + assert EnvVar("TEST_ENV_VAR", EnvParser.INT).get(42) == 42 + + def test_returns_default_when_env_invalid(self, monkeypatch): + monkeypatch.setenv("TEST_ENV_VAR", "invalid") + assert EnvVar("TEST_ENV_VAR", EnvParser.INT).get(42) == 42 + + def test_reads_environment_lazily(self, monkeypatch): + env_var = EnvVar("TEST_ENV_VAR", EnvParser.INT) + monkeypatch.setenv("TEST_ENV_VAR", "1") + assert env_var.get(42) == 1 + monkeypatch.setenv("TEST_ENV_VAR", "2") + assert env_var.get(42) == 2 + + def test_default_is_supplied_by_call_site(self, monkeypatch): + env_var = EnvVar("TEST_ENV_VAR", EnvParser.INT) + monkeypatch.delenv("TEST_ENV_VAR", raising=False) + assert env_var.get(1) == 1 + assert env_var.get(2) == 2 + + +class TestBraintrustEnv: + def test_centralized_env_definitions_are_lazy(self, monkeypatch): + monkeypatch.delenv("BRAINTRUST_HTTP_TIMEOUT", raising=False) + assert BraintrustEnv.HTTP_TIMEOUT.get(60.0) == 60.0 + monkeypatch.setenv("BRAINTRUST_HTTP_TIMEOUT", "0.2") + assert BraintrustEnv.HTTP_TIMEOUT.get(60.0) == 0.2 + + def test_otel_compat_uses_shared_bool_parser(self, monkeypatch): + for value in ("true", "1", "yes"): + monkeypatch.setenv("BRAINTRUST_OTEL_COMPAT", value) + assert BraintrustEnv.OTEL_COMPAT.get(False) is True + + monkeypatch.setenv("BRAINTRUST_OTEL_COMPAT", "false") + assert BraintrustEnv.OTEL_COMPAT.get(True) is False diff --git a/py/src/braintrust/test_util.py b/py/src/braintrust/test_util.py index 0dd27568..86bec1c1 100644 --- a/py/src/braintrust/test_util.py +++ b/py/src/braintrust/test_util.py @@ -1,65 +1,8 @@ -import os import unittest import pytest -from .util import LazyValue, mask_api_key, merge_dicts_with_paths, parse_env_var_float - - -class TestParseEnvVarFloat: - """Tests for parse_env_var_float helper.""" - - def test_returns_default_when_env_not_set(self): - assert parse_env_var_float("NONEXISTENT_VAR_12345", 42.0) == 42.0 - - def test_parses_valid_float(self): - os.environ["TEST_FLOAT"] = "123.45" - try: - assert parse_env_var_float("TEST_FLOAT", 0.0) == 123.45 - finally: - del os.environ["TEST_FLOAT"] - - def test_returns_default_for_nan(self): - os.environ["TEST_FLOAT"] = "nan" - try: - assert parse_env_var_float("TEST_FLOAT", 99.0) == 99.0 - finally: - del os.environ["TEST_FLOAT"] - - def test_returns_default_for_inf(self): - os.environ["TEST_FLOAT"] = "inf" - try: - assert parse_env_var_float("TEST_FLOAT", 99.0) == 99.0 - finally: - del os.environ["TEST_FLOAT"] - - def test_returns_default_for_negative_inf(self): - os.environ["TEST_FLOAT"] = "-inf" - try: - assert parse_env_var_float("TEST_FLOAT", 99.0) == 99.0 - finally: - del os.environ["TEST_FLOAT"] - - def test_returns_default_for_empty_string(self): - os.environ["TEST_FLOAT"] = "" - try: - assert parse_env_var_float("TEST_FLOAT", 99.0) == 99.0 - finally: - del os.environ["TEST_FLOAT"] - - def test_returns_default_for_invalid_string(self): - os.environ["TEST_FLOAT"] = "not_a_number" - try: - assert parse_env_var_float("TEST_FLOAT", 99.0) == 99.0 - finally: - del os.environ["TEST_FLOAT"] - - def test_allows_negative_values(self): - os.environ["TEST_FLOAT"] = "-5.5" - try: - assert parse_env_var_float("TEST_FLOAT", 0.0) == -5.5 - finally: - del os.environ["TEST_FLOAT"] +from .util import LazyValue, mask_api_key, merge_dicts_with_paths class TestLazyValue(unittest.TestCase): diff --git a/py/src/braintrust/util.py b/py/src/braintrust/util.py index 09c5b2c3..0f4d0927 100644 --- a/py/src/braintrust/util.py +++ b/py/src/braintrust/util.py @@ -1,7 +1,5 @@ import inspect import json -import math -import os import sys import threading import urllib.parse @@ -12,24 +10,6 @@ from requests import HTTPError, Response -def parse_env_var_float(name: str, default: float) -> float: - """Parse a float from an environment variable, returning default if invalid. - - Returns the default value if the env var is missing, empty, not a valid - float, NaN, or infinity. - """ - value = os.environ.get(name) - if value is None: - return default - try: - result = float(value) - if math.isnan(result) or math.isinf(result): - return default - return result - except (ValueError, TypeError): - return default - - GLOBAL_PROJECT = "Global" BT_IS_ASYNC_ATTRIBUTE = "_BT_IS_ASYNC"