Skip to main content

scud/commands/spawn/headless/
store.rs

1//! In-memory storage for streaming agent output
2//!
3//! Provides thread-safe storage for multiple headless agent sessions,
4//! with methods for event storage, output rendering, and session management.
5
6use anyhow::Result;
7use std::collections::HashMap;
8use std::path::Path;
9use std::sync::{Arc, RwLock};
10use std::time::Instant;
11
12use super::events::{StreamEvent, StreamEventKind};
13
14/// Maximum number of output lines to retain per session (memory limit)
15const MAX_OUTPUT_LINES: usize = 10_000;
16
17/// Maximum number of events to retain per session (memory limit)
18const MAX_EVENTS: usize = 50_000;
19
20/// Status of a headless session
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum SessionStatus {
23    Starting,
24    Running,
25    Completed,
26    Failed,
27}
28
29/// Stream data for a single agent session
30#[derive(Debug)]
31pub struct SessionStream {
32    /// Unique session ID (from harness)
33    pub session_id: String,
34    /// Associated task ID
35    pub task_id: String,
36    /// Tag/phase
37    pub tag: String,
38    /// All events received (bounded by MAX_EVENTS)
39    pub events: Vec<StreamEvent>,
40    /// Rendered output lines for display (bounded by MAX_OUTPUT_LINES)
41    pub output_lines: Vec<String>,
42    /// Current status
43    pub status: SessionStatus,
44    /// When the session started
45    pub started_at: Instant,
46    /// Process ID (for interruption)
47    pub pid: Option<u32>,
48    /// Partial line buffer for incomplete text deltas
49    partial_line: String,
50}
51
52impl SessionStream {
53    pub fn new(task_id: &str, tag: &str) -> Self {
54        Self {
55            session_id: String::new(),
56            task_id: task_id.to_string(),
57            tag: tag.to_string(),
58            events: Vec::new(),
59            output_lines: Vec::new(),
60            status: SessionStatus::Starting,
61            started_at: Instant::now(),
62            pid: None,
63            partial_line: String::new(),
64        }
65    }
66
67    /// Add an event and update output lines
68    pub fn push_event(&mut self, mut event: StreamEvent) {
69        event.timestamp_ms = self.started_at.elapsed().as_millis() as u64;
70
71        // Update output lines based on event
72        match &event.kind {
73            StreamEventKind::TextDelta { text } => {
74                self.append_text(text);
75            }
76            StreamEventKind::ToolStart {
77                tool_name,
78                input_summary,
79                ..
80            } => {
81                // Flush any partial line first
82                self.flush_partial_line();
83                self.push_line(format!(">> {} {}", tool_name, input_summary));
84            }
85            StreamEventKind::ToolResult {
86                tool_name, success, ..
87            } => {
88                self.flush_partial_line();
89                let status = if *success { "ok" } else { "failed" };
90                self.push_line(format!("<< {} {}", tool_name, status));
91            }
92            StreamEventKind::Complete { success } => {
93                self.flush_partial_line();
94                self.status = if *success {
95                    SessionStatus::Completed
96                } else {
97                    SessionStatus::Failed
98                };
99            }
100            StreamEventKind::Error { message } => {
101                self.flush_partial_line();
102                self.push_line(format!("ERROR: {}", message));
103                self.status = SessionStatus::Failed;
104            }
105            StreamEventKind::SessionAssigned { session_id } => {
106                self.session_id = session_id.clone();
107                self.status = SessionStatus::Running;
108            }
109        }
110
111        // Store event with memory limit
112        if self.events.len() >= MAX_EVENTS {
113            // Remove oldest 10% when limit reached
114            let drain_count = MAX_EVENTS / 10;
115            self.events.drain(0..drain_count);
116        }
117        self.events.push(event);
118    }
119
120    /// Append text, handling newlines properly
121    fn append_text(&mut self, text: &str) {
122        for ch in text.chars() {
123            if ch == '\n' {
124                // Complete the current line and start a new one
125                let line = std::mem::take(&mut self.partial_line);
126                self.push_line(line);
127            } else {
128                self.partial_line.push(ch);
129            }
130        }
131    }
132
133    /// Flush any remaining partial line as a complete line
134    fn flush_partial_line(&mut self) {
135        if !self.partial_line.is_empty() {
136            let line = std::mem::take(&mut self.partial_line);
137            self.push_line(line);
138        }
139    }
140
141    /// Push a line to output with memory limit
142    fn push_line(&mut self, line: String) {
143        if self.output_lines.len() >= MAX_OUTPUT_LINES {
144            // Remove oldest 10% when limit reached
145            let drain_count = MAX_OUTPUT_LINES / 10;
146            self.output_lines.drain(0..drain_count);
147        }
148        self.output_lines.push(line);
149    }
150
151    /// Get the last N output lines
152    pub fn tail(&self, n: usize) -> &[String] {
153        let start = self.output_lines.len().saturating_sub(n);
154        &self.output_lines[start..]
155    }
156
157    /// Get all output lines including any partial line in progress
158    pub fn get_all_output(&self) -> Vec<String> {
159        let mut lines = self.output_lines.clone();
160        if !self.partial_line.is_empty() {
161            lines.push(self.partial_line.clone());
162        }
163        lines
164    }
165
166    /// Check if session is still active
167    pub fn is_active(&self) -> bool {
168        matches!(self.status, SessionStatus::Starting | SessionStatus::Running)
169    }
170
171    /// Get the event count
172    pub fn event_count(&self) -> usize {
173        self.events.len()
174    }
175
176    /// Get the output line count
177    pub fn line_count(&self) -> usize {
178        self.output_lines.len()
179    }
180}
181
182/// Thread-safe store for multiple agent sessions
183#[derive(Debug, Clone, Default)]
184pub struct StreamStore {
185    sessions: Arc<RwLock<HashMap<String, SessionStream>>>,
186}
187
188impl StreamStore {
189    pub fn new() -> Self {
190        Self::default()
191    }
192
193    /// Create a new session for a task
194    pub fn create_session(&self, task_id: &str, tag: &str) -> String {
195        let mut sessions = self.sessions.write().unwrap();
196        let stream = SessionStream::new(task_id, tag);
197        let key = task_id.to_string();
198        sessions.insert(key.clone(), stream);
199        key
200    }
201
202    /// Push an event to a session
203    pub fn push_event(&self, task_id: &str, event: StreamEvent) {
204        let mut sessions = self.sessions.write().unwrap();
205        if let Some(stream) = sessions.get_mut(task_id) {
206            stream.push_event(event);
207        }
208    }
209
210    /// Set the harness session ID for a task
211    pub fn set_session_id(&self, task_id: &str, session_id: &str) {
212        let mut sessions = self.sessions.write().unwrap();
213        if let Some(stream) = sessions.get_mut(task_id) {
214            stream.session_id = session_id.to_string();
215            stream.status = SessionStatus::Running;
216        }
217    }
218
219    /// Set the process ID for a task
220    pub fn set_pid(&self, task_id: &str, pid: u32) {
221        let mut sessions = self.sessions.write().unwrap();
222        if let Some(stream) = sessions.get_mut(task_id) {
223            stream.pid = Some(pid);
224        }
225    }
226
227    /// Get output lines for a task
228    pub fn get_output(&self, task_id: &str, limit: usize) -> Vec<String> {
229        let sessions = self.sessions.read().unwrap();
230        sessions
231            .get(task_id)
232            .map(|s| s.tail(limit).to_vec())
233            .unwrap_or_default()
234    }
235
236    /// Get all output lines for a task, including any partial line
237    pub fn get_all_output(&self, task_id: &str) -> Vec<String> {
238        let sessions = self.sessions.read().unwrap();
239        sessions
240            .get(task_id)
241            .map(|s| s.get_all_output())
242            .unwrap_or_default()
243    }
244
245    /// Get session status
246    pub fn get_status(&self, task_id: &str) -> Option<SessionStatus> {
247        let sessions = self.sessions.read().unwrap();
248        sessions.get(task_id).map(|s| s.status.clone())
249    }
250
251    /// Get harness session ID for continuation
252    pub fn get_session_id(&self, task_id: &str) -> Option<String> {
253        let sessions = self.sessions.read().unwrap();
254        sessions
255            .get(task_id)
256            .filter(|s| !s.session_id.is_empty())
257            .map(|s| s.session_id.clone())
258    }
259
260    /// List all active task IDs
261    pub fn active_tasks(&self) -> Vec<String> {
262        let sessions = self.sessions.read().unwrap();
263        sessions
264            .iter()
265            .filter(|(_, s)| s.is_active())
266            .map(|(k, _)| k.clone())
267            .collect()
268    }
269
270    /// Get all task IDs
271    pub fn all_tasks(&self) -> Vec<String> {
272        let sessions = self.sessions.read().unwrap();
273        sessions.keys().cloned().collect()
274    }
275
276    /// Check if a session exists
277    pub fn has_session(&self, task_id: &str) -> bool {
278        let sessions = self.sessions.read().unwrap();
279        sessions.contains_key(task_id)
280    }
281
282    /// Remove a session
283    pub fn remove_session(&self, task_id: &str) -> Option<SessionStream> {
284        let mut sessions = self.sessions.write().unwrap();
285        sessions.remove(task_id)
286    }
287
288    /// Get session statistics
289    pub fn session_stats(&self, task_id: &str) -> Option<(usize, usize)> {
290        let sessions = self.sessions.read().unwrap();
291        sessions
292            .get(task_id)
293            .map(|s| (s.event_count(), s.line_count()))
294    }
295
296    /// Save session metadata for later continuation
297    ///
298    /// Persists the session ID, task ID, tag, and PID to a JSON file
299    /// in the `.scud/headless/` directory for use by the `attach` command.
300    pub fn save_session_metadata(&self, task_id: &str, project_root: &Path) -> Result<()> {
301        let sessions = self.sessions.read().unwrap();
302        let session = sessions
303            .get(task_id)
304            .ok_or_else(|| anyhow::anyhow!("Session not found: {}", task_id))?;
305
306        let metadata_dir = project_root.join(".scud").join("headless");
307        std::fs::create_dir_all(&metadata_dir)?;
308
309        let metadata = serde_json::json!({
310            "task_id": session.task_id,
311            "session_id": session.session_id,
312            "tag": session.tag,
313            "pid": session.pid,
314            "status": format!("{:?}", session.status),
315            "started_at_ms": session.started_at.elapsed().as_millis() as u64,
316        });
317
318        let metadata_file = metadata_dir.join(format!("{}.json", task_id));
319        std::fs::write(&metadata_file, serde_json::to_string_pretty(&metadata)?)?;
320
321        Ok(())
322    }
323
324    /// Load session metadata from disk
325    ///
326    /// Returns the stored session_id if available for continuation.
327    pub fn load_session_metadata(task_id: &str, project_root: &Path) -> Result<Option<String>> {
328        let metadata_file = project_root
329            .join(".scud")
330            .join("headless")
331            .join(format!("{}.json", task_id));
332
333        if !metadata_file.exists() {
334            return Ok(None);
335        }
336
337        let content = std::fs::read_to_string(&metadata_file)?;
338        let data: serde_json::Value = serde_json::from_str(&content)?;
339
340        Ok(data
341            .get("session_id")
342            .and_then(|v| v.as_str())
343            .filter(|s| !s.is_empty())
344            .map(|s| s.to_string()))
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_session_stream_new() {
354        let stream = SessionStream::new("task-1", "phase-a");
355        assert_eq!(stream.task_id, "task-1");
356        assert_eq!(stream.tag, "phase-a");
357        assert_eq!(stream.status, SessionStatus::Starting);
358        assert!(stream.session_id.is_empty());
359        assert!(stream.events.is_empty());
360        assert!(stream.output_lines.is_empty());
361    }
362
363    #[test]
364    fn test_push_text_delta_single_line() {
365        let mut stream = SessionStream::new("task-1", "test");
366        stream.push_event(StreamEvent::text_delta("Hello world"));
367
368        // Text without newline stays in partial buffer
369        assert_eq!(stream.output_lines.len(), 0);
370        assert_eq!(stream.partial_line, "Hello world");
371        assert_eq!(stream.events.len(), 1);
372    }
373
374    #[test]
375    fn test_push_text_delta_with_newline() {
376        let mut stream = SessionStream::new("task-1", "test");
377        stream.push_event(StreamEvent::text_delta("Hello\nWorld\n"));
378
379        assert_eq!(stream.output_lines.len(), 2);
380        assert_eq!(stream.output_lines[0], "Hello");
381        assert_eq!(stream.output_lines[1], "World");
382        assert!(stream.partial_line.is_empty());
383    }
384
385    #[test]
386    fn test_push_text_delta_incremental() {
387        let mut stream = SessionStream::new("task-1", "test");
388        stream.push_event(StreamEvent::text_delta("Hel"));
389        stream.push_event(StreamEvent::text_delta("lo "));
390        stream.push_event(StreamEvent::text_delta("world\n"));
391
392        assert_eq!(stream.output_lines.len(), 1);
393        assert_eq!(stream.output_lines[0], "Hello world");
394    }
395
396    #[test]
397    fn test_push_tool_start() {
398        let mut stream = SessionStream::new("task-1", "test");
399        stream.push_event(StreamEvent::text_delta("Some text"));
400        stream.push_event(StreamEvent::tool_start("Read", "tool-1", "src/main.rs"));
401
402        // Tool start should flush partial line
403        assert_eq!(stream.output_lines.len(), 2);
404        assert_eq!(stream.output_lines[0], "Some text");
405        assert_eq!(stream.output_lines[1], ">> Read src/main.rs");
406    }
407
408    #[test]
409    fn test_push_tool_result() {
410        let mut stream = SessionStream::new("task-1", "test");
411        stream.push_event(StreamEvent::new(StreamEventKind::ToolResult {
412            tool_name: "Read".to_string(),
413            tool_id: "tool-1".to_string(),
414            success: true,
415        }));
416
417        assert_eq!(stream.output_lines.len(), 1);
418        assert_eq!(stream.output_lines[0], "<< Read ok");
419    }
420
421    #[test]
422    fn test_push_tool_result_failed() {
423        let mut stream = SessionStream::new("task-1", "test");
424        stream.push_event(StreamEvent::new(StreamEventKind::ToolResult {
425            tool_name: "Bash".to_string(),
426            tool_id: "tool-2".to_string(),
427            success: false,
428        }));
429
430        assert_eq!(stream.output_lines[0], "<< Bash failed");
431    }
432
433    #[test]
434    fn test_session_assigned() {
435        let mut stream = SessionStream::new("task-1", "test");
436        assert_eq!(stream.status, SessionStatus::Starting);
437
438        stream.push_event(StreamEvent::new(StreamEventKind::SessionAssigned {
439            session_id: "sess-abc123".to_string(),
440        }));
441
442        assert_eq!(stream.session_id, "sess-abc123");
443        assert_eq!(stream.status, SessionStatus::Running);
444    }
445
446    #[test]
447    fn test_complete_success() {
448        let mut stream = SessionStream::new("task-1", "test");
449        stream.push_event(StreamEvent::complete(true));
450
451        assert_eq!(stream.status, SessionStatus::Completed);
452    }
453
454    #[test]
455    fn test_complete_failure() {
456        let mut stream = SessionStream::new("task-1", "test");
457        stream.push_event(StreamEvent::complete(false));
458
459        assert_eq!(stream.status, SessionStatus::Failed);
460    }
461
462    #[test]
463    fn test_error_event() {
464        let mut stream = SessionStream::new("task-1", "test");
465        stream.push_event(StreamEvent::error("Something went wrong"));
466
467        assert_eq!(stream.status, SessionStatus::Failed);
468        assert_eq!(stream.output_lines[0], "ERROR: Something went wrong");
469    }
470
471    #[test]
472    fn test_tail() {
473        let mut stream = SessionStream::new("task-1", "test");
474        for i in 0..10 {
475            stream.push_event(StreamEvent::text_delta(&format!("Line {}\n", i)));
476        }
477
478        let last3 = stream.tail(3);
479        assert_eq!(last3.len(), 3);
480        assert_eq!(last3[0], "Line 7");
481        assert_eq!(last3[1], "Line 8");
482        assert_eq!(last3[2], "Line 9");
483    }
484
485    #[test]
486    fn test_tail_less_than_requested() {
487        let mut stream = SessionStream::new("task-1", "test");
488        stream.push_event(StreamEvent::text_delta("Only one\n"));
489
490        let last10 = stream.tail(10);
491        assert_eq!(last10.len(), 1);
492        assert_eq!(last10[0], "Only one");
493    }
494
495    #[test]
496    fn test_get_all_output_with_partial() {
497        let mut stream = SessionStream::new("task-1", "test");
498        stream.push_event(StreamEvent::text_delta("Complete line\n"));
499        stream.push_event(StreamEvent::text_delta("Partial"));
500
501        let output = stream.get_all_output();
502        assert_eq!(output.len(), 2);
503        assert_eq!(output[0], "Complete line");
504        assert_eq!(output[1], "Partial");
505    }
506
507    #[test]
508    fn test_is_active() {
509        let mut stream = SessionStream::new("task-1", "test");
510        assert!(stream.is_active()); // Starting
511
512        stream.status = SessionStatus::Running;
513        assert!(stream.is_active());
514
515        stream.status = SessionStatus::Completed;
516        assert!(!stream.is_active());
517
518        stream.status = SessionStatus::Failed;
519        assert!(!stream.is_active());
520    }
521
522    #[test]
523    fn test_event_timestamp() {
524        let mut stream = SessionStream::new("task-1", "test");
525
526        // Small sleep to ensure non-zero timestamp
527        std::thread::sleep(std::time::Duration::from_millis(10));
528
529        stream.push_event(StreamEvent::text_delta("Hello"));
530        assert!(stream.events[0].timestamp_ms > 0);
531    }
532
533    // StreamStore tests
534
535    #[test]
536    fn test_store_create_session() {
537        let store = StreamStore::new();
538        let key = store.create_session("task-1", "phase-a");
539
540        assert_eq!(key, "task-1");
541        assert!(store.has_session("task-1"));
542    }
543
544    #[test]
545    fn test_store_push_event() {
546        let store = StreamStore::new();
547        store.create_session("task-1", "phase-a");
548        store.push_event("task-1", StreamEvent::text_delta("Hello\n"));
549
550        let output = store.get_output("task-1", 100);
551        assert_eq!(output.len(), 1);
552        assert_eq!(output[0], "Hello");
553    }
554
555    #[test]
556    fn test_store_set_session_id() {
557        let store = StreamStore::new();
558        store.create_session("task-1", "phase-a");
559        store.set_session_id("task-1", "sess-xyz");
560
561        let session_id = store.get_session_id("task-1");
562        assert_eq!(session_id, Some("sess-xyz".to_string()));
563    }
564
565    #[test]
566    fn test_store_set_pid() {
567        let store = StreamStore::new();
568        store.create_session("task-1", "phase-a");
569        store.set_pid("task-1", 12345);
570
571        // Verify by checking stats or through save_session_metadata
572        assert!(store.has_session("task-1"));
573    }
574
575    #[test]
576    fn test_store_get_status() {
577        let store = StreamStore::new();
578        store.create_session("task-1", "phase-a");
579
580        assert_eq!(store.get_status("task-1"), Some(SessionStatus::Starting));
581
582        store.push_event("task-1", StreamEvent::complete(true));
583        assert_eq!(store.get_status("task-1"), Some(SessionStatus::Completed));
584    }
585
586    #[test]
587    fn test_store_active_tasks() {
588        let store = StreamStore::new();
589        store.create_session("task-1", "phase-a");
590        store.create_session("task-2", "phase-a");
591        store.push_event("task-2", StreamEvent::complete(true));
592
593        let active = store.active_tasks();
594        assert_eq!(active.len(), 1);
595        assert!(active.contains(&"task-1".to_string()));
596    }
597
598    #[test]
599    fn test_store_all_tasks() {
600        let store = StreamStore::new();
601        store.create_session("task-1", "phase-a");
602        store.create_session("task-2", "phase-b");
603
604        let all = store.all_tasks();
605        assert_eq!(all.len(), 2);
606    }
607
608    #[test]
609    fn test_store_remove_session() {
610        let store = StreamStore::new();
611        store.create_session("task-1", "phase-a");
612        assert!(store.has_session("task-1"));
613
614        let removed = store.remove_session("task-1");
615        assert!(removed.is_some());
616        assert!(!store.has_session("task-1"));
617    }
618
619    #[test]
620    fn test_store_session_stats() {
621        let store = StreamStore::new();
622        store.create_session("task-1", "phase-a");
623        store.push_event("task-1", StreamEvent::text_delta("Line 1\n"));
624        store.push_event("task-1", StreamEvent::text_delta("Line 2\n"));
625
626        let stats = store.session_stats("task-1");
627        assert!(stats.is_some());
628        let (events, lines) = stats.unwrap();
629        assert_eq!(events, 2);
630        assert_eq!(lines, 2);
631    }
632
633    #[test]
634    fn test_store_nonexistent_session() {
635        let store = StreamStore::new();
636
637        assert_eq!(store.get_output("nonexistent", 100), Vec::<String>::new());
638        assert_eq!(store.get_status("nonexistent"), None);
639        assert_eq!(store.get_session_id("nonexistent"), None);
640    }
641
642    #[test]
643    fn test_store_thread_safety() {
644        use std::sync::Arc;
645        use std::thread;
646
647        let store = Arc::new(StreamStore::new());
648        store.create_session("task-1", "phase-a");
649
650        let handles: Vec<_> = (0..10)
651            .map(|i| {
652                let store = Arc::clone(&store);
653                thread::spawn(move || {
654                    for j in 0..100 {
655                        store.push_event(
656                            "task-1",
657                            StreamEvent::text_delta(&format!("Thread {} line {}\n", i, j)),
658                        );
659                    }
660                })
661            })
662            .collect();
663
664        for handle in handles {
665            handle.join().unwrap();
666        }
667
668        let stats = store.session_stats("task-1").unwrap();
669        assert_eq!(stats.0, 1000); // 10 threads * 100 events
670        assert_eq!(stats.1, 1000); // 10 threads * 100 lines
671    }
672
673    #[test]
674    fn test_memory_limit_output_lines() {
675        let mut stream = SessionStream::new("task-1", "test");
676
677        // Push more than MAX_OUTPUT_LINES
678        for i in 0..MAX_OUTPUT_LINES + 1000 {
679            stream.push_event(StreamEvent::text_delta(&format!("Line {}\n", i)));
680        }
681
682        // Should have trimmed to within limits
683        assert!(stream.output_lines.len() <= MAX_OUTPUT_LINES);
684    }
685
686    #[test]
687    fn test_memory_limit_events() {
688        let mut stream = SessionStream::new("task-1", "test");
689
690        // Push more than MAX_EVENTS
691        for i in 0..MAX_EVENTS + 1000 {
692            stream.push_event(StreamEvent::text_delta(&format!("{}", i)));
693        }
694
695        // Should have trimmed to within limits
696        assert!(stream.events.len() <= MAX_EVENTS);
697    }
698
699    #[test]
700    fn test_save_and_load_session_metadata() {
701        let temp_dir = std::env::temp_dir().join(format!("scud_test_{}", std::process::id()));
702        std::fs::create_dir_all(&temp_dir).unwrap();
703
704        let store = StreamStore::new();
705        store.create_session("task-1", "phase-a");
706        store.set_session_id("task-1", "sess-abc123");
707        store.set_pid("task-1", 12345);
708
709        // Save metadata
710        store.save_session_metadata("task-1", &temp_dir).unwrap();
711
712        // Verify file exists
713        let metadata_file = temp_dir.join(".scud").join("headless").join("task-1.json");
714        assert!(metadata_file.exists());
715
716        // Load metadata
717        let loaded = StreamStore::load_session_metadata("task-1", &temp_dir).unwrap();
718        assert_eq!(loaded, Some("sess-abc123".to_string()));
719
720        // Cleanup
721        std::fs::remove_dir_all(&temp_dir).ok();
722    }
723
724    #[test]
725    fn test_load_nonexistent_metadata() {
726        let temp_dir = std::env::temp_dir().join(format!("scud_test_ne_{}", std::process::id()));
727        let loaded = StreamStore::load_session_metadata("nonexistent", &temp_dir).unwrap();
728        assert_eq!(loaded, None);
729    }
730
731    #[test]
732    fn test_get_session_id_empty_string() {
733        let store = StreamStore::new();
734        store.create_session("task-1", "phase-a");
735        // Session ID is empty by default
736
737        // Should return None for empty session ID
738        let session_id = store.get_session_id("task-1");
739        assert_eq!(session_id, None);
740    }
741}