Skip to main content

symbi_runtime/toolclad/
session_executor.rs

1//! SessionExecutor — PTY-based interactive CLI tool session manager.
2//!
3//! Spawns interactive CLI tools in pseudo-terminals, manages per-interaction
4//! command validation and policy checking, and captures evidence transcripts.
5
6#[cfg(feature = "toolclad-session")]
7use pty_process::blocking::{Command as PtyCommand, Pty};
8
9use std::collections::HashMap;
10#[cfg(feature = "toolclad-session")]
11use std::io::Read;
12#[cfg(feature = "toolclad-session")]
13use std::io::Write;
14use std::sync::{Arc, Mutex};
15#[cfg(feature = "toolclad-session")]
16use std::time::{Duration, Instant};
17
18use super::manifest::Manifest;
19#[cfg(feature = "toolclad-session")]
20use super::manifest::SessionDef;
21use super::session_state::*;
22
23/// Manages interactive CLI tool sessions via PTY.
24pub struct SessionExecutor {
25    sessions: Arc<Mutex<HashMap<SessionId, SessionHandle>>>,
26    manifests: HashMap<String, Manifest>,
27}
28
29/// A live session handle.
30struct SessionHandle {
31    #[cfg(feature = "toolclad-session")]
32    pty: Pty,
33    #[cfg(feature = "toolclad-session")]
34    child: std::process::Child,
35    state: SessionState,
36    transcript: SessionTranscript,
37    #[allow(dead_code)]
38    manifest_name: String,
39}
40
41impl SessionExecutor {
42    pub fn new(manifests: Vec<(String, Manifest)>) -> Self {
43        let session_manifests: HashMap<String, Manifest> = manifests
44            .into_iter()
45            .filter(|(_, m)| m.tool.mode == "session")
46            .collect();
47        Self {
48            sessions: Arc::new(Mutex::new(HashMap::new())),
49            manifests: session_manifests,
50        }
51    }
52
53    pub fn handles(&self, tool_name: &str) -> bool {
54        // Check for "toolname.command" pattern
55        if let Some(base) = tool_name.split('.').next() {
56            if let Some(m) = self.manifests.get(base) {
57                if let Some(session) = &m.session {
58                    let cmd = tool_name
59                        .strip_prefix(base)
60                        .unwrap_or("")
61                        .trim_start_matches('.');
62                    return !cmd.is_empty() && session.commands.contains_key(cmd);
63                }
64            }
65        }
66        false
67    }
68
69    /// Execute a session command. Creates the session if it doesn't exist.
70    pub fn execute_session_command(
71        &self,
72        tool_name: &str,
73        args_json: &str,
74    ) -> Result<serde_json::Value, String> {
75        let (manifest_name, command_name) = parse_session_tool_name(tool_name)?;
76
77        let manifest = self
78            .manifests
79            .get(&manifest_name)
80            .ok_or_else(|| format!("No session manifest for '{}'", manifest_name))?;
81        let session_def = manifest
82            .session
83            .as_ref()
84            .ok_or("Manifest has no [session] section")?;
85        let cmd_def = session_def
86            .commands
87            .get(&command_name)
88            .ok_or_else(|| format!("Unknown session command: {}", command_name))?;
89
90        // Parse and validate arguments
91        let args: HashMap<String, serde_json::Value> =
92            serde_json::from_str(args_json).map_err(|e| format!("Invalid arguments: {}", e))?;
93
94        let command_str = args
95            .get("command")
96            .and_then(|v| v.as_str())
97            .ok_or("Session command requires 'command' argument")?;
98
99        // Validate command against pattern
100        let re = regex::Regex::new(&cmd_def.pattern)
101            .map_err(|e| format!("Invalid command pattern: {}", e))?;
102        if !re.is_match(command_str) {
103            return Err(format!(
104                "Command '{}' does not match pattern '{}' for {}",
105                command_str, cmd_def.pattern, command_name
106            ));
107        }
108
109        // Check max interactions
110        {
111            let sessions = self.sessions.lock().map_err(|e| e.to_string())?;
112            if let Some(handle) = sessions.get(&manifest_name) {
113                if handle.state.interaction_count >= session_def.max_interactions {
114                    return Err(format!(
115                        "Session '{}' exceeded max interactions ({})",
116                        manifest_name, session_def.max_interactions
117                    ));
118                }
119            }
120        }
121
122        // Ensure session exists (spawn if needed)
123        #[cfg(feature = "toolclad-session")]
124        {
125            self.ensure_session(&manifest_name, manifest, session_def)?;
126        }
127
128        // Send command and get response
129        #[cfg(feature = "toolclad-session")]
130        {
131            let mut sessions = self.sessions.lock().map_err(|e| e.to_string())?;
132            let handle = sessions
133                .get_mut(&manifest_name)
134                .ok_or("Session not found after ensure")?;
135
136            let start = Instant::now();
137
138            // Write command to PTY
139            handle
140                .pty
141                .write_all(format!("{}\n", command_str).as_bytes())
142                .map_err(|e| format!("Failed to write to PTY: {}", e))?;
143            handle
144                .pty
145                .flush()
146                .map_err(|e| format!("Flush failed: {}", e))?;
147
148            // Log command
149            handle.transcript.append(
150                TranscriptDirection::Command,
151                command_str,
152                Some(&command_name),
153            );
154
155            // Read until prompt
156            let output_wait = session_def
157                .interaction
158                .as_ref()
159                .map(|i| i.output_wait_ms)
160                .unwrap_or(2000);
161            let max_bytes = session_def
162                .interaction
163                .as_ref()
164                .map(|i| i.output_max_bytes)
165                .unwrap_or(1_048_576) as usize;
166
167            let output = read_until_prompt_blocking(
168                &mut handle.pty,
169                &session_def.ready_pattern,
170                Duration::from_millis(output_wait * 5), // give 5x the wait time
171                max_bytes,
172            )?;
173
174            let duration_ms = start.elapsed().as_millis() as u64;
175
176            // Strip ANSI and extract meaningful output
177            let clean_output = strip_ansi(&output.0);
178            let prompt = output.1.clone();
179
180            // Update state
181            handle.state.interaction_count += 1;
182            handle.state.last_interaction_at = Instant::now();
183            handle.state.prompt = prompt.clone();
184            handle.state.inferred_state = infer_state(&prompt);
185
186            // Log response
187            handle.transcript.append(
188                TranscriptDirection::Response,
189                &clean_output,
190                Some(&command_name),
191            );
192
193            // Build envelope
194            let scan_id = format!(
195                "{}-{}",
196                chrono::Utc::now().timestamp(),
197                uuid::Uuid::new_v4().as_fields().0
198            );
199            return Ok(serde_json::json!({
200                "status": "success",
201                "scan_id": scan_id,
202                "tool": tool_name,
203                "session_id": handle.state.session_id,
204                "duration_ms": duration_ms,
205                "timestamp": chrono::Utc::now().to_rfc3339(),
206                "exit_code": 0,
207                "stderr": "",
208                "results": {
209                    "output": clean_output,
210                    "prompt": prompt,
211                    "session_state": handle.state.inferred_state,
212                    "interaction_count": handle.state.interaction_count,
213                }
214            }));
215        }
216
217        #[cfg(not(feature = "toolclad-session"))]
218        Err("Session mode requires the 'toolclad-session' feature".to_string())
219    }
220
221    #[cfg(feature = "toolclad-session")]
222    fn ensure_session(
223        &self,
224        name: &str,
225        _manifest: &Manifest,
226        session_def: &SessionDef,
227    ) -> Result<(), String> {
228        let mut sessions = self.sessions.lock().map_err(|e| e.to_string())?;
229        if sessions.contains_key(name) {
230            return Ok(());
231        }
232
233        // Spawn PTY
234        let pty = Pty::new().map_err(|e| format!("Failed to create PTY: {}", e))?;
235        let pts = pty.pts().map_err(|e| format!("Failed to get PTS: {}", e))?;
236
237        let child = PtyCommand::new("sh")
238            .arg("-c")
239            .arg(&session_def.startup_command)
240            .spawn(&pts)
241            .map_err(|e| format!("Failed to spawn '{}': {}", session_def.startup_command, e))?;
242
243        let session_id = format!("session-{}-{}", name, uuid::Uuid::new_v4().as_fields().0);
244
245        let handle = SessionHandle {
246            pty,
247            child,
248            state: SessionState {
249                status: SessionStatus::Spawning,
250                prompt: String::new(),
251                inferred_state: "spawning".to_string(),
252                interaction_count: 0,
253                started_at: Instant::now(),
254                last_interaction_at: Instant::now(),
255                session_id,
256            },
257            transcript: SessionTranscript::default(),
258            manifest_name: name.to_string(),
259        };
260
261        sessions.insert(name.to_string(), handle);
262
263        // Wait for ready pattern
264        let handle = sessions.get_mut(name).unwrap();
265        let timeout = Duration::from_secs(session_def.startup_timeout_seconds);
266        let output = read_until_prompt_blocking(
267            &mut handle.pty,
268            &session_def.ready_pattern,
269            timeout,
270            1_048_576,
271        )
272        .map_err(|e| format!("Session startup failed: {}", e))?;
273
274        handle.state.status = SessionStatus::Ready;
275        handle.state.prompt = output.1;
276        handle.state.inferred_state = "ready".to_string();
277        handle
278            .transcript
279            .append(TranscriptDirection::System, "Session started", None);
280
281        Ok(())
282    }
283
284    /// Get session transcript for evidence.
285    pub fn get_transcript(&self, manifest_name: &str) -> Option<SessionTranscript> {
286        let sessions = self.sessions.lock().ok()?;
287        sessions.get(manifest_name).map(|h| h.transcript.clone())
288    }
289
290    /// Cleanup all sessions.
291    pub fn cleanup(&self) {
292        if let Ok(mut sessions) = self.sessions.lock() {
293            for (_name, handle) in sessions.drain() {
294                #[cfg(feature = "toolclad-session")]
295                {
296                    let mut child = handle.child;
297                    let _ = child.kill();
298                }
299                #[cfg(not(feature = "toolclad-session"))]
300                {
301                    let _ = handle;
302                }
303            }
304        }
305    }
306}
307
308fn parse_session_tool_name(name: &str) -> Result<(String, String), String> {
309    let parts: Vec<&str> = name.splitn(2, '.').collect();
310    if parts.len() != 2 {
311        return Err(format!(
312            "Invalid session tool name: '{}' (expected 'session.command')",
313            name
314        ));
315    }
316    Ok((parts[0].to_string(), parts[1].to_string()))
317}
318
319#[cfg(feature = "toolclad-session")]
320fn read_until_prompt_blocking(
321    pty: &mut Pty,
322    pattern: &str,
323    timeout: Duration,
324    max_bytes: usize,
325) -> Result<(String, String), String> {
326    let re = regex::Regex::new(pattern)
327        .map_err(|e| format!("Invalid ready pattern '{}': {}", pattern, e))?;
328
329    let start = Instant::now();
330    let mut buffer = Vec::new();
331    let mut byte = [0u8; 1024];
332
333    loop {
334        if start.elapsed() > timeout {
335            let partial = String::from_utf8_lossy(&buffer).to_string();
336            return Err(format!(
337                "Timeout waiting for prompt pattern '{}'. Got: {}",
338                pattern,
339                &partial[..partial.len().min(200)]
340            ));
341        }
342        if buffer.len() > max_bytes {
343            return Err("Output exceeded max bytes".to_string());
344        }
345
346        match pty.read(&mut byte) {
347            Ok(0) => break,
348            Ok(n) => {
349                buffer.extend_from_slice(&byte[..n]);
350                let text = String::from_utf8_lossy(&buffer);
351                // Check if prompt pattern appears at the end
352                for line in text.lines().rev().take(3) {
353                    if re.is_match(line.trim()) {
354                        let output = text.to_string();
355                        let prompt = line.trim().to_string();
356                        return Ok((output, prompt));
357                    }
358                }
359            }
360            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
361                std::thread::sleep(Duration::from_millis(50));
362            }
363            Err(e) => return Err(format!("PTY read error: {}", e)),
364        }
365    }
366
367    let text = String::from_utf8_lossy(&buffer).to_string();
368    Err(format!(
369        "PTY closed before prompt. Got: {}",
370        &text[..text.len().min(200)]
371    ))
372}
373
374/// Strip ANSI escape sequences.
375#[cfg(any(feature = "toolclad-session", test))]
376fn strip_ansi(input: &str) -> String {
377    let re = regex::Regex::new(r"\x1b\[[0-9;]*[a-zA-Z]").unwrap();
378    re.replace_all(input, "").to_string()
379}
380
381/// Infer session state from prompt text.
382#[cfg(any(feature = "toolclad-session", test))]
383fn infer_state(prompt: &str) -> String {
384    let lower = prompt.to_lowercase();
385    if lower.contains("error") {
386        "error".to_string()
387    } else {
388        "ready".to_string()
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_parse_session_tool_name() {
398        let (base, cmd) = parse_session_tool_name("psql_session.select").unwrap();
399        assert_eq!(base, "psql_session");
400        assert_eq!(cmd, "select");
401    }
402
403    #[test]
404    fn test_parse_session_tool_name_invalid() {
405        assert!(parse_session_tool_name("no_dot").is_err());
406    }
407
408    #[test]
409    fn test_strip_ansi() {
410        assert_eq!(strip_ansi("\x1b[32mhello\x1b[0m"), "hello");
411        assert_eq!(strip_ansi("no escapes"), "no escapes");
412    }
413
414    #[test]
415    fn test_infer_state() {
416        assert_eq!(infer_state("dbname=> "), "ready");
417        assert_eq!(infer_state("ERROR: "), "error");
418    }
419
420    #[test]
421    fn test_session_executor_handles() {
422        let manifest_toml = r#"
423[tool]
424name = "test_session"
425mode = "session"
426version = "1.0.0"
427description = "Test"
428
429[session]
430startup_command = "cat"
431ready_pattern = "^$"
432
433[session.commands.echo]
434pattern = "^echo .+$"
435description = "Echo text"
436
437[output]
438format = "text"
439
440[output.schema]
441type = "object"
442"#;
443        let manifest: Manifest = toml::from_str(manifest_toml).unwrap();
444        let executor = SessionExecutor::new(vec![("test_session".to_string(), manifest)]);
445
446        assert!(executor.handles("test_session.echo"));
447        assert!(!executor.handles("test_session.unknown"));
448        assert!(!executor.handles("other_tool"));
449    }
450
451    #[test]
452    fn test_command_pattern_validation() {
453        let re = regex::Regex::new("^SELECT .+$").unwrap();
454        assert!(re.is_match("SELECT * FROM users"));
455        assert!(!re.is_match("DROP TABLE users"));
456    }
457
458    #[test]
459    fn test_transcript() {
460        let mut t = SessionTranscript::default();
461        t.append(TranscriptDirection::Command, "SELECT 1", Some("select"));
462        t.append(TranscriptDirection::Response, "1\n(1 row)", Some("select"));
463        assert_eq!(t.entries.len(), 2);
464        assert!(matches!(
465            t.entries[0].direction,
466            TranscriptDirection::Command
467        ));
468    }
469}