diff --git a/AGENTS.md b/AGENTS.md index 08c137eac1..e34846d11f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -129,7 +129,7 @@ The CLI code is in the `cmd/cog/` and `pkg/` directories. Support tooling is in The main commands for working on the CLI are: - `go run ./cmd/cog` - Runs the Cog CLI directly from source (requires wheel to be built first) -- `mise run build:cog` - Builds the Cog CLI binary, embedding the Python wheel +- `mise run build:cog` - Builds the Cog CLI binary - `mise run install` - Symlinks the built binary to `/usr/local/bin` (run `build:cog` first), or to a custom path with `PREFIX=/custom/path mise run install` - `mise run test:go` - Runs all Go unit tests - `go test ./pkg/...` - Runs tests directly with `go test` @@ -180,7 +180,7 @@ COG_BINARY=dist/go/*/cog mise run test:integration 1. Run `mise install` to set up the development environment 2. Run `mise run build:sdk` after making changes to the `./python` directory 3. Run `mise run build:coglet:wheel:linux-x64` after making changes to the `./crates` directory (needed for Docker testing) -4. Run `mise run build:cog` to build the CLI (embeds the SDK wheel; picks up coglet wheel from `dist/`) +4. Run `mise run build:cog` to build the CLI (wheels are picked up from `dist/` at Docker build time, not embedded in the binary) 5. Run `mise run fmt:fix` to format code 6. Run `mise run lint` to check code quality 7. Run `mise run docs:llm` to regenerate `docs/llms.txt` after changing `README.md` or any `docs/*.md` file @@ -212,7 +212,7 @@ See `crates/README.md` for detailed architecture documentation. - `crates/coglet-python/` - PyO3 bindings for Python predictor integration ### Key Design Patterns -1. **Embedded Python Wheel**: The Go binary embeds the Python wheel at build time (`pkg/dockerfile/embed/`) +1. **Local Wheel Resolution**: The CLI discovers SDK and coglet wheels from `dist/` at Docker build time (not embedded in the binary) 2. **Docker SDK Integration**: Uses Docker Go SDK for container operations 3. **Type Safety**: Dataclasses for Python type validation, strongly typed Go interfaces 4. **Compatibility Matrix**: Automated CUDA/PyTorch/TensorFlow compatibility management diff --git a/crates/Cargo.lock b/crates/Cargo.lock index bcd0232186..8e2874cae4 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -184,11 +184,13 @@ dependencies = [ "anyhow", "async-trait", "axum", + "base64", "chrono", "dashmap", "futures", "http-body-util", "insta", + "mime_guess", "nix", "reqwest", "serde", @@ -1054,6 +1056,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mio" version = "1.1.1" @@ -2455,6 +2467,12 @@ dependencies = [ "unic-common", ] +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/crates/coglet-python/coglet/_impl.pyi b/crates/coglet-python/coglet/_impl.pyi new file mode 100644 index 0000000000..64c9098203 --- /dev/null +++ b/crates/coglet-python/coglet/_impl.pyi @@ -0,0 +1,62 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401, F403, F405 + +import builtins +import typing +from . import _sdk +__all__ = [ + "BuildInfo", + "Server", + "server", +] + +__build__: BuildInfo +__version__: builtins.str +server: Server +@typing.final +class BuildInfo: + r""" + Frozen build metadata exposed as `coglet.__build__`. + """ + @property + def version(self) -> builtins.str: ... + @property + def git_sha(self) -> builtins.str: ... + @property + def build_time(self) -> builtins.str: ... + @property + def rustc_version(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + +@typing.final +class Server: + r""" + The coglet prediction server. + + Access via `coglet.server`. Frozen — attributes cannot be set or deleted. + + - `coglet.server.active` — `True` when running inside a worker subprocess + - `coglet.server.serve(...)` — start the HTTP prediction server (blocking) + """ + @property + def active(self) -> builtins.bool: + r""" + `True` when running inside a coglet worker subprocess. + """ + def serve(self, predictor_ref: typing.Optional[builtins.str] = None, host: builtins.str = '0.0.0.0', port: builtins.int = 5000, await_explicit_shutdown: builtins.bool = False, is_train: builtins.bool = False, output_temp_dir_base: builtins.str = '/tmp/coglet/output', upload_url: typing.Optional[builtins.str] = None) -> None: + r""" + Start the HTTP prediction server. Blocks until shutdown. + """ + def _run_worker(self) -> None: + r""" + Worker subprocess entry point. Called by the orchestrator. + + Sets the active flag, installs log writers and audit hooks, + then enters the worker event loop. + """ + def _is_cancelable(self) -> builtins.bool: + r""" + Returns `True` if the current thread is in a cancelable predict call. + """ + def __repr__(self) -> builtins.str: ... + diff --git a/crates/coglet-python/src/bin/stub_gen.rs b/crates/coglet-python/src/bin/stub_gen.rs index 497838374a..e3b262c8bb 100644 --- a/crates/coglet-python/src/bin/stub_gen.rs +++ b/crates/coglet-python/src/bin/stub_gen.rs @@ -1,11 +1,41 @@ //! Generate Python stub files for coglet. //! //! Run with: cargo run --bin stub_gen +//! +//! Custom generate logic: pyo3-stub-gen places classes from the native +//! `coglet._impl` module into the `coglet` parent package, but mypy stubgen +//! overwrites `coglet/__init__.pyi` from the hand-maintained `__init__.py`. +//! We redirect the `coglet` module output to `coglet/_impl.pyi` so the +//! native module types are preserved. use pyo3_stub_gen::Result; +use std::fs; +use std::io::Write; fn main() -> Result<()> { let stub = coglet::stub_info()?; - stub.generate()?; + + for (name, module) in &stub.modules { + let normalized = name.replace('-', "_"); + + let dest = if normalized == "coglet" { + // Native module classes land here — redirect to _impl.pyi + stub.python_root.join("coglet").join("_impl.pyi") + } else { + // Submodules like "coglet._sdk" → coglet/_sdk/__init__.pyi + let path = normalized.replace('.', "/"); + stub.python_root.join(&path).join("__init__.pyi") + }; + + let dir = dest.parent().expect("cannot get parent directory"); + if !dir.exists() { + fs::create_dir_all(dir)?; + } + + let mut f = fs::File::create(&dest)?; + write!(f, "{module}")?; + eprintln!("Generated stub: {}", dest.display()); + } + Ok(()) } diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index 95b1693504..89fc059243 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -19,6 +19,12 @@ use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberI // Define stub info gatherer for generating .pyi files pyo3_stub_gen::define_stub_info_gatherer!(stub_info); +// Module-level attributes (pyo3-stub-gen can't see m.add() calls). +// Uses "coglet" because that's the module key in StubInfo for the native module. +pyo3_stub_gen::module_variable!("coglet", "__version__", &str); +pyo3_stub_gen::module_variable!("coglet", "__build__", BuildInfo); +pyo3_stub_gen::module_variable!("coglet", "server", CogletServer); + use coglet_core::{ Health, PredictionService, SetupResult, VersionInfo, transport::{ServerConfig, serve as http_serve}, @@ -188,7 +194,8 @@ impl CogletServer { } /// Start the HTTP prediction server. Blocks until shutdown. - #[pyo3(signature = (predictor_ref=None, host="0.0.0.0".to_string(), port=5000, await_explicit_shutdown=false, is_train=false))] + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (predictor_ref=None, host="0.0.0.0".to_string(), port=5000, await_explicit_shutdown=false, is_train=false, output_temp_dir_base="/tmp/coglet/output".to_string(), upload_url=None))] fn serve( &self, py: Python<'_>, @@ -197,6 +204,8 @@ impl CogletServer { port: u16, await_explicit_shutdown: bool, is_train: bool, + output_temp_dir_base: String, + upload_url: Option, ) -> PyResult<()> { serve_impl( py, @@ -205,6 +214,8 @@ impl CogletServer { port, await_explicit_shutdown, is_train, + output_temp_dir_base, + upload_url, ) } @@ -254,6 +265,7 @@ impl CogletServer { } } +#[allow(clippy::too_many_arguments)] fn serve_impl( py: Python<'_>, predictor_ref: Option, @@ -261,6 +273,8 @@ fn serve_impl( port: u16, await_explicit_shutdown: bool, is_train: bool, + _output_temp_dir_base: String, + upload_url: Option, ) -> PyResult<()> { let (setup_log_tx, setup_log_rx) = tokio::sync::mpsc::unbounded_channel(); init_tracing(false, Some(setup_log_tx)); @@ -303,7 +317,15 @@ fn serve_impl( }; info!(predictor_ref = %pred_ref, is_train, "Using subprocess isolation"); - serve_subprocess(py, pred_ref, config, version, is_train, setup_log_rx) + serve_subprocess( + py, + pred_ref, + config, + version, + is_train, + setup_log_rx, + upload_url, + ) } fn serve_subprocess( @@ -313,6 +335,7 @@ fn serve_subprocess( version: VersionInfo, is_train: bool, mut setup_log_rx: tokio::sync::mpsc::UnboundedReceiver, + upload_url: Option, ) -> PyResult<()> { let max_concurrency = read_max_concurrency(py); info!( @@ -322,7 +345,8 @@ fn serve_subprocess( let orch_config = coglet_core::orchestrator::OrchestratorConfig::new(pred_ref) .with_num_slots(max_concurrency) - .with_train(is_train); + .with_train(is_train) + .with_upload_url(upload_url); let service = Arc::new( PredictionService::new_no_pool() diff --git a/crates/coglet-python/src/log_writer.rs b/crates/coglet-python/src/log_writer.rs index 981e35a4f4..22e69398ac 100644 --- a/crates/coglet-python/src/log_writer.rs +++ b/crates/coglet-python/src/log_writer.rs @@ -595,7 +595,7 @@ mod tests { fn registry_operations() { let prediction_id = "pred_123".to_string(); let (tx, _rx) = mpsc::unbounded_channel(); - let sender = Arc::new(SlotSender::new(tx)); + let sender = Arc::new(SlotSender::new(tx, std::env::temp_dir())); // Register register_prediction(prediction_id.clone(), sender.clone()); @@ -609,7 +609,7 @@ mod tests { #[test] fn slot_sender_sends_log() { let (tx, mut rx) = mpsc::unbounded_channel(); - let sender = SlotSender::new(tx); + let sender = SlotSender::new(tx, std::env::temp_dir()); sender.send_log(LogSource::Stdout, "hello").unwrap(); @@ -626,7 +626,7 @@ mod tests { #[test] fn slot_sender_ignores_empty() { let (tx, mut rx) = mpsc::unbounded_channel(); - let sender = SlotSender::new(tx); + let sender = SlotSender::new(tx, std::env::temp_dir()); sender.send_log(LogSource::Stderr, "").unwrap(); @@ -639,7 +639,7 @@ mod tests { let (tx, rx) = mpsc::unbounded_channel::(); drop(rx); // Close receiver - let sender = SlotSender::new(tx); + let sender = SlotSender::new(tx, std::env::temp_dir()); let result = sender.send_log(LogSource::Stdout, "hello"); assert!(result.is_err()); diff --git a/crates/coglet-python/src/predictor.rs b/crates/coglet-python/src/predictor.rs index 68dd512d26..8e2c9f169d 100644 --- a/crates/coglet-python/src/predictor.rs +++ b/crates/coglet-python/src/predictor.rs @@ -1,10 +1,11 @@ //! Python predictor loading and invocation. -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use pyo3::prelude::*; use pyo3::types::PyDict; +use coglet_core::worker::SlotSender; use coglet_core::{PredictionError, PredictionOutput, PredictionResult}; use crate::cancel; @@ -102,6 +103,96 @@ fn format_validation_error(py: Python<'_>, err: &PyErr) -> String { err.value(py).to_string() } +/// Send a single output item over IPC, routing file outputs to disk. +/// +/// For Path outputs (os.PathLike): sends the existing file path via send_file_output. +/// For IOBase outputs: reads bytes, writes to output_dir via write_file_output. +/// For everything else: processes through make_encodeable + upload_files, then send_output. +fn send_output_item( + py: Python<'_>, + item: &Bound<'_, PyAny>, + json_module: &Bound<'_, PyAny>, + slot_sender: &SlotSender, +) -> Result<(), PredictionError> { + let os = py + .import("os") + .map_err(|e| PredictionError::Failed(format!("Failed to import os: {}", e)))?; + let io_mod = py + .import("io") + .map_err(|e| PredictionError::Failed(format!("Failed to import io: {}", e)))?; + let pathlike = os + .getattr("PathLike") + .map_err(|e| PredictionError::Failed(format!("Failed to get os.PathLike: {}", e)))?; + let iobase = io_mod + .getattr("IOBase") + .map_err(|e| PredictionError::Failed(format!("Failed to get io.IOBase: {}", e)))?; + + if item.is_instance(&pathlike).unwrap_or(false) { + // Path output — file already on disk, send path reference + let path_str: String = item + .call_method0("__fspath__") + .and_then(|p| p.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; + slot_sender + .send_file_output(std::path::PathBuf::from(path_str), None) + .map_err(|e| PredictionError::Failed(format!("Failed to send file output: {}", e)))?; + return Ok(()); + } + + if item.is_instance(&iobase).unwrap_or(false) { + // IOBase output — read bytes, write to disk via SlotSender + // Seek to start if seekable + if item + .call_method0("seekable") + .and_then(|r| r.extract::()) + .unwrap_or(false) + { + let _ = item.call_method1("seek", (0,)); + } + let data: Vec = item + .call_method0("read") + .and_then(|d| d.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to read IOBase: {}", e)))?; + + // Try to guess extension from filename + let ext = item + .getattr("name") + .and_then(|n| n.extract::()) + .ok() + .and_then(|name| { + std::path::Path::new(&name) + .extension() + .and_then(|e| e.to_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| "bin".to_string()); + + slot_sender + .write_file_output(&data, &ext, None) + .map_err(|e| PredictionError::Failed(format!("Failed to write file output: {}", e)))?; + return Ok(()); + } + + // Non-file output — process normally + let processed = output::process_output_item(py, item) + .map_err(|e| PredictionError::Failed(format!("Failed to process output item: {}", e)))?; + + let item_str: String = json_module + .call_method1("dumps", (&processed,)) + .map_err(|e| PredictionError::Failed(format!("Failed to serialize output item: {}", e)))? + .extract() + .map_err(|e| PredictionError::Failed(format!("Failed to extract output string: {}", e)))?; + + let item_json: serde_json::Value = serde_json::from_str(&item_str) + .map_err(|e| PredictionError::Failed(format!("Failed to parse output JSON: {}", e)))?; + + slot_sender + .send_output(item_json) + .map_err(|e| PredictionError::Failed(format!("Failed to send output: {}", e)))?; + + Ok(()) +} + /// Type alias for Python object (Py). type PyObject = Py; @@ -480,6 +571,7 @@ impl PythonPredictor { pub fn predict_worker( &self, input: serde_json::Value, + slot_sender: Arc, ) -> Result { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { @@ -530,9 +622,9 @@ impl PythonPredictor { let is_generator: bool = result_bound.is_instance(&generator_type).unwrap_or(false); let output = if is_generator { - self.process_generator_output(py, result_bound, &json_module)? + self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { - self.process_single_output(py, result_bound, &json_module)? + self.process_single_output(py, result_bound, &json_module, &slot_sender)? }; // prepared drops here, cleaning up temp files via RAII @@ -550,6 +642,7 @@ impl PythonPredictor { pub fn train_worker( &self, input: serde_json::Value, + slot_sender: Arc, ) -> Result { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { @@ -600,9 +693,9 @@ impl PythonPredictor { let is_generator: bool = result_bound.is_instance(&generator_type).unwrap_or(false); let output = if is_generator { - self.process_generator_output(py, result_bound, &json_module)? + self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { - self.process_single_output(py, result_bound, &json_module)? + self.process_single_output(py, result_bound, &json_module, &slot_sender)? }; drop(prepared); @@ -615,14 +708,14 @@ impl PythonPredictor { }) } - /// Process generator output into PredictionOutput::Stream. + /// Process generator output by streaming each yield over IPC. fn process_generator_output( &self, py: Python<'_>, result: &Bound<'_, PyAny>, json_module: &Bound<'_, PyAny>, + slot_sender: &SlotSender, ) -> Result { - let mut outputs = Vec::new(); let iter = result .try_iter() .map_err(|e| PredictionError::Failed(format!("Failed to iterate generator: {}", e)))?; @@ -635,37 +728,83 @@ impl PythonPredictor { PredictionError::Failed(format!("Generator iteration error: {}", e)) })?; - let processed = output::process_output_item(py, &item).map_err(|e| { - PredictionError::Failed(format!("Failed to process output item: {}", e)) - })?; - - let item_str: String = json_module - .call_method1("dumps", (&processed,)) - .map_err(|e| { - PredictionError::Failed(format!("Failed to serialize output item: {}", e)) - })? - .extract() - .map_err(|e| { - PredictionError::Failed(format!("Failed to extract output string: {}", e)) - })?; - - let item_json: serde_json::Value = serde_json::from_str(&item_str).map_err(|e| { - PredictionError::Failed(format!("Failed to parse output JSON: {}", e)) - })?; - - outputs.push(item_json); + send_output_item(py, &item, json_module, slot_sender)?; } - Ok(PredictionOutput::Stream(outputs)) + // Outputs already streamed over IPC — return empty stream + Ok(PredictionOutput::Stream(vec![])) } /// Process single output into PredictionOutput::Single. + /// + /// For file outputs (Path/IOBase), the file is sent via slot_sender and + /// an empty Single(Null) is returned since the output was already streamed. fn process_single_output( &self, py: Python<'_>, result: &Bound<'_, PyAny>, json_module: &Bound<'_, PyAny>, + slot_sender: &SlotSender, ) -> Result { + // Check for file-type outputs first + let os = py + .import("os") + .map_err(|e| PredictionError::Failed(format!("Failed to import os: {}", e)))?; + let io_mod = py + .import("io") + .map_err(|e| PredictionError::Failed(format!("Failed to import io: {}", e)))?; + let pathlike = os + .getattr("PathLike") + .map_err(|e| PredictionError::Failed(format!("Failed to get os.PathLike: {}", e)))?; + let iobase = io_mod + .getattr("IOBase") + .map_err(|e| PredictionError::Failed(format!("Failed to get io.IOBase: {}", e)))?; + + if result.is_instance(&pathlike).unwrap_or(false) { + let path_str: String = result + .call_method0("__fspath__") + .and_then(|p| p.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; + slot_sender + .send_file_output(std::path::PathBuf::from(path_str), None) + .map_err(|e| { + PredictionError::Failed(format!("Failed to send file output: {}", e)) + })?; + return Ok(PredictionOutput::Single(serde_json::Value::Null)); + } + + if result.is_instance(&iobase).unwrap_or(false) { + if result + .call_method0("seekable") + .and_then(|r| r.extract::()) + .unwrap_or(false) + { + let _ = result.call_method1("seek", (0,)); + } + let data: Vec = result + .call_method0("read") + .and_then(|d| d.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to read IOBase: {}", e)))?; + let ext = result + .getattr("name") + .and_then(|n| n.extract::()) + .ok() + .and_then(|name| { + std::path::Path::new(&name) + .extension() + .and_then(|e| e.to_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| "bin".to_string()); + slot_sender + .write_file_output(&data, &ext, None) + .map_err(|e| { + PredictionError::Failed(format!("Failed to write file output: {}", e)) + })?; + return Ok(PredictionOutput::Single(serde_json::Value::Null)); + } + + // Non-file output — process normally let processed = output::process_output(py, result, None) .map_err(|e| PredictionError::Failed(format!("Failed to process output: {}", e)))?; @@ -785,6 +924,7 @@ impl PythonPredictor { py: Python<'_>, result: &Bound<'_, PyAny>, is_async_gen: bool, + slot_sender: &SlotSender, ) -> Result { let json_module = py .import("json") @@ -795,28 +935,13 @@ impl PythonPredictor { // Process output let output = if is_async_gen { - // Result is a list - let mut outputs = Vec::new(); + // Result is a pre-collected list — stream each item over IPC if let Ok(list) = result.extract::>>() { for item in list { - let processed = output::process_output_item(py, &item).map_err(|e| { - PredictionError::Failed(format!("Failed to process output item: {}", e)) - })?; - let item_str: String = json_module - .call_method1("dumps", (&processed,)) - .map_err(|e| { - PredictionError::Failed(format!("Failed to serialize: {}", e)) - })? - .extract() - .map_err(|e| { - PredictionError::Failed(format!("Failed to extract: {}", e)) - })?; - let item_json: serde_json::Value = serde_json::from_str(&item_str) - .map_err(|e| PredictionError::Failed(format!("Failed to parse: {}", e)))?; - outputs.push(item_json); + send_output_item(py, &item, &json_module, slot_sender)?; } } - PredictionOutput::Stream(outputs) + PredictionOutput::Stream(vec![]) } else { // Check if result is a generator (sync generator from async predict) let generator_type = types_module.getattr("GeneratorType").map_err(|e| { @@ -825,9 +950,9 @@ impl PythonPredictor { let is_generator: bool = result.is_instance(&generator_type).unwrap_or(false); if is_generator { - self.process_generator_output(py, result, &json_module)? + self.process_generator_output(py, result, &json_module, slot_sender)? } else { - self.process_single_output(py, result, &json_module)? + self.process_single_output(py, result, &json_module, slot_sender)? } }; diff --git a/crates/coglet-python/src/worker_bridge.rs b/crates/coglet-python/src/worker_bridge.rs index 60e7f247b6..0b58c18cca 100644 --- a/crates/coglet-python/src/worker_bridge.rs +++ b/crates/coglet-python/src/worker_bridge.rs @@ -381,8 +381,14 @@ impl PredictHandler for PythonPredictHandler { }); // Block on future.result() + let sender_for_async = slot_sender.clone(); let result = Python::attach(|py| match future.call_method0(py, "result") { - Ok(result) => pred.process_async_result(py, result.bind(py), is_async_gen), + Ok(result) => pred.process_async_result( + py, + result.bind(py), + is_async_gen, + &sender_for_async, + ), Err(e) => { let err_str = e.to_string(); if err_str.contains("CancelledError") || err_str.contains("cancelled") { @@ -404,7 +410,7 @@ impl PredictHandler for PythonPredictHandler { // Sync train - set sync prediction ID for log routing crate::log_writer::set_sync_prediction_id(Some(&id)); let _cancelable = crate::cancel::enter_cancelable(); - let r = pred.train_worker(input); + let r = pred.train_worker(input, slot_sender.clone()); crate::log_writer::set_sync_prediction_id(None); r } @@ -446,8 +452,14 @@ impl PredictHandler for PythonPredictHandler { }); // Block on future.result() + let sender_for_async = slot_sender.clone(); let result = Python::attach(|py| match future.call_method0(py, "result") { - Ok(result) => pred.process_async_result(py, result.bind(py), is_async_gen), + Ok(result) => pred.process_async_result( + py, + result.bind(py), + is_async_gen, + &sender_for_async, + ), Err(e) => { let err_str = e.to_string(); if err_str.contains("CancelledError") || err_str.contains("cancelled") { @@ -470,7 +482,7 @@ impl PredictHandler for PythonPredictHandler { crate::log_writer::set_sync_prediction_id(Some(&id)); let _cancelable = crate::cancel::enter_cancelable(); tracing::trace!(%slot, %id, "Calling predict_worker"); - let r = pred.predict_worker(input); + let r = pred.predict_worker(input, slot_sender.clone()); tracing::trace!(%slot, %id, "predict_worker returned"); crate::log_writer::set_sync_prediction_id(None); r diff --git a/crates/coglet/Cargo.toml b/crates/coglet/Cargo.toml index 73ac7cb41a..5666680530 100644 --- a/crates/coglet/Cargo.toml +++ b/crates/coglet/Cargo.toml @@ -21,6 +21,10 @@ async-trait = "0.1" serde.workspace = true serde_json.workspace = true +# Encoding +base64 = "0.22.1" +mime_guess = "2.0.5" + # Identifiers uuid.workspace = true diff --git a/crates/coglet/src/bridge/codec.rs b/crates/coglet/src/bridge/codec.rs index 76fc89ad4d..23e25fcc37 100644 --- a/crates/coglet/src/bridge/codec.rs +++ b/crates/coglet/src/bridge/codec.rs @@ -64,6 +64,10 @@ impl Encoder for JsonCodec { tracing::trace!(json_size_bytes = json_len, "Encoding frame"); if json_len > 100_000 { tracing::info!( + // This log line should be shipped across the IPC to be emitted, unlike the + // above trace line. This is a real indicator that we've encoded a large + // frame and is generally useful. + target: "coglet::bridge::codec::large_frame", json_size_bytes = json_len, json_size_kb = json_len / 1024, "Large frame being encoded" @@ -117,6 +121,7 @@ mod tests { let req = SlotRequest::Predict { id: "test".to_string(), input: serde_json::json!({"x": 1}), + output_dir: "/tmp/coglet/outputs/test".to_string(), }; codec.encode(req.clone(), &mut buf).unwrap(); @@ -127,14 +132,17 @@ mod tests { SlotRequest::Predict { id: id1, input: input1, + output_dir: dir1, }, SlotRequest::Predict { id: id2, input: input2, + output_dir: dir2, }, ) => { assert_eq!(id1, id2); assert_eq!(input1, input2); + assert_eq!(dir1, dir2); } } } diff --git a/crates/coglet/src/bridge/protocol.rs b/crates/coglet/src/bridge/protocol.rs index dfab4b2de4..2a76461504 100644 --- a/crates/coglet/src/bridge/protocol.rs +++ b/crates/coglet/src/bridge/protocol.rs @@ -42,6 +42,21 @@ impl std::fmt::Display for SlotId { } } +const MAX_WORKER_LOG_SIZE: usize = 1024 * 1024 * 4; // 4MIB +const WORKER_LOG_TRUNCATE_NOTICE: &str = "[**** LOG LINE TRUNCATED AT 4 MiB ****]"; + +/// To ensure no panics happen due to oversized log lines, we truncate at 4 MiB. 1 MiB +/// let alone 4 MiB log line boarder/exceed usefulness from a readability standpoint. +pub fn truncate_worker_log(mut log_message: String) -> String { + if log_message.len() > MAX_WORKER_LOG_SIZE { + let boundary = + log_message.floor_char_boundary(MAX_WORKER_LOG_SIZE - WORKER_LOG_TRUNCATE_NOTICE.len()); + log_message.truncate(boundary); + log_message.push_str(WORKER_LOG_TRUNCATE_NOTICE); + } + log_message +} + /// Control messages from parent to worker. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -186,9 +201,21 @@ pub enum SlotRequest { Predict { id: String, input: serde_json::Value, + /// Directory for writing file outputs (created by coglet before dispatch). + /// Not included in API responses — internal transport detail. + output_dir: String, }, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum FileOutputKind { + /// Output is a file-like return type (e.g. File, Path) + FileType, + /// Output exceeds size threshold for bridge codec serialization but is not a file-like return type + Oversized, +} + /// Messages from worker to parent on slot socket. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -198,6 +225,16 @@ pub enum SlotResponse { data: String, }, + /// Output for a file/path-like output return type or an output that exceeds the size threshold + /// for bridge codec serialization. + FileOutput { + filename: String, + kind: FileOutputKind, + /// Explicit MIME type from the predictor. Falls back to mime_guess when None. + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + }, + /// Streaming output chunk (for generators). Output { output: serde_json::Value, @@ -345,6 +382,7 @@ mod tests { let req = SlotRequest::Predict { id: "pred_123".to_string(), input: json!({"text": "hello"}), + output_dir: "/tmp/coglet/outputs/pred_123".to_string(), }; insta::assert_json_snapshot!(req); } @@ -392,4 +430,29 @@ mod tests { }; insta::assert_json_snapshot!(resp); } + + #[test] + fn truncate_worker_log_truncates_long_messages() { + let emoji = "🦀"; // 4-byte UTF-8 character + // known size of truncate target, add one more character + let count = 1024 * 1024 * 1024 * 4 / emoji.len() + 1; + let message: String = truncate_worker_log(emoji.repeat(count)); + assert!( + message.ends_with(WORKER_LOG_TRUNCATE_NOTICE), + "log message didn't end with {}", + WORKER_LOG_TRUNCATE_NOTICE + ); + } + + #[test] + fn truncate_worker_log_does_not_truncate_short_messages() { + let emoji = "🦀"; // 4-byte UTF-8 character + // known size of truncate target, add one more character + let count = 10; + let message: String = truncate_worker_log(std::iter::repeat(emoji).take(count).collect()); + assert!( + !message.ends_with(WORKER_LOG_TRUNCATE_NOTICE), + "short log message was truncated" + ); + } } diff --git a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap b/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap index 16594e9d37..d22ff9cb96 100644 --- a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap +++ b/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap @@ -7,5 +7,6 @@ expression: req "id": "pred_123", "input": { "text": "hello" - } + }, + "output_dir": "/tmp/coglet/outputs/pred_123" } diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 421c9f8b1b..c5b795c1d2 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -22,12 +22,66 @@ use tokio_util::codec::{FramedRead, FramedWrite}; use crate::PredictionOutput; use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{ - ControlRequest, ControlResponse, HealthcheckStatus, SlotId, SlotRequest, SlotResponse, + ControlRequest, ControlResponse, FileOutputKind, HealthcheckStatus, SlotId, SlotRequest, + SlotResponse, }; use crate::bridge::transport::create_transport; use crate::permit::{InactiveSlotIdleToken, PermitPool, SlotIdleToken}; use crate::prediction::Prediction; +/// Upload a file to a signed endpoint, returning the final URL. +/// +/// Matches the behavior of Python cog's `put_file_to_signed_endpoint`: +/// PUT to `{endpoint}{filename}` with Content-Type header, then extract +/// the final URL from the Location header (falling back to response URL), +/// stripping query parameters. Follows redirects automatically. +async fn upload_file( + endpoint: &str, + filename: &str, + data: &[u8], + content_type: &str, +) -> Result { + let url = format!("{endpoint}{filename}"); + let client = reqwest::Client::new(); + let resp = client + .put(&url) + .header("Content-Type", content_type) + .body(data.to_vec()) + .timeout(std::time::Duration::from_secs(25)) + .send() + .await + .map_err(|e| format!("upload request failed: {e}"))?; + + if !resp.status().is_success() { + return Err(format!("upload returned status {}", resp.status())); + } + + // Prefer Location header, fall back to final request URL (after redirects) + let final_url = resp + .headers() + .get("location") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| resp.url().to_string()); + + // Strip query parameters (signing gubbins) + match reqwest::Url::parse(&final_url) { + Ok(mut parsed) => { + parsed.set_query(None); + Ok(parsed.to_string()) + } + Err(_) => Ok(final_url), + } +} + +fn ensure_trailing_slash(s: &str) -> String { + if s.ends_with('/') { + s.to_string() + } else { + format!("{s}/") + } +} + /// Try to lock a prediction mutex. /// On poison: logs error, recovers to fail the prediction, returns None. /// Caller should remove the prediction from tracking if None is returned. @@ -209,6 +263,8 @@ pub struct OrchestratorConfig { pub is_async: bool, pub setup_timeout: Duration, pub spawner: Arc, + /// Upload URL prefix for file outputs (from --upload-url CLI arg). + pub upload_url: Option, } impl OrchestratorConfig { @@ -220,9 +276,15 @@ impl OrchestratorConfig { is_async: false, setup_timeout: Duration::from_secs(300), spawner: Arc::new(SimpleSpawner), + upload_url: None, } } + pub fn with_upload_url(mut self, upload_url: Option) -> Self { + self.upload_url = upload_url; + self + } + pub fn with_num_slots(mut self, n: usize) -> Self { self.num_slots = n; self @@ -496,6 +558,7 @@ pub async fn spawn_worker( let pool_for_loop = Arc::clone(&pool); let ctrl_writer_for_loop = Arc::clone(&ctrl_writer); + let upload_url = config.upload_url.clone(); tokio::spawn(async move { run_event_loop( ctrl_reader, @@ -504,6 +567,7 @@ pub async fn spawn_worker( register_rx, healthcheck_rx, pool_for_loop, + upload_url, ) .await; }); @@ -532,12 +596,14 @@ async fn run_event_loop( )>, mut healthcheck_rx: mpsc::Receiver>, pool: Arc, + upload_url: Option, ) { let mut predictions: HashMap>> = HashMap::new(); let mut idle_senders: HashMap> = HashMap::new(); let mut pending_healthchecks: Vec> = Vec::new(); let mut healthcheck_counter: u64 = 0; + let mut pending_uploads: HashMap>> = HashMap::new(); let (slot_msg_tx, mut slot_msg_rx) = mpsc::channel::<(SlotId, Result)>(100); @@ -776,21 +842,135 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::Done { id, output, predict_time }) => { + Ok(SlotResponse::FileOutput { filename, kind, mime_type }) => { + tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); + let bytes = match std::fs::read(&filename) { + Ok(b) => b, + Err(e) => { + tracing::error!(%slot_id, %filename, error = %e, "Failed to read FileOutput"); + continue; + } + }; + match kind { + FileOutputKind::Oversized => { + let output: serde_json::Value = match serde_json::from_slice(&bytes) { + Ok(val) => val, + Err(e) => { + tracing::error!(%slot_id, %filename, error = %e, "Failed to parse oversized JSON"); + continue; + } + }; + let poisoned = if let Some(pred) = predictions.get(&slot_id) { + if let Some(mut p) = try_lock_prediction(pred) { + p.append_output(output); + false + } else { + true + } + } else { + false + }; + if poisoned { + predictions.remove(&slot_id); + } + } + FileOutputKind::FileType => { + let mime = mime_type.unwrap_or_else(|| { + mime_guess::from_path(&filename) + .first_or_octet_stream() + .to_string() + }); + if let Some(ref url) = upload_url { + // Spawn upload task so we don't block the event loop + let pred = predictions.get(&slot_id).cloned(); + let endpoint = ensure_trailing_slash(url); + let basename = std::path::Path::new(&filename) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("output") + .to_string(); + let handle = tokio::spawn(async move { + match upload_file(&endpoint, &basename, &bytes, &mime).await { + Ok(url) => { + if let Some(pred) = pred + && let Some(mut p) = try_lock_prediction(&pred) + { + p.append_output(serde_json::Value::String(url)); + } + } + Err(e) => { + tracing::error!(error = %e, "Failed to upload file output"); + } + } + }); + pending_uploads.entry(slot_id).or_default().push(handle); + } else { + // No upload URL — base64-encode as data URI + use base64::Engine; + let encoded = base64::engine::general_purpose::STANDARD + .encode(&bytes); + let output = serde_json::Value::String(format!( + "data:{mime};base64,{encoded}" + )); + let poisoned = if let Some(pred) = predictions.get(&slot_id) { + if let Some(mut p) = try_lock_prediction(pred) { + p.append_output(output); + false + } else { + true + } + } else { + false + }; + if poisoned { + predictions.remove(&slot_id); + } + } + } + } + } + Ok(SlotResponse::Done { id, output: _, predict_time }) => { tracing::info!( target: "coglet::prediction", prediction_id = %id, predict_time, "Prediction succeeded" ); + let uploads = pending_uploads.remove(&slot_id).unwrap_or_default(); if let Some(pred) = predictions.remove(&slot_id) { - if let Some(mut p) = try_lock_prediction(&pred) { - let pred_output = output - .map(PredictionOutput::Single) - .unwrap_or(PredictionOutput::Single(serde_json::Value::Null)); - p.set_succeeded(pred_output); + if uploads.is_empty() { + // No pending uploads — complete synchronously to avoid + // a race between tokio::spawn and Notify::notified() in + // service.rs. notify_waiters() only wakes already- + // registered waiters; spawning a task can fire the + // notification before the service registers its waiter. + if let Some(mut p) = try_lock_prediction(&pred) { + let pred_output = match p.take_outputs().as_slice() { + [] => PredictionOutput::Single(serde_json::Value::Null), + [single] => PredictionOutput::Single(single.clone()), + many => PredictionOutput::Stream(many.to_vec()), + }; + p.set_succeeded(pred_output); + } + } else { + // Has pending uploads — must spawn to await them. + // TODO(#2748): when cancellation is wired end-to-end, + // this spawn should also observe the cancel token and + // abort in-flight uploads on cancellation. + tokio::spawn(async move { + for h in uploads { + let _ = h.await; + } + if let Some(mut p) = try_lock_prediction(&pred) { + let pred_output = match p.take_outputs().as_slice() { + [] => PredictionOutput::Single(serde_json::Value::Null), + [single] => PredictionOutput::Single(single.clone()), + many => PredictionOutput::Stream(many.to_vec()), + }; + p.set_succeeded(pred_output); + } + }); } - // On mutex poison, prediction already failed - nothing more to do } else { tracing::warn!(%slot_id, %id, "Prediction not found for Done message"); } @@ -802,6 +982,10 @@ async fn run_event_loop( %error, "Prediction failed" ); + // Abort any pending uploads — prediction is terminal + if let Some(handles) = pending_uploads.remove(&slot_id) { + for h in handles { h.abort(); } + } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { @@ -814,6 +998,10 @@ async fn run_event_loop( prediction_id = %id, "Prediction cancelled" ); + // Abort any pending uploads — prediction is terminal + if let Some(handles) = pending_uploads.remove(&slot_id) { + for h in handles { h.abort(); } + } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { @@ -822,6 +1010,9 @@ async fn run_event_loop( } Err(e) => { tracing::error!(%slot_id, error = %e, "Slot socket error"); + if let Some(handles) = pending_uploads.remove(&slot_id) { + for h in handles { h.abort(); } + } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index ab976cf8c0..6ae9dc3616 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -119,18 +119,24 @@ impl Prediction { pub fn set_succeeded(&mut self, output: PredictionOutput) { self.status = PredictionStatus::Succeeded; self.output = Some(output); - self.completion.notify_waiters(); + // notify_one stores a permit so a future .notified().await will + // consume it immediately. notify_waiters only wakes currently- + // registered waiters and would race with the service task that + // checks is_terminal() then awaits — the notification can fire + // in between. There is exactly one waiter per prediction + // (service.rs predict()), so notify_one is semantically correct. + self.completion.notify_one(); } pub fn set_failed(&mut self, error: String) { self.status = PredictionStatus::Failed; self.error = Some(error); - self.completion.notify_waiters(); + self.completion.notify_one(); } pub fn set_canceled(&mut self) { self.status = PredictionStatus::Canceled; - self.completion.notify_waiters(); + self.completion.notify_one(); } pub fn elapsed(&self) -> std::time::Duration { @@ -153,6 +159,10 @@ impl Prediction { &self.outputs } + pub fn take_outputs(&mut self) -> Vec { + std::mem::take(&mut self.outputs) + } + pub fn output(&self) -> Option<&PredictionOutput> { self.output.as_ref() } diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 17e4722dee..5a852de703 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -297,9 +297,18 @@ impl PredictionService { .register_prediction(slot_id, Arc::clone(&prediction_arc), idle_tx) .await; + // Create per-prediction output dir for file-based outputs + let output_dir = std::path::PathBuf::from("/tmp/coglet/outputs").join(&prediction_id); + std::fs::create_dir_all(&output_dir) + .map_err(|e| PredictionError::Failed(format!("Failed to create output dir: {}", e)))?; + let request = SlotRequest::Predict { id: prediction_id.clone(), input, + output_dir: output_dir + .to_str() + .expect("output dir path is valid UTF-8") + .to_string(), }; // permit_mut returns None if permit isn't InUse (shouldn't happen here) diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index bcaaec5d15..ac999b3b5f 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -11,6 +11,7 @@ use std::collections::HashMap; use std::io; +use std::path::PathBuf; use std::sync::Arc; use std::sync::OnceLock; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -19,6 +20,8 @@ use futures::{SinkExt, StreamExt}; use tokio::sync::mpsc; use tokio_util::codec::{FramedRead, FramedWrite}; +use crate::bridge::protocol::truncate_worker_log; + // ============================================================================ // Dropped log tracking // ============================================================================ @@ -128,7 +131,8 @@ fn init_worker_tracing(tx: mpsc::Sender) { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{ - ControlRequest, ControlResponse, LogSource, SlotId, SlotOutcome, SlotRequest, SlotResponse, + ControlRequest, ControlResponse, FileOutputKind, LogSource, SlotId, SlotOutcome, SlotRequest, + SlotResponse, }; use crate::bridge::transport::{ChildTransportInfo, connect_transport}; use crate::orchestrator::HealthcheckResult; @@ -144,11 +148,23 @@ type SlotWriter = #[derive(Clone)] pub struct SlotSender { tx: mpsc::UnboundedSender, + output_dir: PathBuf, + file_counter: Arc, } impl SlotSender { - pub fn new(tx: mpsc::UnboundedSender) -> Self { - Self { tx } + pub fn new(tx: mpsc::UnboundedSender, output_dir: PathBuf) -> Self { + Self { + tx, + output_dir, + file_counter: Arc::new(AtomicUsize::new(0)), + } + } + + /// Generate a unique filename in the output dir. + fn next_output_path(&self, extension: &str) -> PathBuf { + let n = self.file_counter.fetch_add(1, Ordering::Relaxed); + self.output_dir.join(format!("{n}.{extension}")) } pub fn send_log(&self, source: LogSource, data: &str) -> io::Result<()> { @@ -158,7 +174,7 @@ impl SlotSender { let msg = SlotResponse::Log { source, - data: data.to_string(), + data: truncate_worker_log(data.to_string()), }; self.tx @@ -166,14 +182,76 @@ impl SlotSender { .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } + /// Write raw bytes to a file in the output dir and send as FileOutput. + /// + /// Used by FFI workers (Python, Node, etc.) to hand off file data without + /// needing language-specific file I/O — SlotSender owns the write. + pub fn write_file_output( + &self, + data: &[u8], + extension: &str, + mime_type: Option, + ) -> io::Result<()> { + let path = self.next_output_path(extension); + std::fs::write(&path, data)?; + self.send_file_output(path, mime_type) + } + + /// Send a file-typed output (e.g. Path, File return types). + /// + /// The file is already on disk at `path` — we just send the path reference. + /// `mime_type` is an explicit MIME type; when None the parent guesses from extension. + pub fn send_file_output(&self, path: PathBuf, mime_type: Option) -> io::Result<()> { + let filename = path + .to_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? + .to_string(); + let msg = SlotResponse::FileOutput { + filename, + kind: FileOutputKind::FileType, + mime_type, + }; + self.tx + .send(msg) + .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) + } + + /// Send prediction output, either inline or spilled to disk if too large. pub fn send_output(&self, output: serde_json::Value) -> io::Result<()> { - let msg = SlotResponse::Output { output }; + let msg = build_output_message(&self.output_dir, output)?; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } } +const MAX_INLINE_OUTPUT_SIZE: usize = 1024 * 1024 * 6; // 6MiB + +/// Build an output message, spilling to disk if larger than the IPC frame limit. +fn build_output_message( + output_dir: &std::path::Path, + output: serde_json::Value, +) -> io::Result { + let serialized = + serde_json::to_vec(&output).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + if serialized.len() > MAX_INLINE_OUTPUT_SIZE { + let path = output_dir.join(format!("spill_{}.json", uuid::Uuid::new_v4())); + std::fs::write(&path, &serialized)?; + let filename = path + .to_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? + .to_string(); + Ok(SlotResponse::FileOutput { + filename, + kind: FileOutputKind::Oversized, + mime_type: None, + }) + } else { + Ok(SlotResponse::Output { output }) + } +} + /// Setup phase errors. /// /// These errors occur during predictor loading and setup, before predictions @@ -585,7 +663,7 @@ pub async fn run_worker( } match request { - SlotRequest::Predict { id, input } => { + SlotRequest::Predict { id, input, output_dir } => { tracing::trace!(%slot_id, %id, "Prediction request received"); slot_busy.insert(slot_id, true); @@ -604,6 +682,7 @@ pub async fn run_worker( slot_id, id, input, + PathBuf::from(output_dir), handler, writer, ).await; @@ -647,6 +726,7 @@ async fn run_prediction( slot_id: SlotId, prediction_id: String, input: serde_json::Value, + output_dir: PathBuf, handler: Arc, writer: SlotWriter, ) -> SlotCompletion { @@ -654,7 +734,7 @@ async fn run_prediction( // Create channel for log streaming let (log_tx, mut log_rx) = mpsc::unbounded_channel::(); - let slot_sender = Arc::new(SlotSender::new(log_tx)); + let slot_sender = Arc::new(SlotSender::new(log_tx, output_dir.clone())); // Forward logs to slot socket let writer_for_logs = Arc::clone(&writer); @@ -669,7 +749,8 @@ async fn run_prediction( tracing::trace!("Prediction log forwarder exiting"); }); - // Run prediction + // Run prediction — slot_sender is moved in, dropped when predict returns, + // which closes the log channel and lets the log forwarder exit. let result = handler .predict(slot_id, prediction_id.clone(), input, slot_sender) .await; @@ -680,16 +761,39 @@ async fn run_prediction( let _ = log_forwarder.await; tracing::trace!(%slot_id, %prediction_id, "Log forwarder done"); - // Send result on slot socket + // Send result on slot socket. + // Output is always sent separately from Done so that large values get + // spilled to disk and never exceed the IPC frame limit. + let mut w = writer.lock().await; let response = match result.outcome { PredictionOutcome::Success { output, predict_time, - } => SlotResponse::Done { - id: prediction_id.clone(), - output: Some(output), - predict_time, - }, + } => { + // Send output as a separate message (handles spilling for large values). + // Skip if null or empty array — those mean "already streamed" (generators). + if !output.is_null() && output != serde_json::Value::Array(vec![]) { + let output_msg = match build_output_message(&output_dir, output) { + Ok(msg) => msg, + Err(e) => { + tracing::error!(error = %e, "Failed to build output message"); + return SlotCompletion::poisoned( + slot_id, + format!("Output spill error: {}", e), + ); + } + }; + if let Err(e) = w.send(output_msg).await { + tracing::error!(error = %e, "Failed to send prediction output"); + return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e)); + } + } + SlotResponse::Done { + id: prediction_id.clone(), + output: None, + predict_time, + } + } PredictionOutcome::Cancelled { .. } => SlotResponse::Cancelled { id: prediction_id.clone(), }, @@ -699,7 +803,6 @@ async fn run_prediction( }, }; - let mut w = writer.lock().await; if let Err(e) = w.send(response).await { tracing::error!(error = %e, "Failed to send prediction response"); return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e)); diff --git a/crates/coglet/src/worker_tracing_layer.rs b/crates/coglet/src/worker_tracing_layer.rs index 4eacdb361c..dafe54d5d2 100644 --- a/crates/coglet/src/worker_tracing_layer.rs +++ b/crates/coglet/src/worker_tracing_layer.rs @@ -10,7 +10,7 @@ use tokio::sync::mpsc; use tracing::{Level, Subscriber}; use tracing_subscriber::layer::{Context, Layer}; -use crate::bridge::protocol::ControlResponse; +use crate::bridge::protocol::{ControlResponse, truncate_worker_log}; pub struct WorkerTracingLayer { tx: mpsc::Sender, @@ -52,7 +52,7 @@ where let mut visitor = MessageVisitor::default(); event.record(&mut visitor); - let message = visitor.message; + let message = truncate_worker_log(visitor.message); // Targets excluded from IPC: // - coglet::bridge::codec: feedback loop when encoding WorkerLog messages diff --git a/docs/cli.md b/docs/cli.md index f282d72530..d6f69ef615 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -173,6 +173,7 @@ cog serve [flags] -h, --help help for serve -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") + --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` diff --git a/docs/llms.txt b/docs/llms.txt index e514c62341..948e992801 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -370,6 +370,7 @@ cog serve [flags] -h, --help help for serve -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") + --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 682ddafa5e..99bf2b801b 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -26,6 +26,17 @@ import ( "github.com/replicate/cog/pkg/registry_testhelpers" ) +// propagatedEnvVars lists host environment variables that should be propagated +// into testscript environments (Setup) and background processes (cmdCogServe). +// Keep this list in sync: if you add a new env var to propagate, add it here. +var propagatedEnvVars = []string{ + "COG_WHEEL", // SDK wheel override + "COGLET_WHEEL", // coglet wheel override + "RUST_LOG", // Rust logging control + "COG_CA_CERT", // custom CA certificates (e.g. Cloudflare WARP) + "BUILDKIT_PROGRESS", // Docker build output format +} + // Harness provides utilities for running cog integration tests. // serverInfo tracks a running cog serve process and its port type serverInfo struct { @@ -40,6 +51,40 @@ type registryInfo struct { host string // e.g., "localhost:5432" } +// mockUploadRecord records a single upload received by the mock upload server. +type mockUploadRecord struct { + Path string + ContentType string + Size int +} + +// mockUploadServer is a lightweight HTTP server that accepts PUT requests +// and records what was uploaded. +type mockUploadServer struct { + server *http.Server + port int + mu sync.Mutex + uploads []mockUploadRecord +} + +// webhookResult is the summary written to stdout by webhook-server-wait. +type webhookResult struct { + Status string `json:"status"` + OutputSize int `json:"output_size"` + HasError bool `json:"has_error"` +} + +// webhookServer accepts prediction webhook callbacks from coglet. +// It parses the JSON payload to extract status and output size, without +// ever exposing the (potentially huge) output to testscript's log buffer. +type webhookServer struct { + server *http.Server + port int + mu sync.Mutex + result *webhookResult + done chan struct{} // closed on first terminal webhook +} + type Harness struct { CogBinary string // realHome is captured at creation time before testscript overrides HOME @@ -52,6 +97,12 @@ type Harness struct { // registries tracks test registry containers for cleanup, keyed by work directory registries map[string]*registryInfo registriesMu sync.Mutex + // uploadServers tracks mock upload servers for cleanup, keyed by work directory + uploadServers map[string]*mockUploadServer + uploadServersMu sync.Mutex + // webhookServers tracks webhook receiver servers for cleanup, keyed by work directory + webhookServers map[string]*webhookServer + webhookServersMu sync.Mutex } // New creates a new Harness, resolving the cog binary location. @@ -65,11 +116,13 @@ func New() (*Harness, error) { return nil, err } return &Harness{ - CogBinary: cogBinary, - realHome: os.Getenv("HOME"), - repoRoot: repoRoot, - serverProcs: make(map[string]*serverInfo), - registries: make(map[string]*registryInfo), + CogBinary: cogBinary, + realHome: os.Getenv("HOME"), + repoRoot: repoRoot, + serverProcs: make(map[string]*serverInfo), + registries: make(map[string]*registryInfo), + uploadServers: make(map[string]*mockUploadServer), + webhookServers: make(map[string]*webhookServer), }, nil } @@ -196,6 +249,14 @@ func (h *Harness) Commands() map[string]func(ts *testscript.TestScript, neg bool NewCommand("docker-push", h.cmdDockerPush), NewCommand("mock-weights", h.cmdMockWeights), + // Mock upload server commands + NewCommand("upload-server-start", h.cmdUploadServerStart), + NewCommand("upload-server-count", h.cmdUploadServerCount), + + // Webhook receiver commands + NewCommand("webhook-server-start", h.cmdWebhookServerStart), + NewCommand("webhook-server-wait", h.cmdWebhookServerWait), + // PTY command (defined in cmd_pty.go) &PtyRunCommand{harness: h}, } @@ -251,19 +312,11 @@ func (h *Harness) Setup(env *testscript.Env) error { // Disable update checks during tests env.Setenv("COG_NO_UPDATE_CHECK", "1") - // Propagate COG_WHEEL environment variable for runtime selection - if cogWheel := os.Getenv("COG_WHEEL"); cogWheel != "" { - env.Setenv("COG_WHEEL", cogWheel) - } - - // Propagate COGLET_WHEEL for coglet server - if cogletWheel := os.Getenv("COGLET_WHEEL"); cogletWheel != "" { - env.Setenv("COGLET_WHEEL", cogletWheel) - } - - // Propagate RUST_LOG for Rust logging control - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env.Setenv("RUST_LOG", rustLog) + // Propagate host env vars listed in propagatedEnvVars + for _, key := range propagatedEnvVars { + if val := os.Getenv(key); val != "" { + env.Setenv(key, val) + } } // Generate unique image name for this test run @@ -279,6 +332,10 @@ func (h *Harness) Setup(env *testscript.Env) error { h.stopServerByWorkDir(workDir) // Stop the registry for this specific test (if any) h.stopRegistryByWorkDir(workDir) + // Stop the upload server for this specific test (if any) + h.stopUploadServerByWorkDir(workDir) + // Stop the webhook server for this specific test (if any) + h.stopWebhookServerByWorkDir(workDir) removeDockerImage(imageName) }) @@ -352,11 +409,17 @@ func (h *Harness) cmdCogServe(ts *testscript.TestScript, neg bool, args []string cmd := exec.Command(h.CogBinary, expandedArgs...) cmd.Dir = workDir - // Build environment from testscript + // Build environment from testscript. + // Always include core vars, plus everything from propagatedEnvVars. var env []string - for _, key := range []string{"HOME", "PATH", "REPO_ROOT", "COG_NO_UPDATE_CHECK", "COG_WHEEL", "COGLET_WHEEL", "RUST_LOG", "BUILDKIT_PROGRESS", "TEST_IMAGE"} { + for _, key := range []string{"HOME", "PATH", "REPO_ROOT", "COG_NO_UPDATE_CHECK", "TEST_IMAGE"} { if val := ts.Getenv(key); val != "" { - env = append(env, fmt.Sprintf("%s=%s", key, val)) + env = append(env, key+"="+val) + } + } + for _, key := range propagatedEnvVars { + if val := ts.Getenv(key); val != "" { + env = append(env, key+"="+val) } } cmd.Env = env @@ -388,14 +451,34 @@ func (h *Harness) cmdCogServe(ts *testscript.TestScript, neg bool, args []string // cmdCurl implements the 'curl' command for testscript. // It makes HTTP requests to the server started with 'serve'. // Includes built-in retry logic (10 attempts, 500ms delay) for resilience. -// Usage: curl [method] [path] [body] +// Usage: curl [-H key:value]... [method] [path] [body] // Examples: // // curl GET /health-check // curl POST /predictions '{"input":{"s":"hello"}}' +// curl -H Prefer:respond-async POST /predictions '{"input":{}}' func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 2 { - ts.Fatalf("curl: usage: curl [method] [path] [body]") + ts.Fatalf("curl: usage: curl [-H key:value]... [method] [path] [body]") + } + + // Parse -H flags for extra headers + var extraHeaders [][2]string + for len(args) >= 2 && args[0] == "-H" { + kv := args[1] + parts := strings.SplitN(kv, ":", 2) + if len(parts) != 2 { + ts.Fatalf("curl: invalid header %q (expected key:value)", kv) + } + extraHeaders = append(extraHeaders, [2]string{ + strings.TrimSpace(parts[0]), + strings.TrimSpace(parts[1]), + }) + args = args[2:] + } + + if len(args) < 2 { + ts.Fatalf("curl: usage: curl [-H key:value]... [method] [path] [body]") } serverURL := ts.Getenv("SERVER_URL") @@ -407,7 +490,7 @@ func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { path := args[1] var body string if len(args) > 2 { - body = args[2] + body = os.Expand(args[2], ts.Getenv) } // Retry settings @@ -433,6 +516,9 @@ func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { if body != "" { req.Header.Set("Content-Type", "application/json") } + for _, h := range extraHeaders { + req.Header.Set(h[0], h[1]) + } resp, err := client.Do(req) if err != nil { @@ -1059,3 +1145,271 @@ func parseSize(s string) (int64, error) { return int64(num * float64(multiplier)), nil } + +// ============================================================================= +// Mock upload server commands +// ============================================================================= + +// cmdUploadServerStart starts a mock HTTP upload server on the host. +// It accepts PUT requests, records them, and responds with a Location header. +// Usage: upload-server-start +// Exports $UPLOAD_SERVER_URL with the server's base URL. +func (h *Harness) cmdUploadServerStart(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("upload-server-start: does not support negation") + } + + workDir := ts.Getenv("WORK") + + h.uploadServersMu.Lock() + if _, exists := h.uploadServers[workDir]; exists { + h.uploadServersMu.Unlock() + ts.Fatalf("upload-server-start: server already running for this test") + } + h.uploadServersMu.Unlock() + + mus := &mockUploadServer{} + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + body, _ := io.ReadAll(r.Body) + record := mockUploadRecord{ + Path: r.URL.Path, + ContentType: r.Header.Get("Content-Type"), + Size: len(body), + } + mus.mu.Lock() + mus.uploads = append(mus.uploads, record) + mus.mu.Unlock() + + // Return a clean URL without query params (simulates a signed URL redirect) + location := fmt.Sprintf("http://host.docker.internal:%d%s", mus.port, r.URL.Path) + w.Header().Set("Location", location) + w.WriteHeader(http.StatusOK) + }) + + mus.server = &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec // test harness, not production + + // Bind to all interfaces so the container can reach us via host.docker.internal + ln, err := net.Listen("tcp", "0.0.0.0:0") //nolint:gosec // must be reachable from Docker container + if err != nil { + ts.Fatalf("upload-server-start: failed to listen: %v", err) + } + mus.port = ln.Addr().(*net.TCPAddr).Port + + go func() { _ = mus.server.Serve(ln) }() + + h.uploadServersMu.Lock() + h.uploadServers[workDir] = mus + h.uploadServersMu.Unlock() + + // Advertise host.docker.internal so the container can reach the host server. + // On Linux, cog serve adds --add-host=host.docker.internal:host-gateway. + // On Mac, Docker Desktop resolves host.docker.internal automatically. + url := fmt.Sprintf("http://host.docker.internal:%d/", mus.port) + ts.Setenv("UPLOAD_SERVER_URL", url) + ts.Logf("upload-server-start: listening on 0.0.0.0:%d, container URL: %s", mus.port, url) +} + +// cmdUploadServerCount verifies exactly N uploads were received. +// Usage: upload-server-count N +func (h *Harness) cmdUploadServerCount(ts *testscript.TestScript, neg bool, args []string) { + if len(args) != 1 { + ts.Fatalf("upload-server-count: usage: upload-server-count N") + } + + expected, err := strconv.Atoi(args[0]) + if err != nil { + ts.Fatalf("upload-server-count: invalid count %q: %v", args[0], err) + } + + workDir := ts.Getenv("WORK") + h.uploadServersMu.Lock() + mus, exists := h.uploadServers[workDir] + h.uploadServersMu.Unlock() + + if !exists { + ts.Fatalf("upload-server-count: no upload server running (call upload-server-start first)") + } + + mus.mu.Lock() + got := len(mus.uploads) + mus.mu.Unlock() + + if neg { + if got == expected { + ts.Fatalf("upload-server-count: expected NOT %d uploads but got %d", expected, got) + } + return + } + + if got != expected { + ts.Fatalf("upload-server-count: expected %d uploads but got %d", expected, got) + } +} + +// stopUploadServerByWorkDir shuts down the upload server for a work directory. +func (h *Harness) stopUploadServerByWorkDir(workDir string) { + h.uploadServersMu.Lock() + mus, exists := h.uploadServers[workDir] + if !exists { + h.uploadServersMu.Unlock() + return + } + delete(h.uploadServers, workDir) + h.uploadServersMu.Unlock() + + if mus.server != nil { + _ = mus.server.Close() + } +} + +// ============================================================================= +// Webhook receiver commands +// ============================================================================= + +// cmdWebhookServerStart starts a webhook receiver that accepts prediction callbacks. +// It parses the JSON payload to extract status and measure the output size, without +// ever exposing the (potentially huge) output to testscript's log buffer. +// Usage: webhook-server-start +// Exports $WEBHOOK_URL with the server's callback URL. +func (h *Harness) cmdWebhookServerStart(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("webhook-server-start: does not support negation") + } + + workDir := ts.Getenv("WORK") + + h.webhookServersMu.Lock() + if _, exists := h.webhookServers[workDir]; exists { + h.webhookServersMu.Unlock() + ts.Fatalf("webhook-server-start: server already running for this test") + } + h.webhookServersMu.Unlock() + + ws := &webhookServer{ + done: make(chan struct{}), + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // Stream-parse the JSON to extract status and measure output size + // without holding the entire output string in testscript memory. + var payload struct { + Status string `json:"status"` + Output string `json:"output"` + Error string `json:"error"` + } + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) + + // Only record terminal statuses + switch payload.Status { + case "succeeded", "failed", "canceled": + default: + return + } + + ws.mu.Lock() + defer ws.mu.Unlock() + + // Only record the first terminal callback + if ws.result != nil { + return + } + ws.result = &webhookResult{ + Status: payload.Status, + OutputSize: len(payload.Output), + HasError: payload.Error != "", + } + close(ws.done) + }) + + ws.server = &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec + + // Bind to all interfaces so the container can reach us via host.docker.internal + ln, err := net.Listen("tcp", "0.0.0.0:0") //nolint:gosec + if err != nil { + ts.Fatalf("webhook-server-start: failed to listen: %v", err) + } + ws.port = ln.Addr().(*net.TCPAddr).Port + + go func() { _ = ws.server.Serve(ln) }() + + h.webhookServersMu.Lock() + h.webhookServers[workDir] = ws + h.webhookServersMu.Unlock() + + url := fmt.Sprintf("http://host.docker.internal:%d/webhook", ws.port) + ts.Setenv("WEBHOOK_URL", url) + ts.Logf("webhook-server-start: listening on 0.0.0.0:%d, container URL: %s", ws.port, url) +} + +// cmdWebhookServerWait blocks until the webhook server receives a terminal prediction callback, +// then writes a compact JSON summary to stdout for assertion with stdout/stderr matchers. +// Usage: webhook-server-wait [timeout] +// Default timeout: 120s +func (h *Harness) cmdWebhookServerWait(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("webhook-server-wait: does not support negation") + } + + timeout := 120 * time.Second + if len(args) > 0 { + if d, err := time.ParseDuration(args[0]); err == nil { + timeout = d + } + } + + workDir := ts.Getenv("WORK") + h.webhookServersMu.Lock() + ws, exists := h.webhookServers[workDir] + h.webhookServersMu.Unlock() + + if !exists { + ts.Fatalf("webhook-server-wait: no webhook server running (call webhook-server-start first)") + } + + select { + case <-ws.done: + case <-time.After(timeout): + ts.Fatalf("webhook-server-wait: timed out after %s waiting for terminal webhook", timeout) + } + + ws.mu.Lock() + result := ws.result + ws.mu.Unlock() + + out, _ := json.Marshal(result) + _, _ = ts.Stdout().Write(out) + _, _ = ts.Stdout().Write([]byte("\n")) +} + +// stopWebhookServerByWorkDir shuts down the webhook server for a work directory. +func (h *Harness) stopWebhookServerByWorkDir(workDir string) { + h.webhookServersMu.Lock() + ws, exists := h.webhookServers[workDir] + if !exists { + h.webhookServersMu.Unlock() + return + } + delete(h.webhookServers, workDir) + h.webhookServersMu.Unlock() + + if ws.server != nil { + _ = ws.server.Close() + } +} diff --git a/integration-tests/tests/coglet_iterator_path_output.txtar b/integration-tests/tests/coglet_iterator_path_output.txtar new file mode 100644 index 0000000000..d49e6db6c4 --- /dev/null +++ b/integration-tests/tests/coglet_iterator_path_output.txtar @@ -0,0 +1,34 @@ +# Test iterator prediction with Path outputs (no upload URL — files written to disk) + +cog build -t $TEST_IMAGE +cog predict $TEST_IMAGE + +# cog predict writes file outputs to disk, not as base64 to stdout +stderr 'Written output to: output.0.png' +stderr 'Written output to: output.1.png' +stderr 'Written output to: output.2.png' +! stderr 'failed' + +-- cog.yaml -- +build: + python_version: "3.12" + python_packages: + - "pillow==10.4.0" +predict: "predict.py:Predictor" + +-- predict.py -- +import os +import tempfile +from typing import Iterator + +from cog import BasePredictor, Path +from PIL import Image + + +class Predictor(BasePredictor): + def predict(self) -> Iterator[Path]: + for color in ["red", "blue", "green"]: + d = tempfile.mkdtemp() + p = os.path.join(d, f"{color}.png") + Image.new("RGB", (10, 10), color).save(p) + yield Path(p) diff --git a/integration-tests/tests/coglet_iterator_upload_url.txtar b/integration-tests/tests/coglet_iterator_upload_url.txtar new file mode 100644 index 0000000000..eff521644a --- /dev/null +++ b/integration-tests/tests/coglet_iterator_upload_url.txtar @@ -0,0 +1,42 @@ +# Test that iterator Path outputs are uploaded per-yield to --upload-url. + +cog build -t $TEST_IMAGE + +# Start mock upload server on the host, sets $UPLOAD_SERVER_URL +upload-server-start + +cog serve --upload-url $UPLOAD_SERVER_URL + +# Run a prediction — three Path outputs should be uploaded, not base64-encoded +curl POST /predictions '{"input":{}}' +stdout '"status":"succeeded"' +# Outputs should be URLs pointing at the mock server, not data URIs +stdout '"output":\["http://host.docker.internal' +! stdout 'data:image/' + +# Verify the mock server received exactly 3 PUT uploads +upload-server-count 3 + +-- cog.yaml -- +build: + python_version: "3.12" + python_packages: + - "pillow==10.4.0" +predict: "predict.py:Predictor" + +-- predict.py -- +import os +import tempfile +from typing import Iterator + +from cog import BasePredictor, Path +from PIL import Image + + +class Predictor(BasePredictor): + def predict(self) -> Iterator[Path]: + for color in ["red", "blue", "green"]: + d = tempfile.mkdtemp() + p = os.path.join(d, f"{color}.png") + Image.new("RGB", (10, 10), color).save(p) + yield Path(p) diff --git a/integration-tests/tests/coglet_large_output.txtar b/integration-tests/tests/coglet_large_output.txtar new file mode 100644 index 0000000000..b22828665b --- /dev/null +++ b/integration-tests/tests/coglet_large_output.txtar @@ -0,0 +1,40 @@ +# Test that outputs larger than the 8MiB IPC frame limit spill to disk +# and are reconstructed correctly by the orchestrator. +# Without spilling this would panic the bridge and poison the slot. +# +# Uses async prediction + webhook so the 9MiB output goes directly to our +# Go webhook receiver — never through testscript's log buffer. +# +# --upload-url is set to a dummy value so cog serve adds +# --add-host=host.docker.internal:host-gateway (needed on Linux for the +# webhook callback to reach the host). Nothing is actually uploaded because +# the output is a plain string, not a Path. + +webhook-server-start +cog serve --upload-url http://unused/ + +# Async prediction — server returns 202 immediately, delivers result to webhook +curl -H Prefer:respond-async POST /predictions '{"id":"large-output-test","webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}' + +# Wait for the webhook callback (up to 120s) +webhook-server-wait + +# 1. Prediction succeeded +stdout '"status":"succeeded"' + +# 2. Output is correct — 9 * 1024 * 1024 = 9437184 bytes +stdout '"output_size":9437184' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + + +class Predictor(BasePredictor): + def predict(self) -> str: + # 9MiB string — exceeds the 8MiB IPC frame limit + return "x" * (9 * 1024 * 1024) diff --git a/mise.toml b/mise.toml index e6da17cb8d..fac2303356 100644 --- a/mise.toml +++ b/mise.toml @@ -55,9 +55,9 @@ ty = "0.0.10" [env] _.path = "./bin" _.file = [".env"] -_.python.venv = ".venv" # Set REPO_ROOT only if not already set (e.g., by CI) REPO_ROOT = "{{env.REPO_ROOT | default(value=config_root)}}" +_.python.venv = "{{env.REPO_ROOT}}/.venv" [settings] lockfile = true @@ -69,13 +69,19 @@ experimental = true [tasks._setup_dist] hide = true -quiet = true +silent = true description = "Create dist directory" run = "mkdir -p dist" +[tasks._setup_venv] +hide = true +silent = true +description = "Ensure root .venv exists with Python" +run = "test -d .venv || uv venv --quiet" + [tasks._clean_dist] hide = true -quiet = true +silent = true description = "Clean dist directory" run = "rm -f dist/cog-*.whl dist/cog-*.tar.gz dist/coglet-*.whl" @@ -109,6 +115,8 @@ echo "Installed $PREFIX/bin/cog -> $BINARY" [tasks."build:cog"] description = "Build cog CLI (development)" +sources = ["cmd/**/*.go", "pkg/**/*.go", "go.mod", "go.sum"] +outputs = ["dist/go/*/cog"] run = "GOFLAGS=-buildvcs=false go run github.com/goreleaser/goreleaser/v2@latest build --clean --snapshot --single-target --id cog --output cog" [tasks."build:cog:release"] @@ -125,32 +133,35 @@ run = "cargo build --manifest-path crates/Cargo.toml --workspace --release" [tasks."build:coglet"] description = "Build coglet Python wheel (development, local install)" -run = "cd crates/coglet-python && uv run maturin develop" +run = [ + { task = "_setup_venv" }, + "maturin develop --manifest-path crates/coglet-python/Cargo.toml", +] [tasks."build:coglet:wheel"] description = "Build coglet Python wheel (native platform)" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml", ] [tasks."build:coglet:wheel:linux-x64"] description = "Build coglet Python wheel for Linux x86_64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] -outputs = ["dist/coglet-*-linux_x86_64.whl"] +outputs = ["dist/coglet-*manylinux*x86_64*.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target x86_64-unknown-linux-gnu --zig", ] [tasks."build:coglet:wheel:linux-arm64"] description = "Build coglet Python wheel for Linux ARM64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] -outputs = ["dist/coglet-*-linux_aarch64.whl"] +outputs = ["dist/coglet-*manylinux*aarch64*.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target aarch64-unknown-linux-gnu --zig", ] @@ -159,7 +170,7 @@ description = "Build coglet Python wheel for macOS ARM64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*-macosx_*_arm64.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target aarch64-apple-darwin", ] @@ -168,7 +179,7 @@ description = "Build coglet Python wheel for macOS x86_64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*-macosx_*_x86_64.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target x86_64-apple-darwin", ] @@ -185,7 +196,7 @@ description = "Build cog SDK wheel" sources = ["python/**/*.py", "pyproject.toml"] outputs = ["dist/cog-*.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "uv build --wheel --out-dir=dist .", ] @@ -249,7 +260,7 @@ run = "nox -s coglet" [tasks."test:integration"] description = "Run integration tests (skips slow tests by default, set SHORT=0 for full suite)" -depends = ["build:cog", "build:sdk", "build:coglet:wheel"] +depends = ["build:cog", "build:sdk", "build:coglet:wheel:linux-x64"] run = """ #!/usr/bin/env bash set -e @@ -257,10 +268,13 @@ SHORT_FLAG="-short" if [ "${SHORT:-1}" = "0" ]; then SHORT_FLAG="" fi -if [ -n "$1" ]; then - gotestsum -- -tags integration -v $SHORT_FLAG -run "TestIntegration/$1" -timeout 30m ./integration-tests/... +# If first arg is a bare name (no dash), treat as test name filter; +# remaining args are passed through to go test. +# e.g. mise run test:integration coglet_large_output -count=4 +if [ $# -gt 0 ] && [[ "$1" != -* ]]; then + gotestsum -- -tags integration -v $SHORT_FLAG -run "TestIntegration/$1" "${@:2}" -timeout 30m ./integration-tests/... else - gotestsum -- -tags integration -v $SHORT_FLAG -parallel ${TEST_PARALLEL:-4} -timeout 30m ./integration-tests/... + gotestsum -- -tags integration -v $SHORT_FLAG -parallel ${TEST_PARALLEL:-4} "$@" -timeout 30m ./integration-tests/... fi """ @@ -405,15 +419,18 @@ run = [ alias = "stub:generate" description = "Generate Python type stubs for coglet" sources = ["crates/coglet-python/src/**/*.rs", "crates/coglet-python/Cargo.toml", "crates/coglet-python/coglet/__init__.py"] -outputs = ["crates/coglet-python/coglet/_sdk/__init__.pyi", "crates/coglet-python/coglet/__init__.pyi"] +outputs = ["coglet/_sdk/__init__.pyi", "coglet/__init__.pyi", "coglet/_impl.pyi"] dir = "crates/coglet-python" -run = ''' -# 1. pyo3-stub-gen: Rust types → _sdk/__init__.pyi (and initial __init__.pyi) -cargo run --bin stub_gen +run = [ + { task = "_setup_venv" }, + ''' +# 1. pyo3-stub-gen: Rust types → _impl.pyi and _sdk/__init__.pyi +uv run --active cargo run --bin stub_gen -# 2. mypy stubgen: Python __init__.py → __init__.pyi (overwrites pyo3 version) +# 2. mypy stubgen: Python __init__.py → __init__.pyi (re-export stubs) uvx --from mypy stubgen coglet/__init__.py -o . --quiet -''' +''', +] [tasks."generate:compat"] description = "Regenerate CUDA/PyTorch/TensorFlow compatibility matrices" @@ -460,15 +477,11 @@ echo "Done." [tasks."stub:check"] description = "Check that coglet Python stubs are up to date" dir = "crates/coglet-python" -run = ''' +run = [ + { task = "generate:stubs" }, + ''' #!/usr/bin/env bash set -e - -# Regenerate stubs in-place -cargo run --bin stub_gen 2>/dev/null -uvx --from mypy stubgen coglet/__init__.py -o . --quiet - -# Check if any .pyi files changed if ! git diff --quiet -- '**/*.pyi'; then echo "ERROR: Stubs are out of date:" git diff -- '**/*.pyi' @@ -477,7 +490,8 @@ if ! git diff --quiet -- '**/*.pyi'; then exit 1 fi echo "Stubs are up to date." -''' +''', +] [tasks."stub:typecheck"] description = "Type-check coglet stubs with ty" @@ -504,7 +518,7 @@ run = "cd crates && cargo clean" [tasks."clean:python"] description = "Clean Python build artifacts" -run = "rm -rf .tox build python/cog.egg-info" +run = "rm -rf .tox build python/cog.egg-info .venv crates/coglet-python/.venv crates/coglet-python/coglet/*.so" # ============================================================================= # Docs tasks diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 6bf88f596b..034a60bfe3 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -193,13 +193,3 @@ func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map OCIIndex: model.OCIIndexEnabled(), } } - -// buildBaseOptionsFromFlags creates BuildBaseOptions from the current CLI flag values. -func buildBaseOptionsFromFlags(cmd *cobra.Command) model.BuildBaseOptions { - return model.BuildBaseOptions{ - UseCudaBaseImage: buildUseCudaBaseImage, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - ProgressOutput: buildProgressOutput, - RequiresCog: true, - } -} diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index ad78645ef8..2de0951438 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -180,13 +180,13 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) if err != nil { return err } imageName = m.ImageRef() - // Base image doesn't have /src in it, so mount as volume + // ExcludeSource build doesn't have /src in it, so mount as volume volumes = append(volumes, command.Volume{ Source: src.ProjectDir, Destination: "/src", diff --git a/pkg/cli/run.go b/pkg/cli/run.go index dcfc354ee9..501d82f047 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -65,7 +65,9 @@ func run(cmd *cobra.Command, args []string) error { resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + opts := serveBuildOptions(cmd) + opts.SkipSchemaValidation = true + m, err := resolver.Build(ctx, src, opts) if err != nil { return err } diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index 88092ff03b..ccccfee4c5 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -14,7 +14,8 @@ import ( ) var ( - port = 8393 + port = 8393 + uploadURL = "" ) func newServeCommand() *cobra.Command { @@ -36,10 +37,24 @@ Generate and run an HTTP server based on the declared model inputs and outputs.` addConfigFlag(cmd) cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen") + cmd.Flags().StringVar(&uploadURL, "upload-url", "", "Upload URL for file outputs (e.g. https://example.com/upload/)") return cmd } +// serveBuildOptions creates BuildOptions for cog serve. +// Same build path as cog build, but with ExcludeSource so COPY . /src is +// skipped — source is volume-mounted at runtime instead. All other layers +// (wheels, apt, etc.) share Docker layer cache with cog build. +func serveBuildOptions(cmd *cobra.Command) model.BuildOptions { + return model.BuildOptions{ + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + ProgressOutput: buildProgressOutput, + ExcludeSource: true, + } +} + func cmdServe(cmd *cobra.Command, arg []string) error { ctx := cmd.Context() @@ -54,7 +69,7 @@ func cmdServe(cmd *cobra.Command, arg []string) error { } resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) if err != nil { return err } @@ -73,6 +88,10 @@ func cmdServe(cmd *cobra.Command, arg []string) error { "--await-explicit-shutdown", "true", } + if uploadURL != "" { + args = append(args, "--upload-url", uploadURL) + } + // Automatically propagate RUST_LOG for Rust coglet debugging env := envFlags if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { @@ -88,6 +107,13 @@ func cmdServe(cmd *cobra.Command, arg []string) error { Workdir: "/src", } + // On Linux, host.docker.internal is not available by default — add it. + // This allows the container to reach services running on the host, + // e.g. when --upload-url points to a local upload server. + if uploadURL != "" { + runOptions.ExtraHosts = []string{"host.docker.internal:host-gateway"} + } + runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000}) console.Info("") diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 368a872db3..122de8ae40 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -73,13 +73,13 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) if err != nil { return err } imageName = m.ImageRef() - // Base image doesn't have /src in it, so mount as volume + // ExcludeSource build doesn't have /src in it, so mount as volume volumes = append(volumes, command.Volume{ Source: src.ProjectDir, Destination: "/src", diff --git a/pkg/config/image_name.go b/pkg/config/image_name.go index fa98a07ce6..9dca07f095 100644 --- a/pkg/config/image_name.go +++ b/pkg/config/image_name.go @@ -30,8 +30,3 @@ func DockerImageName(projectDir string) string { return projectName } - -// BaseDockerImageName returns the Docker image name for base images -func BaseDockerImageName(projectDir string) string { - return DockerImageName(projectDir) + "-base" -} diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index 6f510824e7..4c991f6b6a 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -48,17 +48,18 @@ type ImageBuildOptions struct { } type RunOptions struct { - Detach bool - Args []string - Env []string - GPUs string - Image string - Ports []Port - Volumes []Volume - Workdir string - Stdin io.Reader - Stdout io.Writer - Stderr io.Writer + Detach bool + Args []string + Env []string + GPUs string + Image string + Ports []Port + Volumes []Volume + Workdir string + ExtraHosts []string + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer } type Port struct { diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index a9c7733ec5..0643fd9f24 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -447,6 +447,11 @@ func (c *apiClient) containerRun(ctx context.Context, options command.RunOptions } } + // Configure extra hosts (e.g. host.docker.internal on Linux) + if len(options.ExtraHosts) > 0 { + hostCfg.ExtraHosts = options.ExtraHosts + } + networkingCfg := &network.NetworkingConfig{ EndpointsConfig: map[string]*network.EndpointSettings{}, } diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index a35b13863a..65a177233c 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -467,7 +467,6 @@ func (g *StandardGenerator) installCog() (string, error) { if !g.requiresCog { return "", nil } - // Use override if set, otherwise auto-detect via env var / dist / PyPI var wheelConfig *wheels.WheelConfig var err error diff --git a/pkg/image/build.go b/pkg/image/build.go index e22c436f0c..c85d8df7e4 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -54,6 +54,8 @@ func Build( useCogBaseImage *bool, strip bool, precompile bool, + excludeSource bool, + skipSchemaValidation bool, annotations map[string]string, dockerCommand command.Command, client registry.Client) (string, error) { @@ -162,7 +164,15 @@ func Build( return "", fmt.Errorf("Failed to build runner Docker image: %w", err) } } else { - dockerfileContents, err := generator.GenerateDockerfileWithoutSeparateWeights(ctx) + var dockerfileContents string + if excludeSource { + // Dev mode (cog serve): same layers as cog build but without + // COPY . /src — source is volume-mounted at runtime instead. + // This shares Docker layer cache with full builds. + dockerfileContents, err = generator.GenerateModelBase(ctx) + } else { + dockerfileContents, err = generator.GenerateDockerfileWithoutSeparateWeights(ctx) + } if err != nil { return "", fmt.Errorf("Failed to generate Dockerfile: %w", err) } @@ -186,7 +196,10 @@ func Build( } var schemaJSON []byte - if schemaFile != "" { + switch { + case skipSchemaValidation: + console.Debug("Skipping model schema validation") + case schemaFile != "": console.Infof("Validating model schema from %s...", schemaFile) data, err := os.ReadFile(schemaFile) if err != nil { @@ -194,9 +207,15 @@ func Build( } schemaJSON = data - } else { + default: console.Info("Validating model schema...") - schema, err := GenerateOpenAPISchema(ctx, dockerCommand, tmpImageId, cfg.Build.GPU) + // When excludeSource is true (cog serve), /src was not COPYed into the + // image, so we need to volume-mount the project directory for schema generation. + schemaSourceDir := "" + if excludeSource { + schemaSourceDir = dir + } + schema, err := GenerateOpenAPISchema(ctx, dockerCommand, tmpImageId, cfg.Build.GPU, schemaSourceDir) if err != nil { return "", fmt.Errorf("Failed to get type signature: %w", err) } @@ -209,20 +228,22 @@ func Build( schemaJSON = data } - // save open_api schema file - if err := os.WriteFile(bundledSchemaFile, schemaJSON, 0o644); err != nil { - return "", fmt.Errorf("failed to store bundled schema file %s: %w", bundledSchemaFile, err) - } + if !skipSchemaValidation { + // save open_api schema file + if err := os.WriteFile(bundledSchemaFile, schemaJSON, 0o644); err != nil { + return "", fmt.Errorf("failed to store bundled schema file %s: %w", bundledSchemaFile, err) + } - loader := openapi3.NewLoader() - loader.IsExternalRefsAllowed = true - doc, err := loader.LoadFromData(schemaJSON) - if err != nil { - return "", fmt.Errorf("Failed to load model schema JSON: %w", err) - } - err = doc.Validate(loader.Context) - if err != nil { - return "", fmt.Errorf("Model schema is invalid: %w\n\n%s", err, string(schemaJSON)) + loader := openapi3.NewLoader() + loader.IsExternalRefsAllowed = true + doc, err := loader.LoadFromData(schemaJSON) + if err != nil { + return "", fmt.Errorf("Failed to load model schema JSON: %w", err) + } + err = doc.Validate(loader.Context) + if err != nil { + return "", fmt.Errorf("Model schema is invalid: %w\n\n%s", err, string(schemaJSON)) + } } console.Info("Adding labels to image...") @@ -299,8 +320,13 @@ func Build( maps.Copy(labels, annotations) - // The final image ID comes from the label-adding step - imageID, err := BuildAddLabelsAndSchemaToImage(ctx, dockerCommand, tmpImageId, imageName, labels, bundledSchemaFile, progressOutput) + // The final image ID comes from the label-adding step. + // When schema validation is skipped (cog run), there is no schema file to bundle. + schemaFileToBundle := bundledSchemaFile + if skipSchemaValidation { + schemaFileToBundle = "" + } + imageID, err := BuildAddLabelsAndSchemaToImage(ctx, dockerCommand, tmpImageId, imageName, labels, schemaFileToBundle, progressOutput) if err != nil { return "", fmt.Errorf("Failed to add labels to image: %w", err) } @@ -321,7 +347,12 @@ func Build( // The new image is based on the provided image with the labels and schema file appended to it. // tmpName is the source image to build from, image is the final image name/tag. func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Command, tmpName, image string, labels map[string]string, bundledSchemaFile string, progressOutput string) (string, error) { - dockerfile := fmt.Sprintf("FROM %s\nCOPY %s .cog\n", tmpName, bundledSchemaFile) + var dockerfile string + if bundledSchemaFile != "" { + dockerfile = fmt.Sprintf("FROM %s\nCOPY %s .cog\n", tmpName, bundledSchemaFile) + } else { + dockerfile = fmt.Sprintf("FROM %s\n", tmpName) + } buildOpts := command.ImageBuildOptions{ DockerfileContents: dockerfile, @@ -337,56 +368,6 @@ func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Co return imageID, nil } -func BuildBase(ctx context.Context, dockerClient command.Command, cfg *config.Config, dir string, configFilename string, useCudaBaseImage string, useCogBaseImage *bool, progressOutput string, client registry.Client, requiresCog bool) (string, error) { - // TODO: better image management so we don't eat up disk space - // https://github.com/replicate/cog/issues/80 - imageName := config.BaseDockerImageName(dir) - - console.Info("Building Docker image from environment in cog.yaml...") - generator, err := dockerfile.NewGenerator(cfg, dir, configFilename, dockerClient, client, requiresCog) - if err != nil { - return "", fmt.Errorf("Error creating Dockerfile generator: %w", err) - } - contextDir, err := generator.BuildDir() - if err != nil { - return "", err - } - buildContexts, err := generator.BuildContexts() - if err != nil { - return "", err - } - defer func() { - if err := generator.Cleanup(); err != nil { - console.Warnf("Error cleaning up Dockerfile generator: %s", err) - } - }() - - generator.SetUseCudaBaseImage(useCudaBaseImage) - if useCogBaseImage != nil { - generator.SetUseCogBaseImage(*useCogBaseImage) - } - - dockerfileContents, err := generator.GenerateModelBase(ctx) - if err != nil { - return "", fmt.Errorf("Failed to generate Dockerfile: %w", err) - } - - buildOpts := command.ImageBuildOptions{ - WorkingDir: dir, - DockerfileContents: dockerfileContents, - ImageName: imageName, - NoCache: false, - ProgressOutput: progressOutput, - Epoch: &config.BuildSourceEpochTimestamp, - ContextDir: contextDir, - BuildContexts: buildContexts, - } - if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil { - return "", fmt.Errorf("Failed to build Docker image: %w", err) - } - return imageName, nil -} - func isGitWorkTree(ctx context.Context, dir string) bool { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/pkg/image/openapi_schema.go b/pkg/image/openapi_schema.go index 3b176e5de4..cca3d566ee 100644 --- a/pkg/image/openapi_schema.go +++ b/pkg/image/openapi_schema.go @@ -12,7 +12,8 @@ import ( // GenerateOpenAPISchema by running the image and executing Cog // This will be run as part of the build process then added as a label to the image. It can be retrieved more efficiently with the label by using GetOpenAPISchema -func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, imageName string, enableGPU bool) (map[string]any, error) { +// sourceDir, when non-empty, is volume-mounted as /src (needed for ExcludeSource builds where COPY . /src was skipped). +func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, imageName string, enableGPU bool, sourceDir string) (map[string]any, error) { console.Debugf("=== image.GenerateOpenAPISchema %s", imageName) var stdout bytes.Buffer var stderr bytes.Buffer @@ -23,19 +24,24 @@ func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, im gpus = "all" } - err := docker.RunWithIO(ctx, dockerClient, command.RunOptions{ + runOpts := command.RunOptions{ Image: imageName, Args: []string{ "python", "-m", "cog.command.openapi_schema", }, GPUs: gpus, - }, nil, &stdout, &stderr) + } + if sourceDir != "" { + runOpts.Volumes = []command.Volume{{Source: sourceDir, Destination: "/src"}} + } + + err := docker.RunWithIO(ctx, dockerClient, runOpts, nil, &stdout, &stderr) if enableGPU && err == docker.ErrMissingDeviceDriver { console.Debug(stdout.String()) console.Debug(stderr.String()) console.Debug("Missing device driver, re-trying without GPU") - return GenerateOpenAPISchema(ctx, dockerClient, imageName, false) + return GenerateOpenAPISchema(ctx, dockerClient, imageName, false, sourceDir) } if err != nil { diff --git a/pkg/model/factory.go b/pkg/model/factory.go index 70095a51e8..1818f496f2 100644 --- a/pkg/model/factory.go +++ b/pkg/model/factory.go @@ -12,12 +12,10 @@ import ( // Different implementations handle different build strategies. type Factory interface { // Build creates a Docker image from source and returns ImageArtifact metadata. + // For dev mode (cog serve), set ExcludeSource=true in BuildOptions to skip + // COPY . /src — the source directory is volume-mounted at runtime instead. Build(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) - // BuildBase creates a base image for dev mode (without /src copied). - // The source directory is expected to be mounted as a volume at runtime. - BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) - // Name returns the factory name for logging/debugging. Name() string } @@ -56,6 +54,8 @@ func (f *DockerfileFactory) Build(ctx context.Context, src *Source, opts BuildOp opts.UseCogBaseImage, opts.Strip, opts.Precompile, + opts.ExcludeSource, + opts.SkipSchemaValidation, opts.Annotations, f.docker, f.registry, @@ -71,30 +71,6 @@ func (f *DockerfileFactory) Build(ctx context.Context, src *Source, opts BuildOp }, nil } -// BuildBase delegates to the existing image.BuildBase() function. -func (f *DockerfileFactory) BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) { - imageName, err := image.BuildBase( - ctx, - f.docker, - src.Config, - src.ProjectDir, - src.ConfigFilename, - opts.UseCudaBaseImage, - opts.UseCogBaseImage, - opts.ProgressOutput, - f.registry, - opts.RequiresCog, - ) - if err != nil { - return nil, err - } - - return &ImageArtifact{ - Reference: imageName, - Source: ImageSourceBuild, - }, nil -} - // DefaultFactory returns a Factory based on environment variables. // It checks COG_BUILDER and COGPACK to select the appropriate backend. // diff --git a/pkg/model/options.go b/pkg/model/options.go index c5ae55f4e8..8b03b15041 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -49,6 +49,18 @@ type BuildOptions struct { // artifacts and pushes create an OCI Image Index. Set via COG_OCI_INDEX=1. // Remove this field once index pushes are validated with all registries. OCIIndex bool + + // ExcludeSource skips the COPY . /src step in the generated Dockerfile. + // Used by `cog serve` to produce an image identical to `cog build` minus + // the source copy — the source directory is volume-mounted at runtime. + // All other layers (wheel installs, apt, etc.) are shared with `cog build` + // via Docker layer caching. + ExcludeSource bool + + // SkipSchemaValidation skips OpenAPI schema generation and validation. + // Used by `cog run` which executes arbitrary commands and may not have + // a predictor or trainer defined in cog.yaml. + SkipSchemaValidation bool } // WithDefaults returns a copy of BuildOptions with defaults applied from Source. @@ -66,27 +78,3 @@ func (o BuildOptions) WithDefaults(src *Source) BuildOptions { return o } - -// BuildBaseOptions contains settings for building a base image (dev mode). -// Base images don't copy /src - the source is mounted as a volume at runtime. -type BuildBaseOptions struct { - // UseCudaBaseImage controls CUDA base image usage: "auto", "true", or "false". - UseCudaBaseImage string - - // UseCogBaseImage controls cog base image usage. nil means auto-detect. - UseCogBaseImage *bool - - // ProgressOutput controls build output format: "auto", "plain", or "tty". - ProgressOutput string - - // RequiresCog indicates whether the build requires cog to be installed. - RequiresCog bool -} - -// WithDefaults returns a copy of BuildBaseOptions with defaults applied. -func (o BuildBaseOptions) WithDefaults() BuildBaseOptions { - if o.ProgressOutput == "" { - o.ProgressOutput = "auto" - } - return o -} diff --git a/pkg/model/resolver.go b/pkg/model/resolver.go index 1afec5c5c1..a90961a2a4 100644 --- a/pkg/model/resolver.go +++ b/pkg/model/resolver.go @@ -275,38 +275,6 @@ func (r *Resolver) Push(ctx context.Context, m *Model, opts PushOptions) error { return pusher.Push(ctx, m, opts) } -// BuildBase creates a base image for dev mode (without /src copied). -// The source directory is expected to be mounted as a volume at runtime. -// Returns a Model with the built image info and the source config. -// -// NOTE: Unlike Build(), this does not use ImageBuilder because base images -// don't have labels yet (they're added during full builds). The returned -// ImageArtifact has no labels, no descriptor, and no inspect results. -func (r *Resolver) BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*Model, error) { - if src == nil { - return nil, fmt.Errorf("source is required for BuildBase") - } - if src.Config == nil { - return nil, fmt.Errorf("source.Config is required for BuildBase") - } - if src.ProjectDir == "" { - return nil, fmt.Errorf("source.ProjectDir is required for BuildBase") - } - opts = opts.WithDefaults() - - img, err := r.factory.BuildBase(ctx, src, opts) - if err != nil { - return nil, err - } - - // For base builds, we don't have labels yet (they're added in full builds). - // Return the model with the source config and the built image. - return &Model{ - Image: img, - Config: src.Config, - }, nil -} - // loadLocal loads a Model from the local docker daemon. func (r *Resolver) loadLocal(ctx context.Context, ref *ParsedRef) (*Model, error) { resp, err := r.docker.Inspect(ctx, ref.String()) diff --git a/pkg/model/resolver_test.go b/pkg/model/resolver_test.go index 0819a41696..1f9f2e8509 100644 --- a/pkg/model/resolver_test.go +++ b/pkg/model/resolver_test.go @@ -143,9 +143,8 @@ func (m *mockRegistry) WriteLayer(ctx context.Context, opts registry.WriteLayerO // mockFactory implements Factory for testing. type mockFactory struct { - name string - buildFunc func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) - buildBaseFunc func(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) + name string + buildFunc func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) } func (f *mockFactory) Build(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { @@ -155,13 +154,6 @@ func (f *mockFactory) Build(ctx context.Context, src *Source, opts BuildOptions) return &ImageArtifact{Reference: opts.ImageName, Source: ImageSourceBuild}, nil } -func (f *mockFactory) BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) { - if f.buildBaseFunc != nil { - return f.buildBaseFunc(ctx, src, opts) - } - return &ImageArtifact{Reference: "cog-base", Source: ImageSourceBuild}, nil -} - func (f *mockFactory) Name() string { return f.name } diff --git a/python/cog/server/http.py b/python/cog/server/http.py index e47d5f5a90..972d5714e2 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -84,5 +84,6 @@ port=port, await_explicit_shutdown=args.await_explicit_shutdown, is_train=is_train, + upload_url=args.upload_url, ) sys.exit(0)