From eb09dfeca69c62c6728bc070693c6783500b9053 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 08:43:44 -0800 Subject: [PATCH 01/21] fix(coglet): ensure logs are under 4MiB across bridge This commit truncates log lines at 4MiB of data and adds a truncation notice to the log line. This protects against a panic/slot poisoning that can happen if we exceed the codec configured (8MiB) size. Any log line that exceeds 1MiB boarders on useless. To ensure that even the most insane log lines are kept without disruption, we have implemented 4MiB limit. --- crates/coglet/src/bridge/codec.rs | 4 +++ crates/coglet/src/bridge/protocol.rs | 40 +++++++++++++++++++++++ crates/coglet/src/worker.rs | 4 ++- crates/coglet/src/worker_tracing_layer.rs | 4 +-- 4 files changed, 49 insertions(+), 3 deletions(-) diff --git a/crates/coglet/src/bridge/codec.rs b/crates/coglet/src/bridge/codec.rs index 76fc89ad4d..d6d34604d4 100644 --- a/crates/coglet/src/bridge/codec.rs +++ b/crates/coglet/src/bridge/codec.rs @@ -64,6 +64,10 @@ impl Encoder for JsonCodec { tracing::trace!(json_size_bytes = json_len, "Encoding frame"); if json_len > 100_000 { tracing::info!( + // This log line should be shipped across the IPC to be emitted, unlike the + // above trace line. This is a real indicator that we've encoded a large + // frame and is generally useful. + target: "coglet::bridge::codec::large_frame", json_size_bytes = json_len, json_size_kb = json_len / 1024, "Large frame being encoded" diff --git a/crates/coglet/src/bridge/protocol.rs b/crates/coglet/src/bridge/protocol.rs index dfab4b2de4..b385ea78f7 100644 --- a/crates/coglet/src/bridge/protocol.rs +++ b/crates/coglet/src/bridge/protocol.rs @@ -42,6 +42,21 @@ impl std::fmt::Display for SlotId { } } +const MAX_WORKER_LOG_SIZE: usize = 1024 * 1024 * 4; // 4MIB +const WORKER_LOG_TRUNCATE_NOTICE: &str = "[**** LOG LINE TRUNCATED AT 4 MiB ****]"; + +/// To ensure no panics happen due to oversized log lines, we truncate at 4 MiB. 1 MiB +/// let alone 4 MiB log line boarder/exceed usefulness from a readability standpoint. +pub fn truncate_worker_log(mut log_message: String) -> String { + if log_message.len() > MAX_WORKER_LOG_SIZE { + let boundary = + log_message.floor_char_boundary(MAX_WORKER_LOG_SIZE - WORKER_LOG_TRUNCATE_NOTICE.len()); + log_message.truncate(boundary); + log_message.push_str(WORKER_LOG_TRUNCATE_NOTICE); + } + log_message +} + /// Control messages from parent to worker. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -392,4 +407,29 @@ mod tests { }; insta::assert_json_snapshot!(resp); } + + #[test] + fn truncate_worker_log_truncates_long_messages() { + let emoji = "🦀"; // 4-byte UTF-8 character + // known size of truncate target, add one more character + let count = 1024 * 1024 * 1024 * 4 / emoji.len() + 1; + let message: String = truncate_worker_log(emoji.repeat(count)); + assert!( + message.ends_with(WORKER_LOG_TRUNCATE_NOTICE), + "log message didn't end with {}", + WORKER_LOG_TRUNCATE_NOTICE + ); + } + + #[test] + fn truncate_worker_log_does_not_truncate_short_messages() { + let emoji = "🦀"; // 4-byte UTF-8 character + // known size of truncate target, add one more character + let count = 10; + let message: String = truncate_worker_log(std::iter::repeat(emoji).take(count).collect()); + assert!( + !message.ends_with(WORKER_LOG_TRUNCATE_NOTICE), + "short log message was truncated" + ); + } } diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index bcaaec5d15..c6b9c10295 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -19,6 +19,8 @@ use futures::{SinkExt, StreamExt}; use tokio::sync::mpsc; use tokio_util::codec::{FramedRead, FramedWrite}; +use crate::bridge::protocol::truncate_worker_log; + // ============================================================================ // Dropped log tracking // ============================================================================ @@ -158,7 +160,7 @@ impl SlotSender { let msg = SlotResponse::Log { source, - data: data.to_string(), + data: truncate_worker_log(data.to_string()), }; self.tx diff --git a/crates/coglet/src/worker_tracing_layer.rs b/crates/coglet/src/worker_tracing_layer.rs index 4eacdb361c..dafe54d5d2 100644 --- a/crates/coglet/src/worker_tracing_layer.rs +++ b/crates/coglet/src/worker_tracing_layer.rs @@ -10,7 +10,7 @@ use tokio::sync::mpsc; use tracing::{Level, Subscriber}; use tracing_subscriber::layer::{Context, Layer}; -use crate::bridge::protocol::ControlResponse; +use crate::bridge::protocol::{ControlResponse, truncate_worker_log}; pub struct WorkerTracingLayer { tx: mpsc::Sender, @@ -52,7 +52,7 @@ where let mut visitor = MessageVisitor::default(); event.record(&mut visitor); - let message = visitor.message; + let message = truncate_worker_log(visitor.message); // Targets excluded from IPC: // - coglet::bridge::codec: feedback loop when encoding WorkerLog messages From e83ba91f827bdc53330ce062efc7ea2910cd4194 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 10:10:13 -0800 Subject: [PATCH 02/21] feat(coglet): wire file-based output spilling for large IPC frames Outputs > 6MiB are spilled to disk and sent as FileOutput path references over IPC instead of inline, avoiding the 8MiB LengthDelimitedCodec frame limit. Coglet creates per-prediction output dirs at /tmp/coglet/outputs/{prediction_id}/ and passes the path to the worker via SlotRequest::Predict. --- crates/coglet-python/src/lib.rs | 6 +- crates/coglet/src/bridge/codec.rs | 4 ++ crates/coglet/src/bridge/protocol.rs | 20 ++++++ ...tocol__tests__slot_predict_serializes.snap | 3 +- crates/coglet/src/orchestrator.rs | 4 ++ crates/coglet/src/service.rs | 9 +++ crates/coglet/src/worker.rs | 64 +++++++++++++++++-- 7 files changed, 102 insertions(+), 8 deletions(-) diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index 95b1693504..5487ad7f09 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -188,7 +188,8 @@ impl CogletServer { } /// Start the HTTP prediction server. Blocks until shutdown. - #[pyo3(signature = (predictor_ref=None, host="0.0.0.0".to_string(), port=5000, await_explicit_shutdown=false, is_train=false))] + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (predictor_ref=None, host="0.0.0.0".to_string(), port=5000, await_explicit_shutdown=false, is_train=false, output_temp_dir_base="/tmp/coglet/output".to_string()))] fn serve( &self, py: Python<'_>, @@ -197,6 +198,7 @@ impl CogletServer { port: u16, await_explicit_shutdown: bool, is_train: bool, + output_temp_dir_base: String, ) -> PyResult<()> { serve_impl( py, @@ -205,6 +207,7 @@ impl CogletServer { port, await_explicit_shutdown, is_train, + output_temp_dir_base, ) } @@ -261,6 +264,7 @@ fn serve_impl( port: u16, await_explicit_shutdown: bool, is_train: bool, + _output_temp_dir_base: String, ) -> PyResult<()> { let (setup_log_tx, setup_log_rx) = tokio::sync::mpsc::unbounded_channel(); init_tracing(false, Some(setup_log_tx)); diff --git a/crates/coglet/src/bridge/codec.rs b/crates/coglet/src/bridge/codec.rs index d6d34604d4..23e25fcc37 100644 --- a/crates/coglet/src/bridge/codec.rs +++ b/crates/coglet/src/bridge/codec.rs @@ -121,6 +121,7 @@ mod tests { let req = SlotRequest::Predict { id: "test".to_string(), input: serde_json::json!({"x": 1}), + output_dir: "/tmp/coglet/outputs/test".to_string(), }; codec.encode(req.clone(), &mut buf).unwrap(); @@ -131,14 +132,17 @@ mod tests { SlotRequest::Predict { id: id1, input: input1, + output_dir: dir1, }, SlotRequest::Predict { id: id2, input: input2, + output_dir: dir2, }, ) => { assert_eq!(id1, id2); assert_eq!(input1, input2); + assert_eq!(dir1, dir2); } } } diff --git a/crates/coglet/src/bridge/protocol.rs b/crates/coglet/src/bridge/protocol.rs index b385ea78f7..02761b1dc4 100644 --- a/crates/coglet/src/bridge/protocol.rs +++ b/crates/coglet/src/bridge/protocol.rs @@ -201,9 +201,21 @@ pub enum SlotRequest { Predict { id: String, input: serde_json::Value, + /// Directory for writing file outputs (created by coglet before dispatch). + /// Not included in API responses — internal transport detail. + output_dir: String, }, } +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum FileOutputKind { + /// Output is a file-like return type (e.g. File, Path) + FileType, + /// Output exceeds size threshold for bridge codec serialization but is not a file-like return type + Oversized, +} + /// Messages from worker to parent on slot socket. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] @@ -213,6 +225,13 @@ pub enum SlotResponse { data: String, }, + /// Output for a file/path-like output return type or an output that exceeds the size threshold + /// for bridge codec serialization. + FileOutput { + filename: String, + kind: FileOutputKind, + }, + /// Streaming output chunk (for generators). Output { output: serde_json::Value, @@ -360,6 +379,7 @@ mod tests { let req = SlotRequest::Predict { id: "pred_123".to_string(), input: json!({"text": "hello"}), + output_dir: "/tmp/coglet/outputs/pred_123".to_string(), }; insta::assert_json_snapshot!(req); } diff --git a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap b/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap index 16594e9d37..d22ff9cb96 100644 --- a/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap +++ b/crates/coglet/src/bridge/snapshots/coglet__bridge__protocol__tests__slot_predict_serializes.snap @@ -7,5 +7,6 @@ expression: req "id": "pred_123", "input": { "text": "hello" - } + }, + "output_dir": "/tmp/coglet/outputs/pred_123" } diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 421c9f8b1b..621301b1e0 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -776,6 +776,10 @@ async fn run_event_loop( predictions.remove(&slot_id); } } + Ok(SlotResponse::FileOutput { filename, kind }) => { + // TODO: read file from disk, integrate into prediction output + tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received (not yet wired)"); + } Ok(SlotResponse::Done { id, output, predict_time }) => { tracing::info!( target: "coglet::prediction", diff --git a/crates/coglet/src/service.rs b/crates/coglet/src/service.rs index 17e4722dee..5a852de703 100644 --- a/crates/coglet/src/service.rs +++ b/crates/coglet/src/service.rs @@ -297,9 +297,18 @@ impl PredictionService { .register_prediction(slot_id, Arc::clone(&prediction_arc), idle_tx) .await; + // Create per-prediction output dir for file-based outputs + let output_dir = std::path::PathBuf::from("/tmp/coglet/outputs").join(&prediction_id); + std::fs::create_dir_all(&output_dir) + .map_err(|e| PredictionError::Failed(format!("Failed to create output dir: {}", e)))?; + let request = SlotRequest::Predict { id: prediction_id.clone(), input, + output_dir: output_dir + .to_str() + .expect("output dir path is valid UTF-8") + .to_string(), }; // permit_mut returns None if permit isn't InUse (shouldn't happen here) diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index c6b9c10295..77adca9bc0 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -11,6 +11,7 @@ use std::collections::HashMap; use std::io; +use std::path::PathBuf; use std::sync::Arc; use std::sync::OnceLock; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -130,7 +131,8 @@ fn init_worker_tracing(tx: mpsc::Sender) { use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{ - ControlRequest, ControlResponse, LogSource, SlotId, SlotOutcome, SlotRequest, SlotResponse, + ControlRequest, ControlResponse, FileOutputKind, LogSource, SlotId, SlotOutcome, SlotRequest, + SlotResponse, }; use crate::bridge::transport::{ChildTransportInfo, connect_transport}; use crate::orchestrator::HealthcheckResult; @@ -146,11 +148,23 @@ type SlotWriter = #[derive(Clone)] pub struct SlotSender { tx: mpsc::UnboundedSender, + output_dir: PathBuf, + file_counter: Arc, } impl SlotSender { - pub fn new(tx: mpsc::UnboundedSender) -> Self { - Self { tx } + pub fn new(tx: mpsc::UnboundedSender, output_dir: PathBuf) -> Self { + Self { + tx, + output_dir, + file_counter: Arc::new(AtomicUsize::new(0)), + } + } + + /// Generate a unique filename in the output dir. + fn next_output_path(&self, extension: &str) -> PathBuf { + let n = self.file_counter.fetch_add(1, Ordering::Relaxed); + self.output_dir.join(format!("{n}.{extension}")) } pub fn send_log(&self, source: LogSource, data: &str) -> io::Result<()> { @@ -168,8 +182,44 @@ impl SlotSender { .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } + /// Send a file-typed output (e.g. Path, File return types). + /// + /// The file is already on disk at `path` — we just send the path reference. + pub fn send_file_output(&self, path: PathBuf) -> io::Result<()> { + let filename = path + .to_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? + .to_string(); + let msg = SlotResponse::FileOutput { + filename, + kind: FileOutputKind::FileType, + }; + self.tx + .send(msg) + .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) + } + + const MAX_INLINE_OUTPUT_SIZE: usize = 1024 * 1024 * 6; // 6MiB + + /// Send prediction output, either inline or spilled to disk if too large. pub fn send_output(&self, output: serde_json::Value) -> io::Result<()> { - let msg = SlotResponse::Output { output }; + let serialized = serde_json::to_vec(&output) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let msg = if serialized.len() > Self::MAX_INLINE_OUTPUT_SIZE { + let path = self.next_output_path("json"); + std::fs::write(&path, &serialized)?; + let filename = path + .to_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? + .to_string(); + SlotResponse::FileOutput { + filename, + kind: FileOutputKind::Oversized, + } + } else { + SlotResponse::Output { output } + }; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) @@ -587,7 +637,7 @@ pub async fn run_worker( } match request { - SlotRequest::Predict { id, input } => { + SlotRequest::Predict { id, input, output_dir } => { tracing::trace!(%slot_id, %id, "Prediction request received"); slot_busy.insert(slot_id, true); @@ -606,6 +656,7 @@ pub async fn run_worker( slot_id, id, input, + PathBuf::from(output_dir), handler, writer, ).await; @@ -649,6 +700,7 @@ async fn run_prediction( slot_id: SlotId, prediction_id: String, input: serde_json::Value, + output_dir: PathBuf, handler: Arc, writer: SlotWriter, ) -> SlotCompletion { @@ -656,7 +708,7 @@ async fn run_prediction( // Create channel for log streaming let (log_tx, mut log_rx) = mpsc::unbounded_channel::(); - let slot_sender = Arc::new(SlotSender::new(log_tx)); + let slot_sender = Arc::new(SlotSender::new(log_tx, output_dir)); // Forward logs to slot socket let writer_for_logs = Arc::clone(&writer); From 9e0843206ddf992ebeb8290f2576fd25373bbdd4 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 10:17:07 -0800 Subject: [PATCH 03/21] feat(coglet): wire orchestrator FileOutput handling from disk When the worker spills output to disk (FileOutput), the orchestrator reads it back and integrates it into the prediction via append_output. Handles both Oversized (JSON spill) and FileType (path reference) variants. --- crates/coglet/src/orchestrator.rs | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 621301b1e0..38d4c805a4 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -777,8 +777,33 @@ async fn run_event_loop( } } Ok(SlotResponse::FileOutput { filename, kind }) => { - // TODO: read file from disk, integrate into prediction output - tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received (not yet wired)"); + tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); + let output = match std::fs::read(&filename) { + Ok(bytes) => match serde_json::from_slice(&bytes) { + Ok(val) => val, + Err(e) => { + tracing::error!(%slot_id, %filename, error = %e, "Failed to parse FileOutput JSON"); + continue; + } + }, + Err(e) => { + tracing::error!(%slot_id, %filename, error = %e, "Failed to read FileOutput"); + continue; + } + }; + let poisoned = if let Some(pred) = predictions.get(&slot_id) { + if let Some(mut p) = try_lock_prediction(pred) { + p.append_output(output); + false + } else { + true + } + } else { + false + }; + if poisoned { + predictions.remove(&slot_id); + } } Ok(SlotResponse::Done { id, output, predict_time }) => { tracing::info!( From 94d2dd9895b23e86030c8fbdb1d50d89bdcc622b Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 10:29:15 -0800 Subject: [PATCH 04/21] feat(coglet): stream generator yields over IPC as they happen Generator outputs were collected into a Vec and bundled into a single Done frame, which could exceed the 8MiB IPC limit. Now each yield is sent immediately via slot_sender.send_output(), streaming through SlotResponse::Output frames. The Done frame carries an empty array for generators since outputs were already streamed. Plumbs slot_sender from PythonPredictHandler::predict through to process_generator_output and process_async_result for both predict and train paths. --- crates/coglet-python/src/log_writer.rs | 8 +++--- crates/coglet-python/src/predictor.rs | 32 ++++++++++++++--------- crates/coglet-python/src/worker_bridge.rs | 20 +++++++++++--- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/crates/coglet-python/src/log_writer.rs b/crates/coglet-python/src/log_writer.rs index 981e35a4f4..22e69398ac 100644 --- a/crates/coglet-python/src/log_writer.rs +++ b/crates/coglet-python/src/log_writer.rs @@ -595,7 +595,7 @@ mod tests { fn registry_operations() { let prediction_id = "pred_123".to_string(); let (tx, _rx) = mpsc::unbounded_channel(); - let sender = Arc::new(SlotSender::new(tx)); + let sender = Arc::new(SlotSender::new(tx, std::env::temp_dir())); // Register register_prediction(prediction_id.clone(), sender.clone()); @@ -609,7 +609,7 @@ mod tests { #[test] fn slot_sender_sends_log() { let (tx, mut rx) = mpsc::unbounded_channel(); - let sender = SlotSender::new(tx); + let sender = SlotSender::new(tx, std::env::temp_dir()); sender.send_log(LogSource::Stdout, "hello").unwrap(); @@ -626,7 +626,7 @@ mod tests { #[test] fn slot_sender_ignores_empty() { let (tx, mut rx) = mpsc::unbounded_channel(); - let sender = SlotSender::new(tx); + let sender = SlotSender::new(tx, std::env::temp_dir()); sender.send_log(LogSource::Stderr, "").unwrap(); @@ -639,7 +639,7 @@ mod tests { let (tx, rx) = mpsc::unbounded_channel::(); drop(rx); // Close receiver - let sender = SlotSender::new(tx); + let sender = SlotSender::new(tx, std::env::temp_dir()); let result = sender.send_log(LogSource::Stdout, "hello"); assert!(result.is_err()); diff --git a/crates/coglet-python/src/predictor.rs b/crates/coglet-python/src/predictor.rs index 68dd512d26..eead44742e 100644 --- a/crates/coglet-python/src/predictor.rs +++ b/crates/coglet-python/src/predictor.rs @@ -1,10 +1,11 @@ //! Python predictor loading and invocation. -use std::sync::OnceLock; +use std::sync::{Arc, OnceLock}; use pyo3::prelude::*; use pyo3::types::PyDict; +use coglet_core::worker::SlotSender; use coglet_core::{PredictionError, PredictionOutput, PredictionResult}; use crate::cancel; @@ -480,6 +481,7 @@ impl PythonPredictor { pub fn predict_worker( &self, input: serde_json::Value, + slot_sender: Arc, ) -> Result { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { @@ -530,7 +532,7 @@ impl PythonPredictor { let is_generator: bool = result_bound.is_instance(&generator_type).unwrap_or(false); let output = if is_generator { - self.process_generator_output(py, result_bound, &json_module)? + self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { self.process_single_output(py, result_bound, &json_module)? }; @@ -550,6 +552,7 @@ impl PythonPredictor { pub fn train_worker( &self, input: serde_json::Value, + slot_sender: Arc, ) -> Result { Python::attach(|py| { let json_module = py.import("json").map_err(|e| { @@ -600,7 +603,7 @@ impl PythonPredictor { let is_generator: bool = result_bound.is_instance(&generator_type).unwrap_or(false); let output = if is_generator { - self.process_generator_output(py, result_bound, &json_module)? + self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { self.process_single_output(py, result_bound, &json_module)? }; @@ -615,14 +618,14 @@ impl PythonPredictor { }) } - /// Process generator output into PredictionOutput::Stream. + /// Process generator output by streaming each yield over IPC. fn process_generator_output( &self, py: Python<'_>, result: &Bound<'_, PyAny>, json_module: &Bound<'_, PyAny>, + slot_sender: &SlotSender, ) -> Result { - let mut outputs = Vec::new(); let iter = result .try_iter() .map_err(|e| PredictionError::Failed(format!("Failed to iterate generator: {}", e)))?; @@ -653,10 +656,13 @@ impl PythonPredictor { PredictionError::Failed(format!("Failed to parse output JSON: {}", e)) })?; - outputs.push(item_json); + slot_sender + .send_output(item_json) + .map_err(|e| PredictionError::Failed(format!("Failed to send output: {}", e)))?; } - Ok(PredictionOutput::Stream(outputs)) + // Outputs already streamed over IPC — return empty stream + Ok(PredictionOutput::Stream(vec![])) } /// Process single output into PredictionOutput::Single. @@ -785,6 +791,7 @@ impl PythonPredictor { py: Python<'_>, result: &Bound<'_, PyAny>, is_async_gen: bool, + slot_sender: &SlotSender, ) -> Result { let json_module = py .import("json") @@ -795,8 +802,7 @@ impl PythonPredictor { // Process output let output = if is_async_gen { - // Result is a list - let mut outputs = Vec::new(); + // Result is a pre-collected list — stream each item over IPC if let Ok(list) = result.extract::>>() { for item in list { let processed = output::process_output_item(py, &item).map_err(|e| { @@ -813,10 +819,12 @@ impl PythonPredictor { })?; let item_json: serde_json::Value = serde_json::from_str(&item_str) .map_err(|e| PredictionError::Failed(format!("Failed to parse: {}", e)))?; - outputs.push(item_json); + slot_sender.send_output(item_json).map_err(|e| { + PredictionError::Failed(format!("Failed to send output: {}", e)) + })?; } } - PredictionOutput::Stream(outputs) + PredictionOutput::Stream(vec![]) } else { // Check if result is a generator (sync generator from async predict) let generator_type = types_module.getattr("GeneratorType").map_err(|e| { @@ -825,7 +833,7 @@ impl PythonPredictor { let is_generator: bool = result.is_instance(&generator_type).unwrap_or(false); if is_generator { - self.process_generator_output(py, result, &json_module)? + self.process_generator_output(py, result, &json_module, slot_sender)? } else { self.process_single_output(py, result, &json_module)? } diff --git a/crates/coglet-python/src/worker_bridge.rs b/crates/coglet-python/src/worker_bridge.rs index 60e7f247b6..0b58c18cca 100644 --- a/crates/coglet-python/src/worker_bridge.rs +++ b/crates/coglet-python/src/worker_bridge.rs @@ -381,8 +381,14 @@ impl PredictHandler for PythonPredictHandler { }); // Block on future.result() + let sender_for_async = slot_sender.clone(); let result = Python::attach(|py| match future.call_method0(py, "result") { - Ok(result) => pred.process_async_result(py, result.bind(py), is_async_gen), + Ok(result) => pred.process_async_result( + py, + result.bind(py), + is_async_gen, + &sender_for_async, + ), Err(e) => { let err_str = e.to_string(); if err_str.contains("CancelledError") || err_str.contains("cancelled") { @@ -404,7 +410,7 @@ impl PredictHandler for PythonPredictHandler { // Sync train - set sync prediction ID for log routing crate::log_writer::set_sync_prediction_id(Some(&id)); let _cancelable = crate::cancel::enter_cancelable(); - let r = pred.train_worker(input); + let r = pred.train_worker(input, slot_sender.clone()); crate::log_writer::set_sync_prediction_id(None); r } @@ -446,8 +452,14 @@ impl PredictHandler for PythonPredictHandler { }); // Block on future.result() + let sender_for_async = slot_sender.clone(); let result = Python::attach(|py| match future.call_method0(py, "result") { - Ok(result) => pred.process_async_result(py, result.bind(py), is_async_gen), + Ok(result) => pred.process_async_result( + py, + result.bind(py), + is_async_gen, + &sender_for_async, + ), Err(e) => { let err_str = e.to_string(); if err_str.contains("CancelledError") || err_str.contains("cancelled") { @@ -470,7 +482,7 @@ impl PredictHandler for PythonPredictHandler { crate::log_writer::set_sync_prediction_id(Some(&id)); let _cancelable = crate::cancel::enter_cancelable(); tracing::trace!(%slot, %id, "Calling predict_worker"); - let r = pred.predict_worker(input); + let r = pred.predict_worker(input, slot_sender.clone()); tracing::trace!(%slot, %id, "predict_worker returned"); crate::log_writer::set_sync_prediction_id(None); r From b2bae4d7a3118d1f3b83bb6d324bb8a179184e37 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 11:33:35 -0800 Subject: [PATCH 05/21] feat(coglet): route file outputs to disk via SlotSender MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit File-type outputs (os.PathLike, io.IOBase) are now detected in the worker and written to the per-prediction output dir instead of being base64-encoded in-process. This keeps network I/O out of the worker subprocess — the parent process handles uploads. - Add SlotSender::write_file_output(bytes, ext) for language-agnostic file writing from any FFI worker (Python, Node, etc.) - Add send_output_item() helper that routes by type: PathLike sends existing path, IOBase reads bytes and writes via SlotSender, everything else goes through the existing JSON serialization path - Update process_single_output to also detect and route file types --- crates/coglet-python/src/predictor.rs | 197 ++++++++++++++++++++------ crates/coglet/src/worker.rs | 10 ++ 2 files changed, 166 insertions(+), 41 deletions(-) diff --git a/crates/coglet-python/src/predictor.rs b/crates/coglet-python/src/predictor.rs index eead44742e..22a967a90b 100644 --- a/crates/coglet-python/src/predictor.rs +++ b/crates/coglet-python/src/predictor.rs @@ -103,6 +103,96 @@ fn format_validation_error(py: Python<'_>, err: &PyErr) -> String { err.value(py).to_string() } +/// Send a single output item over IPC, routing file outputs to disk. +/// +/// For Path outputs (os.PathLike): sends the existing file path via send_file_output. +/// For IOBase outputs: reads bytes, writes to output_dir via write_file_output. +/// For everything else: processes through make_encodeable + upload_files, then send_output. +fn send_output_item( + py: Python<'_>, + item: &Bound<'_, PyAny>, + json_module: &Bound<'_, PyAny>, + slot_sender: &SlotSender, +) -> Result<(), PredictionError> { + let os = py + .import("os") + .map_err(|e| PredictionError::Failed(format!("Failed to import os: {}", e)))?; + let io_mod = py + .import("io") + .map_err(|e| PredictionError::Failed(format!("Failed to import io: {}", e)))?; + let pathlike = os + .getattr("PathLike") + .map_err(|e| PredictionError::Failed(format!("Failed to get os.PathLike: {}", e)))?; + let iobase = io_mod + .getattr("IOBase") + .map_err(|e| PredictionError::Failed(format!("Failed to get io.IOBase: {}", e)))?; + + if item.is_instance(&pathlike).unwrap_or(false) { + // Path output — file already on disk, send path reference + let path_str: String = item + .call_method0("__fspath__") + .and_then(|p| p.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; + slot_sender + .send_file_output(std::path::PathBuf::from(path_str)) + .map_err(|e| PredictionError::Failed(format!("Failed to send file output: {}", e)))?; + return Ok(()); + } + + if item.is_instance(&iobase).unwrap_or(false) { + // IOBase output — read bytes, write to disk via SlotSender + // Seek to start if seekable + if item + .call_method0("seekable") + .and_then(|r| r.extract::()) + .unwrap_or(false) + { + let _ = item.call_method1("seek", (0,)); + } + let data: Vec = item + .call_method0("read") + .and_then(|d| d.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to read IOBase: {}", e)))?; + + // Try to guess extension from filename + let ext = item + .getattr("name") + .and_then(|n| n.extract::()) + .ok() + .and_then(|name| { + std::path::Path::new(&name) + .extension() + .and_then(|e| e.to_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| "bin".to_string()); + + slot_sender + .write_file_output(&data, &ext) + .map_err(|e| PredictionError::Failed(format!("Failed to write file output: {}", e)))?; + return Ok(()); + } + + // Non-file output — process normally + let processed = output::process_output_item(py, item) + .map_err(|e| PredictionError::Failed(format!("Failed to process output item: {}", e)))?; + + let item_str: String = json_module + .call_method1("dumps", (&processed,)) + .map_err(|e| PredictionError::Failed(format!("Failed to serialize output item: {}", e)))? + .extract() + .map_err(|e| PredictionError::Failed(format!("Failed to extract output string: {}", e)))?; + + let item_json: serde_json::Value = serde_json::from_str(&item_str) + .map_err(|e| PredictionError::Failed(format!("Failed to parse output JSON: {}", e)))?; + + slot_sender + .send_output(item_json) + .map_err(|e| PredictionError::Failed(format!("Failed to send output: {}", e)))?; + + Ok(()) +} + /// Type alias for Python object (Py). type PyObject = Py; @@ -534,7 +624,7 @@ impl PythonPredictor { let output = if is_generator { self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { - self.process_single_output(py, result_bound, &json_module)? + self.process_single_output(py, result_bound, &json_module, &slot_sender)? }; // prepared drops here, cleaning up temp files via RAII @@ -605,7 +695,7 @@ impl PythonPredictor { let output = if is_generator { self.process_generator_output(py, result_bound, &json_module, &slot_sender)? } else { - self.process_single_output(py, result_bound, &json_module)? + self.process_single_output(py, result_bound, &json_module, &slot_sender)? }; drop(prepared); @@ -638,27 +728,7 @@ impl PythonPredictor { PredictionError::Failed(format!("Generator iteration error: {}", e)) })?; - let processed = output::process_output_item(py, &item).map_err(|e| { - PredictionError::Failed(format!("Failed to process output item: {}", e)) - })?; - - let item_str: String = json_module - .call_method1("dumps", (&processed,)) - .map_err(|e| { - PredictionError::Failed(format!("Failed to serialize output item: {}", e)) - })? - .extract() - .map_err(|e| { - PredictionError::Failed(format!("Failed to extract output string: {}", e)) - })?; - - let item_json: serde_json::Value = serde_json::from_str(&item_str).map_err(|e| { - PredictionError::Failed(format!("Failed to parse output JSON: {}", e)) - })?; - - slot_sender - .send_output(item_json) - .map_err(|e| PredictionError::Failed(format!("Failed to send output: {}", e)))?; + send_output_item(py, &item, json_module, slot_sender)?; } // Outputs already streamed over IPC — return empty stream @@ -666,12 +736,73 @@ impl PythonPredictor { } /// Process single output into PredictionOutput::Single. + /// + /// For file outputs (Path/IOBase), the file is sent via slot_sender and + /// an empty Single(Null) is returned since the output was already streamed. fn process_single_output( &self, py: Python<'_>, result: &Bound<'_, PyAny>, json_module: &Bound<'_, PyAny>, + slot_sender: &SlotSender, ) -> Result { + // Check for file-type outputs first + let os = py + .import("os") + .map_err(|e| PredictionError::Failed(format!("Failed to import os: {}", e)))?; + let io_mod = py + .import("io") + .map_err(|e| PredictionError::Failed(format!("Failed to import io: {}", e)))?; + let pathlike = os + .getattr("PathLike") + .map_err(|e| PredictionError::Failed(format!("Failed to get os.PathLike: {}", e)))?; + let iobase = io_mod + .getattr("IOBase") + .map_err(|e| PredictionError::Failed(format!("Failed to get io.IOBase: {}", e)))?; + + if result.is_instance(&pathlike).unwrap_or(false) { + let path_str: String = result + .call_method0("__fspath__") + .and_then(|p| p.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; + slot_sender + .send_file_output(std::path::PathBuf::from(path_str)) + .map_err(|e| { + PredictionError::Failed(format!("Failed to send file output: {}", e)) + })?; + return Ok(PredictionOutput::Single(serde_json::Value::Null)); + } + + if result.is_instance(&iobase).unwrap_or(false) { + if result + .call_method0("seekable") + .and_then(|r| r.extract::()) + .unwrap_or(false) + { + let _ = result.call_method1("seek", (0,)); + } + let data: Vec = result + .call_method0("read") + .and_then(|d| d.extract()) + .map_err(|e| PredictionError::Failed(format!("Failed to read IOBase: {}", e)))?; + let ext = result + .getattr("name") + .and_then(|n| n.extract::()) + .ok() + .and_then(|name| { + std::path::Path::new(&name) + .extension() + .and_then(|e| e.to_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| "bin".to_string()); + slot_sender.write_file_output(&data, &ext).map_err(|e| { + PredictionError::Failed(format!("Failed to write file output: {}", e)) + })?; + return Ok(PredictionOutput::Single(serde_json::Value::Null)); + } + + // Non-file output — process normally let processed = output::process_output(py, result, None) .map_err(|e| PredictionError::Failed(format!("Failed to process output: {}", e)))?; @@ -805,23 +936,7 @@ impl PythonPredictor { // Result is a pre-collected list — stream each item over IPC if let Ok(list) = result.extract::>>() { for item in list { - let processed = output::process_output_item(py, &item).map_err(|e| { - PredictionError::Failed(format!("Failed to process output item: {}", e)) - })?; - let item_str: String = json_module - .call_method1("dumps", (&processed,)) - .map_err(|e| { - PredictionError::Failed(format!("Failed to serialize: {}", e)) - })? - .extract() - .map_err(|e| { - PredictionError::Failed(format!("Failed to extract: {}", e)) - })?; - let item_json: serde_json::Value = serde_json::from_str(&item_str) - .map_err(|e| PredictionError::Failed(format!("Failed to parse: {}", e)))?; - slot_sender.send_output(item_json).map_err(|e| { - PredictionError::Failed(format!("Failed to send output: {}", e)) - })?; + send_output_item(py, &item, &json_module, slot_sender)?; } } PredictionOutput::Stream(vec![]) @@ -835,7 +950,7 @@ impl PythonPredictor { if is_generator { self.process_generator_output(py, result, &json_module, slot_sender)? } else { - self.process_single_output(py, result, &json_module)? + self.process_single_output(py, result, &json_module, slot_sender)? } }; diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index 77adca9bc0..97f8ac12da 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -182,6 +182,16 @@ impl SlotSender { .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } + /// Write raw bytes to a file in the output dir and send as FileOutput. + /// + /// Used by FFI workers (Python, Node, etc.) to hand off file data without + /// needing language-specific file I/O — SlotSender owns the write. + pub fn write_file_output(&self, data: &[u8], extension: &str) -> io::Result<()> { + let path = self.next_output_path(extension); + std::fs::write(&path, data)?; + self.send_file_output(path) + } + /// Send a file-typed output (e.g. Path, File return types). /// /// The file is already on disk at `path` — we just send the path reference. From 9e916df8aad0aa9ac56731afbb2f29f4dbf4b343 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 11:54:16 -0800 Subject: [PATCH 06/21] feat(coglet): base64 data URI fallback for file outputs in orchestrator The orchestrator's FileOutput handler now distinguishes between FileOutputKind::Oversized (JSON spill, deserialize as before) and FileOutputKind::FileType (binary file, base64-encode as data URI). This is the fallback path when no upload_url is configured. File outputs from the worker are read from disk and encoded as data:{mime};base64,{data} URIs using mime_guess for content type detection. --- crates/Cargo.lock | 18 ++++++++++++++++++ crates/coglet/Cargo.toml | 4 ++++ crates/coglet/src/orchestrator.rs | 30 ++++++++++++++++++++++++------ 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index bcd0232186..8e2874cae4 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -184,11 +184,13 @@ dependencies = [ "anyhow", "async-trait", "axum", + "base64", "chrono", "dashmap", "futures", "http-body-util", "insta", + "mime_guess", "nix", "reqwest", "serde", @@ -1054,6 +1056,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mio" version = "1.1.1" @@ -2455,6 +2467,12 @@ dependencies = [ "unic-common", ] +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-ident" version = "1.0.22" diff --git a/crates/coglet/Cargo.toml b/crates/coglet/Cargo.toml index 73ac7cb41a..5666680530 100644 --- a/crates/coglet/Cargo.toml +++ b/crates/coglet/Cargo.toml @@ -21,6 +21,10 @@ async-trait = "0.1" serde.workspace = true serde_json.workspace = true +# Encoding +base64 = "0.22.1" +mime_guess = "2.0.5" + # Identifiers uuid.workspace = true diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 38d4c805a4..5cc4939b85 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -22,7 +22,8 @@ use tokio_util::codec::{FramedRead, FramedWrite}; use crate::PredictionOutput; use crate::bridge::codec::JsonCodec; use crate::bridge::protocol::{ - ControlRequest, ControlResponse, HealthcheckStatus, SlotId, SlotRequest, SlotResponse, + ControlRequest, ControlResponse, FileOutputKind, HealthcheckStatus, SlotId, SlotRequest, + SlotResponse, }; use crate::bridge::transport::create_transport; use crate::permit::{InactiveSlotIdleToken, PermitPool, SlotIdleToken}; @@ -779,11 +780,28 @@ async fn run_event_loop( Ok(SlotResponse::FileOutput { filename, kind }) => { tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); let output = match std::fs::read(&filename) { - Ok(bytes) => match serde_json::from_slice(&bytes) { - Ok(val) => val, - Err(e) => { - tracing::error!(%slot_id, %filename, error = %e, "Failed to parse FileOutput JSON"); - continue; + Ok(bytes) => match kind { + FileOutputKind::Oversized => { + match serde_json::from_slice(&bytes) { + Ok(val) => val, + Err(e) => { + tracing::error!(%slot_id, %filename, error = %e, "Failed to parse oversized JSON"); + continue; + } + } + } + FileOutputKind::FileType => { + // Binary file — base64-encode as data URI + // TODO: upload to signed endpoint when upload_url is set + let mime = mime_guess::from_path(&filename) + .first_or_octet_stream() + .to_string(); + use base64::Engine; + let encoded = + base64::engine::general_purpose::STANDARD.encode(&bytes); + serde_json::Value::String(format!( + "data:{mime};base64,{encoded}" + )) } }, Err(e) => { From 653bf75339bc42286c62bf3da3c605391099b6b7 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 12:00:25 -0800 Subject: [PATCH 07/21] feat(coglet): add optional mime_type to FileOutput protocol MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add mime_type: Option to SlotResponse::FileOutput so FFI workers can pass an explicit MIME type. When None, the orchestrator falls back to mime_guess from the file extension (matching old cog behavior). Plumbing is in place for future use — all current callers pass None. --- crates/coglet-python/src/predictor.rs | 14 ++++++++------ crates/coglet/src/bridge/protocol.rs | 3 +++ crates/coglet/src/orchestrator.rs | 10 ++++++---- crates/coglet/src/worker.rs | 14 +++++++++++--- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/crates/coglet-python/src/predictor.rs b/crates/coglet-python/src/predictor.rs index 22a967a90b..8e2c9f169d 100644 --- a/crates/coglet-python/src/predictor.rs +++ b/crates/coglet-python/src/predictor.rs @@ -134,7 +134,7 @@ fn send_output_item( .and_then(|p| p.extract()) .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; slot_sender - .send_file_output(std::path::PathBuf::from(path_str)) + .send_file_output(std::path::PathBuf::from(path_str), None) .map_err(|e| PredictionError::Failed(format!("Failed to send file output: {}", e)))?; return Ok(()); } @@ -168,7 +168,7 @@ fn send_output_item( .unwrap_or_else(|| "bin".to_string()); slot_sender - .write_file_output(&data, &ext) + .write_file_output(&data, &ext, None) .map_err(|e| PredictionError::Failed(format!("Failed to write file output: {}", e)))?; return Ok(()); } @@ -766,7 +766,7 @@ impl PythonPredictor { .and_then(|p| p.extract()) .map_err(|e| PredictionError::Failed(format!("Failed to get fspath: {}", e)))?; slot_sender - .send_file_output(std::path::PathBuf::from(path_str)) + .send_file_output(std::path::PathBuf::from(path_str), None) .map_err(|e| { PredictionError::Failed(format!("Failed to send file output: {}", e)) })?; @@ -796,9 +796,11 @@ impl PythonPredictor { .map(|s| s.to_string()) }) .unwrap_or_else(|| "bin".to_string()); - slot_sender.write_file_output(&data, &ext).map_err(|e| { - PredictionError::Failed(format!("Failed to write file output: {}", e)) - })?; + slot_sender + .write_file_output(&data, &ext, None) + .map_err(|e| { + PredictionError::Failed(format!("Failed to write file output: {}", e)) + })?; return Ok(PredictionOutput::Single(serde_json::Value::Null)); } diff --git a/crates/coglet/src/bridge/protocol.rs b/crates/coglet/src/bridge/protocol.rs index 02761b1dc4..2a76461504 100644 --- a/crates/coglet/src/bridge/protocol.rs +++ b/crates/coglet/src/bridge/protocol.rs @@ -230,6 +230,9 @@ pub enum SlotResponse { FileOutput { filename: String, kind: FileOutputKind, + /// Explicit MIME type from the predictor. Falls back to mime_guess when None. + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, }, /// Streaming output chunk (for generators). diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 5cc4939b85..76bfee0c1a 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -777,7 +777,7 @@ async fn run_event_loop( predictions.remove(&slot_id); } } - Ok(SlotResponse::FileOutput { filename, kind }) => { + Ok(SlotResponse::FileOutput { filename, kind, mime_type }) => { tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); let output = match std::fs::read(&filename) { Ok(bytes) => match kind { @@ -793,9 +793,11 @@ async fn run_event_loop( FileOutputKind::FileType => { // Binary file — base64-encode as data URI // TODO: upload to signed endpoint when upload_url is set - let mime = mime_guess::from_path(&filename) - .first_or_octet_stream() - .to_string(); + let mime = mime_type.unwrap_or_else(|| { + mime_guess::from_path(&filename) + .first_or_octet_stream() + .to_string() + }); use base64::Engine; let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index 97f8ac12da..75a7e8cfa4 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -186,16 +186,22 @@ impl SlotSender { /// /// Used by FFI workers (Python, Node, etc.) to hand off file data without /// needing language-specific file I/O — SlotSender owns the write. - pub fn write_file_output(&self, data: &[u8], extension: &str) -> io::Result<()> { + pub fn write_file_output( + &self, + data: &[u8], + extension: &str, + mime_type: Option, + ) -> io::Result<()> { let path = self.next_output_path(extension); std::fs::write(&path, data)?; - self.send_file_output(path) + self.send_file_output(path, mime_type) } /// Send a file-typed output (e.g. Path, File return types). /// /// The file is already on disk at `path` — we just send the path reference. - pub fn send_file_output(&self, path: PathBuf) -> io::Result<()> { + /// `mime_type` is an explicit MIME type; when None the parent guesses from extension. + pub fn send_file_output(&self, path: PathBuf, mime_type: Option) -> io::Result<()> { let filename = path .to_str() .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? @@ -203,6 +209,7 @@ impl SlotSender { let msg = SlotResponse::FileOutput { filename, kind: FileOutputKind::FileType, + mime_type, }; self.tx .send(msg) @@ -226,6 +233,7 @@ impl SlotSender { SlotResponse::FileOutput { filename, kind: FileOutputKind::Oversized, + mime_type: None, } } else { SlotResponse::Output { output } From 5f7f1b63495078c2e7ff5d03d09d1e07de7f3083 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 12:45:03 -0800 Subject: [PATCH 08/21] feat(coglet): wire upload_url from CLI to orchestrator for file uploads Thread --upload-url CLI arg through coglet.server.serve() to the orchestrator event loop. When set, file outputs are PUT to the signed endpoint per-yield as received (matching old Python cog behavior). - upload_file() does PUT with Content-Type, extracts Location header, strips query params, follows redirects via reqwest - Uploads are spawned as tokio tasks so they don't block the event loop - Done handler awaits pending uploads before finalizing the prediction - Failed/Cancelled/Error handlers abort pending uploads immediately - Falls back to base64 data URI encoding when no upload_url is set --- crates/coglet-python/src/lib.rs | 19 ++- crates/coglet/src/orchestrator.rs | 208 +++++++++++++++++++++++------- python/cog/server/http.py | 1 + 3 files changed, 181 insertions(+), 47 deletions(-) diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index 5487ad7f09..ce27deca3f 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -189,7 +189,7 @@ impl CogletServer { /// Start the HTTP prediction server. Blocks until shutdown. #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (predictor_ref=None, host="0.0.0.0".to_string(), port=5000, await_explicit_shutdown=false, is_train=false, output_temp_dir_base="/tmp/coglet/output".to_string()))] + #[pyo3(signature = (predictor_ref=None, host="0.0.0.0".to_string(), port=5000, await_explicit_shutdown=false, is_train=false, output_temp_dir_base="/tmp/coglet/output".to_string(), upload_url=None))] fn serve( &self, py: Python<'_>, @@ -199,6 +199,7 @@ impl CogletServer { await_explicit_shutdown: bool, is_train: bool, output_temp_dir_base: String, + upload_url: Option, ) -> PyResult<()> { serve_impl( py, @@ -208,6 +209,7 @@ impl CogletServer { await_explicit_shutdown, is_train, output_temp_dir_base, + upload_url, ) } @@ -265,6 +267,7 @@ fn serve_impl( await_explicit_shutdown: bool, is_train: bool, _output_temp_dir_base: String, + upload_url: Option, ) -> PyResult<()> { let (setup_log_tx, setup_log_rx) = tokio::sync::mpsc::unbounded_channel(); init_tracing(false, Some(setup_log_tx)); @@ -307,7 +310,15 @@ fn serve_impl( }; info!(predictor_ref = %pred_ref, is_train, "Using subprocess isolation"); - serve_subprocess(py, pred_ref, config, version, is_train, setup_log_rx) + serve_subprocess( + py, + pred_ref, + config, + version, + is_train, + setup_log_rx, + upload_url, + ) } fn serve_subprocess( @@ -317,6 +328,7 @@ fn serve_subprocess( version: VersionInfo, is_train: bool, mut setup_log_rx: tokio::sync::mpsc::UnboundedReceiver, + upload_url: Option, ) -> PyResult<()> { let max_concurrency = read_max_concurrency(py); info!( @@ -326,7 +338,8 @@ fn serve_subprocess( let orch_config = coglet_core::orchestrator::OrchestratorConfig::new(pred_ref) .with_num_slots(max_concurrency) - .with_train(is_train); + .with_train(is_train) + .with_upload_url(upload_url); let service = Arc::new( PredictionService::new_no_pool() diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 76bfee0c1a..4dba8a6157 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -29,6 +29,59 @@ use crate::bridge::transport::create_transport; use crate::permit::{InactiveSlotIdleToken, PermitPool, SlotIdleToken}; use crate::prediction::Prediction; +/// Upload a file to a signed endpoint, returning the final URL. +/// +/// Matches the behavior of Python cog's `put_file_to_signed_endpoint`: +/// PUT to `{endpoint}{filename}` with Content-Type header, then extract +/// the final URL from the Location header (falling back to response URL), +/// stripping query parameters. Follows redirects automatically. +async fn upload_file( + endpoint: &str, + filename: &str, + data: &[u8], + content_type: &str, +) -> Result { + let url = format!("{endpoint}{filename}"); + let client = reqwest::Client::new(); + let resp = client + .put(&url) + .header("Content-Type", content_type) + .body(data.to_vec()) + .timeout(std::time::Duration::from_secs(25)) + .send() + .await + .map_err(|e| format!("upload request failed: {e}"))?; + + if !resp.status().is_success() { + return Err(format!("upload returned status {}", resp.status())); + } + + // Prefer Location header, fall back to final request URL (after redirects) + let final_url = resp + .headers() + .get("location") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| resp.url().to_string()); + + // Strip query parameters (signing gubbins) + match reqwest::Url::parse(&final_url) { + Ok(mut parsed) => { + parsed.set_query(None); + Ok(parsed.to_string()) + } + Err(_) => Ok(final_url), + } +} + +fn ensure_trailing_slash(s: &str) -> String { + if s.ends_with('/') { + s.to_string() + } else { + format!("{s}/") + } +} + /// Try to lock a prediction mutex. /// On poison: logs error, recovers to fail the prediction, returns None. /// Caller should remove the prediction from tracking if None is returned. @@ -210,6 +263,8 @@ pub struct OrchestratorConfig { pub is_async: bool, pub setup_timeout: Duration, pub spawner: Arc, + /// Upload URL prefix for file outputs (from --upload-url CLI arg). + pub upload_url: Option, } impl OrchestratorConfig { @@ -221,9 +276,15 @@ impl OrchestratorConfig { is_async: false, setup_timeout: Duration::from_secs(300), spawner: Arc::new(SimpleSpawner), + upload_url: None, } } + pub fn with_upload_url(mut self, upload_url: Option) -> Self { + self.upload_url = upload_url; + self + } + pub fn with_num_slots(mut self, n: usize) -> Self { self.num_slots = n; self @@ -497,6 +558,7 @@ pub async fn spawn_worker( let pool_for_loop = Arc::clone(&pool); let ctrl_writer_for_loop = Arc::clone(&ctrl_writer); + let upload_url = config.upload_url.clone(); tokio::spawn(async move { run_event_loop( ctrl_reader, @@ -505,6 +567,7 @@ pub async fn spawn_worker( register_rx, healthcheck_rx, pool_for_loop, + upload_url, ) .await; }); @@ -533,12 +596,14 @@ async fn run_event_loop( )>, mut healthcheck_rx: mpsc::Receiver>, pool: Arc, + upload_url: Option, ) { let mut predictions: HashMap>> = HashMap::new(); let mut idle_senders: HashMap> = HashMap::new(); let mut pending_healthchecks: Vec> = Vec::new(); let mut healthcheck_counter: u64 = 0; + let mut pending_uploads: HashMap>> = HashMap::new(); let (slot_msg_tx, mut slot_msg_rx) = mpsc::channel::<(SlotId, Result)>(100); @@ -779,50 +844,89 @@ async fn run_event_loop( } Ok(SlotResponse::FileOutput { filename, kind, mime_type }) => { tracing::debug!(%slot_id, %filename, ?kind, "FileOutput received"); - let output = match std::fs::read(&filename) { - Ok(bytes) => match kind { - FileOutputKind::Oversized => { - match serde_json::from_slice(&bytes) { - Ok(val) => val, - Err(e) => { - tracing::error!(%slot_id, %filename, error = %e, "Failed to parse oversized JSON"); - continue; - } + let bytes = match std::fs::read(&filename) { + Ok(b) => b, + Err(e) => { + tracing::error!(%slot_id, %filename, error = %e, "Failed to read FileOutput"); + continue; + } + }; + match kind { + FileOutputKind::Oversized => { + let output: serde_json::Value = match serde_json::from_slice(&bytes) { + Ok(val) => val, + Err(e) => { + tracing::error!(%slot_id, %filename, error = %e, "Failed to parse oversized JSON"); + continue; + } + }; + let poisoned = if let Some(pred) = predictions.get(&slot_id) { + if let Some(mut p) = try_lock_prediction(pred) { + p.append_output(output); + false + } else { + true } + } else { + false + }; + if poisoned { + predictions.remove(&slot_id); } - FileOutputKind::FileType => { - // Binary file — base64-encode as data URI - // TODO: upload to signed endpoint when upload_url is set - let mime = mime_type.unwrap_or_else(|| { - mime_guess::from_path(&filename) - .first_or_octet_stream() - .to_string() + } + FileOutputKind::FileType => { + let mime = mime_type.unwrap_or_else(|| { + mime_guess::from_path(&filename) + .first_or_octet_stream() + .to_string() + }); + if let Some(ref url) = upload_url { + // Spawn upload task so we don't block the event loop + let pred = predictions.get(&slot_id).cloned(); + let endpoint = ensure_trailing_slash(url); + let basename = std::path::Path::new(&filename) + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("output") + .to_string(); + let handle = tokio::spawn(async move { + match upload_file(&endpoint, &basename, &bytes, &mime).await { + Ok(url) => { + if let Some(pred) = pred + && let Some(mut p) = try_lock_prediction(&pred) + { + p.append_output(serde_json::Value::String(url)); + } + } + Err(e) => { + tracing::error!(error = %e, "Failed to upload file output"); + } + } }); + pending_uploads.entry(slot_id).or_default().push(handle); + } else { + // No upload URL — base64-encode as data URI use base64::Engine; - let encoded = - base64::engine::general_purpose::STANDARD.encode(&bytes); - serde_json::Value::String(format!( + let encoded = base64::engine::general_purpose::STANDARD + .encode(&bytes); + let output = serde_json::Value::String(format!( "data:{mime};base64,{encoded}" - )) + )); + let poisoned = if let Some(pred) = predictions.get(&slot_id) { + if let Some(mut p) = try_lock_prediction(pred) { + p.append_output(output); + false + } else { + true + } + } else { + false + }; + if poisoned { + predictions.remove(&slot_id); + } } - }, - Err(e) => { - tracing::error!(%slot_id, %filename, error = %e, "Failed to read FileOutput"); - continue; - } - }; - let poisoned = if let Some(pred) = predictions.get(&slot_id) { - if let Some(mut p) = try_lock_prediction(pred) { - p.append_output(output); - false - } else { - true } - } else { - false - }; - if poisoned { - predictions.remove(&slot_id); } } Ok(SlotResponse::Done { id, output, predict_time }) => { @@ -832,14 +936,19 @@ async fn run_event_loop( predict_time, "Prediction succeeded" ); + let uploads = pending_uploads.remove(&slot_id).unwrap_or_default(); if let Some(pred) = predictions.remove(&slot_id) { - if let Some(mut p) = try_lock_prediction(&pred) { - let pred_output = output - .map(PredictionOutput::Single) - .unwrap_or(PredictionOutput::Single(serde_json::Value::Null)); - p.set_succeeded(pred_output); - } - // On mutex poison, prediction already failed - nothing more to do + tokio::spawn(async move { + for h in uploads { + let _ = h.await; + } + if let Some(mut p) = try_lock_prediction(&pred) { + let pred_output = output + .map(PredictionOutput::Single) + .unwrap_or(PredictionOutput::Single(serde_json::Value::Null)); + p.set_succeeded(pred_output); + } + }); } else { tracing::warn!(%slot_id, %id, "Prediction not found for Done message"); } @@ -851,6 +960,10 @@ async fn run_event_loop( %error, "Prediction failed" ); + // Abort any pending uploads — prediction is terminal + if let Some(handles) = pending_uploads.remove(&slot_id) { + for h in handles { h.abort(); } + } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { @@ -863,6 +976,10 @@ async fn run_event_loop( prediction_id = %id, "Prediction cancelled" ); + // Abort any pending uploads — prediction is terminal + if let Some(handles) = pending_uploads.remove(&slot_id) { + for h in handles { h.abort(); } + } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { @@ -871,6 +988,9 @@ async fn run_event_loop( } Err(e) => { tracing::error!(%slot_id, error = %e, "Slot socket error"); + if let Some(handles) = pending_uploads.remove(&slot_id) { + for h in handles { h.abort(); } + } if let Some(pred) = predictions.remove(&slot_id) && let Some(mut p) = try_lock_prediction(&pred) { diff --git a/python/cog/server/http.py b/python/cog/server/http.py index e47d5f5a90..972d5714e2 100644 --- a/python/cog/server/http.py +++ b/python/cog/server/http.py @@ -84,5 +84,6 @@ port=port, await_explicit_shutdown=args.await_explicit_shutdown, is_train=is_train, + upload_url=args.upload_url, ) sys.exit(0) From 09a431ef3b9af7cbbe8f59bd74b878f1bf2e33c0 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 13:56:28 -0800 Subject: [PATCH 09/21] fix(coglet): fix stub generation and venv setup - Root .venv is now created by _setup_venv task (used by build:coglet, build:coglet:wheel, build:sdk, generate:stubs, stub:check) - _.python.venv uses REPO_ROOT so all tasks share the same venv - mise clean:python removes .venv, coglet-python/.venv, and *.so - stub_gen.rs: write coglet/_impl.pyi (not __init__.pyi) so the native module stubs don't get overwritten by mypy stubgen - module_variable! declarations for __version__, __build__, server so ty can resolve members imported from coglet._impl - stub:check depends on generate:stubs to avoid duplicating logic - #[allow(clippy::too_many_arguments)] on serve_impl (8 args after upload_url) --- crates/coglet-python/src/bin/stub_gen.rs | 32 +++++++++++++- crates/coglet-python/src/lib.rs | 7 +++ mise.toml | 55 ++++++++++++++---------- 3 files changed, 70 insertions(+), 24 deletions(-) diff --git a/crates/coglet-python/src/bin/stub_gen.rs b/crates/coglet-python/src/bin/stub_gen.rs index 497838374a..e3b262c8bb 100644 --- a/crates/coglet-python/src/bin/stub_gen.rs +++ b/crates/coglet-python/src/bin/stub_gen.rs @@ -1,11 +1,41 @@ //! Generate Python stub files for coglet. //! //! Run with: cargo run --bin stub_gen +//! +//! Custom generate logic: pyo3-stub-gen places classes from the native +//! `coglet._impl` module into the `coglet` parent package, but mypy stubgen +//! overwrites `coglet/__init__.pyi` from the hand-maintained `__init__.py`. +//! We redirect the `coglet` module output to `coglet/_impl.pyi` so the +//! native module types are preserved. use pyo3_stub_gen::Result; +use std::fs; +use std::io::Write; fn main() -> Result<()> { let stub = coglet::stub_info()?; - stub.generate()?; + + for (name, module) in &stub.modules { + let normalized = name.replace('-', "_"); + + let dest = if normalized == "coglet" { + // Native module classes land here — redirect to _impl.pyi + stub.python_root.join("coglet").join("_impl.pyi") + } else { + // Submodules like "coglet._sdk" → coglet/_sdk/__init__.pyi + let path = normalized.replace('.', "/"); + stub.python_root.join(&path).join("__init__.pyi") + }; + + let dir = dest.parent().expect("cannot get parent directory"); + if !dir.exists() { + fs::create_dir_all(dir)?; + } + + let mut f = fs::File::create(&dest)?; + write!(f, "{module}")?; + eprintln!("Generated stub: {}", dest.display()); + } + Ok(()) } diff --git a/crates/coglet-python/src/lib.rs b/crates/coglet-python/src/lib.rs index ce27deca3f..89fc059243 100644 --- a/crates/coglet-python/src/lib.rs +++ b/crates/coglet-python/src/lib.rs @@ -19,6 +19,12 @@ use tracing_subscriber::{EnvFilter, fmt, layer::SubscriberExt, util::SubscriberI // Define stub info gatherer for generating .pyi files pyo3_stub_gen::define_stub_info_gatherer!(stub_info); +// Module-level attributes (pyo3-stub-gen can't see m.add() calls). +// Uses "coglet" because that's the module key in StubInfo for the native module. +pyo3_stub_gen::module_variable!("coglet", "__version__", &str); +pyo3_stub_gen::module_variable!("coglet", "__build__", BuildInfo); +pyo3_stub_gen::module_variable!("coglet", "server", CogletServer); + use coglet_core::{ Health, PredictionService, SetupResult, VersionInfo, transport::{ServerConfig, serve as http_serve}, @@ -259,6 +265,7 @@ impl CogletServer { } } +#[allow(clippy::too_many_arguments)] fn serve_impl( py: Python<'_>, predictor_ref: Option, diff --git a/mise.toml b/mise.toml index e6da17cb8d..7fb21f93b0 100644 --- a/mise.toml +++ b/mise.toml @@ -55,9 +55,9 @@ ty = "0.0.10" [env] _.path = "./bin" _.file = [".env"] -_.python.venv = ".venv" # Set REPO_ROOT only if not already set (e.g., by CI) REPO_ROOT = "{{env.REPO_ROOT | default(value=config_root)}}" +_.python.venv = "{{env.REPO_ROOT}}/.venv" [settings] lockfile = true @@ -73,6 +73,12 @@ quiet = true description = "Create dist directory" run = "mkdir -p dist" +[tasks._setup_venv] +hide = true +quiet = true +description = "Ensure root .venv exists with Python" +run = "test -d .venv || uv venv --quiet" + [tasks._clean_dist] hide = true quiet = true @@ -125,14 +131,17 @@ run = "cargo build --manifest-path crates/Cargo.toml --workspace --release" [tasks."build:coglet"] description = "Build coglet Python wheel (development, local install)" -run = "cd crates/coglet-python && uv run maturin develop" +run = [ + { task = "_setup_venv" }, + "maturin develop --manifest-path crates/coglet-python/Cargo.toml", +] [tasks."build:coglet:wheel"] description = "Build coglet Python wheel (native platform)" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml", ] @@ -141,7 +150,7 @@ description = "Build coglet Python wheel for Linux x86_64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*-linux_x86_64.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target x86_64-unknown-linux-gnu --zig", ] @@ -150,7 +159,7 @@ description = "Build coglet Python wheel for Linux ARM64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*-linux_aarch64.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target aarch64-unknown-linux-gnu --zig", ] @@ -159,7 +168,7 @@ description = "Build coglet Python wheel for macOS ARM64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*-macosx_*_arm64.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target aarch64-apple-darwin", ] @@ -168,7 +177,7 @@ description = "Build coglet Python wheel for macOS x86_64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] outputs = ["dist/coglet-*-macosx_*_x86_64.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target x86_64-apple-darwin", ] @@ -185,7 +194,7 @@ description = "Build cog SDK wheel" sources = ["python/**/*.py", "pyproject.toml"] outputs = ["dist/cog-*.whl"] run = [ - { task = "_setup_dist" }, + { tasks = ["_setup_dist", "_setup_venv"] }, "uv build --wheel --out-dir=dist .", ] @@ -405,15 +414,18 @@ run = [ alias = "stub:generate" description = "Generate Python type stubs for coglet" sources = ["crates/coglet-python/src/**/*.rs", "crates/coglet-python/Cargo.toml", "crates/coglet-python/coglet/__init__.py"] -outputs = ["crates/coglet-python/coglet/_sdk/__init__.pyi", "crates/coglet-python/coglet/__init__.pyi"] +outputs = ["coglet/_sdk/__init__.pyi", "coglet/__init__.pyi", "coglet/_impl.pyi"] dir = "crates/coglet-python" -run = ''' -# 1. pyo3-stub-gen: Rust types → _sdk/__init__.pyi (and initial __init__.pyi) -cargo run --bin stub_gen +run = [ + { task = "_setup_venv" }, + ''' +# 1. pyo3-stub-gen: Rust types → _impl.pyi and _sdk/__init__.pyi +uv run --active cargo run --bin stub_gen -# 2. mypy stubgen: Python __init__.py → __init__.pyi (overwrites pyo3 version) +# 2. mypy stubgen: Python __init__.py → __init__.pyi (re-export stubs) uvx --from mypy stubgen coglet/__init__.py -o . --quiet -''' +''', +] [tasks."generate:compat"] description = "Regenerate CUDA/PyTorch/TensorFlow compatibility matrices" @@ -460,15 +472,11 @@ echo "Done." [tasks."stub:check"] description = "Check that coglet Python stubs are up to date" dir = "crates/coglet-python" -run = ''' +run = [ + { task = "generate:stubs" }, + ''' #!/usr/bin/env bash set -e - -# Regenerate stubs in-place -cargo run --bin stub_gen 2>/dev/null -uvx --from mypy stubgen coglet/__init__.py -o . --quiet - -# Check if any .pyi files changed if ! git diff --quiet -- '**/*.pyi'; then echo "ERROR: Stubs are out of date:" git diff -- '**/*.pyi' @@ -477,7 +485,8 @@ if ! git diff --quiet -- '**/*.pyi'; then exit 1 fi echo "Stubs are up to date." -''' +''', +] [tasks."stub:typecheck"] description = "Type-check coglet stubs with ty" @@ -504,7 +513,7 @@ run = "cd crates && cargo clean" [tasks."clean:python"] description = "Clean Python build artifacts" -run = "rm -rf .tox build python/cog.egg-info" +run = "rm -rf .tox build python/cog.egg-info .venv crates/coglet-python/.venv crates/coglet-python/coglet/*.so" # ============================================================================= # Docs tasks From c1ff72d5ec61ac3acaff7433950f3a0d6c723d17 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 13:57:27 -0800 Subject: [PATCH 10/21] chore(coglet): regenerate Python type stubs --- crates/coglet-python/coglet/_impl.pyi | 62 +++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 crates/coglet-python/coglet/_impl.pyi diff --git a/crates/coglet-python/coglet/_impl.pyi b/crates/coglet-python/coglet/_impl.pyi new file mode 100644 index 0000000000..64c9098203 --- /dev/null +++ b/crates/coglet-python/coglet/_impl.pyi @@ -0,0 +1,62 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401, F403, F405 + +import builtins +import typing +from . import _sdk +__all__ = [ + "BuildInfo", + "Server", + "server", +] + +__build__: BuildInfo +__version__: builtins.str +server: Server +@typing.final +class BuildInfo: + r""" + Frozen build metadata exposed as `coglet.__build__`. + """ + @property + def version(self) -> builtins.str: ... + @property + def git_sha(self) -> builtins.str: ... + @property + def build_time(self) -> builtins.str: ... + @property + def rustc_version(self) -> builtins.str: ... + def __repr__(self) -> builtins.str: ... + +@typing.final +class Server: + r""" + The coglet prediction server. + + Access via `coglet.server`. Frozen — attributes cannot be set or deleted. + + - `coglet.server.active` — `True` when running inside a worker subprocess + - `coglet.server.serve(...)` — start the HTTP prediction server (blocking) + """ + @property + def active(self) -> builtins.bool: + r""" + `True` when running inside a coglet worker subprocess. + """ + def serve(self, predictor_ref: typing.Optional[builtins.str] = None, host: builtins.str = '0.0.0.0', port: builtins.int = 5000, await_explicit_shutdown: builtins.bool = False, is_train: builtins.bool = False, output_temp_dir_base: builtins.str = '/tmp/coglet/output', upload_url: typing.Optional[builtins.str] = None) -> None: + r""" + Start the HTTP prediction server. Blocks until shutdown. + """ + def _run_worker(self) -> None: + r""" + Worker subprocess entry point. Called by the orchestrator. + + Sets the active flag, installs log writers and audit hooks, + then enters the worker event loop. + """ + def _is_cancelable(self) -> builtins.bool: + r""" + Returns `True` if the current thread is in a cancelable predict call. + """ + def __repr__(self) -> builtins.str: ... + From 762925565e1046ad6c928932a66766bcc3f3551d Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 15:50:08 -0800 Subject: [PATCH 11/21] fix(coglet): decouple output from Done message to prevent IPC frame overflow Output is now always sent as a separate message before Done, with automatic spill-to-disk for values exceeding the 6MiB IPC frame limit. The orchestrator reconstructs the final output from accumulated per-yield values (generators, file uploads) rather than from the Done message payload. This fixes three issues: - Large single outputs (>8MiB) caused "frame size too big" panics that poisoned slots, since the entire output was embedded in the Done message - Generator/iterator outputs were lost because per-yield values accumulated in outputs() were ignored by the Done handler in favor of the empty Stream(vec![]) from the worker - File upload outputs (Path, IOBase) yielded from generators were silently dropped for the same reason Also adds --upload-url flag to `cog serve`, mock upload server to the integration test harness, and three new integration tests: - coglet_large_output: 9MiB string output (would poison slot without spill) - coglet_iterator_upload_url: generator Path yields uploaded to --upload-url - coglet_iterator_path_output: generator Path yields as base64 data URIs --- crates/coglet/src/orchestrator.rs | 13 +- crates/coglet/src/prediction.rs | 4 + crates/coglet/src/worker.rs | 89 ++++++---- integration-tests/harness/harness.go | 152 +++++++++++++++++- .../tests/coglet_iterator_path_output.txtar | 30 ++++ .../tests/coglet_iterator_upload_url.txtar | 42 +++++ .../tests/coglet_large_output.txtar | 23 +++ pkg/cli/serve.go | 15 +- pkg/docker/command/command.go | 23 +-- pkg/docker/docker.go | 5 + 10 files changed, 349 insertions(+), 47 deletions(-) create mode 100644 integration-tests/tests/coglet_iterator_path_output.txtar create mode 100644 integration-tests/tests/coglet_iterator_upload_url.txtar create mode 100644 integration-tests/tests/coglet_large_output.txtar diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 4dba8a6157..53e213d77b 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -929,7 +929,7 @@ async fn run_event_loop( } } } - Ok(SlotResponse::Done { id, output, predict_time }) => { + Ok(SlotResponse::Done { id, output: _, predict_time }) => { tracing::info!( target: "coglet::prediction", prediction_id = %id, @@ -943,9 +943,14 @@ async fn run_event_loop( let _ = h.await; } if let Some(mut p) = try_lock_prediction(&pred) { - let pred_output = output - .map(PredictionOutput::Single) - .unwrap_or(PredictionOutput::Single(serde_json::Value::Null)); + // Outputs are accumulated via append_output (from Output + // and FileOutput messages). Done always arrives with + // output: None — the actual values are in outputs(). + let pred_output = match p.take_outputs().as_slice() { + [] => PredictionOutput::Single(serde_json::Value::Null), + [single] => PredictionOutput::Single(single.clone()), + many => PredictionOutput::Stream(many.to_vec()), + }; p.set_succeeded(pred_output); } }); diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index ab976cf8c0..6416838d73 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -153,6 +153,10 @@ impl Prediction { &self.outputs } + pub fn take_outputs(&mut self) -> Vec { + std::mem::take(&mut self.outputs) + } + pub fn output(&self) -> Option<&PredictionOutput> { self.output.as_ref() } diff --git a/crates/coglet/src/worker.rs b/crates/coglet/src/worker.rs index 75a7e8cfa4..ac999b3b5f 100644 --- a/crates/coglet/src/worker.rs +++ b/crates/coglet/src/worker.rs @@ -216,34 +216,42 @@ impl SlotSender { .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } - const MAX_INLINE_OUTPUT_SIZE: usize = 1024 * 1024 * 6; // 6MiB - /// Send prediction output, either inline or spilled to disk if too large. pub fn send_output(&self, output: serde_json::Value) -> io::Result<()> { - let serialized = serde_json::to_vec(&output) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - - let msg = if serialized.len() > Self::MAX_INLINE_OUTPUT_SIZE { - let path = self.next_output_path("json"); - std::fs::write(&path, &serialized)?; - let filename = path - .to_str() - .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? - .to_string(); - SlotResponse::FileOutput { - filename, - kind: FileOutputKind::Oversized, - mime_type: None, - } - } else { - SlotResponse::Output { output } - }; + let msg = build_output_message(&self.output_dir, output)?; self.tx .send(msg) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "slot channel closed")) } } +const MAX_INLINE_OUTPUT_SIZE: usize = 1024 * 1024 * 6; // 6MiB + +/// Build an output message, spilling to disk if larger than the IPC frame limit. +fn build_output_message( + output_dir: &std::path::Path, + output: serde_json::Value, +) -> io::Result { + let serialized = + serde_json::to_vec(&output).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + if serialized.len() > MAX_INLINE_OUTPUT_SIZE { + let path = output_dir.join(format!("spill_{}.json", uuid::Uuid::new_v4())); + std::fs::write(&path, &serialized)?; + let filename = path + .to_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "non-UTF-8 path"))? + .to_string(); + Ok(SlotResponse::FileOutput { + filename, + kind: FileOutputKind::Oversized, + mime_type: None, + }) + } else { + Ok(SlotResponse::Output { output }) + } +} + /// Setup phase errors. /// /// These errors occur during predictor loading and setup, before predictions @@ -726,7 +734,7 @@ async fn run_prediction( // Create channel for log streaming let (log_tx, mut log_rx) = mpsc::unbounded_channel::(); - let slot_sender = Arc::new(SlotSender::new(log_tx, output_dir)); + let slot_sender = Arc::new(SlotSender::new(log_tx, output_dir.clone())); // Forward logs to slot socket let writer_for_logs = Arc::clone(&writer); @@ -741,7 +749,8 @@ async fn run_prediction( tracing::trace!("Prediction log forwarder exiting"); }); - // Run prediction + // Run prediction — slot_sender is moved in, dropped when predict returns, + // which closes the log channel and lets the log forwarder exit. let result = handler .predict(slot_id, prediction_id.clone(), input, slot_sender) .await; @@ -752,16 +761,39 @@ async fn run_prediction( let _ = log_forwarder.await; tracing::trace!(%slot_id, %prediction_id, "Log forwarder done"); - // Send result on slot socket + // Send result on slot socket. + // Output is always sent separately from Done so that large values get + // spilled to disk and never exceed the IPC frame limit. + let mut w = writer.lock().await; let response = match result.outcome { PredictionOutcome::Success { output, predict_time, - } => SlotResponse::Done { - id: prediction_id.clone(), - output: Some(output), - predict_time, - }, + } => { + // Send output as a separate message (handles spilling for large values). + // Skip if null or empty array — those mean "already streamed" (generators). + if !output.is_null() && output != serde_json::Value::Array(vec![]) { + let output_msg = match build_output_message(&output_dir, output) { + Ok(msg) => msg, + Err(e) => { + tracing::error!(error = %e, "Failed to build output message"); + return SlotCompletion::poisoned( + slot_id, + format!("Output spill error: {}", e), + ); + } + }; + if let Err(e) = w.send(output_msg).await { + tracing::error!(error = %e, "Failed to send prediction output"); + return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e)); + } + } + SlotResponse::Done { + id: prediction_id.clone(), + output: None, + predict_time, + } + } PredictionOutcome::Cancelled { .. } => SlotResponse::Cancelled { id: prediction_id.clone(), }, @@ -771,7 +803,6 @@ async fn run_prediction( }, }; - let mut w = writer.lock().await; if let Err(e) = w.send(response).await { tracing::error!(error = %e, "Failed to send prediction response"); return SlotCompletion::poisoned(slot_id, format!("Socket write error: {}", e)); diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 682ddafa5e..efb0687a3b 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -40,6 +40,22 @@ type registryInfo struct { host string // e.g., "localhost:5432" } +// mockUploadRecord records a single upload received by the mock upload server. +type mockUploadRecord struct { + Path string + ContentType string + Size int +} + +// mockUploadServer is a lightweight HTTP server that accepts PUT requests +// and records what was uploaded. +type mockUploadServer struct { + server *http.Server + port int + mu sync.Mutex + uploads []mockUploadRecord +} + type Harness struct { CogBinary string // realHome is captured at creation time before testscript overrides HOME @@ -52,6 +68,9 @@ type Harness struct { // registries tracks test registry containers for cleanup, keyed by work directory registries map[string]*registryInfo registriesMu sync.Mutex + // uploadServers tracks mock upload servers for cleanup, keyed by work directory + uploadServers map[string]*mockUploadServer + uploadServersMu sync.Mutex } // New creates a new Harness, resolving the cog binary location. @@ -68,8 +87,9 @@ func New() (*Harness, error) { CogBinary: cogBinary, realHome: os.Getenv("HOME"), repoRoot: repoRoot, - serverProcs: make(map[string]*serverInfo), - registries: make(map[string]*registryInfo), + serverProcs: make(map[string]*serverInfo), + registries: make(map[string]*registryInfo), + uploadServers: make(map[string]*mockUploadServer), }, nil } @@ -196,6 +216,10 @@ func (h *Harness) Commands() map[string]func(ts *testscript.TestScript, neg bool NewCommand("docker-push", h.cmdDockerPush), NewCommand("mock-weights", h.cmdMockWeights), + // Mock upload server commands + NewCommand("upload-server-start", h.cmdUploadServerStart), + NewCommand("upload-server-count", h.cmdUploadServerCount), + // PTY command (defined in cmd_pty.go) &PtyRunCommand{harness: h}, } @@ -279,6 +303,8 @@ func (h *Harness) Setup(env *testscript.Env) error { h.stopServerByWorkDir(workDir) // Stop the registry for this specific test (if any) h.stopRegistryByWorkDir(workDir) + // Stop the upload server for this specific test (if any) + h.stopUploadServerByWorkDir(workDir) removeDockerImage(imageName) }) @@ -1059,3 +1085,125 @@ func parseSize(s string) (int64, error) { return int64(num * float64(multiplier)), nil } + +// ============================================================================= +// Mock upload server commands +// ============================================================================= + +// cmdUploadServerStart starts a mock HTTP upload server on the host. +// It accepts PUT requests, records them, and responds with a Location header. +// Usage: upload-server-start +// Exports $UPLOAD_SERVER_URL with the server's base URL. +func (h *Harness) cmdUploadServerStart(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("upload-server-start: does not support negation") + } + + workDir := ts.Getenv("WORK") + + h.uploadServersMu.Lock() + if _, exists := h.uploadServers[workDir]; exists { + h.uploadServersMu.Unlock() + ts.Fatalf("upload-server-start: server already running for this test") + } + h.uploadServersMu.Unlock() + + mus := &mockUploadServer{} + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + body, _ := io.ReadAll(r.Body) + record := mockUploadRecord{ + Path: r.URL.Path, + ContentType: r.Header.Get("Content-Type"), + Size: len(body), + } + mus.mu.Lock() + mus.uploads = append(mus.uploads, record) + mus.mu.Unlock() + + // Return a clean URL without query params (simulates a signed URL redirect) + location := fmt.Sprintf("http://host.docker.internal:%d%s", mus.port, r.URL.Path) + w.Header().Set("Location", location) + w.WriteHeader(http.StatusOK) + }) + + mus.server = &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec // test harness, not production + + // Bind to all interfaces so the container can reach us via host.docker.internal + ln, err := net.Listen("tcp", "0.0.0.0:0") //nolint:gosec // must be reachable from Docker container + if err != nil { + ts.Fatalf("upload-server-start: failed to listen: %v", err) + } + mus.port = ln.Addr().(*net.TCPAddr).Port + + go func() { _ = mus.server.Serve(ln) }() + + h.uploadServersMu.Lock() + h.uploadServers[workDir] = mus + h.uploadServersMu.Unlock() + + // Advertise host.docker.internal so the container can reach the host server. + // On Linux, cog serve adds --add-host=host.docker.internal:host-gateway. + // On Mac, Docker Desktop resolves host.docker.internal automatically. + url := fmt.Sprintf("http://host.docker.internal:%d/", mus.port) + ts.Setenv("UPLOAD_SERVER_URL", url) + ts.Logf("upload-server-start: listening on 0.0.0.0:%d, container URL: %s", mus.port, url) +} + +// cmdUploadServerCount verifies exactly N uploads were received. +// Usage: upload-server-count N +func (h *Harness) cmdUploadServerCount(ts *testscript.TestScript, neg bool, args []string) { + if len(args) != 1 { + ts.Fatalf("upload-server-count: usage: upload-server-count N") + } + + expected, err := strconv.Atoi(args[0]) + if err != nil { + ts.Fatalf("upload-server-count: invalid count %q: %v", args[0], err) + } + + workDir := ts.Getenv("WORK") + h.uploadServersMu.Lock() + mus, exists := h.uploadServers[workDir] + h.uploadServersMu.Unlock() + + if !exists { + ts.Fatalf("upload-server-count: no upload server running (call upload-server-start first)") + } + + mus.mu.Lock() + got := len(mus.uploads) + mus.mu.Unlock() + + if neg { + if got == expected { + ts.Fatalf("upload-server-count: expected NOT %d uploads but got %d", expected, got) + } + return + } + + if got != expected { + ts.Fatalf("upload-server-count: expected %d uploads but got %d", expected, got) + } +} + +// stopUploadServerByWorkDir shuts down the upload server for a work directory. +func (h *Harness) stopUploadServerByWorkDir(workDir string) { + h.uploadServersMu.Lock() + mus, exists := h.uploadServers[workDir] + if !exists { + h.uploadServersMu.Unlock() + return + } + delete(h.uploadServers, workDir) + h.uploadServersMu.Unlock() + + if mus.server != nil { + _ = mus.server.Close() + } +} diff --git a/integration-tests/tests/coglet_iterator_path_output.txtar b/integration-tests/tests/coglet_iterator_path_output.txtar new file mode 100644 index 0000000000..13e487d142 --- /dev/null +++ b/integration-tests/tests/coglet_iterator_path_output.txtar @@ -0,0 +1,30 @@ +# Test iterator prediction with Path outputs + +cog build -t $TEST_IMAGE +cog predict $TEST_IMAGE +stdout '"output":\["data:image/' +! stdout '"status":"failed"' + +-- cog.yaml -- +build: + python_version: "3.12" + python_packages: + - "pillow==10.4.0" +predict: "predict.py:Predictor" + +-- predict.py -- +import os +import tempfile +from typing import Iterator + +from cog import BasePredictor, Path +from PIL import Image + + +class Predictor(BasePredictor): + def predict(self) -> Iterator[Path]: + for color in ["red", "blue", "green"]: + d = tempfile.mkdtemp() + p = os.path.join(d, f"{color}.png") + Image.new("RGB", (10, 10), color).save(p) + yield Path(p) diff --git a/integration-tests/tests/coglet_iterator_upload_url.txtar b/integration-tests/tests/coglet_iterator_upload_url.txtar new file mode 100644 index 0000000000..eff521644a --- /dev/null +++ b/integration-tests/tests/coglet_iterator_upload_url.txtar @@ -0,0 +1,42 @@ +# Test that iterator Path outputs are uploaded per-yield to --upload-url. + +cog build -t $TEST_IMAGE + +# Start mock upload server on the host, sets $UPLOAD_SERVER_URL +upload-server-start + +cog serve --upload-url $UPLOAD_SERVER_URL + +# Run a prediction — three Path outputs should be uploaded, not base64-encoded +curl POST /predictions '{"input":{}}' +stdout '"status":"succeeded"' +# Outputs should be URLs pointing at the mock server, not data URIs +stdout '"output":\["http://host.docker.internal' +! stdout 'data:image/' + +# Verify the mock server received exactly 3 PUT uploads +upload-server-count 3 + +-- cog.yaml -- +build: + python_version: "3.12" + python_packages: + - "pillow==10.4.0" +predict: "predict.py:Predictor" + +-- predict.py -- +import os +import tempfile +from typing import Iterator + +from cog import BasePredictor, Path +from PIL import Image + + +class Predictor(BasePredictor): + def predict(self) -> Iterator[Path]: + for color in ["red", "blue", "green"]: + d = tempfile.mkdtemp() + p = os.path.join(d, f"{color}.png") + Image.new("RGB", (10, 10), color).save(p) + yield Path(p) diff --git a/integration-tests/tests/coglet_large_output.txtar b/integration-tests/tests/coglet_large_output.txtar new file mode 100644 index 0000000000..ffb7f35f13 --- /dev/null +++ b/integration-tests/tests/coglet_large_output.txtar @@ -0,0 +1,23 @@ +# Test that outputs larger than the 8MiB IPC frame limit spill to disk +# and are reconstructed correctly by the orchestrator. +# Without spilling this would panic the bridge and poison the slot. + +cog build -t $TEST_IMAGE +cog predict $TEST_IMAGE + +# Output should be a 9MiB string of 'x' characters — not an error +stdout 'xxx' +! stdout 'failed' + +-- cog.yaml -- +build: + python_version: "3.12" +predict: "predict.py:Predictor" + +-- predict.py -- +from cog import BasePredictor + +class Predictor(BasePredictor): + def predict(self) -> str: + # 9MiB string — exceeds the 8MiB IPC frame limit + return "x" * (9 * 1024 * 1024) diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index 88092ff03b..f45648cf74 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -14,7 +14,8 @@ import ( ) var ( - port = 8393 + port = 8393 + uploadURL = "" ) func newServeCommand() *cobra.Command { @@ -36,6 +37,7 @@ Generate and run an HTTP server based on the declared model inputs and outputs.` addConfigFlag(cmd) cmd.Flags().IntVarP(&port, "port", "p", port, "Port on which to listen") + cmd.Flags().StringVar(&uploadURL, "upload-url", "", "Upload URL for file outputs (e.g. https://example.com/upload/)") return cmd } @@ -73,6 +75,10 @@ func cmdServe(cmd *cobra.Command, arg []string) error { "--await-explicit-shutdown", "true", } + if uploadURL != "" { + args = append(args, "--upload-url", uploadURL) + } + // Automatically propagate RUST_LOG for Rust coglet debugging env := envFlags if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { @@ -88,6 +94,13 @@ func cmdServe(cmd *cobra.Command, arg []string) error { Workdir: "/src", } + // On Linux, host.docker.internal is not available by default — add it. + // This allows the container to reach services running on the host, + // e.g. when --upload-url points to a local upload server. + if uploadURL != "" { + runOptions.ExtraHosts = []string{"host.docker.internal:host-gateway"} + } + runOptions.Ports = append(runOptions.Ports, command.Port{HostPort: port, ContainerPort: 5000}) console.Info("") diff --git a/pkg/docker/command/command.go b/pkg/docker/command/command.go index 6f510824e7..4c991f6b6a 100644 --- a/pkg/docker/command/command.go +++ b/pkg/docker/command/command.go @@ -48,17 +48,18 @@ type ImageBuildOptions struct { } type RunOptions struct { - Detach bool - Args []string - Env []string - GPUs string - Image string - Ports []Port - Volumes []Volume - Workdir string - Stdin io.Reader - Stdout io.Writer - Stderr io.Writer + Detach bool + Args []string + Env []string + GPUs string + Image string + Ports []Port + Volumes []Volume + Workdir string + ExtraHosts []string + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer } type Port struct { diff --git a/pkg/docker/docker.go b/pkg/docker/docker.go index a9c7733ec5..0643fd9f24 100644 --- a/pkg/docker/docker.go +++ b/pkg/docker/docker.go @@ -447,6 +447,11 @@ func (c *apiClient) containerRun(ctx context.Context, options command.RunOptions } } + // Configure extra hosts (e.g. host.docker.internal on Linux) + if len(options.ExtraHosts) > 0 { + hostCfg.ExtraHosts = options.ExtraHosts + } + networkingCfg := &network.NetworkingConfig{ EndpointsConfig: map[string]*network.EndpointSettings{}, } From a46fffefdddd27dbee33fb0fcae2cf8ae5c5111e Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 16:03:50 -0800 Subject: [PATCH 12/21] chore: go fmt --- integration-tests/harness/harness.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index efb0687a3b..34e312964f 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -84,9 +84,9 @@ func New() (*Harness, error) { return nil, err } return &Harness{ - CogBinary: cogBinary, - realHome: os.Getenv("HOME"), - repoRoot: repoRoot, + CogBinary: cogBinary, + realHome: os.Getenv("HOME"), + repoRoot: repoRoot, serverProcs: make(map[string]*serverInfo), registries: make(map[string]*registryInfo), uploadServers: make(map[string]*mockUploadServer), From b8ca940519e5d9b68f4c9d2dea3bfbdcd8f2e4a5 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Wed, 18 Feb 2026 16:15:35 -0800 Subject: [PATCH 13/21] fix(test): correct coglet_iterator_path_output assertion and go formatting cog predict writes file outputs to disk, not as base64 JSON to stdout. Fix assertion to check stderr for "Written output to" messages. --- .../tests/coglet_iterator_path_output.txtar | 10 +++++++--- integration-tests/tests/coglet_large_output.txtar | 6 +++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/integration-tests/tests/coglet_iterator_path_output.txtar b/integration-tests/tests/coglet_iterator_path_output.txtar index 13e487d142..d49e6db6c4 100644 --- a/integration-tests/tests/coglet_iterator_path_output.txtar +++ b/integration-tests/tests/coglet_iterator_path_output.txtar @@ -1,9 +1,13 @@ -# Test iterator prediction with Path outputs +# Test iterator prediction with Path outputs (no upload URL — files written to disk) cog build -t $TEST_IMAGE cog predict $TEST_IMAGE -stdout '"output":\["data:image/' -! stdout '"status":"failed"' + +# cog predict writes file outputs to disk, not as base64 to stdout +stderr 'Written output to: output.0.png' +stderr 'Written output to: output.1.png' +stderr 'Written output to: output.2.png' +! stderr 'failed' -- cog.yaml -- build: diff --git a/integration-tests/tests/coglet_large_output.txtar b/integration-tests/tests/coglet_large_output.txtar index ffb7f35f13..e5bb562c7b 100644 --- a/integration-tests/tests/coglet_large_output.txtar +++ b/integration-tests/tests/coglet_large_output.txtar @@ -5,9 +5,9 @@ cog build -t $TEST_IMAGE cog predict $TEST_IMAGE -# Output should be a 9MiB string of 'x' characters — not an error -stdout 'xxx' -! stdout 'failed' +# Prediction should succeed — not fail with "frame size too big" +! stderr 'frame size too big' +! stderr 'failed' -- cog.yaml -- build: From 33365fefc152c5f43d9bbd5db1ef25477ac94a1f Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 10:41:52 -0800 Subject: [PATCH 14/21] fix(coglet): complete prediction synchronously when no uploads pending The Done handler unconditionally used tokio::spawn to wait for pending uploads before calling set_succeeded(). This introduced a race with service.rs: Notify::notified() could miss the wakeup if the spawned task called notify_waiters() before the service registered its waiter, causing the prediction to hang indefinitely. Split the Done handler into two paths: - No uploads: call set_succeeded() synchronously in the event loop - Has uploads: spawn a task to await them (preserves existing behavior) This restores the synchronous notification path for the common case and fixes the CI hang in the coglet_large_output integration test. --- crates/coglet/src/orchestrator.rs | 33 +++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/crates/coglet/src/orchestrator.rs b/crates/coglet/src/orchestrator.rs index 53e213d77b..c5b795c1d2 100644 --- a/crates/coglet/src/orchestrator.rs +++ b/crates/coglet/src/orchestrator.rs @@ -938,14 +938,13 @@ async fn run_event_loop( ); let uploads = pending_uploads.remove(&slot_id).unwrap_or_default(); if let Some(pred) = predictions.remove(&slot_id) { - tokio::spawn(async move { - for h in uploads { - let _ = h.await; - } + if uploads.is_empty() { + // No pending uploads — complete synchronously to avoid + // a race between tokio::spawn and Notify::notified() in + // service.rs. notify_waiters() only wakes already- + // registered waiters; spawning a task can fire the + // notification before the service registers its waiter. if let Some(mut p) = try_lock_prediction(&pred) { - // Outputs are accumulated via append_output (from Output - // and FileOutput messages). Done always arrives with - // output: None — the actual values are in outputs(). let pred_output = match p.take_outputs().as_slice() { [] => PredictionOutput::Single(serde_json::Value::Null), [single] => PredictionOutput::Single(single.clone()), @@ -953,7 +952,25 @@ async fn run_event_loop( }; p.set_succeeded(pred_output); } - }); + } else { + // Has pending uploads — must spawn to await them. + // TODO(#2748): when cancellation is wired end-to-end, + // this spawn should also observe the cancel token and + // abort in-flight uploads on cancellation. + tokio::spawn(async move { + for h in uploads { + let _ = h.await; + } + if let Some(mut p) = try_lock_prediction(&pred) { + let pred_output = match p.take_outputs().as_slice() { + [] => PredictionOutput::Single(serde_json::Value::Null), + [single] => PredictionOutput::Single(single.clone()), + many => PredictionOutput::Stream(many.to_vec()), + }; + p.set_succeeded(pred_output); + } + }); + } } else { tracing::warn!(%slot_id, %id, "Prediction not found for Done message"); } From 2e9303b8c759629d242970ee5c670f1e3fc42862 Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 12:02:35 -0800 Subject: [PATCH 15/21] chore: update llm docs --- docs/cli.md | 1 + docs/llms.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/cli.md b/docs/cli.md index f282d72530..d6f69ef615 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -173,6 +173,7 @@ cog serve [flags] -h, --help help for serve -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") + --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` diff --git a/docs/llms.txt b/docs/llms.txt index e514c62341..948e992801 100644 --- a/docs/llms.txt +++ b/docs/llms.txt @@ -370,6 +370,7 @@ cog serve [flags] -h, --help help for serve -p, --port int Port on which to listen (default 8393) --progress string Set type of build progress output, 'auto' (default), 'tty', 'plain', or 'quiet' (default "auto") + --upload-url string Upload URL for file outputs (e.g. https://example.com/upload/) --use-cog-base-image Use pre-built Cog base image for faster cold boots (default true) --use-cuda-base-image string Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects (default "auto") ``` From 86a7824863498208c0594ca29e8ef20f59a50d5c Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 12:28:47 -0800 Subject: [PATCH 16/21] fix(coglet): use notify_one instead of notify_waiters to eliminate race MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit notify_waiters() only wakes currently-registered waiters. In the predict flow, service.rs checks is_terminal() then awaits on a separate Notify — if the orchestrator fires notify_waiters() between those two steps, the notification is lost and the prediction hangs. notify_one() stores a permit that a future .notified().await consumes immediately, closing the race window entirely. There is exactly one waiter per prediction so notify_one is semantically correct. --- crates/coglet/src/prediction.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/crates/coglet/src/prediction.rs b/crates/coglet/src/prediction.rs index 6416838d73..6ae9dc3616 100644 --- a/crates/coglet/src/prediction.rs +++ b/crates/coglet/src/prediction.rs @@ -119,18 +119,24 @@ impl Prediction { pub fn set_succeeded(&mut self, output: PredictionOutput) { self.status = PredictionStatus::Succeeded; self.output = Some(output); - self.completion.notify_waiters(); + // notify_one stores a permit so a future .notified().await will + // consume it immediately. notify_waiters only wakes currently- + // registered waiters and would race with the service task that + // checks is_terminal() then awaits — the notification can fire + // in between. There is exactly one waiter per prediction + // (service.rs predict()), so notify_one is semantically correct. + self.completion.notify_one(); } pub fn set_failed(&mut self, error: String) { self.status = PredictionStatus::Failed; self.error = Some(error); - self.completion.notify_waiters(); + self.completion.notify_one(); } pub fn set_canceled(&mut self) { self.status = PredictionStatus::Canceled; - self.completion.notify_waiters(); + self.completion.notify_one(); } pub fn elapsed(&self) -> std::time::Duration { From 64592f9d3cdddb334387b802252711daa1d1bdd7 Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 12:43:23 -0800 Subject: [PATCH 17/21] fix(build): add source caching, fix output globs, and propagate COG_CA_CERT - build:cog: add sources/outputs so mise skips rebuild when Go sources are unchanged - build:coglet:wheel:linux-x64: fix output glob to match maturin's actual manylinux filename (was never matching, always rebuilding) - build:coglet:wheel:linux-arm64: same glob fix - test:integration: depend on linux-x64 wheel instead of native macOS wheel (only linux wheel is needed for Docker tests) - test:integration: pass extra args through to go test (e.g. -count=4) - test:integration: propagate COG_CA_CERT for custom CA certs (WARP) - Helper tasks (_setup_dist, _setup_venv, _clean_dist): use silent instead of quiet to suppress timing output - AGENTS.md: correct outdated references to embedded Python wheel (wheels are resolved from dist/ at Docker build time) --- AGENTS.md | 6 +++--- integration-tests/harness/harness.go | 5 +++++ mise.toml | 23 ++++++++++++++--------- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 08c137eac1..e34846d11f 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -129,7 +129,7 @@ The CLI code is in the `cmd/cog/` and `pkg/` directories. Support tooling is in The main commands for working on the CLI are: - `go run ./cmd/cog` - Runs the Cog CLI directly from source (requires wheel to be built first) -- `mise run build:cog` - Builds the Cog CLI binary, embedding the Python wheel +- `mise run build:cog` - Builds the Cog CLI binary - `mise run install` - Symlinks the built binary to `/usr/local/bin` (run `build:cog` first), or to a custom path with `PREFIX=/custom/path mise run install` - `mise run test:go` - Runs all Go unit tests - `go test ./pkg/...` - Runs tests directly with `go test` @@ -180,7 +180,7 @@ COG_BINARY=dist/go/*/cog mise run test:integration 1. Run `mise install` to set up the development environment 2. Run `mise run build:sdk` after making changes to the `./python` directory 3. Run `mise run build:coglet:wheel:linux-x64` after making changes to the `./crates` directory (needed for Docker testing) -4. Run `mise run build:cog` to build the CLI (embeds the SDK wheel; picks up coglet wheel from `dist/`) +4. Run `mise run build:cog` to build the CLI (wheels are picked up from `dist/` at Docker build time, not embedded in the binary) 5. Run `mise run fmt:fix` to format code 6. Run `mise run lint` to check code quality 7. Run `mise run docs:llm` to regenerate `docs/llms.txt` after changing `README.md` or any `docs/*.md` file @@ -212,7 +212,7 @@ See `crates/README.md` for detailed architecture documentation. - `crates/coglet-python/` - PyO3 bindings for Python predictor integration ### Key Design Patterns -1. **Embedded Python Wheel**: The Go binary embeds the Python wheel at build time (`pkg/dockerfile/embed/`) +1. **Local Wheel Resolution**: The CLI discovers SDK and coglet wheels from `dist/` at Docker build time (not embedded in the binary) 2. **Docker SDK Integration**: Uses Docker Go SDK for container operations 3. **Type Safety**: Dataclasses for Python type validation, strongly typed Go interfaces 4. **Compatibility Matrix**: Automated CUDA/PyTorch/TensorFlow compatibility management diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 34e312964f..0abd4b1744 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -290,6 +290,11 @@ func (h *Harness) Setup(env *testscript.Env) error { env.Setenv("RUST_LOG", rustLog) } + // Propagate COG_CA_CERT for custom CA certificates (e.g. Cloudflare WARP) + if caCert := os.Getenv("COG_CA_CERT"); caCert != "" { + env.Setenv("COG_CA_CERT", caCert) + } + // Generate unique image name for this test run imageName := generateUniqueImageName() env.Setenv("TEST_IMAGE", imageName) diff --git a/mise.toml b/mise.toml index 7fb21f93b0..fac2303356 100644 --- a/mise.toml +++ b/mise.toml @@ -69,19 +69,19 @@ experimental = true [tasks._setup_dist] hide = true -quiet = true +silent = true description = "Create dist directory" run = "mkdir -p dist" [tasks._setup_venv] hide = true -quiet = true +silent = true description = "Ensure root .venv exists with Python" run = "test -d .venv || uv venv --quiet" [tasks._clean_dist] hide = true -quiet = true +silent = true description = "Clean dist directory" run = "rm -f dist/cog-*.whl dist/cog-*.tar.gz dist/coglet-*.whl" @@ -115,6 +115,8 @@ echo "Installed $PREFIX/bin/cog -> $BINARY" [tasks."build:cog"] description = "Build cog CLI (development)" +sources = ["cmd/**/*.go", "pkg/**/*.go", "go.mod", "go.sum"] +outputs = ["dist/go/*/cog"] run = "GOFLAGS=-buildvcs=false go run github.com/goreleaser/goreleaser/v2@latest build --clean --snapshot --single-target --id cog --output cog" [tasks."build:cog:release"] @@ -148,7 +150,7 @@ run = [ [tasks."build:coglet:wheel:linux-x64"] description = "Build coglet Python wheel for Linux x86_64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] -outputs = ["dist/coglet-*-linux_x86_64.whl"] +outputs = ["dist/coglet-*manylinux*x86_64*.whl"] run = [ { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target x86_64-unknown-linux-gnu --zig", @@ -157,7 +159,7 @@ run = [ [tasks."build:coglet:wheel:linux-arm64"] description = "Build coglet Python wheel for Linux ARM64" sources = ["crates/**/*.rs", "crates/**/Cargo.toml", "Cargo.lock"] -outputs = ["dist/coglet-*-linux_aarch64.whl"] +outputs = ["dist/coglet-*manylinux*aarch64*.whl"] run = [ { tasks = ["_setup_dist", "_setup_venv"] }, "maturin build --release --out dist --manifest-path crates/coglet-python/Cargo.toml --target aarch64-unknown-linux-gnu --zig", @@ -258,7 +260,7 @@ run = "nox -s coglet" [tasks."test:integration"] description = "Run integration tests (skips slow tests by default, set SHORT=0 for full suite)" -depends = ["build:cog", "build:sdk", "build:coglet:wheel"] +depends = ["build:cog", "build:sdk", "build:coglet:wheel:linux-x64"] run = """ #!/usr/bin/env bash set -e @@ -266,10 +268,13 @@ SHORT_FLAG="-short" if [ "${SHORT:-1}" = "0" ]; then SHORT_FLAG="" fi -if [ -n "$1" ]; then - gotestsum -- -tags integration -v $SHORT_FLAG -run "TestIntegration/$1" -timeout 30m ./integration-tests/... +# If first arg is a bare name (no dash), treat as test name filter; +# remaining args are passed through to go test. +# e.g. mise run test:integration coglet_large_output -count=4 +if [ $# -gt 0 ] && [[ "$1" != -* ]]; then + gotestsum -- -tags integration -v $SHORT_FLAG -run "TestIntegration/$1" "${@:2}" -timeout 30m ./integration-tests/... else - gotestsum -- -tags integration -v $SHORT_FLAG -parallel ${TEST_PARALLEL:-4} -timeout 30m ./integration-tests/... + gotestsum -- -tags integration -v $SHORT_FLAG -parallel ${TEST_PARALLEL:-4} "$@" -timeout 30m ./integration-tests/... fi """ From 900423792f62d34fc63306248f06295eee05ec00 Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 14:52:53 -0800 Subject: [PATCH 18/21] fix(serve): unify cog serve build path with cog build via ExcludeSource All CLI dev-mode commands (serve, predict, run, train) now use resolver.Build() with ExcludeSource=true instead of the separate BuildBase() path. This ensures coglet and SDK wheels are properly installed in dev-mode images, sharing Docker layer cache with cog build (all layers before COPY . /src are identical). Key changes: - Add ExcludeSource field to BuildOptions; when true, generates Dockerfile via GenerateModelBase (no COPY . /src) - Update serve, predict, run, train CLI commands to use Build() with serveBuildOptions() instead of BuildBase() - Fix GenerateOpenAPISchema to accept optional sourceDir for volume-mounting /src during schema validation - Centralize harness env var propagation (propagatedEnvVars) used by both Setup() and cmdCogServe - Add coglet_large_output integration test: async webhook-based test for 9MiB output that exceeds IPC frame limit --- integration-tests/harness/harness.go | 267 ++++++++++++++++-- .../tests/coglet_large_output.txtar | 27 +- pkg/cli/predict.go | 4 +- pkg/cli/run.go | 2 +- pkg/cli/serve.go | 15 +- pkg/cli/train.go | 4 +- pkg/dockerfile/standard_generator.go | 1 - pkg/image/build.go | 19 +- pkg/image/openapi_schema.go | 14 +- pkg/model/factory.go | 1 + pkg/model/options.go | 7 + 11 files changed, 313 insertions(+), 48 deletions(-) diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 0abd4b1744..07b7f897f5 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -26,6 +26,17 @@ import ( "github.com/replicate/cog/pkg/registry_testhelpers" ) +// propagatedEnvVars lists host environment variables that should be propagated +// into testscript environments (Setup) and background processes (cmdCogServe). +// Keep this list in sync: if you add a new env var to propagate, add it here. +var propagatedEnvVars = []string{ + "COG_WHEEL", // SDK wheel override + "COGLET_WHEEL", // coglet wheel override + "RUST_LOG", // Rust logging control + "COG_CA_CERT", // custom CA certificates (e.g. Cloudflare WARP) + "BUILDKIT_PROGRESS", // Docker build output format +} + // Harness provides utilities for running cog integration tests. // serverInfo tracks a running cog serve process and its port type serverInfo struct { @@ -33,6 +44,12 @@ type serverInfo struct { port int } +// containerInfo tracks a running Docker container started by container-start. +type containerInfo struct { + containerID string + port int +} + // registryInfo tracks a running test registry container type registryInfo struct { container *registry_testhelpers.RegistryContainer @@ -56,6 +73,24 @@ type mockUploadServer struct { uploads []mockUploadRecord } +// webhookResult is the summary written to stdout by webhook-server-wait. +type webhookResult struct { + Status string `json:"status"` + OutputSize int `json:"output_size"` + HasError bool `json:"has_error"` +} + +// webhookServer accepts prediction webhook callbacks from coglet. +// It parses the JSON payload to extract status and output size, without +// ever exposing the (potentially huge) output to testscript's log buffer. +type webhookServer struct { + server *http.Server + port int + mu sync.Mutex + result *webhookResult + done chan struct{} // closed on first terminal webhook +} + type Harness struct { CogBinary string // realHome is captured at creation time before testscript overrides HOME @@ -71,6 +106,9 @@ type Harness struct { // uploadServers tracks mock upload servers for cleanup, keyed by work directory uploadServers map[string]*mockUploadServer uploadServersMu sync.Mutex + // webhookServers tracks webhook receiver servers for cleanup, keyed by work directory + webhookServers map[string]*webhookServer + webhookServersMu sync.Mutex } // New creates a new Harness, resolving the cog binary location. @@ -84,12 +122,13 @@ func New() (*Harness, error) { return nil, err } return &Harness{ - CogBinary: cogBinary, - realHome: os.Getenv("HOME"), - repoRoot: repoRoot, - serverProcs: make(map[string]*serverInfo), - registries: make(map[string]*registryInfo), - uploadServers: make(map[string]*mockUploadServer), + CogBinary: cogBinary, + realHome: os.Getenv("HOME"), + repoRoot: repoRoot, + serverProcs: make(map[string]*serverInfo), + registries: make(map[string]*registryInfo), + uploadServers: make(map[string]*mockUploadServer), + webhookServers: make(map[string]*webhookServer), }, nil } @@ -220,6 +259,10 @@ func (h *Harness) Commands() map[string]func(ts *testscript.TestScript, neg bool NewCommand("upload-server-start", h.cmdUploadServerStart), NewCommand("upload-server-count", h.cmdUploadServerCount), + // Webhook receiver commands + NewCommand("webhook-server-start", h.cmdWebhookServerStart), + NewCommand("webhook-server-wait", h.cmdWebhookServerWait), + // PTY command (defined in cmd_pty.go) &PtyRunCommand{harness: h}, } @@ -275,24 +318,11 @@ func (h *Harness) Setup(env *testscript.Env) error { // Disable update checks during tests env.Setenv("COG_NO_UPDATE_CHECK", "1") - // Propagate COG_WHEEL environment variable for runtime selection - if cogWheel := os.Getenv("COG_WHEEL"); cogWheel != "" { - env.Setenv("COG_WHEEL", cogWheel) - } - - // Propagate COGLET_WHEEL for coglet server - if cogletWheel := os.Getenv("COGLET_WHEEL"); cogletWheel != "" { - env.Setenv("COGLET_WHEEL", cogletWheel) - } - - // Propagate RUST_LOG for Rust logging control - if rustLog := os.Getenv("RUST_LOG"); rustLog != "" { - env.Setenv("RUST_LOG", rustLog) - } - - // Propagate COG_CA_CERT for custom CA certificates (e.g. Cloudflare WARP) - if caCert := os.Getenv("COG_CA_CERT"); caCert != "" { - env.Setenv("COG_CA_CERT", caCert) + // Propagate host env vars listed in propagatedEnvVars + for _, key := range propagatedEnvVars { + if val := os.Getenv(key); val != "" { + env.Setenv(key, val) + } } // Generate unique image name for this test run @@ -310,6 +340,8 @@ func (h *Harness) Setup(env *testscript.Env) error { h.stopRegistryByWorkDir(workDir) // Stop the upload server for this specific test (if any) h.stopUploadServerByWorkDir(workDir) + // Stop the webhook server for this specific test (if any) + h.stopWebhookServerByWorkDir(workDir) removeDockerImage(imageName) }) @@ -383,11 +415,17 @@ func (h *Harness) cmdCogServe(ts *testscript.TestScript, neg bool, args []string cmd := exec.Command(h.CogBinary, expandedArgs...) cmd.Dir = workDir - // Build environment from testscript + // Build environment from testscript. + // Always include core vars, plus everything from propagatedEnvVars. var env []string - for _, key := range []string{"HOME", "PATH", "REPO_ROOT", "COG_NO_UPDATE_CHECK", "COG_WHEEL", "COGLET_WHEEL", "RUST_LOG", "BUILDKIT_PROGRESS", "TEST_IMAGE"} { + for _, key := range []string{"HOME", "PATH", "REPO_ROOT", "COG_NO_UPDATE_CHECK", "TEST_IMAGE"} { + if val := ts.Getenv(key); val != "" { + env = append(env, key+"="+val) + } + } + for _, key := range propagatedEnvVars { if val := ts.Getenv(key); val != "" { - env = append(env, fmt.Sprintf("%s=%s", key, val)) + env = append(env, key+"="+val) } } cmd.Env = env @@ -419,14 +457,34 @@ func (h *Harness) cmdCogServe(ts *testscript.TestScript, neg bool, args []string // cmdCurl implements the 'curl' command for testscript. // It makes HTTP requests to the server started with 'serve'. // Includes built-in retry logic (10 attempts, 500ms delay) for resilience. -// Usage: curl [method] [path] [body] +// Usage: curl [-H key:value]... [method] [path] [body] // Examples: // // curl GET /health-check // curl POST /predictions '{"input":{"s":"hello"}}' +// curl -H Prefer:respond-async POST /predictions '{"input":{}}' func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { if len(args) < 2 { - ts.Fatalf("curl: usage: curl [method] [path] [body]") + ts.Fatalf("curl: usage: curl [-H key:value]... [method] [path] [body]") + } + + // Parse -H flags for extra headers + var extraHeaders [][2]string + for len(args) >= 2 && args[0] == "-H" { + kv := args[1] + parts := strings.SplitN(kv, ":", 2) + if len(parts) != 2 { + ts.Fatalf("curl: invalid header %q (expected key:value)", kv) + } + extraHeaders = append(extraHeaders, [2]string{ + strings.TrimSpace(parts[0]), + strings.TrimSpace(parts[1]), + }) + args = args[2:] + } + + if len(args) < 2 { + ts.Fatalf("curl: usage: curl [-H key:value]... [method] [path] [body]") } serverURL := ts.Getenv("SERVER_URL") @@ -438,7 +496,7 @@ func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { path := args[1] var body string if len(args) > 2 { - body = args[2] + body = os.Expand(args[2], ts.Getenv) } // Retry settings @@ -464,6 +522,9 @@ func (h *Harness) cmdCurl(ts *testscript.TestScript, neg bool, args []string) { if body != "" { req.Header.Set("Content-Type", "application/json") } + for _, h := range extraHeaders { + req.Header.Set(h[0], h[1]) + } resp, err := client.Do(req) if err != nil { @@ -1212,3 +1273,149 @@ func (h *Harness) stopUploadServerByWorkDir(workDir string) { _ = mus.server.Close() } } + +// ============================================================================= +// Webhook receiver commands +// ============================================================================= + +// cmdWebhookServerStart starts a webhook receiver that accepts prediction callbacks. +// It parses the JSON payload to extract status and measure the output size, without +// ever exposing the (potentially huge) output to testscript's log buffer. +// Usage: webhook-server-start +// Exports $WEBHOOK_URL with the server's callback URL. +func (h *Harness) cmdWebhookServerStart(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("webhook-server-start: does not support negation") + } + + workDir := ts.Getenv("WORK") + + h.webhookServersMu.Lock() + if _, exists := h.webhookServers[workDir]; exists { + h.webhookServersMu.Unlock() + ts.Fatalf("webhook-server-start: server already running for this test") + } + h.webhookServersMu.Unlock() + + ws := &webhookServer{ + done: make(chan struct{}), + } + + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // Stream-parse the JSON to extract status and measure output size + // without holding the entire output string in testscript memory. + var payload struct { + Status string `json:"status"` + Output string `json:"output"` + Error string `json:"error"` + } + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + + w.WriteHeader(http.StatusOK) + + // Only record terminal statuses + switch payload.Status { + case "succeeded", "failed", "canceled": + default: + return + } + + ws.mu.Lock() + defer ws.mu.Unlock() + + // Only record the first terminal callback + if ws.result != nil { + return + } + ws.result = &webhookResult{ + Status: payload.Status, + OutputSize: len(payload.Output), + HasError: payload.Error != "", + } + close(ws.done) + }) + + ws.server = &http.Server{Handler: mux, ReadHeaderTimeout: 10 * time.Second} //nolint:gosec + + // Bind to all interfaces so the container can reach us via host.docker.internal + ln, err := net.Listen("tcp", "0.0.0.0:0") //nolint:gosec + if err != nil { + ts.Fatalf("webhook-server-start: failed to listen: %v", err) + } + ws.port = ln.Addr().(*net.TCPAddr).Port + + go func() { _ = ws.server.Serve(ln) }() + + h.webhookServersMu.Lock() + h.webhookServers[workDir] = ws + h.webhookServersMu.Unlock() + + url := fmt.Sprintf("http://host.docker.internal:%d/webhook", ws.port) + ts.Setenv("WEBHOOK_URL", url) + ts.Logf("webhook-server-start: listening on 0.0.0.0:%d, container URL: %s", ws.port, url) +} + +// cmdWebhookServerWait blocks until the webhook server receives a terminal prediction callback, +// then writes a compact JSON summary to stdout for assertion with stdout/stderr matchers. +// Usage: webhook-server-wait [timeout] +// Default timeout: 120s +func (h *Harness) cmdWebhookServerWait(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("webhook-server-wait: does not support negation") + } + + timeout := 120 * time.Second + if len(args) > 0 { + if d, err := time.ParseDuration(args[0]); err == nil { + timeout = d + } + } + + workDir := ts.Getenv("WORK") + h.webhookServersMu.Lock() + ws, exists := h.webhookServers[workDir] + h.webhookServersMu.Unlock() + + if !exists { + ts.Fatalf("webhook-server-wait: no webhook server running (call webhook-server-start first)") + } + + select { + case <-ws.done: + case <-time.After(timeout): + ts.Fatalf("webhook-server-wait: timed out after %s waiting for terminal webhook", timeout) + } + + ws.mu.Lock() + result := ws.result + ws.mu.Unlock() + + out, _ := json.Marshal(result) + _, _ = ts.Stdout().Write(out) + _, _ = ts.Stdout().Write([]byte("\n")) +} + +// stopWebhookServerByWorkDir shuts down the webhook server for a work directory. +func (h *Harness) stopWebhookServerByWorkDir(workDir string) { + h.webhookServersMu.Lock() + ws, exists := h.webhookServers[workDir] + if !exists { + h.webhookServersMu.Unlock() + return + } + delete(h.webhookServers, workDir) + h.webhookServersMu.Unlock() + + if ws.server != nil { + _ = ws.server.Close() + } +} diff --git a/integration-tests/tests/coglet_large_output.txtar b/integration-tests/tests/coglet_large_output.txtar index e5bb562c7b..b22828665b 100644 --- a/integration-tests/tests/coglet_large_output.txtar +++ b/integration-tests/tests/coglet_large_output.txtar @@ -1,13 +1,29 @@ # Test that outputs larger than the 8MiB IPC frame limit spill to disk # and are reconstructed correctly by the orchestrator. # Without spilling this would panic the bridge and poison the slot. +# +# Uses async prediction + webhook so the 9MiB output goes directly to our +# Go webhook receiver — never through testscript's log buffer. +# +# --upload-url is set to a dummy value so cog serve adds +# --add-host=host.docker.internal:host-gateway (needed on Linux for the +# webhook callback to reach the host). Nothing is actually uploaded because +# the output is a plain string, not a Path. -cog build -t $TEST_IMAGE -cog predict $TEST_IMAGE +webhook-server-start +cog serve --upload-url http://unused/ -# Prediction should succeed — not fail with "frame size too big" -! stderr 'frame size too big' -! stderr 'failed' +# Async prediction — server returns 202 immediately, delivers result to webhook +curl -H Prefer:respond-async POST /predictions '{"id":"large-output-test","webhook":"$WEBHOOK_URL","webhook_events_filter":["completed"]}' + +# Wait for the webhook callback (up to 120s) +webhook-server-wait + +# 1. Prediction succeeded +stdout '"status":"succeeded"' + +# 2. Output is correct — 9 * 1024 * 1024 = 9437184 bytes +stdout '"output_size":9437184' -- cog.yaml -- build: @@ -17,6 +33,7 @@ predict: "predict.py:Predictor" -- predict.py -- from cog import BasePredictor + class Predictor(BasePredictor): def predict(self) -> str: # 9MiB string — exceeds the 8MiB IPC frame limit diff --git a/pkg/cli/predict.go b/pkg/cli/predict.go index ad78645ef8..2de0951438 100644 --- a/pkg/cli/predict.go +++ b/pkg/cli/predict.go @@ -180,13 +180,13 @@ func cmdPredict(cmd *cobra.Command, args []string) error { return err } - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) if err != nil { return err } imageName = m.ImageRef() - // Base image doesn't have /src in it, so mount as volume + // ExcludeSource build doesn't have /src in it, so mount as volume volumes = append(volumes, command.Volume{ Source: src.ProjectDir, Destination: "/src", diff --git a/pkg/cli/run.go b/pkg/cli/run.go index dcfc354ee9..6662d5618c 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -65,7 +65,7 @@ func run(cmd *cobra.Command, args []string) error { resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) if err != nil { return err } diff --git a/pkg/cli/serve.go b/pkg/cli/serve.go index f45648cf74..ccccfee4c5 100644 --- a/pkg/cli/serve.go +++ b/pkg/cli/serve.go @@ -42,6 +42,19 @@ Generate and run an HTTP server based on the declared model inputs and outputs.` return cmd } +// serveBuildOptions creates BuildOptions for cog serve. +// Same build path as cog build, but with ExcludeSource so COPY . /src is +// skipped — source is volume-mounted at runtime instead. All other layers +// (wheels, apt, etc.) share Docker layer cache with cog build. +func serveBuildOptions(cmd *cobra.Command) model.BuildOptions { + return model.BuildOptions{ + UseCudaBaseImage: buildUseCudaBaseImage, + UseCogBaseImage: DetermineUseCogBaseImage(cmd), + ProgressOutput: buildProgressOutput, + ExcludeSource: true, + } +} + func cmdServe(cmd *cobra.Command, arg []string) error { ctx := cmd.Context() @@ -56,7 +69,7 @@ func cmdServe(cmd *cobra.Command, arg []string) error { } resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) if err != nil { return err } diff --git a/pkg/cli/train.go b/pkg/cli/train.go index 368a872db3..122de8ae40 100644 --- a/pkg/cli/train.go +++ b/pkg/cli/train.go @@ -73,13 +73,13 @@ func cmdTrain(cmd *cobra.Command, args []string) error { return err } - m, err := resolver.BuildBase(ctx, src, buildBaseOptionsFromFlags(cmd)) + m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) if err != nil { return err } imageName = m.ImageRef() - // Base image doesn't have /src in it, so mount as volume + // ExcludeSource build doesn't have /src in it, so mount as volume volumes = append(volumes, command.Volume{ Source: src.ProjectDir, Destination: "/src", diff --git a/pkg/dockerfile/standard_generator.go b/pkg/dockerfile/standard_generator.go index a35b13863a..65a177233c 100644 --- a/pkg/dockerfile/standard_generator.go +++ b/pkg/dockerfile/standard_generator.go @@ -467,7 +467,6 @@ func (g *StandardGenerator) installCog() (string, error) { if !g.requiresCog { return "", nil } - // Use override if set, otherwise auto-detect via env var / dist / PyPI var wheelConfig *wheels.WheelConfig var err error diff --git a/pkg/image/build.go b/pkg/image/build.go index e22c436f0c..c116963796 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -54,6 +54,7 @@ func Build( useCogBaseImage *bool, strip bool, precompile bool, + excludeSource bool, annotations map[string]string, dockerCommand command.Command, client registry.Client) (string, error) { @@ -162,7 +163,15 @@ func Build( return "", fmt.Errorf("Failed to build runner Docker image: %w", err) } } else { - dockerfileContents, err := generator.GenerateDockerfileWithoutSeparateWeights(ctx) + var dockerfileContents string + if excludeSource { + // Dev mode (cog serve): same layers as cog build but without + // COPY . /src — source is volume-mounted at runtime instead. + // This shares Docker layer cache with full builds. + dockerfileContents, err = generator.GenerateModelBase(ctx) + } else { + dockerfileContents, err = generator.GenerateDockerfileWithoutSeparateWeights(ctx) + } if err != nil { return "", fmt.Errorf("Failed to generate Dockerfile: %w", err) } @@ -196,7 +205,13 @@ func Build( schemaJSON = data } else { console.Info("Validating model schema...") - schema, err := GenerateOpenAPISchema(ctx, dockerCommand, tmpImageId, cfg.Build.GPU) + // When excludeSource is true (cog serve), /src was not COPYed into the + // image, so we need to volume-mount the project directory for schema generation. + schemaSourceDir := "" + if excludeSource { + schemaSourceDir = dir + } + schema, err := GenerateOpenAPISchema(ctx, dockerCommand, tmpImageId, cfg.Build.GPU, schemaSourceDir) if err != nil { return "", fmt.Errorf("Failed to get type signature: %w", err) } diff --git a/pkg/image/openapi_schema.go b/pkg/image/openapi_schema.go index 3b176e5de4..cca3d566ee 100644 --- a/pkg/image/openapi_schema.go +++ b/pkg/image/openapi_schema.go @@ -12,7 +12,8 @@ import ( // GenerateOpenAPISchema by running the image and executing Cog // This will be run as part of the build process then added as a label to the image. It can be retrieved more efficiently with the label by using GetOpenAPISchema -func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, imageName string, enableGPU bool) (map[string]any, error) { +// sourceDir, when non-empty, is volume-mounted as /src (needed for ExcludeSource builds where COPY . /src was skipped). +func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, imageName string, enableGPU bool, sourceDir string) (map[string]any, error) { console.Debugf("=== image.GenerateOpenAPISchema %s", imageName) var stdout bytes.Buffer var stderr bytes.Buffer @@ -23,19 +24,24 @@ func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, im gpus = "all" } - err := docker.RunWithIO(ctx, dockerClient, command.RunOptions{ + runOpts := command.RunOptions{ Image: imageName, Args: []string{ "python", "-m", "cog.command.openapi_schema", }, GPUs: gpus, - }, nil, &stdout, &stderr) + } + if sourceDir != "" { + runOpts.Volumes = []command.Volume{{Source: sourceDir, Destination: "/src"}} + } + + err := docker.RunWithIO(ctx, dockerClient, runOpts, nil, &stdout, &stderr) if enableGPU && err == docker.ErrMissingDeviceDriver { console.Debug(stdout.String()) console.Debug(stderr.String()) console.Debug("Missing device driver, re-trying without GPU") - return GenerateOpenAPISchema(ctx, dockerClient, imageName, false) + return GenerateOpenAPISchema(ctx, dockerClient, imageName, false, sourceDir) } if err != nil { diff --git a/pkg/model/factory.go b/pkg/model/factory.go index 70095a51e8..434ab66d32 100644 --- a/pkg/model/factory.go +++ b/pkg/model/factory.go @@ -56,6 +56,7 @@ func (f *DockerfileFactory) Build(ctx context.Context, src *Source, opts BuildOp opts.UseCogBaseImage, opts.Strip, opts.Precompile, + opts.ExcludeSource, opts.Annotations, f.docker, f.registry, diff --git a/pkg/model/options.go b/pkg/model/options.go index c5ae55f4e8..b9c4608b78 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -49,6 +49,13 @@ type BuildOptions struct { // artifacts and pushes create an OCI Image Index. Set via COG_OCI_INDEX=1. // Remove this field once index pushes are validated with all registries. OCIIndex bool + + // ExcludeSource skips the COPY . /src step in the generated Dockerfile. + // Used by `cog serve` to produce an image identical to `cog build` minus + // the source copy — the source directory is volume-mounted at runtime. + // All other layers (wheel installs, apt, etc.) are shared with `cog build` + // via Docker layer caching. + ExcludeSource bool } // WithDefaults returns a copy of BuildOptions with defaults applied from Source. From 4cc2f96c7f5e429d7b0fcc74f2183ff66b15bb4b Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 14:57:32 -0800 Subject: [PATCH 19/21] =?UTF-8?q?refactor:=20remove=20BuildBase=20?= =?UTF-8?q?=E2=80=94=20all=20dev-mode=20commands=20use=20Build(ExcludeSour?= =?UTF-8?q?ce)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the entire BuildBase code path that is now dead: - Factory.BuildBase interface method - DockerfileFactory.BuildBase implementation - Resolver.BuildBase method - image.BuildBase function - BuildBaseOptions struct and WithDefaults - buildBaseOptionsFromFlags CLI helper - config.BaseDockerImageName helper - mockFactory.buildBaseFunc in tests All dev-mode CLI commands (serve, predict, run, train) now use Build(ExcludeSource=true) which shares Docker layer cache with cog build and properly installs all wheels. --- pkg/cli/build.go | 10 -------- pkg/config/image_name.go | 5 ---- pkg/image/build.go | 50 -------------------------------------- pkg/model/factory.go | 30 ++--------------------- pkg/model/options.go | 24 ------------------ pkg/model/resolver.go | 32 ------------------------ pkg/model/resolver_test.go | 12 ++------- 7 files changed, 4 insertions(+), 159 deletions(-) diff --git a/pkg/cli/build.go b/pkg/cli/build.go index 6bf88f596b..034a60bfe3 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -193,13 +193,3 @@ func buildOptionsFromFlags(cmd *cobra.Command, imageName string, annotations map OCIIndex: model.OCIIndexEnabled(), } } - -// buildBaseOptionsFromFlags creates BuildBaseOptions from the current CLI flag values. -func buildBaseOptionsFromFlags(cmd *cobra.Command) model.BuildBaseOptions { - return model.BuildBaseOptions{ - UseCudaBaseImage: buildUseCudaBaseImage, - UseCogBaseImage: DetermineUseCogBaseImage(cmd), - ProgressOutput: buildProgressOutput, - RequiresCog: true, - } -} diff --git a/pkg/config/image_name.go b/pkg/config/image_name.go index fa98a07ce6..9dca07f095 100644 --- a/pkg/config/image_name.go +++ b/pkg/config/image_name.go @@ -30,8 +30,3 @@ func DockerImageName(projectDir string) string { return projectName } - -// BaseDockerImageName returns the Docker image name for base images -func BaseDockerImageName(projectDir string) string { - return DockerImageName(projectDir) + "-base" -} diff --git a/pkg/image/build.go b/pkg/image/build.go index c116963796..77717c6dea 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -352,56 +352,6 @@ func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Co return imageID, nil } -func BuildBase(ctx context.Context, dockerClient command.Command, cfg *config.Config, dir string, configFilename string, useCudaBaseImage string, useCogBaseImage *bool, progressOutput string, client registry.Client, requiresCog bool) (string, error) { - // TODO: better image management so we don't eat up disk space - // https://github.com/replicate/cog/issues/80 - imageName := config.BaseDockerImageName(dir) - - console.Info("Building Docker image from environment in cog.yaml...") - generator, err := dockerfile.NewGenerator(cfg, dir, configFilename, dockerClient, client, requiresCog) - if err != nil { - return "", fmt.Errorf("Error creating Dockerfile generator: %w", err) - } - contextDir, err := generator.BuildDir() - if err != nil { - return "", err - } - buildContexts, err := generator.BuildContexts() - if err != nil { - return "", err - } - defer func() { - if err := generator.Cleanup(); err != nil { - console.Warnf("Error cleaning up Dockerfile generator: %s", err) - } - }() - - generator.SetUseCudaBaseImage(useCudaBaseImage) - if useCogBaseImage != nil { - generator.SetUseCogBaseImage(*useCogBaseImage) - } - - dockerfileContents, err := generator.GenerateModelBase(ctx) - if err != nil { - return "", fmt.Errorf("Failed to generate Dockerfile: %w", err) - } - - buildOpts := command.ImageBuildOptions{ - WorkingDir: dir, - DockerfileContents: dockerfileContents, - ImageName: imageName, - NoCache: false, - ProgressOutput: progressOutput, - Epoch: &config.BuildSourceEpochTimestamp, - ContextDir: contextDir, - BuildContexts: buildContexts, - } - if _, err := dockerClient.ImageBuild(ctx, buildOpts); err != nil { - return "", fmt.Errorf("Failed to build Docker image: %w", err) - } - return imageName, nil -} - func isGitWorkTree(ctx context.Context, dir string) bool { ctx, cancel := context.WithTimeout(ctx, 3*time.Second) defer cancel() diff --git a/pkg/model/factory.go b/pkg/model/factory.go index 434ab66d32..c960da6d8f 100644 --- a/pkg/model/factory.go +++ b/pkg/model/factory.go @@ -12,12 +12,10 @@ import ( // Different implementations handle different build strategies. type Factory interface { // Build creates a Docker image from source and returns ImageArtifact metadata. + // For dev mode (cog serve), set ExcludeSource=true in BuildOptions to skip + // COPY . /src — the source directory is volume-mounted at runtime instead. Build(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) - // BuildBase creates a base image for dev mode (without /src copied). - // The source directory is expected to be mounted as a volume at runtime. - BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) - // Name returns the factory name for logging/debugging. Name() string } @@ -72,30 +70,6 @@ func (f *DockerfileFactory) Build(ctx context.Context, src *Source, opts BuildOp }, nil } -// BuildBase delegates to the existing image.BuildBase() function. -func (f *DockerfileFactory) BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) { - imageName, err := image.BuildBase( - ctx, - f.docker, - src.Config, - src.ProjectDir, - src.ConfigFilename, - opts.UseCudaBaseImage, - opts.UseCogBaseImage, - opts.ProgressOutput, - f.registry, - opts.RequiresCog, - ) - if err != nil { - return nil, err - } - - return &ImageArtifact{ - Reference: imageName, - Source: ImageSourceBuild, - }, nil -} - // DefaultFactory returns a Factory based on environment variables. // It checks COG_BUILDER and COGPACK to select the appropriate backend. // diff --git a/pkg/model/options.go b/pkg/model/options.go index b9c4608b78..3e6fc36adb 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -73,27 +73,3 @@ func (o BuildOptions) WithDefaults(src *Source) BuildOptions { return o } - -// BuildBaseOptions contains settings for building a base image (dev mode). -// Base images don't copy /src - the source is mounted as a volume at runtime. -type BuildBaseOptions struct { - // UseCudaBaseImage controls CUDA base image usage: "auto", "true", or "false". - UseCudaBaseImage string - - // UseCogBaseImage controls cog base image usage. nil means auto-detect. - UseCogBaseImage *bool - - // ProgressOutput controls build output format: "auto", "plain", or "tty". - ProgressOutput string - - // RequiresCog indicates whether the build requires cog to be installed. - RequiresCog bool -} - -// WithDefaults returns a copy of BuildBaseOptions with defaults applied. -func (o BuildBaseOptions) WithDefaults() BuildBaseOptions { - if o.ProgressOutput == "" { - o.ProgressOutput = "auto" - } - return o -} diff --git a/pkg/model/resolver.go b/pkg/model/resolver.go index 1afec5c5c1..a90961a2a4 100644 --- a/pkg/model/resolver.go +++ b/pkg/model/resolver.go @@ -275,38 +275,6 @@ func (r *Resolver) Push(ctx context.Context, m *Model, opts PushOptions) error { return pusher.Push(ctx, m, opts) } -// BuildBase creates a base image for dev mode (without /src copied). -// The source directory is expected to be mounted as a volume at runtime. -// Returns a Model with the built image info and the source config. -// -// NOTE: Unlike Build(), this does not use ImageBuilder because base images -// don't have labels yet (they're added during full builds). The returned -// ImageArtifact has no labels, no descriptor, and no inspect results. -func (r *Resolver) BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*Model, error) { - if src == nil { - return nil, fmt.Errorf("source is required for BuildBase") - } - if src.Config == nil { - return nil, fmt.Errorf("source.Config is required for BuildBase") - } - if src.ProjectDir == "" { - return nil, fmt.Errorf("source.ProjectDir is required for BuildBase") - } - opts = opts.WithDefaults() - - img, err := r.factory.BuildBase(ctx, src, opts) - if err != nil { - return nil, err - } - - // For base builds, we don't have labels yet (they're added in full builds). - // Return the model with the source config and the built image. - return &Model{ - Image: img, - Config: src.Config, - }, nil -} - // loadLocal loads a Model from the local docker daemon. func (r *Resolver) loadLocal(ctx context.Context, ref *ParsedRef) (*Model, error) { resp, err := r.docker.Inspect(ctx, ref.String()) diff --git a/pkg/model/resolver_test.go b/pkg/model/resolver_test.go index 0819a41696..1f9f2e8509 100644 --- a/pkg/model/resolver_test.go +++ b/pkg/model/resolver_test.go @@ -143,9 +143,8 @@ func (m *mockRegistry) WriteLayer(ctx context.Context, opts registry.WriteLayerO // mockFactory implements Factory for testing. type mockFactory struct { - name string - buildFunc func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) - buildBaseFunc func(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) + name string + buildFunc func(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) } func (f *mockFactory) Build(ctx context.Context, src *Source, opts BuildOptions) (*ImageArtifact, error) { @@ -155,13 +154,6 @@ func (f *mockFactory) Build(ctx context.Context, src *Source, opts BuildOptions) return &ImageArtifact{Reference: opts.ImageName, Source: ImageSourceBuild}, nil } -func (f *mockFactory) BuildBase(ctx context.Context, src *Source, opts BuildBaseOptions) (*ImageArtifact, error) { - if f.buildBaseFunc != nil { - return f.buildBaseFunc(ctx, src, opts) - } - return &ImageArtifact{Reference: "cog-base", Source: ImageSourceBuild}, nil -} - func (f *mockFactory) Name() string { return f.name } From ecf89ec6e41d8a5251a82815689be6f0b78404d0 Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 14:59:57 -0800 Subject: [PATCH 20/21] chore: remove unused containerInfo type (lint fix) --- integration-tests/harness/harness.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/integration-tests/harness/harness.go b/integration-tests/harness/harness.go index 07b7f897f5..99bf2b801b 100644 --- a/integration-tests/harness/harness.go +++ b/integration-tests/harness/harness.go @@ -44,12 +44,6 @@ type serverInfo struct { port int } -// containerInfo tracks a running Docker container started by container-start. -type containerInfo struct { - containerID string - port int -} - // registryInfo tracks a running test registry container type registryInfo struct { container *registry_testhelpers.RegistryContainer From 9be713c0ee09980ca81606a05989bde62dbc068b Mon Sep 17 00:00:00 2001 From: morgan fainberg Date: Thu, 19 Feb 2026 15:16:00 -0800 Subject: [PATCH 21/21] fix(run): skip schema validation for cog run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cog run executes arbitrary commands (echo, cat, etc.) and may not have a predictor or trainer in cog.yaml. Add SkipSchemaValidation to BuildOptions so schema generation, validation, and bundling are all skipped for cog run — which previously worked because BuildBase never validated schemas. --- pkg/cli/run.go | 4 +++- pkg/image/build.go | 52 +++++++++++++++++++++++++++++--------------- pkg/model/factory.go | 1 + pkg/model/options.go | 5 +++++ 4 files changed, 43 insertions(+), 19 deletions(-) diff --git a/pkg/cli/run.go b/pkg/cli/run.go index 6662d5618c..501d82f047 100644 --- a/pkg/cli/run.go +++ b/pkg/cli/run.go @@ -65,7 +65,9 @@ func run(cmd *cobra.Command, args []string) error { resolver := model.NewResolver(dockerClient, registry.NewRegistryClient()) - m, err := resolver.Build(ctx, src, serveBuildOptions(cmd)) + opts := serveBuildOptions(cmd) + opts.SkipSchemaValidation = true + m, err := resolver.Build(ctx, src, opts) if err != nil { return err } diff --git a/pkg/image/build.go b/pkg/image/build.go index 77717c6dea..c85d8df7e4 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -55,6 +55,7 @@ func Build( strip bool, precompile bool, excludeSource bool, + skipSchemaValidation bool, annotations map[string]string, dockerCommand command.Command, client registry.Client) (string, error) { @@ -195,7 +196,10 @@ func Build( } var schemaJSON []byte - if schemaFile != "" { + switch { + case skipSchemaValidation: + console.Debug("Skipping model schema validation") + case schemaFile != "": console.Infof("Validating model schema from %s...", schemaFile) data, err := os.ReadFile(schemaFile) if err != nil { @@ -203,7 +207,7 @@ func Build( } schemaJSON = data - } else { + default: console.Info("Validating model schema...") // When excludeSource is true (cog serve), /src was not COPYed into the // image, so we need to volume-mount the project directory for schema generation. @@ -224,20 +228,22 @@ func Build( schemaJSON = data } - // save open_api schema file - if err := os.WriteFile(bundledSchemaFile, schemaJSON, 0o644); err != nil { - return "", fmt.Errorf("failed to store bundled schema file %s: %w", bundledSchemaFile, err) - } + if !skipSchemaValidation { + // save open_api schema file + if err := os.WriteFile(bundledSchemaFile, schemaJSON, 0o644); err != nil { + return "", fmt.Errorf("failed to store bundled schema file %s: %w", bundledSchemaFile, err) + } - loader := openapi3.NewLoader() - loader.IsExternalRefsAllowed = true - doc, err := loader.LoadFromData(schemaJSON) - if err != nil { - return "", fmt.Errorf("Failed to load model schema JSON: %w", err) - } - err = doc.Validate(loader.Context) - if err != nil { - return "", fmt.Errorf("Model schema is invalid: %w\n\n%s", err, string(schemaJSON)) + loader := openapi3.NewLoader() + loader.IsExternalRefsAllowed = true + doc, err := loader.LoadFromData(schemaJSON) + if err != nil { + return "", fmt.Errorf("Failed to load model schema JSON: %w", err) + } + err = doc.Validate(loader.Context) + if err != nil { + return "", fmt.Errorf("Model schema is invalid: %w\n\n%s", err, string(schemaJSON)) + } } console.Info("Adding labels to image...") @@ -314,8 +320,13 @@ func Build( maps.Copy(labels, annotations) - // The final image ID comes from the label-adding step - imageID, err := BuildAddLabelsAndSchemaToImage(ctx, dockerCommand, tmpImageId, imageName, labels, bundledSchemaFile, progressOutput) + // The final image ID comes from the label-adding step. + // When schema validation is skipped (cog run), there is no schema file to bundle. + schemaFileToBundle := bundledSchemaFile + if skipSchemaValidation { + schemaFileToBundle = "" + } + imageID, err := BuildAddLabelsAndSchemaToImage(ctx, dockerCommand, tmpImageId, imageName, labels, schemaFileToBundle, progressOutput) if err != nil { return "", fmt.Errorf("Failed to add labels to image: %w", err) } @@ -336,7 +347,12 @@ func Build( // The new image is based on the provided image with the labels and schema file appended to it. // tmpName is the source image to build from, image is the final image name/tag. func BuildAddLabelsAndSchemaToImage(ctx context.Context, dockerClient command.Command, tmpName, image string, labels map[string]string, bundledSchemaFile string, progressOutput string) (string, error) { - dockerfile := fmt.Sprintf("FROM %s\nCOPY %s .cog\n", tmpName, bundledSchemaFile) + var dockerfile string + if bundledSchemaFile != "" { + dockerfile = fmt.Sprintf("FROM %s\nCOPY %s .cog\n", tmpName, bundledSchemaFile) + } else { + dockerfile = fmt.Sprintf("FROM %s\n", tmpName) + } buildOpts := command.ImageBuildOptions{ DockerfileContents: dockerfile, diff --git a/pkg/model/factory.go b/pkg/model/factory.go index c960da6d8f..1818f496f2 100644 --- a/pkg/model/factory.go +++ b/pkg/model/factory.go @@ -55,6 +55,7 @@ func (f *DockerfileFactory) Build(ctx context.Context, src *Source, opts BuildOp opts.Strip, opts.Precompile, opts.ExcludeSource, + opts.SkipSchemaValidation, opts.Annotations, f.docker, f.registry, diff --git a/pkg/model/options.go b/pkg/model/options.go index 3e6fc36adb..8b03b15041 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -56,6 +56,11 @@ type BuildOptions struct { // All other layers (wheel installs, apt, etc.) are shared with `cog build` // via Docker layer caching. ExcludeSource bool + + // SkipSchemaValidation skips OpenAPI schema generation and validation. + // Used by `cog run` which executes arbitrary commands and may not have + // a predictor or trainer defined in cog.yaml. + SkipSchemaValidation bool } // WithDefaults returns a copy of BuildOptions with defaults applied from Source.