Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions src/acp/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ struct PoolState {
/// Stored separately so cancel can work without locking the connection.
cancel_handles: HashMap<String, (Arc<tokio::sync::Mutex<tokio::process::ChildStdin>>, String)>,
/// Suspended sessions: thread_key → ACP sessionId.
/// Saved on eviction so sessions can be resumed via `session/load`.
/// Used at runtime to decide which thread can be resumed via `session/load`
/// because it no longer has a live in-memory connection.
suspended: HashMap<String, String>,
/// Persisted resumable sessions: thread_key → ACP sessionId.
/// Includes both suspended sessions and active sessions so a process restart
/// can recover any live thread via `session/load`.
persisted: HashMap<String, String>,
/// Serializes create/resume work per thread so rapid same-thread requests
/// cannot race each other into duplicate `session/load` attempts.
creating: HashMap<String, Arc<Mutex<()>>>,
Expand Down Expand Up @@ -68,6 +73,7 @@ impl SessionPool {
state: RwLock::new(PoolState {
active: HashMap::new(),
cancel_handles: HashMap::new(),
persisted: suspended.clone(),
suspended,
creating: HashMap::new(),
}),
Expand All @@ -87,8 +93,8 @@ impl SessionPool {
}
}

fn save_mapping(&self, suspended: &HashMap<String, String>) {
let data = match serde_json::to_string_pretty(suspended) {
fn save_mapping(&self, persisted: &HashMap<String, String>) {
let data = match serde_json::to_string_pretty(persisted) {
Ok(d) => d,
Err(e) => {
warn!(error = %e, "failed to serialize thread mapping");
Expand Down Expand Up @@ -220,14 +226,19 @@ impl SessionPool {
warn!(thread_id, "stale connection, rebuilding");
drop(existing);
state.active.remove(thread_id);
state.cancel_handles.remove(thread_id);
}

if state.active.len() >= self.max_sessions {
if let Some((key, expected_conn, _, sid)) = eviction_candidate {
if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() {
state.cancel_handles.remove(&key);
info!(evicted = %key, "pool full, suspending oldest idle session");
if let Some(sid) = sid {
state.persisted.insert(key.clone(), sid.clone());
state.suspended.insert(key, sid);
} else {
state.persisted.remove(&key);
}
} else {
warn!(evicted = %key, "pool full but eviction candidate changed before removal");
Expand All @@ -245,14 +256,21 @@ impl SessionPool {
return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions));
}

if cancel_session_id.is_empty() {
state.persisted.remove(thread_id);
} else {
state
.persisted
.insert(thread_id.to_string(), cancel_session_id.clone());
}
state.suspended.remove(thread_id);
state.active.insert(thread_id.to_string(), new_conn);
self.save_mapping(&state.suspended);
if !cancel_session_id.is_empty() {
state
.cancel_handles
.insert(thread_id.to_string(), (cancel_handle, cancel_session_id));
}
self.save_mapping(&state.persisted);
Ok(())
}

Expand Down Expand Up @@ -367,8 +385,9 @@ impl SessionPool {
let had_active = state.active.remove(thread_id).is_some();
state.cancel_handles.remove(thread_id);
state.suspended.remove(thread_id);
state.persisted.remove(thread_id);
state.creating.remove(thread_id);
self.save_mapping(&state.suspended);
self.save_mapping(&state.persisted);
if had_active {
info!(thread_id, "session reset");
Ok(())
Expand Down Expand Up @@ -410,12 +429,16 @@ impl SessionPool {
for (key, expected_conn, sid) in stale {
if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() {
info!(thread_id = %key, "cleaning up idle session");
state.cancel_handles.remove(&key);
if let Some(sid) = sid {
state.persisted.insert(key.clone(), sid.clone());
state.suspended.insert(key, sid);
} else {
state.persisted.remove(&key);
}
}
}
self.save_mapping(&state.suspended);
self.save_mapping(&state.persisted);
}

pub async fn shutdown(&self) {
Expand All @@ -441,11 +464,13 @@ impl SessionPool {

let mut state = self.state.write().await;
for (key, sid) in session_ids {
state.persisted.insert(key.clone(), sid.clone());
state.suspended.insert(key, sid);
}
self.save_mapping(&state.suspended);
self.save_mapping(&state.persisted);
let count = state.active.len();
state.active.clear();
state.cancel_handles.clear();
info!(count, "pool shutdown complete");
}
}
Expand Down Expand Up @@ -491,4 +516,22 @@ mod tests {
assert!(Arc::ptr_eq(&first, &second));
assert_eq!(map.len(), 1);
}

#[test]
fn persisted_mapping_can_include_active_and_suspended_sessions() {
let persisted = HashMap::from([
("active-thread".to_string(), "session-active".to_string()),
("suspended-thread".to_string(), "session-suspended".to_string()),
]);

let serialized = serde_json::to_string_pretty(&persisted).expect("serialize persisted mapping");
let roundtrip: HashMap<String, String> =
serde_json::from_str(&serialized).expect("deserialize persisted mapping");

assert_eq!(roundtrip.get("active-thread"), Some(&"session-active".to_string()));
assert_eq!(
roundtrip.get("suspended-thread"),
Some(&"session-suspended".to_string())
);
}
}
Loading