diff --git a/src/acp/pool.rs b/src/acp/pool.rs index 42fc1113b..98ef3e35d 100644 --- a/src/acp/pool.rs +++ b/src/acp/pool.rs @@ -18,8 +18,13 @@ struct PoolState { /// Stored separately so cancel can work without locking the connection. cancel_handles: HashMap>, String)>, /// Suspended sessions: thread_key → ACP sessionId. - /// Saved on eviction so sessions can be resumed via `session/load`. + /// Used at runtime to decide which thread can be resumed via `session/load` + /// because it no longer has a live in-memory connection. suspended: HashMap, + /// Persisted resumable sessions: thread_key → ACP sessionId. + /// Includes both suspended sessions and active sessions so a process restart + /// can recover any live thread via `session/load`. + persisted: HashMap, /// Serializes create/resume work per thread so rapid same-thread requests /// cannot race each other into duplicate `session/load` attempts. creating: HashMap>>, @@ -68,6 +73,7 @@ impl SessionPool { state: RwLock::new(PoolState { active: HashMap::new(), cancel_handles: HashMap::new(), + persisted: suspended.clone(), suspended, creating: HashMap::new(), }), @@ -87,8 +93,8 @@ impl SessionPool { } } - fn save_mapping(&self, suspended: &HashMap) { - let data = match serde_json::to_string_pretty(suspended) { + fn save_mapping(&self, persisted: &HashMap) { + let data = match serde_json::to_string_pretty(persisted) { Ok(d) => d, Err(e) => { warn!(error = %e, "failed to serialize thread mapping"); @@ -220,14 +226,19 @@ impl SessionPool { warn!(thread_id, "stale connection, rebuilding"); drop(existing); state.active.remove(thread_id); + state.cancel_handles.remove(thread_id); } if state.active.len() >= self.max_sessions { if let Some((key, expected_conn, _, sid)) = eviction_candidate { if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { + state.cancel_handles.remove(&key); info!(evicted = %key, "pool full, suspending oldest idle session"); if let Some(sid) = sid { + state.persisted.insert(key.clone(), sid.clone()); state.suspended.insert(key, sid); + } else { + state.persisted.remove(&key); } } else { warn!(evicted = %key, "pool full but eviction candidate changed before removal"); @@ -245,14 +256,21 @@ impl SessionPool { return Err(anyhow!("pool exhausted ({} sessions)", self.max_sessions)); } + if cancel_session_id.is_empty() { + state.persisted.remove(thread_id); + } else { + state + .persisted + .insert(thread_id.to_string(), cancel_session_id.clone()); + } state.suspended.remove(thread_id); state.active.insert(thread_id.to_string(), new_conn); - self.save_mapping(&state.suspended); if !cancel_session_id.is_empty() { state .cancel_handles .insert(thread_id.to_string(), (cancel_handle, cancel_session_id)); } + self.save_mapping(&state.persisted); Ok(()) } @@ -367,8 +385,9 @@ impl SessionPool { let had_active = state.active.remove(thread_id).is_some(); state.cancel_handles.remove(thread_id); state.suspended.remove(thread_id); + state.persisted.remove(thread_id); state.creating.remove(thread_id); - self.save_mapping(&state.suspended); + self.save_mapping(&state.persisted); if had_active { info!(thread_id, "session reset"); Ok(()) @@ -410,12 +429,16 @@ impl SessionPool { for (key, expected_conn, sid) in stale { if remove_if_same_handle(&mut state.active, &key, &expected_conn).is_some() { info!(thread_id = %key, "cleaning up idle session"); + state.cancel_handles.remove(&key); if let Some(sid) = sid { + state.persisted.insert(key.clone(), sid.clone()); state.suspended.insert(key, sid); + } else { + state.persisted.remove(&key); } } } - self.save_mapping(&state.suspended); + self.save_mapping(&state.persisted); } pub async fn shutdown(&self) { @@ -441,11 +464,13 @@ impl SessionPool { let mut state = self.state.write().await; for (key, sid) in session_ids { + state.persisted.insert(key.clone(), sid.clone()); state.suspended.insert(key, sid); } - self.save_mapping(&state.suspended); + self.save_mapping(&state.persisted); let count = state.active.len(); state.active.clear(); + state.cancel_handles.clear(); info!(count, "pool shutdown complete"); } } @@ -491,4 +516,22 @@ mod tests { assert!(Arc::ptr_eq(&first, &second)); assert_eq!(map.len(), 1); } + + #[test] + fn persisted_mapping_can_include_active_and_suspended_sessions() { + let persisted = HashMap::from([ + ("active-thread".to_string(), "session-active".to_string()), + ("suspended-thread".to_string(), "session-suspended".to_string()), + ]); + + let serialized = serde_json::to_string_pretty(&persisted).expect("serialize persisted mapping"); + let roundtrip: HashMap = + serde_json::from_str(&serialized).expect("deserialize persisted mapping"); + + assert_eq!(roundtrip.get("active-thread"), Some(&"session-active".to_string())); + assert_eq!( + roundtrip.get("suspended-thread"), + Some(&"session-suspended".to_string()) + ); + } }