From aac16e7273c018dc7ae8cd985895a67bcc0c1f75 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Fri, 20 Jun 2025 13:36:25 -0400 Subject: [PATCH 1/4] Add dotenv environment variable defaults * Allow a .env file to specify the default environment variables to be passed to predict --- go.mod | 1 + go.sum | 2 ++ pkg/cli/predict.go | 21 ++++++++++++++++++- .../fixtures/env-cli-project/.env | 1 + .../fixtures/env-cli-project/cog.yaml | 1 + .../fixtures/env-cli-project/predict.py | 7 +++++++ .../test_integration/test_predict.py | 15 +++++++++++++ 7 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 test-integration/test_integration/fixtures/env-cli-project/.env create mode 100644 test-integration/test_integration/fixtures/env-cli-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/env-cli-project/predict.py diff --git a/go.mod b/go.mod index 7507397498..a648f3f173 100644 --- a/go.mod +++ b/go.mod @@ -176,6 +176,7 @@ require ( github.com/jgautheron/goconst v1.7.1 // indirect github.com/jingyugao/rowserrcheck v1.1.1 // indirect github.com/jjti/go-spancheck v0.6.4 // indirect + github.com/joho/godotenv v1.5.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/julz/importas v0.2.0 // indirect github.com/karamaru-alpha/copyloopvar v1.2.1 // indirect diff --git a/go.sum b/go.sum index 25938fc8d1..77d2429bb3 100644 --- a/go.sum +++ b/go.sum @@ -338,6 +338,8 @@ github.com/jingyugao/rowserrcheck v1.1.1 h1:zibz55j/MJtLsjP1OF4bSdgXxwL1b+Vn7Tjz github.com/jingyugao/rowserrcheck v1.1.1/go.mod h1:4yvlZSDb3IyDTUZJUmpZfm2Hwok+Dtp+nu2qOq+er9c= github.com/jjti/go-spancheck v0.6.4 h1:Tl7gQpYf4/TMU7AT84MN83/6PutY21Nb9fuQjFTpRRc= github.com/jjti/go-spancheck v0.6.4/go.mod h1:yAEYdKJ2lRkDA8g7X+oKUHXOWVAXSBJRv04OhF+QUjk= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/julz/importas v0.2.0 h1:y+MJN/UdL63QbFJHws9BVC5RpA2iq0kpjrFajTGivjQ= diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index f2396b5852..8353faf118 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -17,6 +17,7 @@ import ( "time" "github.com/getkin/kin-openapi/openapi3" + "github.com/joho/godotenv" "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" "golang.org/x/sys/unix" @@ -69,10 +70,10 @@ the prediction on that.`, addFastFlag(cmd) addLocalImage(cmd) addConfigFlag(cmd) + addEnvFlag(cmd) cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i path=@image.jpg") cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path") - cmd.Flags().StringArrayVarP(&envFlags, "env", "e", []string{}, "Environment variables, in the form name=value") cmd.Flags().BoolVar(&useReplicateAPIToken, "use-replicate-token", false, "Pass REPLICATE_API_TOKEN from local environment into the model context") cmd.Flags().StringVar(&inputJSON, "json", "", "Pass inputs as JSON object, read from file (@inputs.json) or via stdin (@-)") @@ -643,3 +644,21 @@ func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error func addSetupTimeoutFlag(cmd *cobra.Command) { cmd.Flags().Uint32Var(&setupTimeout, "setup-timeout", 5*60, "The timeout for a container to setup (in seconds).") } + +func addEnvFlag(cmd *cobra.Command) { + defaultEnv := []string{} + + // Attempt to fill default environment variables with a .env file + exists, _ := files.Exists(".env") + if exists { + var dotEnv map[string]string + dotEnv, err := godotenv.Read() + if err == nil { + for key, value := range dotEnv { + defaultEnv = append(defaultEnv, fmt.Sprintf("%s=%s", key, value)) + } + } + } + + cmd.Flags().StringArrayVarP(&envFlags, "env", "e", defaultEnv, "Environment variables, in the form name=value") +} diff --git a/test-integration/test_integration/fixtures/env-cli-project/.env b/test-integration/test_integration/fixtures/env-cli-project/.env new file mode 100644 index 0000000000..92e90afffc --- /dev/null +++ b/test-integration/test_integration/fixtures/env-cli-project/.env @@ -0,0 +1 @@ +TEST_VAR=test_value diff --git a/test-integration/test_integration/fixtures/env-cli-project/cog.yaml b/test-integration/test_integration/fixtures/env-cli-project/cog.yaml new file mode 100644 index 0000000000..cdf01d2b46 --- /dev/null +++ b/test-integration/test_integration/fixtures/env-cli-project/cog.yaml @@ -0,0 +1 @@ +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/env-cli-project/predict.py b/test-integration/test_integration/fixtures/env-cli-project/predict.py new file mode 100644 index 0000000000..6029e472c4 --- /dev/null +++ b/test-integration/test_integration/fixtures/env-cli-project/predict.py @@ -0,0 +1,7 @@ +from cog import BasePredictor +import os + + +class Predictor(BasePredictor): + def predict(self, name: str) -> str: + return f"ENV[{name}]={os.getenv(name)}" diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index ff1d4fa626..2f74cf83a6 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -801,3 +801,18 @@ def test_predict_future_annotations(cog_binary): timeout=120.0, ) assert result.returncode == 0 + + +def test_predict_dotenv(docker_image, cog_binary): + project_dir = Path(__file__).parent / "fixtures/future-annotations-project" + + result = subprocess.run( + [cog_binary, "predict", "--debug", docker_image, "-i", "name=TEST_VAR"], + cwd=project_dir, + check=True, + capture_output=True, + text=True, + timeout=120.0, + ) + assert result.returncode == 0 + assert result.stdout == "ENV[TEST_VAR]=test_value\n" From a4364762941c42a429f27a1c1dbf2c621016e14b Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Fri, 20 Jun 2025 13:55:38 -0400 Subject: [PATCH 2/4] Fix fixture location --- test-integration/test_integration/test_predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 2f74cf83a6..2329ce218e 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -804,7 +804,7 @@ def test_predict_future_annotations(cog_binary): def test_predict_dotenv(docker_image, cog_binary): - project_dir = Path(__file__).parent / "fixtures/future-annotations-project" + project_dir = Path(__file__).parent / "fixtures/env-cli-project" result = subprocess.run( [cog_binary, "predict", "--debug", docker_image, "-i", "name=TEST_VAR"], From ee8f669c5c92b0a3ce49c928a88eb8b7d402d25c Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Wed, 25 Jun 2025 14:00:03 -0400 Subject: [PATCH 3/4] Remove check --- test-integration/test_integration/test_predict.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index 2329ce218e..bdbf42c0fa 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -809,7 +809,6 @@ def test_predict_dotenv(docker_image, cog_binary): result = subprocess.run( [cog_binary, "predict", "--debug", docker_image, "-i", "name=TEST_VAR"], cwd=project_dir, - check=True, capture_output=True, text=True, timeout=120.0, From 55bca6515361abf4a1c787fb7c2c09bc07410675 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Wed, 25 Jun 2025 14:26:16 -0400 Subject: [PATCH 4/4] Remove docker_image in predict --- test-integration/test_integration/test_predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index bdbf42c0fa..926e44a922 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -803,11 +803,11 @@ def test_predict_future_annotations(cog_binary): assert result.returncode == 0 -def test_predict_dotenv(docker_image, cog_binary): +def test_predict_dotenv(cog_binary): project_dir = Path(__file__).parent / "fixtures/env-cli-project" result = subprocess.run( - [cog_binary, "predict", "--debug", docker_image, "-i", "name=TEST_VAR"], + [cog_binary, "predict", "--debug", "-i", "name=TEST_VAR"], cwd=project_dir, capture_output=True, text=True,