1use std::collections::HashMap;
14use std::sync::atomic::{AtomicU32, Ordering};
15use std::sync::{Arc, Mutex};
16use std::time::{Duration, Instant};
17
18use crate::tools::strip_ansi;
19use crate::{Result, RuntimeError};
20
21use super::config::ShellConfig;
22use super::pty::PtyHandle;
23use super::readiness::{ReadinessDetector, ReadinessResult, ReadinessStrategy};
24
25pub struct SessionManager {
31 sessions: Mutex<HashMap<String, ShellSession>>,
32 config: ShellConfig,
33 next_id: AtomicU32,
34}
35
36struct ShellSession {
38 pty: PtyHandle,
39 detector: ReadinessDetector,
40 created_at: Instant,
41 last_active: Instant,
42 idle_timeout: Duration,
43 status: SessionStatus,
44}
45
46#[derive(Debug, Clone, PartialEq)]
48pub enum SessionStatus {
49 Active,
50 Exited(Option<i32>),
51 Closed,
52}
53
54pub struct SessionOpts {
56 pub command: Option<String>,
57 pub working_directory: Option<String>,
58 pub env: HashMap<String, String>,
59 pub rows: Option<u16>,
60 pub cols: Option<u16>,
61 pub readiness_timeout_ms: Option<u64>,
62 pub idle_timeout: Option<u64>,
63}
64
65#[derive(Debug)]
67pub struct SendResult {
68 pub output: String,
69 pub status: String,
71}
72
73pub struct ShellSessionInfo {
75 pub id: String,
76 pub status: SessionStatus,
77 pub created_at: Instant,
78 pub last_active: Instant,
79}
80
81fn normalize_output(raw: &str) -> String {
87 strip_ansi(raw).replace("\r\n", "\n").replace('\r', "")
88}
89
90fn process_input_escapes(input: &str) -> String {
96 let mut result = String::with_capacity(input.len());
97 let mut chars = input.chars().peekable();
98
99 while let Some(ch) = chars.next() {
100 if ch == '\\' {
101 match chars.peek() {
102 Some('n') => { chars.next(); result.push('\n'); }
103 Some('r') => { chars.next(); result.push('\r'); }
104 Some('t') => { chars.next(); result.push('\t'); }
105 Some('\\') => { chars.next(); result.push('\\'); }
106 Some('a') => { chars.next(); result.push('\x07'); } Some('b') => { chars.next(); result.push('\x08'); } Some('0') => { chars.next(); result.push('\0'); } Some('e') => {
110 chars.next();
111 tracing::warn!("blocked \\e escape sequence (raw ESC) in shell input");
112 }
113 Some('x') => {
114 chars.next(); let mut hex = String::new();
116 for _ in 0..2 {
117 if let Some(&c) = chars.peek() {
118 if c.is_ascii_hexdigit() {
119 hex.push(c);
120 chars.next();
121 } else {
122 break;
123 }
124 }
125 }
126 if let Ok(byte) = u8::from_str_radix(&hex, 16) {
127 if byte == 0x1b {
128 tracing::warn!("blocked \\x1b escape sequence (raw ESC) in shell input");
130 } else if byte >= 0x80 {
131 tracing::warn!("blocked \\x{hex:} high byte (>= 0x80) in shell input");
133 } else {
134 result.push(byte as char);
136 }
137 } else {
138 result.push('\\');
140 result.push('x');
141 result.push_str(&hex);
142 }
143 }
144 _ => {
145 result.push(ch);
147 }
148 }
149 } else {
150 result.push(ch);
151 }
152 }
153 result
154}
155
156fn status_string(status: &SessionStatus) -> String {
158 match status {
159 SessionStatus::Active => "active".into(),
160 SessionStatus::Exited(Some(code)) => format!("exited({code})"),
161 SessionStatus::Exited(None) => "exited".into(),
162 SessionStatus::Closed => "closed".into(),
163 }
164}
165
166async fn wait_for_output(
171 pty: &mut PtyHandle,
172 detector: &ReadinessDetector,
173 timeout_override: Option<u64>,
174 tx_delta: Option<&tokio::sync::mpsc::UnboundedSender<String>>,
175 max_output: usize,
176) -> (String, String) {
177 let override_detector;
180 let effective_detector = if let Some(ms) = timeout_override {
181 override_detector = ReadinessDetector::new(
182 ReadinessStrategy::Hybrid,
183 &[], ms,
185 ms.saturating_mul(10).max(10_000), );
187 &override_detector
188 } else {
189 detector
190 };
191
192 let mut output = String::new();
193 let start = Instant::now();
194 let mut last_output_time = Instant::now();
195 let poll_interval = Duration::from_millis(50);
196
197 loop {
198 let bytes = pty.try_read_output(poll_interval).await;
200
201 if !bytes.is_empty() {
202 let text = String::from_utf8_lossy(&bytes);
203 output.push_str(&text);
204 last_output_time = Instant::now();
205
206 if let Some(tx) = tx_delta {
208 let _ = tx.send(normalize_output(&text));
209 }
210 }
211
212 if output.len() > max_output {
214 let mut trunc = max_output;
215 while trunc > 0 && !output.is_char_boundary(trunc) {
216 trunc -= 1;
217 }
218 output.truncate(trunc);
219 return (normalize_output(&output), "active".into());
220 }
221
222 if !pty.is_alive() {
224 tokio::time::sleep(Duration::from_millis(50)).await;
226 let remaining = pty.try_read_output(Duration::from_millis(100)).await;
227 if !remaining.is_empty() {
228 let remaining_text = String::from_utf8_lossy(&remaining);
229 output.push_str(&remaining_text);
230
231 if let Some(tx) = tx_delta {
233 let _ = tx.send(normalize_output(&remaining_text));
234 }
235 }
236 return (normalize_output(&output), status_string(&SessionStatus::Exited(None)));
238 }
239
240 let silence_elapsed = last_output_time.elapsed();
242 let total_elapsed = start.elapsed();
243
244 match effective_detector.check(&output, silence_elapsed, total_elapsed) {
245 ReadinessResult::Ready => return (normalize_output(&output), "active".into()),
246 ReadinessResult::SilenceTimeout => return (normalize_output(&output), "active".into()),
247 ReadinessResult::MaxTimeout => return (normalize_output(&output), "timeout".into()),
248 ReadinessResult::Waiting => continue,
249 }
250 }
251}
252
253impl SessionManager {
258 pub fn new(config: ShellConfig) -> Arc<Self> {
260 Arc::new(Self {
261 sessions: Mutex::new(HashMap::new()),
262 config,
263 next_id: AtomicU32::new(0),
264 })
265 }
266
267 pub async fn create_session(
271 &self,
272 opts: SessionOpts,
273 tx_delta: Option<&tokio::sync::mpsc::UnboundedSender<String>>,
274 ) -> Result<(String, String, String)> {
275 {
277 let sessions = self.sessions.lock().map_err(|e| {
278 RuntimeError::Tool(format!("session lock poisoned: {e}"))
279 })?;
280 if sessions.len() >= self.config.max_sessions {
281 return Err(RuntimeError::Tool(format!(
282 "maximum session limit reached ({})",
283 self.config.max_sessions
284 )));
285 }
286 }
287
288 let seq = self.next_id.fetch_add(1, Ordering::SeqCst) + 1;
290 let id = format!("shell_{:02}", seq);
291
292 let command = opts.command.unwrap_or_else(|| {
294 std::env::var("SHELL").unwrap_or_else(|_| "bash".into())
295 });
296 let rows = opts.rows.unwrap_or(self.config.default_rows);
297 let cols = opts.cols.unwrap_or(self.config.default_cols);
298
299 let idle_timeout = opts
300 .idle_timeout
301 .map(Duration::from_secs)
302 .unwrap_or(self.config.idle_timeout);
303
304 let mut pty = PtyHandle::spawn(
306 &command,
307 opts.working_directory.as_deref(),
308 opts.env,
309 rows,
310 cols,
311 )?;
312
313 let silence_ms = opts
315 .readiness_timeout_ms
316 .unwrap_or(self.config.readiness_timeout_ms);
317 let detector = ReadinessDetector::new(
318 super::readiness::ReadinessStrategy::Hybrid,
319 &self.config.prompt_patterns,
320 silence_ms,
321 self.config.max_readiness_timeout_ms,
322 );
323
324 tokio::time::sleep(Duration::from_millis(200)).await;
329 let (initial_output, status_str) =
330 wait_for_output(&mut pty, &detector, opts.readiness_timeout_ms, tx_delta, 30000).await;
331
332 let now = Instant::now();
333 let status = if status_str.starts_with("exited") {
334 SessionStatus::Exited(None)
335 } else {
336 SessionStatus::Active
337 };
338
339 let session = ShellSession {
340 pty,
341 detector,
342 created_at: now,
343 last_active: now,
344 idle_timeout,
345 status,
346 };
347
348 {
350 let mut sessions = self.sessions.lock().map_err(|e| {
351 RuntimeError::Tool(format!("session lock poisoned: {e}"))
352 })?;
353 sessions.insert(id.clone(), session);
354 }
355
356 Ok((id, initial_output, status_str))
357 }
358
359 pub async fn send_input(
361 &self,
362 id: &str,
363 input: &str,
364 timeout_ms: Option<u64>,
365 tx_delta: Option<&tokio::sync::mpsc::UnboundedSender<String>>,
366 ) -> Result<SendResult> {
367 let mut session = {
369 let mut sessions = self.sessions.lock().map_err(|e| {
370 RuntimeError::Tool(format!("session lock poisoned: {e}"))
371 })?;
372 sessions.remove(id).ok_or_else(|| {
373 RuntimeError::Tool(format!(
374 "session {id} not found — it may have been closed, reaped, or is currently in use by another call"
375 ))
376 })?
377 };
378
379 if session.status != SessionStatus::Active {
381 let s_str = status_string(&session.status);
383 let mut sessions = self.sessions.lock().map_err(|e| {
384 RuntimeError::Tool(format!("session lock poisoned: {e}"))
385 })?;
386 sessions.insert(id.to_string(), session);
387 return Err(RuntimeError::Tool(format!(
388 "session {id} is not active (status: {s_str})"
389 )));
390 }
391
392 let processed = process_input_escapes(input);
394 session.pty.write(processed.as_bytes())?;
395
396 let (output, status_str) =
398 wait_for_output(&mut session.pty, &session.detector, timeout_ms, tx_delta, 30000).await;
399
400 session.last_active = Instant::now();
402 if !session.pty.is_alive() {
403 session.status = SessionStatus::Exited(None);
404 }
405
406 let result = SendResult {
407 output,
408 status: status_str,
409 };
410
411 {
413 let mut sessions = self.sessions.lock().map_err(|e| {
414 RuntimeError::Tool(format!("session lock poisoned: {e}"))
415 })?;
416 sessions.insert(id.to_string(), session);
417 }
418
419 Ok(result)
420 }
421
422 pub async fn close_session(&self, id: &str) -> Result<String> {
426 let mut session = {
427 let mut sessions = self.sessions.lock().map_err(|e| {
428 RuntimeError::Tool(format!("session lock poisoned: {e}"))
429 })?;
430 match sessions.remove(id) {
431 Some(s) => s,
432 None => return Ok(String::new()),
433 }
434 };
435
436 let remaining = session
438 .pty
439 .try_read_output(Duration::from_millis(100))
440 .await;
441 let final_output = if remaining.is_empty() {
442 String::new()
443 } else {
444 strip_ansi(&String::from_utf8_lossy(&remaining))
445 };
446
447 drop(session);
449
450 Ok(final_output)
451 }
452
453 pub fn reap_idle(&self) -> Vec<String> {
457 let mut sessions = match self.sessions.lock() {
458 Ok(s) => s,
459 Err(e) => {
460 tracing::error!("session lock poisoned: {e}");
461 return Vec::new();
462 }
463 };
464
465 let grace_period = Duration::from_secs(5);
466
467 let ids_to_reap: Vec<String> = sessions
468 .iter()
469 .filter(|(_, s)| {
470 let elapsed = s.last_active.elapsed();
471 elapsed > s.idle_timeout && elapsed > grace_period
472 })
473 .map(|(id, _)| id.clone())
474 .collect();
475
476 for id in &ids_to_reap {
477 sessions.remove(id);
478 }
480
481 ids_to_reap
482 }
483
484 pub fn shutdown_all(&self) {
486 match self.sessions.lock() {
487 Ok(mut sessions) => {
488 sessions.drain();
489 }
491 Err(e) => {
492 tracing::error!("session lock poisoned: {e}");
493 }
494 }
495 }
496
497 pub fn active_count(&self) -> usize {
499 match self.sessions.lock() {
500 Ok(s) => s.len(),
501 Err(e) => {
502 tracing::error!("session lock poisoned: {e}");
503 0
504 }
505 }
506 }
507
508 pub fn list_sessions(&self) -> Vec<ShellSessionInfo> {
510 match self.sessions.lock() {
511 Ok(sessions) => {
512 sessions
513 .iter()
514 .map(|(id, s)| ShellSessionInfo {
515 id: id.clone(),
516 status: s.status.clone(),
517 created_at: s.created_at,
518 last_active: s.last_active,
519 })
520 .collect()
521 }
522 Err(e) => {
523 tracing::error!("session lock poisoned: {e}");
524 Vec::new()
525 }
526 }
527 }
528}
529
530impl Drop for SessionManager {
535 fn drop(&mut self) {
536 self.shutdown_all();
537 }
538}
539
540pub fn start_reaper(
550 manager: Arc<SessionManager>,
551 cancel: tokio_util::sync::CancellationToken,
552) -> tokio::task::JoinHandle<()> {
553 tokio::spawn(async move {
554 let interval = Duration::from_secs(30);
555 loop {
556 tokio::select! {
557 _ = cancel.cancelled() => break,
558 _ = tokio::time::sleep(interval) => {
559 let reaped = manager.reap_idle();
560 for id in &reaped {
561 tracing::info!(session_id = %id, "reaped idle shell session");
562 }
563 }
564 }
565 }
566 })
567}
568
569#[cfg(test)]
574mod tests {
575 use super::*;
576
577 fn default_manager() -> Arc<SessionManager> {
578 SessionManager::new(ShellConfig::default())
579 }
580
581 fn opts_for(command: &str) -> SessionOpts {
582 SessionOpts {
583 command: Some(command.to_string()),
584 working_directory: None,
585 env: HashMap::new(),
586 rows: None,
587 cols: None,
588 readiness_timeout_ms: None,
589 idle_timeout: None,
590 }
591 }
592
593 #[tokio::test]
595 async fn test_create_session_echo_hello() {
596 let mgr = default_manager();
597 let (id, output, _status) = mgr
598 .create_session(opts_for("echo hello"), None)
599 .await
600 .expect("failed to create session");
601
602 assert!(id.starts_with("shell_"));
603 assert!(
604 output.contains("hello"),
605 "expected 'hello' in output, got: {output:?}"
606 );
607 }
608
609 #[tokio::test]
611 async fn test_send_input_echo() {
612 let mgr = default_manager();
613 let (id, _initial, _status) = mgr
614 .create_session(opts_for("bash"), None)
615 .await
616 .expect("failed to create session");
617
618 let result = mgr
619 .send_input(&id, "echo test\n", None, None)
620 .await
621 .expect("failed to send input");
622
623 assert!(
624 result.output.contains("test"),
625 "expected 'test' in output, got: {:?}",
626 result.output
627 );
628
629 let _ = mgr.close_session(&id).await;
631 }
632
633 #[tokio::test]
635 async fn test_close_session_idempotent() {
636 let mgr = default_manager();
637 let (id, _, _status) = mgr
638 .create_session(opts_for("bash"), None)
639 .await
640 .expect("failed to create session");
641
642 let result1 = mgr.close_session(&id).await;
643 assert!(result1.is_ok(), "first close should succeed");
644
645 let result2 = mgr.close_session(&id).await;
646 assert!(result2.is_ok(), "second close should also succeed (idempotent)");
647 assert_eq!(result2.unwrap(), "", "second close returns empty string");
648 }
649
650 #[tokio::test]
652 async fn test_max_sessions_limit() {
653 let mut config = ShellConfig::default();
654 config.max_sessions = 2;
655 let mgr = SessionManager::new(config);
656
657 let (id1, _, _s) = mgr
658 .create_session(opts_for("bash"), None)
659 .await
660 .expect("session 1");
661 let (id2, _, _s) = mgr
662 .create_session(opts_for("bash"), None)
663 .await
664 .expect("session 2");
665
666 let result = mgr.create_session(opts_for("bash"), None).await;
667 assert!(result.is_err(), "third session should fail");
668 let err_msg = format!("{}", result.unwrap_err());
669 assert!(
670 err_msg.contains("maximum session limit"),
671 "error should mention limit, got: {err_msg}"
672 );
673
674 let _ = mgr.close_session(&id1).await;
676 let _ = mgr.close_session(&id2).await;
677 }
678
679 #[tokio::test]
681 async fn test_session_not_found() {
682 let mgr = default_manager();
683 let result = mgr.send_input("shell_99", "hello\n", None, None).await;
684 assert!(result.is_err(), "send to non-existent session should fail");
685 let err_msg = format!("{}", result.unwrap_err());
686 assert!(
687 err_msg.contains("not found"),
688 "error should mention 'not found', got: {err_msg}"
689 );
690 }
691
692 #[test]
695 fn test_normalize_output_crlf() {
696 assert_eq!(normalize_output("hello\r\nworld\r\n"), "hello\nworld\n");
697 }
698
699 #[test]
700 fn test_normalize_output_lone_cr() {
701 assert_eq!(normalize_output("abc\rdef"), "abcdef");
702 }
703
704 #[test]
707 fn test_escape_newline() {
708 assert_eq!(process_input_escapes(r"hello\n"), "hello\n");
709 }
710
711 #[test]
712 fn test_escape_tab() {
713 assert_eq!(process_input_escapes(r"a\tb"), "a\tb");
714 }
715
716 #[test]
717 fn test_escape_ctrl_c() {
718 assert_eq!(process_input_escapes(r"\x03"), "\x03");
719 }
720
721 #[test]
722 fn test_escape_ctrl_d() {
723 assert_eq!(process_input_escapes(r"\x04"), "\x04");
724 }
725
726 #[test]
727 fn test_escape_literal_backslash() {
728 assert_eq!(process_input_escapes(r"a\\b"), "a\\b");
729 }
730
731 #[test]
732 fn test_escape_real_newline_passthrough() {
733 assert_eq!(process_input_escapes("hello\n"), "hello\n");
735 }
736
737 #[test]
738 fn test_escape_mixed() {
739 assert_eq!(process_input_escapes(r"ls -la\n"), "ls -la\n");
740 assert_eq!(process_input_escapes(r"124\n"), "124\n");
741 }
742
743 #[test]
744 fn test_escape_unknown_sequence() {
745 assert_eq!(process_input_escapes(r"\q"), "\\q");
747 }
748
749 #[test]
750 fn test_escape_hex_partial() {
751 assert_eq!(process_input_escapes(r"\xZZ"), "\\xZZ");
753 }
754
755 #[test]
756 fn test_escape_bell() {
757 assert_eq!(process_input_escapes(r"\a"), "\x07");
758 }
759
760 #[test]
761 fn test_escape_backspace() {
762 assert_eq!(process_input_escapes(r"\b"), "\x08");
763 }
764
765 #[test]
766 fn test_escape_null() {
767 assert_eq!(process_input_escapes(r"\0"), "\0");
768 }
769
770 #[test]
771 fn test_escape_esc_blocked() {
772 assert_eq!(process_input_escapes(r"\e"), "");
774 }
775
776 #[test]
777 fn test_escape_hex_1b_blocked() {
778 assert_eq!(process_input_escapes(r"\x1b"), "");
780 }
781
782 #[test]
783 fn test_escape_hex_high_byte_blocked() {
784 assert_eq!(process_input_escapes(r"\x80"), "");
786 assert_eq!(process_input_escapes(r"\xff"), "");
787 }
788
789 #[test]
790 fn test_escape_hex_del_allowed() {
791 assert_eq!(process_input_escapes(r"\x7f"), "\x7f");
793 }
794}