1use crate::error::SessionManagerError;
2use crate::types::{SessionHandle, SessionRuntimeState};
3use stakpak_agent_core::AgentCommand;
4use std::{collections::HashMap, future::Future, sync::Arc};
5use tokio::sync::RwLock;
6use uuid::Uuid;
7
8#[derive(Clone, Default)]
10pub struct SessionManager {
11 states: Arc<RwLock<HashMap<Uuid, SessionRuntimeState>>>,
12}
13
14impl SessionManager {
15 pub fn new() -> Self {
16 Self::default()
17 }
18
19 pub async fn state(&self, session_id: Uuid) -> SessionRuntimeState {
20 let guard = self.states.read().await;
21 guard
22 .get(&session_id)
23 .cloned()
24 .unwrap_or(SessionRuntimeState::Idle)
25 }
26
27 pub async fn active_run_id(&self, session_id: Uuid) -> Option<Uuid> {
28 self.state(session_id).await.run_id()
29 }
30
31 pub async fn running_runs(&self) -> Vec<(Uuid, Uuid)> {
32 let guard = self.states.read().await;
33 guard
34 .iter()
35 .filter_map(|(session_id, state)| match state {
36 SessionRuntimeState::Running { run_id, .. } => Some((*session_id, *run_id)),
37 SessionRuntimeState::Idle
38 | SessionRuntimeState::Starting { .. }
39 | SessionRuntimeState::Failed { .. } => None,
40 })
41 .collect()
42 }
43
44 pub async fn start_run<F, Fut>(
45 &self,
46 session_id: Uuid,
47 spawn_actor: F,
48 ) -> Result<Uuid, SessionManagerError>
49 where
50 F: FnOnce(Uuid) -> Fut,
51 Fut: Future<Output = Result<SessionHandle, String>>,
52 {
53 let run_id = {
54 let mut guard = self.states.write().await;
55 match guard.get(&session_id) {
56 Some(SessionRuntimeState::Starting { .. })
57 | Some(SessionRuntimeState::Running { .. }) => {
58 return Err(SessionManagerError::SessionAlreadyRunning);
59 }
60 _ => {}
61 }
62
63 let run_id = Uuid::new_v4();
64 guard.insert(session_id, SessionRuntimeState::Starting { run_id });
65 run_id
66 };
67
68 match spawn_actor(run_id).await {
69 Ok(handle) => {
70 let mut guard = self.states.write().await;
71 if matches!(
72 guard.get(&session_id),
73 Some(SessionRuntimeState::Starting { run_id: active_run_id })
74 if *active_run_id == run_id
75 ) {
76 guard.insert(session_id, SessionRuntimeState::Running { run_id, handle });
77 Ok(run_id)
78 } else {
79 let error = "session state changed before actor startup completed".to_string();
80 guard.insert(
81 session_id,
82 SessionRuntimeState::Failed {
83 last_error: error.clone(),
84 },
85 );
86 Err(SessionManagerError::ActorStartupFailed(error))
87 }
88 }
89 Err(error) => {
90 let mut guard = self.states.write().await;
91 guard.insert(
92 session_id,
93 SessionRuntimeState::Failed {
94 last_error: error.clone(),
95 },
96 );
97 Err(SessionManagerError::ActorStartupFailed(error))
98 }
99 }
100 }
101
102 pub async fn mark_run_finished(
103 &self,
104 session_id: Uuid,
105 run_id: Uuid,
106 outcome: Result<(), String>,
107 ) -> Result<(), SessionManagerError> {
108 let mut guard = self.states.write().await;
109
110 match guard.get(&session_id) {
111 Some(SessionRuntimeState::Starting {
112 run_id: active_run_id,
113 })
114 | Some(SessionRuntimeState::Running {
115 run_id: active_run_id,
116 ..
117 }) => {
118 if *active_run_id != run_id {
119 return Err(SessionManagerError::RunMismatch {
120 active_run_id: *active_run_id,
121 requested_run_id: run_id,
122 });
123 }
124 }
125 Some(SessionRuntimeState::Idle) | None | Some(SessionRuntimeState::Failed { .. }) => {
126 return Err(SessionManagerError::SessionNotRunning);
127 }
128 }
129
130 match outcome {
131 Ok(()) => {
132 guard.insert(session_id, SessionRuntimeState::Idle);
133 }
134 Err(error) => {
135 guard.insert(
136 session_id,
137 SessionRuntimeState::Failed { last_error: error },
138 );
139 }
140 }
141
142 Ok(())
143 }
144
145 pub async fn send_command(
146 &self,
147 session_id: Uuid,
148 run_id: Uuid,
149 command: AgentCommand,
150 ) -> Result<(), SessionManagerError> {
151 let command_tx = {
152 let guard = self.states.read().await;
153 match guard.get(&session_id) {
154 Some(SessionRuntimeState::Running {
155 run_id: active_run_id,
156 handle,
157 }) => {
158 if *active_run_id != run_id {
159 return Err(SessionManagerError::RunMismatch {
160 active_run_id: *active_run_id,
161 requested_run_id: run_id,
162 });
163 }
164 handle.command_tx.clone()
165 }
166 Some(SessionRuntimeState::Starting { .. }) => {
167 return Err(SessionManagerError::SessionStarting);
168 }
169 Some(SessionRuntimeState::Idle)
170 | None
171 | Some(SessionRuntimeState::Failed { .. }) => {
172 return Err(SessionManagerError::SessionNotRunning);
173 }
174 }
175 };
176
177 command_tx
178 .send(command)
179 .await
180 .map_err(|_| SessionManagerError::CommandChannelClosed)
181 }
182
183 pub async fn cancel_run(
184 &self,
185 session_id: Uuid,
186 run_id: Uuid,
187 ) -> Result<(), SessionManagerError> {
188 let cancel_token = {
189 let guard = self.states.read().await;
190 match guard.get(&session_id) {
191 Some(SessionRuntimeState::Running {
192 run_id: active_run_id,
193 handle,
194 }) => {
195 if *active_run_id != run_id {
196 return Err(SessionManagerError::RunMismatch {
197 active_run_id: *active_run_id,
198 requested_run_id: run_id,
199 });
200 }
201 handle.cancel.clone()
202 }
203 Some(SessionRuntimeState::Starting { .. }) => {
204 return Err(SessionManagerError::SessionStarting);
205 }
206 Some(SessionRuntimeState::Idle)
207 | None
208 | Some(SessionRuntimeState::Failed { .. }) => {
209 return Err(SessionManagerError::SessionNotRunning);
210 }
211 }
212 };
213
214 cancel_token.cancel();
215 Ok(())
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use stakpak_agent_core::AgentCommand;
223 use std::sync::Arc;
224 use tokio::{sync::Barrier, sync::mpsc, time::Duration};
225 use tokio_util::sync::CancellationToken;
226
227 fn make_handle() -> (SessionHandle, mpsc::Receiver<AgentCommand>) {
228 let (command_tx, command_rx) = mpsc::channel(8);
229 (
230 SessionHandle::new(command_tx, CancellationToken::new()),
231 command_rx,
232 )
233 }
234
235 #[tokio::test]
236 async fn start_run_is_atomic_under_concurrency() {
237 let manager = Arc::new(SessionManager::new());
238 let session_id = Uuid::new_v4();
239 let barrier = Arc::new(Barrier::new(2));
240
241 let mut tasks = Vec::new();
242 for _ in 0..2 {
243 let manager_clone = manager.clone();
244 let barrier_clone = barrier.clone();
245 let session = session_id;
246 tasks.push(tokio::spawn(async move {
247 barrier_clone.wait().await;
248 manager_clone
249 .start_run(session, |_run_id| async {
250 tokio::time::sleep(Duration::from_millis(10)).await;
251 let (handle, _rx) = make_handle();
252 Ok(handle)
253 })
254 .await
255 }));
256 }
257
258 let mut successes = 0usize;
259 let mut conflicts = 0usize;
260
261 for task in tasks {
262 match task.await {
263 Ok(Ok(_)) => successes += 1,
264 Ok(Err(SessionManagerError::SessionAlreadyRunning)) => conflicts += 1,
265 Ok(Err(other)) => panic!("unexpected error: {other}"),
266 Err(join_error) => panic!("join error: {join_error}"),
267 }
268 }
269
270 assert_eq!(successes, 1);
271 assert_eq!(conflicts, 1);
272 }
273
274 #[tokio::test]
275 async fn run_scoped_command_rejects_stale_run_id() {
276 let manager = SessionManager::new();
277 let session_id = Uuid::new_v4();
278
279 let (handle, _rx) = make_handle();
280 let run_id = match manager
281 .start_run(
282 session_id,
283 move |_allocated_run_id| async move { Ok(handle) },
284 )
285 .await
286 {
287 Ok(run_id) => run_id,
288 Err(error) => panic!("start_run should succeed: {error}"),
289 };
290
291 let wrong_run_id = Uuid::new_v4();
292 let result = manager
293 .send_command(session_id, wrong_run_id, AgentCommand::Cancel)
294 .await;
295
296 assert_eq!(
297 result,
298 Err(SessionManagerError::RunMismatch {
299 active_run_id: run_id,
300 requested_run_id: wrong_run_id,
301 })
302 );
303 }
304
305 #[tokio::test]
306 async fn run_scoped_command_accepts_active_run_id() {
307 let manager = SessionManager::new();
308 let session_id = Uuid::new_v4();
309
310 let (handle, mut rx) = make_handle();
311 let run_id = match manager
312 .start_run(
313 session_id,
314 move |_allocated_run_id| async move { Ok(handle) },
315 )
316 .await
317 {
318 Ok(run_id) => run_id,
319 Err(error) => panic!("start_run should succeed: {error}"),
320 };
321
322 let send_result = manager
323 .send_command(session_id, run_id, AgentCommand::Cancel)
324 .await;
325 assert!(send_result.is_ok());
326
327 let received = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
328 match received {
329 Ok(Some(AgentCommand::Cancel)) => {}
330 Ok(Some(_other)) => panic!("unexpected command variant"),
331 Ok(None) => panic!("command channel closed unexpectedly"),
332 Err(timeout_error) => panic!("did not receive command in time: {timeout_error}"),
333 }
334 }
335
336 #[tokio::test]
337 async fn running_runs_lists_only_running_sessions() {
338 let manager = SessionManager::new();
339 let running_session_id = Uuid::new_v4();
340 let finished_session_id = Uuid::new_v4();
341
342 let (running_handle, _running_rx) = make_handle();
343 let running_run_id = match manager
344 .start_run(running_session_id, move |_allocated_run_id| async move {
345 Ok(running_handle)
346 })
347 .await
348 {
349 Ok(run_id) => run_id,
350 Err(error) => panic!("start_run should succeed: {error}"),
351 };
352
353 let (finished_handle, _finished_rx) = make_handle();
354 let finished_run_id = match manager
355 .start_run(finished_session_id, move |_allocated_run_id| async move {
356 Ok(finished_handle)
357 })
358 .await
359 {
360 Ok(run_id) => run_id,
361 Err(error) => panic!("start_run should succeed: {error}"),
362 };
363
364 let mark_finished = manager
365 .mark_run_finished(finished_session_id, finished_run_id, Ok(()))
366 .await;
367 assert!(mark_finished.is_ok());
368
369 let running_runs = manager.running_runs().await;
370 assert_eq!(running_runs.len(), 1);
371 assert_eq!(running_runs[0], (running_session_id, running_run_id));
372 }
373
374 #[tokio::test]
375 async fn startup_failure_transitions_to_failed_state() {
376 let manager = SessionManager::new();
377 let session_id = Uuid::new_v4();
378
379 let result = manager
380 .start_run(session_id, |_run_id| async move { Err("boom".to_string()) })
381 .await;
382
383 assert_eq!(
384 result,
385 Err(SessionManagerError::ActorStartupFailed("boom".to_string()))
386 );
387
388 let state = manager.state(session_id).await;
389 match state {
390 SessionRuntimeState::Failed { last_error } => {
391 assert_eq!(last_error, "boom".to_string());
392 }
393 other => panic!("expected failed state, got: {other:?}"),
394 }
395 }
396
397 #[tokio::test]
398 async fn mark_run_finished_requires_active_run_match() {
399 let manager = SessionManager::new();
400 let session_id = Uuid::new_v4();
401
402 let (handle, _rx) = make_handle();
403 let run_id = match manager
404 .start_run(
405 session_id,
406 move |_allocated_run_id| async move { Ok(handle) },
407 )
408 .await
409 {
410 Ok(run_id) => run_id,
411 Err(error) => panic!("start_run should succeed: {error}"),
412 };
413
414 let wrong_run_id = Uuid::new_v4();
415 let mismatch = manager
416 .mark_run_finished(session_id, wrong_run_id, Ok(()))
417 .await;
418
419 assert_eq!(
420 mismatch,
421 Err(SessionManagerError::RunMismatch {
422 active_run_id: run_id,
423 requested_run_id: wrong_run_id,
424 })
425 );
426
427 let finish = manager.mark_run_finished(session_id, run_id, Ok(())).await;
428 assert!(finish.is_ok());
429
430 let state = manager.state(session_id).await;
431 assert!(matches!(state, SessionRuntimeState::Idle));
432 }
433
434 #[tokio::test]
435 async fn cancel_run_requires_active_run_match_and_cancels_token() {
436 let manager = SessionManager::new();
437 let session_id = Uuid::new_v4();
438
439 let (handle, _rx) = make_handle();
440 let cancel = handle.cancel.clone();
441 let run_id = match manager
442 .start_run(
443 session_id,
444 move |_allocated_run_id| async move { Ok(handle) },
445 )
446 .await
447 {
448 Ok(run_id) => run_id,
449 Err(error) => panic!("start_run should succeed: {error}"),
450 };
451
452 let cancel_result = manager.cancel_run(session_id, run_id).await;
453 assert!(cancel_result.is_ok());
454 assert!(cancel.is_cancelled());
455 }
456}