Skip to main content

stakpak_server/
session_manager.rs

1use crate::error::SessionManagerError;
2use crate::types::{SessionHandle, SessionRuntimeState};
3use stakpak_agent_core::AgentCommand;
4use std::{collections::HashMap, future::Future, sync::Arc};
5use tokio::sync::RwLock;
6use uuid::Uuid;
7
8/// In-memory runtime run coordinator (not persistent session storage).
9#[derive(Clone, Default)]
10pub struct SessionManager {
11    states: Arc<RwLock<HashMap<Uuid, SessionRuntimeState>>>,
12}
13
14impl SessionManager {
15    pub fn new() -> Self {
16        Self::default()
17    }
18
19    pub async fn state(&self, session_id: Uuid) -> SessionRuntimeState {
20        let guard = self.states.read().await;
21        guard
22            .get(&session_id)
23            .cloned()
24            .unwrap_or(SessionRuntimeState::Idle)
25    }
26
27    pub async fn active_run_id(&self, session_id: Uuid) -> Option<Uuid> {
28        self.state(session_id).await.run_id()
29    }
30
31    pub async fn running_runs(&self) -> Vec<(Uuid, Uuid)> {
32        let guard = self.states.read().await;
33        guard
34            .iter()
35            .filter_map(|(session_id, state)| match state {
36                SessionRuntimeState::Running { run_id, .. } => Some((*session_id, *run_id)),
37                SessionRuntimeState::Idle
38                | SessionRuntimeState::Starting { .. }
39                | SessionRuntimeState::Failed { .. } => None,
40            })
41            .collect()
42    }
43
44    pub async fn start_run<F, Fut>(
45        &self,
46        session_id: Uuid,
47        spawn_actor: F,
48    ) -> Result<Uuid, SessionManagerError>
49    where
50        F: FnOnce(Uuid) -> Fut,
51        Fut: Future<Output = Result<SessionHandle, String>>,
52    {
53        let run_id = {
54            let mut guard = self.states.write().await;
55            match guard.get(&session_id) {
56                Some(SessionRuntimeState::Starting { .. })
57                | Some(SessionRuntimeState::Running { .. }) => {
58                    return Err(SessionManagerError::SessionAlreadyRunning);
59                }
60                _ => {}
61            }
62
63            let run_id = Uuid::new_v4();
64            guard.insert(session_id, SessionRuntimeState::Starting { run_id });
65            run_id
66        };
67
68        match spawn_actor(run_id).await {
69            Ok(handle) => {
70                let mut guard = self.states.write().await;
71                if matches!(
72                    guard.get(&session_id),
73                    Some(SessionRuntimeState::Starting { run_id: active_run_id })
74                        if *active_run_id == run_id
75                ) {
76                    guard.insert(session_id, SessionRuntimeState::Running { run_id, handle });
77                    Ok(run_id)
78                } else {
79                    let error = "session state changed before actor startup completed".to_string();
80                    guard.insert(
81                        session_id,
82                        SessionRuntimeState::Failed {
83                            last_error: error.clone(),
84                        },
85                    );
86                    Err(SessionManagerError::ActorStartupFailed(error))
87                }
88            }
89            Err(error) => {
90                let mut guard = self.states.write().await;
91                guard.insert(
92                    session_id,
93                    SessionRuntimeState::Failed {
94                        last_error: error.clone(),
95                    },
96                );
97                Err(SessionManagerError::ActorStartupFailed(error))
98            }
99        }
100    }
101
102    pub async fn mark_run_finished(
103        &self,
104        session_id: Uuid,
105        run_id: Uuid,
106        outcome: Result<(), String>,
107    ) -> Result<(), SessionManagerError> {
108        let mut guard = self.states.write().await;
109
110        match guard.get(&session_id) {
111            Some(SessionRuntimeState::Starting {
112                run_id: active_run_id,
113            })
114            | Some(SessionRuntimeState::Running {
115                run_id: active_run_id,
116                ..
117            }) => {
118                if *active_run_id != run_id {
119                    return Err(SessionManagerError::RunMismatch {
120                        active_run_id: *active_run_id,
121                        requested_run_id: run_id,
122                    });
123                }
124            }
125            Some(SessionRuntimeState::Idle) | None | Some(SessionRuntimeState::Failed { .. }) => {
126                return Err(SessionManagerError::SessionNotRunning);
127            }
128        }
129
130        match outcome {
131            Ok(()) => {
132                guard.insert(session_id, SessionRuntimeState::Idle);
133            }
134            Err(error) => {
135                guard.insert(
136                    session_id,
137                    SessionRuntimeState::Failed { last_error: error },
138                );
139            }
140        }
141
142        Ok(())
143    }
144
145    pub async fn send_command(
146        &self,
147        session_id: Uuid,
148        run_id: Uuid,
149        command: AgentCommand,
150    ) -> Result<(), SessionManagerError> {
151        let command_tx = {
152            let guard = self.states.read().await;
153            match guard.get(&session_id) {
154                Some(SessionRuntimeState::Running {
155                    run_id: active_run_id,
156                    handle,
157                }) => {
158                    if *active_run_id != run_id {
159                        return Err(SessionManagerError::RunMismatch {
160                            active_run_id: *active_run_id,
161                            requested_run_id: run_id,
162                        });
163                    }
164                    handle.command_tx.clone()
165                }
166                Some(SessionRuntimeState::Starting { .. }) => {
167                    return Err(SessionManagerError::SessionStarting);
168                }
169                Some(SessionRuntimeState::Idle)
170                | None
171                | Some(SessionRuntimeState::Failed { .. }) => {
172                    return Err(SessionManagerError::SessionNotRunning);
173                }
174            }
175        };
176
177        command_tx
178            .send(command)
179            .await
180            .map_err(|_| SessionManagerError::CommandChannelClosed)
181    }
182
183    pub async fn cancel_run(
184        &self,
185        session_id: Uuid,
186        run_id: Uuid,
187    ) -> Result<(), SessionManagerError> {
188        let cancel_token = {
189            let guard = self.states.read().await;
190            match guard.get(&session_id) {
191                Some(SessionRuntimeState::Running {
192                    run_id: active_run_id,
193                    handle,
194                }) => {
195                    if *active_run_id != run_id {
196                        return Err(SessionManagerError::RunMismatch {
197                            active_run_id: *active_run_id,
198                            requested_run_id: run_id,
199                        });
200                    }
201                    handle.cancel.clone()
202                }
203                Some(SessionRuntimeState::Starting { .. }) => {
204                    return Err(SessionManagerError::SessionStarting);
205                }
206                Some(SessionRuntimeState::Idle)
207                | None
208                | Some(SessionRuntimeState::Failed { .. }) => {
209                    return Err(SessionManagerError::SessionNotRunning);
210                }
211            }
212        };
213
214        cancel_token.cancel();
215        Ok(())
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use stakpak_agent_core::AgentCommand;
223    use std::sync::Arc;
224    use tokio::{sync::Barrier, sync::mpsc, time::Duration};
225    use tokio_util::sync::CancellationToken;
226
227    fn make_handle() -> (SessionHandle, mpsc::Receiver<AgentCommand>) {
228        let (command_tx, command_rx) = mpsc::channel(8);
229        (
230            SessionHandle::new(command_tx, CancellationToken::new()),
231            command_rx,
232        )
233    }
234
235    #[tokio::test]
236    async fn start_run_is_atomic_under_concurrency() {
237        let manager = Arc::new(SessionManager::new());
238        let session_id = Uuid::new_v4();
239        let barrier = Arc::new(Barrier::new(2));
240
241        let mut tasks = Vec::new();
242        for _ in 0..2 {
243            let manager_clone = manager.clone();
244            let barrier_clone = barrier.clone();
245            let session = session_id;
246            tasks.push(tokio::spawn(async move {
247                barrier_clone.wait().await;
248                manager_clone
249                    .start_run(session, |_run_id| async {
250                        tokio::time::sleep(Duration::from_millis(10)).await;
251                        let (handle, _rx) = make_handle();
252                        Ok(handle)
253                    })
254                    .await
255            }));
256        }
257
258        let mut successes = 0usize;
259        let mut conflicts = 0usize;
260
261        for task in tasks {
262            match task.await {
263                Ok(Ok(_)) => successes += 1,
264                Ok(Err(SessionManagerError::SessionAlreadyRunning)) => conflicts += 1,
265                Ok(Err(other)) => panic!("unexpected error: {other}"),
266                Err(join_error) => panic!("join error: {join_error}"),
267            }
268        }
269
270        assert_eq!(successes, 1);
271        assert_eq!(conflicts, 1);
272    }
273
274    #[tokio::test]
275    async fn run_scoped_command_rejects_stale_run_id() {
276        let manager = SessionManager::new();
277        let session_id = Uuid::new_v4();
278
279        let (handle, _rx) = make_handle();
280        let run_id = match manager
281            .start_run(
282                session_id,
283                move |_allocated_run_id| async move { Ok(handle) },
284            )
285            .await
286        {
287            Ok(run_id) => run_id,
288            Err(error) => panic!("start_run should succeed: {error}"),
289        };
290
291        let wrong_run_id = Uuid::new_v4();
292        let result = manager
293            .send_command(session_id, wrong_run_id, AgentCommand::Cancel)
294            .await;
295
296        assert_eq!(
297            result,
298            Err(SessionManagerError::RunMismatch {
299                active_run_id: run_id,
300                requested_run_id: wrong_run_id,
301            })
302        );
303    }
304
305    #[tokio::test]
306    async fn run_scoped_command_accepts_active_run_id() {
307        let manager = SessionManager::new();
308        let session_id = Uuid::new_v4();
309
310        let (handle, mut rx) = make_handle();
311        let run_id = match manager
312            .start_run(
313                session_id,
314                move |_allocated_run_id| async move { Ok(handle) },
315            )
316            .await
317        {
318            Ok(run_id) => run_id,
319            Err(error) => panic!("start_run should succeed: {error}"),
320        };
321
322        let send_result = manager
323            .send_command(session_id, run_id, AgentCommand::Cancel)
324            .await;
325        assert!(send_result.is_ok());
326
327        let received = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
328        match received {
329            Ok(Some(AgentCommand::Cancel)) => {}
330            Ok(Some(_other)) => panic!("unexpected command variant"),
331            Ok(None) => panic!("command channel closed unexpectedly"),
332            Err(timeout_error) => panic!("did not receive command in time: {timeout_error}"),
333        }
334    }
335
336    #[tokio::test]
337    async fn running_runs_lists_only_running_sessions() {
338        let manager = SessionManager::new();
339        let running_session_id = Uuid::new_v4();
340        let finished_session_id = Uuid::new_v4();
341
342        let (running_handle, _running_rx) = make_handle();
343        let running_run_id = match manager
344            .start_run(running_session_id, move |_allocated_run_id| async move {
345                Ok(running_handle)
346            })
347            .await
348        {
349            Ok(run_id) => run_id,
350            Err(error) => panic!("start_run should succeed: {error}"),
351        };
352
353        let (finished_handle, _finished_rx) = make_handle();
354        let finished_run_id = match manager
355            .start_run(finished_session_id, move |_allocated_run_id| async move {
356                Ok(finished_handle)
357            })
358            .await
359        {
360            Ok(run_id) => run_id,
361            Err(error) => panic!("start_run should succeed: {error}"),
362        };
363
364        let mark_finished = manager
365            .mark_run_finished(finished_session_id, finished_run_id, Ok(()))
366            .await;
367        assert!(mark_finished.is_ok());
368
369        let running_runs = manager.running_runs().await;
370        assert_eq!(running_runs.len(), 1);
371        assert_eq!(running_runs[0], (running_session_id, running_run_id));
372    }
373
374    #[tokio::test]
375    async fn startup_failure_transitions_to_failed_state() {
376        let manager = SessionManager::new();
377        let session_id = Uuid::new_v4();
378
379        let result = manager
380            .start_run(session_id, |_run_id| async move { Err("boom".to_string()) })
381            .await;
382
383        assert_eq!(
384            result,
385            Err(SessionManagerError::ActorStartupFailed("boom".to_string()))
386        );
387
388        let state = manager.state(session_id).await;
389        match state {
390            SessionRuntimeState::Failed { last_error } => {
391                assert_eq!(last_error, "boom".to_string());
392            }
393            other => panic!("expected failed state, got: {other:?}"),
394        }
395    }
396
397    #[tokio::test]
398    async fn mark_run_finished_requires_active_run_match() {
399        let manager = SessionManager::new();
400        let session_id = Uuid::new_v4();
401
402        let (handle, _rx) = make_handle();
403        let run_id = match manager
404            .start_run(
405                session_id,
406                move |_allocated_run_id| async move { Ok(handle) },
407            )
408            .await
409        {
410            Ok(run_id) => run_id,
411            Err(error) => panic!("start_run should succeed: {error}"),
412        };
413
414        let wrong_run_id = Uuid::new_v4();
415        let mismatch = manager
416            .mark_run_finished(session_id, wrong_run_id, Ok(()))
417            .await;
418
419        assert_eq!(
420            mismatch,
421            Err(SessionManagerError::RunMismatch {
422                active_run_id: run_id,
423                requested_run_id: wrong_run_id,
424            })
425        );
426
427        let finish = manager.mark_run_finished(session_id, run_id, Ok(())).await;
428        assert!(finish.is_ok());
429
430        let state = manager.state(session_id).await;
431        assert!(matches!(state, SessionRuntimeState::Idle));
432    }
433
434    #[tokio::test]
435    async fn cancel_run_requires_active_run_match_and_cancels_token() {
436        let manager = SessionManager::new();
437        let session_id = Uuid::new_v4();
438
439        let (handle, _rx) = make_handle();
440        let cancel = handle.cancel.clone();
441        let run_id = match manager
442            .start_run(
443                session_id,
444                move |_allocated_run_id| async move { Ok(handle) },
445            )
446            .await
447        {
448            Ok(run_id) => run_id,
449            Err(error) => panic!("start_run should succeed: {error}"),
450        };
451
452        let cancel_result = manager.cancel_run(session_id, run_id).await;
453        assert!(cancel_result.is_ok());
454        assert!(cancel.is_cancelled());
455    }
456}