1#[cfg(feature = "toolclad-session")]
7use pty_process::blocking::{Command as PtyCommand, Pty};
8
9use std::collections::HashMap;
10#[cfg(feature = "toolclad-session")]
11use std::io::Read;
12#[cfg(feature = "toolclad-session")]
13use std::io::Write;
14use std::sync::{Arc, Mutex};
15#[cfg(feature = "toolclad-session")]
16use std::time::{Duration, Instant};
17
18use super::manifest::Manifest;
19#[cfg(feature = "toolclad-session")]
20use super::manifest::SessionDef;
21use super::session_state::*;
22
23pub struct SessionExecutor {
25 sessions: Arc<Mutex<HashMap<SessionId, SessionHandle>>>,
26 manifests: HashMap<String, Manifest>,
27}
28
29struct SessionHandle {
31 #[cfg(feature = "toolclad-session")]
32 pty: Pty,
33 #[cfg(feature = "toolclad-session")]
34 child: std::process::Child,
35 state: SessionState,
36 transcript: SessionTranscript,
37 #[allow(dead_code)]
38 manifest_name: String,
39}
40
41impl SessionExecutor {
42 pub fn new(manifests: Vec<(String, Manifest)>) -> Self {
43 let session_manifests: HashMap<String, Manifest> = manifests
44 .into_iter()
45 .filter(|(_, m)| m.tool.mode == "session")
46 .collect();
47 Self {
48 sessions: Arc::new(Mutex::new(HashMap::new())),
49 manifests: session_manifests,
50 }
51 }
52
53 pub fn handles(&self, tool_name: &str) -> bool {
54 if let Some(base) = tool_name.split('.').next() {
56 if let Some(m) = self.manifests.get(base) {
57 if let Some(session) = &m.session {
58 let cmd = tool_name
59 .strip_prefix(base)
60 .unwrap_or("")
61 .trim_start_matches('.');
62 return !cmd.is_empty() && session.commands.contains_key(cmd);
63 }
64 }
65 }
66 false
67 }
68
69 pub fn execute_session_command(
71 &self,
72 tool_name: &str,
73 args_json: &str,
74 ) -> Result<serde_json::Value, String> {
75 let (manifest_name, command_name) = parse_session_tool_name(tool_name)?;
76
77 let manifest = self
78 .manifests
79 .get(&manifest_name)
80 .ok_or_else(|| format!("No session manifest for '{}'", manifest_name))?;
81 let session_def = manifest
82 .session
83 .as_ref()
84 .ok_or("Manifest has no [session] section")?;
85 let cmd_def = session_def
86 .commands
87 .get(&command_name)
88 .ok_or_else(|| format!("Unknown session command: {}", command_name))?;
89
90 let args: HashMap<String, serde_json::Value> =
92 serde_json::from_str(args_json).map_err(|e| format!("Invalid arguments: {}", e))?;
93
94 let command_str = args
95 .get("command")
96 .and_then(|v| v.as_str())
97 .ok_or("Session command requires 'command' argument")?;
98
99 let re = regex::Regex::new(&cmd_def.pattern)
101 .map_err(|e| format!("Invalid command pattern: {}", e))?;
102 if !re.is_match(command_str) {
103 return Err(format!(
104 "Command '{}' does not match pattern '{}' for {}",
105 command_str, cmd_def.pattern, command_name
106 ));
107 }
108
109 {
111 let sessions = self.sessions.lock().map_err(|e| e.to_string())?;
112 if let Some(handle) = sessions.get(&manifest_name) {
113 if handle.state.interaction_count >= session_def.max_interactions {
114 return Err(format!(
115 "Session '{}' exceeded max interactions ({})",
116 manifest_name, session_def.max_interactions
117 ));
118 }
119 }
120 }
121
122 #[cfg(feature = "toolclad-session")]
124 {
125 self.ensure_session(&manifest_name, manifest, session_def)?;
126 }
127
128 #[cfg(feature = "toolclad-session")]
130 {
131 let mut sessions = self.sessions.lock().map_err(|e| e.to_string())?;
132 let handle = sessions
133 .get_mut(&manifest_name)
134 .ok_or("Session not found after ensure")?;
135
136 let start = Instant::now();
137
138 handle
140 .pty
141 .write_all(format!("{}\n", command_str).as_bytes())
142 .map_err(|e| format!("Failed to write to PTY: {}", e))?;
143 handle
144 .pty
145 .flush()
146 .map_err(|e| format!("Flush failed: {}", e))?;
147
148 handle.transcript.append(
150 TranscriptDirection::Command,
151 command_str,
152 Some(&command_name),
153 );
154
155 let output_wait = session_def
157 .interaction
158 .as_ref()
159 .map(|i| i.output_wait_ms)
160 .unwrap_or(2000);
161 let max_bytes = session_def
162 .interaction
163 .as_ref()
164 .map(|i| i.output_max_bytes)
165 .unwrap_or(1_048_576) as usize;
166
167 let output = read_until_prompt_blocking(
168 &mut handle.pty,
169 &session_def.ready_pattern,
170 Duration::from_millis(output_wait * 5), max_bytes,
172 )?;
173
174 let duration_ms = start.elapsed().as_millis() as u64;
175
176 let clean_output = strip_ansi(&output.0);
178 let prompt = output.1.clone();
179
180 handle.state.interaction_count += 1;
182 handle.state.last_interaction_at = Instant::now();
183 handle.state.prompt = prompt.clone();
184 handle.state.inferred_state = infer_state(&prompt);
185
186 handle.transcript.append(
188 TranscriptDirection::Response,
189 &clean_output,
190 Some(&command_name),
191 );
192
193 let scan_id = format!(
195 "{}-{}",
196 chrono::Utc::now().timestamp(),
197 uuid::Uuid::new_v4().as_fields().0
198 );
199 return Ok(serde_json::json!({
200 "status": "success",
201 "scan_id": scan_id,
202 "tool": tool_name,
203 "session_id": handle.state.session_id,
204 "duration_ms": duration_ms,
205 "timestamp": chrono::Utc::now().to_rfc3339(),
206 "exit_code": 0,
207 "stderr": "",
208 "results": {
209 "output": clean_output,
210 "prompt": prompt,
211 "session_state": handle.state.inferred_state,
212 "interaction_count": handle.state.interaction_count,
213 }
214 }));
215 }
216
217 #[cfg(not(feature = "toolclad-session"))]
218 Err("Session mode requires the 'toolclad-session' feature".to_string())
219 }
220
221 #[cfg(feature = "toolclad-session")]
222 fn ensure_session(
223 &self,
224 name: &str,
225 _manifest: &Manifest,
226 session_def: &SessionDef,
227 ) -> Result<(), String> {
228 let mut sessions = self.sessions.lock().map_err(|e| e.to_string())?;
229 if sessions.contains_key(name) {
230 return Ok(());
231 }
232
233 let pty = Pty::new().map_err(|e| format!("Failed to create PTY: {}", e))?;
235 let pts = pty.pts().map_err(|e| format!("Failed to get PTS: {}", e))?;
236
237 let child = PtyCommand::new("sh")
238 .arg("-c")
239 .arg(&session_def.startup_command)
240 .spawn(&pts)
241 .map_err(|e| format!("Failed to spawn '{}': {}", session_def.startup_command, e))?;
242
243 let session_id = format!("session-{}-{}", name, uuid::Uuid::new_v4().as_fields().0);
244
245 let handle = SessionHandle {
246 pty,
247 child,
248 state: SessionState {
249 status: SessionStatus::Spawning,
250 prompt: String::new(),
251 inferred_state: "spawning".to_string(),
252 interaction_count: 0,
253 started_at: Instant::now(),
254 last_interaction_at: Instant::now(),
255 session_id,
256 },
257 transcript: SessionTranscript::default(),
258 manifest_name: name.to_string(),
259 };
260
261 sessions.insert(name.to_string(), handle);
262
263 let handle = sessions.get_mut(name).unwrap();
265 let timeout = Duration::from_secs(session_def.startup_timeout_seconds);
266 let output = read_until_prompt_blocking(
267 &mut handle.pty,
268 &session_def.ready_pattern,
269 timeout,
270 1_048_576,
271 )
272 .map_err(|e| format!("Session startup failed: {}", e))?;
273
274 handle.state.status = SessionStatus::Ready;
275 handle.state.prompt = output.1;
276 handle.state.inferred_state = "ready".to_string();
277 handle
278 .transcript
279 .append(TranscriptDirection::System, "Session started", None);
280
281 Ok(())
282 }
283
284 pub fn get_transcript(&self, manifest_name: &str) -> Option<SessionTranscript> {
286 let sessions = self.sessions.lock().ok()?;
287 sessions.get(manifest_name).map(|h| h.transcript.clone())
288 }
289
290 pub fn cleanup(&self) {
292 if let Ok(mut sessions) = self.sessions.lock() {
293 for (_name, handle) in sessions.drain() {
294 #[cfg(feature = "toolclad-session")]
295 {
296 let mut child = handle.child;
297 let _ = child.kill();
298 }
299 #[cfg(not(feature = "toolclad-session"))]
300 {
301 let _ = handle;
302 }
303 }
304 }
305 }
306}
307
308fn parse_session_tool_name(name: &str) -> Result<(String, String), String> {
309 let parts: Vec<&str> = name.splitn(2, '.').collect();
310 if parts.len() != 2 {
311 return Err(format!(
312 "Invalid session tool name: '{}' (expected 'session.command')",
313 name
314 ));
315 }
316 Ok((parts[0].to_string(), parts[1].to_string()))
317}
318
319#[cfg(feature = "toolclad-session")]
320fn read_until_prompt_blocking(
321 pty: &mut Pty,
322 pattern: &str,
323 timeout: Duration,
324 max_bytes: usize,
325) -> Result<(String, String), String> {
326 let re = regex::Regex::new(pattern)
327 .map_err(|e| format!("Invalid ready pattern '{}': {}", pattern, e))?;
328
329 let start = Instant::now();
330 let mut buffer = Vec::new();
331 let mut byte = [0u8; 1024];
332
333 loop {
334 if start.elapsed() > timeout {
335 let partial = String::from_utf8_lossy(&buffer).to_string();
336 return Err(format!(
337 "Timeout waiting for prompt pattern '{}'. Got: {}",
338 pattern,
339 &partial[..partial.len().min(200)]
340 ));
341 }
342 if buffer.len() > max_bytes {
343 return Err("Output exceeded max bytes".to_string());
344 }
345
346 match pty.read(&mut byte) {
347 Ok(0) => break,
348 Ok(n) => {
349 buffer.extend_from_slice(&byte[..n]);
350 let text = String::from_utf8_lossy(&buffer);
351 for line in text.lines().rev().take(3) {
353 if re.is_match(line.trim()) {
354 let output = text.to_string();
355 let prompt = line.trim().to_string();
356 return Ok((output, prompt));
357 }
358 }
359 }
360 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
361 std::thread::sleep(Duration::from_millis(50));
362 }
363 Err(e) => return Err(format!("PTY read error: {}", e)),
364 }
365 }
366
367 let text = String::from_utf8_lossy(&buffer).to_string();
368 Err(format!(
369 "PTY closed before prompt. Got: {}",
370 &text[..text.len().min(200)]
371 ))
372}
373
374#[cfg(any(feature = "toolclad-session", test))]
376fn strip_ansi(input: &str) -> String {
377 let re = regex::Regex::new(r"\x1b\[[0-9;]*[a-zA-Z]").unwrap();
378 re.replace_all(input, "").to_string()
379}
380
381#[cfg(any(feature = "toolclad-session", test))]
383fn infer_state(prompt: &str) -> String {
384 let lower = prompt.to_lowercase();
385 if lower.contains("error") {
386 "error".to_string()
387 } else {
388 "ready".to_string()
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_parse_session_tool_name() {
398 let (base, cmd) = parse_session_tool_name("psql_session.select").unwrap();
399 assert_eq!(base, "psql_session");
400 assert_eq!(cmd, "select");
401 }
402
403 #[test]
404 fn test_parse_session_tool_name_invalid() {
405 assert!(parse_session_tool_name("no_dot").is_err());
406 }
407
408 #[test]
409 fn test_strip_ansi() {
410 assert_eq!(strip_ansi("\x1b[32mhello\x1b[0m"), "hello");
411 assert_eq!(strip_ansi("no escapes"), "no escapes");
412 }
413
414 #[test]
415 fn test_infer_state() {
416 assert_eq!(infer_state("dbname=> "), "ready");
417 assert_eq!(infer_state("ERROR: "), "error");
418 }
419
420 #[test]
421 fn test_session_executor_handles() {
422 let manifest_toml = r#"
423[tool]
424name = "test_session"
425mode = "session"
426version = "1.0.0"
427description = "Test"
428
429[session]
430startup_command = "cat"
431ready_pattern = "^$"
432
433[session.commands.echo]
434pattern = "^echo .+$"
435description = "Echo text"
436
437[output]
438format = "text"
439
440[output.schema]
441type = "object"
442"#;
443 let manifest: Manifest = toml::from_str(manifest_toml).unwrap();
444 let executor = SessionExecutor::new(vec![("test_session".to_string(), manifest)]);
445
446 assert!(executor.handles("test_session.echo"));
447 assert!(!executor.handles("test_session.unknown"));
448 assert!(!executor.handles("other_tool"));
449 }
450
451 #[test]
452 fn test_command_pattern_validation() {
453 let re = regex::Regex::new("^SELECT .+$").unwrap();
454 assert!(re.is_match("SELECT * FROM users"));
455 assert!(!re.is_match("DROP TABLE users"));
456 }
457
458 #[test]
459 fn test_transcript() {
460 let mut t = SessionTranscript::default();
461 t.append(TranscriptDirection::Command, "SELECT 1", Some("select"));
462 t.append(TranscriptDirection::Response, "1\n(1 row)", Some("select"));
463 assert_eq!(t.entries.len(), 2);
464 assert!(matches!(
465 t.entries[0].direction,
466 TranscriptDirection::Command
467 ));
468 }
469}