1use crate::capability::{Capability, Context, Output};
50use crate::validation::path::{validate_path, PathContext};
51use crate::{Error, Result};
52use serde::{Deserialize, Serialize};
53use serde_json::Value;
54use std::fs;
55use std::io::{Read, Write};
56use std::os::unix::process::CommandExt;
57use std::process::{Child, Command, ExitStatus};
58use std::thread;
59use std::time::{Duration, Instant};
60
61const DEFAULT_TIMEOUT_SECS: u64 = 30;
63
64const MAX_OUTPUT_BYTES: usize = 10 * 1024 * 1024;
66
67const MAX_STDIN_BYTES: usize = 1024 * 1024;
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ShellExecArgs {
73 pub cmd: String,
77 pub args: Option<Vec<String>>,
81 pub timeout_secs: Option<u64>,
83 pub cwd: Option<String>,
85 pub stdin: Option<String>,
87}
88
89fn resolve_program(program: &str) -> Result<String> {
96 if program.starts_with('/') {
97 return Ok(program.to_string());
98 }
99 if program.contains('/') {
100 return Err(Error::ExecutionFailed(format!(
101 "relative paths are not allowed: '{}'", program
102 )));
103 }
104 if let Ok(path_env) = std::env::var("PATH") {
105 for dir in path_env.split(':') {
106 let candidate = std::path::PathBuf::from(dir).join(program);
107 if candidate.exists() {
108 return Ok(candidate.to_string_lossy().to_string());
109 }
110 }
111 }
112 Ok(program.to_string())
113}
114
115fn is_dangerous_command(program: &str, args: &[String]) -> Option<&'static str> {
119 let program_lower = program.to_lowercase();
120
121 match program_lower.as_str() {
122 "mkfs" | "mkfs.ext2" | "mkfs.ext3" | "mkfs.ext4" | "mkfs.xfs"
123 | "mkfs.vfat" | "mkfs.btrfs" | "mkswap" => {
124 return Some("filesystem creation commands are blocked");
125 }
126 "fdisk" | "parted" | "sfdisk" | "cfdisk" => {
127 return Some("disk partitioning commands are blocked");
128 }
129 "dd" => {
130 return Some("dd (disk destroyer) is blocked");
131 }
132 "shutdown" | "reboot" | "halt" | "poweroff" => {
133 return Some("system power commands are blocked");
134 }
135 _ => {}
136 }
137
138 if program_lower == "rm" {
139 let has_recursive = args.iter().any(|a| a.starts_with('-') && a.contains('r'));
140 let has_force = args.iter().any(|a| a.starts_with('-') && a.contains('f'));
141 let targets_dangerous = args.iter().any(|a| {
142 a == "/" || a == "/*" || a.starts_with("/dev/") || a.starts_with("/boot")
143 });
144 if has_recursive && has_force && targets_dangerous {
145 return Some("rm -rf on root, devices, or boot is blocked");
146 }
147 }
148
149 if program_lower == "chmod" && args.iter().any(|a| a == "/")
150 && args.iter().any(|a| a == "777" || a == "0777") {
151 return Some("chmod 777 / is blocked");
152 }
153
154 None
155}
156
157fn wait_with_timeout(
166 child: &mut Child,
167 pgid: u32,
168 timeout_secs: u64,
169) -> Result<(ExitStatus, Vec<u8>, Vec<u8>, Vec<u32>)> {
170 let start = Instant::now();
171 let timeout = Duration::from_secs(timeout_secs);
172 let child_pid = child.id();
173
174 let stdout_thread = child.stdout.take().map(|stdout| {
177 thread::spawn(move || {
178 let mut data = Vec::new();
179 let _ = stdout.take(MAX_OUTPUT_BYTES as u64).read_to_end(&mut data);
180 data
181 })
182 });
183
184 let stderr_thread = child.stderr.take().map(|stderr| {
185 thread::spawn(move || {
186 let mut data = Vec::new();
187 let _ = stderr.take(MAX_OUTPUT_BYTES as u64).read_to_end(&mut data);
188 data
189 })
190 });
191
192 #[allow(unused_assignments)]
193 let mut last_descendants = Vec::new();
194
195 loop {
196 if start.elapsed() > timeout {
197 unsafe {
199 let _ = libc::kill(-(pgid as libc::pid_t), libc::SIGKILL);
200 }
201 last_descendants = get_all_descendants(child_pid);
203 let _status = child.wait().map_err(|e| {
205 Error::ExecutionFailed(format!("failed to reap after kill: {}", e))
206 })?;
207 let _stdout_data = stdout_thread
209 .map(|h| h.join().unwrap_or_default())
210 .unwrap_or_default();
211 let _stderr_data = stderr_thread
212 .map(|h| h.join().unwrap_or_default())
213 .unwrap_or_default();
214 return Err(Error::ExecutionFailed(format!(
215 "command timed out after {}s ({} descendants found)",
216 timeout_secs,
217 last_descendants.len()
218 )));
219 }
220
221 last_descendants = get_all_descendants(child_pid);
223
224 match child.try_wait() {
225 Ok(Some(status)) => {
226 let stdout_data = stdout_thread
227 .map(|h| h.join().unwrap_or_default())
228 .unwrap_or_default();
229 let stderr_data = stderr_thread
230 .map(|h| h.join().unwrap_or_default())
231 .unwrap_or_default();
232 return Ok((status, stdout_data, stderr_data, last_descendants));
233 }
234 Ok(None) => {
235 std::thread::sleep(Duration::from_millis(50));
236 }
237 Err(e) => {
238 return Err(Error::ExecutionFailed(format!("error waiting: {}", e)));
239 }
240 }
241 }
242}
243
244fn get_direct_children(pid: u32) -> Vec<u32> {
248 let children_path = format!("/proc/{}/children", pid);
249 if let Ok(content) = fs::read_to_string(&children_path) {
250 content
251 .split_whitespace()
252 .filter_map(|s| s.parse::<u32>().ok())
253 .collect()
254 } else {
255 Vec::new()
256 }
257}
258
259fn get_all_descendants(pid: u32) -> Vec<u32> {
265 let mut descendants = Vec::new();
266 let mut stack = vec![pid];
267 let mut visited = std::collections::HashSet::new();
268
269 while let Some(current) = stack.pop() {
270 if visited.contains(¤t) {
271 continue;
272 }
273 visited.insert(current);
274
275 let children = get_direct_children(current);
276 if children.is_empty() {
277 if let Ok(output) = std::process::Command::new("pgrep")
279 .arg("-P")
280 .arg(current.to_string())
281 .output()
282 {
283 if output.status.success() {
284 let pgrep_children = String::from_utf8_lossy(&output.stdout)
285 .lines()
286 .filter_map(|s| s.trim().parse::<u32>().ok())
287 .collect::<Vec<_>>();
288 for child in pgrep_children {
289 if !visited.contains(&child) {
290 descendants.push(child);
291 stack.push(child);
292 }
293 }
294 continue;
295 }
296 }
297 }
298
299 for child in children {
300 if !visited.contains(&child) {
301 descendants.push(child);
302 stack.push(child);
303 }
304 }
305 }
306
307 descendants
308}
309
310pub struct ShellExec;
321
322impl Capability for ShellExec {
323 fn name(&self) -> &'static str {
324 "ShellExec"
325 }
326
327 fn description(&self) -> &'static str {
328 "Execute a shell command with timeout, output capture, process group isolation, and audit logging. Dangerous commands are blocked."
329 }
330
331 fn schema(&self) -> Value {
335 serde_json::json!({
336 "type": "object",
337 "properties": {
338 "cmd": { "type": "string" },
339 "args": { "type": "array", "items": { "type": "string" } },
340 "timeout_secs": { "type": "integer", "minimum": 1, "maximum": 300 },
341 "cwd": { "type": "string" },
342 "stdin": { "type": "string" }
343 },
344 "required": ["cmd"]
345 })
346 }
347
348 fn validate(&self, args: &Value) -> Result<()> {
349 let args: ShellExecArgs = serde_json::from_value(args.clone())
350 .map_err(|e| Error::SchemaValidationFailed(e.to_string()))?;
351
352 if args.cmd.is_empty() {
353 return Err(Error::SchemaValidationFailed("cmd is empty".into()));
354 }
355
356 Ok(())
357 }
358
359 fn execute(&self, args: &Value, ctx: &Context) -> Result<Output> {
360 if ctx.dry_run {
362 return Ok(Output {
363 success: true,
364 data: serde_json::json!({
365 "cmd": args.get("cmd").and_then(|v| v.as_str()).unwrap_or(""),
366 "dry_run": true,
367 }),
368 message: Some("DRY RUN: would execute shell command".to_string()),
369 });
370 }
371
372 let args: ShellExecArgs = serde_json::from_value(args.clone())
373 .map_err(|e| Error::ExecutionFailed(e.to_string()))?;
374
375 let timeout = args.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS);
376
377 let (program, program_args): (String, Vec<String>) =
379 if let Some(ref explicit_args) = args.args {
380 (args.cmd.clone(), explicit_args.clone())
381 } else {
382 let mut parts = args.cmd.split_whitespace();
383 let program = parts
384 .next()
385 .ok_or_else(|| Error::ExecutionFailed("cmd is empty after split".into()))?
386 .to_string();
387 (program, parts.map(String::from).collect())
388 };
389
390 if let Some(reason) = is_dangerous_command(&program, &program_args) {
392 return Err(Error::ExecutionFailed(format!(
393 "dangerous command blocked: {}", reason
394 )));
395 }
396
397 let resolved_program = resolve_program(&program)?;
399
400 let mut cmd = Command::new(&resolved_program);
402 cmd.args(&program_args);
403
404 if let Some(cwd) = &args.cwd {
406 let path_ctx = PathContext {
407 require_exists: true,
408 require_file: false,
409 ..Default::default()
410 };
411 let cwd_path = validate_path(cwd, &path_ctx)
412 .map_err(|e| Error::ExecutionFailed(format!("invalid cwd: {}", e)))?;
413 cmd.current_dir(cwd_path);
414 }
415
416 let mut child = cmd
418 .process_group(0)
419 .stdout(std::process::Stdio::piped())
420 .stderr(std::process::Stdio::piped())
421 .stdin(if args.stdin.is_some() {
422 std::process::Stdio::piped()
423 } else {
424 std::process::Stdio::null()
425 })
426 .spawn()
427 .map_err(|e| Error::ExecutionFailed(format!("failed to spawn: {}", e)))?;
428
429 let child_pid = child.id();
430 let pgid = child_pid; if let Some(ref stdin_content) = args.stdin {
434 if stdin_content.len() > MAX_STDIN_BYTES {
435 return Err(Error::ExecutionFailed(format!(
436 "stdin exceeds maximum size ({} > {} bytes)",
437 stdin_content.len(),
438 MAX_STDIN_BYTES
439 )));
440 }
441 if let Some(mut stdin_pipe) = child.stdin.take() {
442 stdin_pipe
443 .write_all(stdin_content.as_bytes())
444 .map_err(|e| Error::ExecutionFailed(format!("failed to write stdin: {}", e)))?;
445 }
447 }
448
449 let (exit_status, stdout, stderr, descendants) =
451 wait_with_timeout(&mut child, pgid, timeout)?;
452
453 let mut spawned_pids = vec![child_pid];
454 spawned_pids.extend(descendants);
455
456 let stdout_str = String::from_utf8_lossy(&stdout).to_string();
457 let stderr_str = String::from_utf8_lossy(&stderr).to_string();
458 let success = exit_status.success();
459
460 Ok(Output {
461 success,
462 data: serde_json::json!({
463 "cmd": args.cmd,
464 "stdout": stdout_str,
465 "stderr": stderr_str,
466 "exit_code": exit_status.code().unwrap_or(-1),
467 "pid": child_pid,
468 "spawned_pids": spawned_pids,
469 "timeout_secs": timeout,
470 "timed_out": exit_status.code().is_none(),
471 "truncated": stdout.len() >= MAX_OUTPUT_BYTES || stderr.len() >= MAX_OUTPUT_BYTES,
472 }),
473 message: if success {
474 Some("Command completed successfully".to_string())
475 } else if exit_status.code().is_none() {
476 Some(format!("Command timed out after {}s", timeout))
477 } else {
478 Some(format!(
479 "Command failed with exit code {}",
480 exit_status.code().unwrap_or(-1)
481 ))
482 },
483 })
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::capability::Capability;
491 use std::time::Instant;
492
493 #[test]
494 fn executes_uptime() {
495 let result = ShellExec
496 .execute(
497 &serde_json::json!({ "cmd": "uptime" }),
498 &Context {
499 dry_run: false,
500 job_id: "test".into(),
501 working_dir: std::env::temp_dir(),
502 },
503 )
504 .expect("Execution failed");
505
506 eprintln!("result.success={}", result.success);
507 eprintln!("result.data={}", result.data);
508 eprintln!(
509 "stdout={:?}",
510 result.data.get("stdout").map(|v| v.as_str())
511 );
512 assert!(result.success);
513 assert!(result.data["stdout"].as_str().unwrap().contains("up"));
514 }
515
516 #[test]
517 fn captures_exit_code() {
518 let result = ShellExec
519 .execute(
520 &serde_json::json!({ "cmd": "false" }),
521 &Context {
522 dry_run: false,
523 job_id: "test".into(),
524 working_dir: std::env::temp_dir(),
525 },
526 )
527 .expect("Execution failed");
528
529 assert!(!result.success);
530 assert_eq!(result.data["exit_code"].as_i64().unwrap(), 1);
531 }
532
533 #[test]
534 fn captures_stderr() {
535 let result = ShellExec
536 .execute(
537 &serde_json::json!({
538 "cmd": "cat",
539 "args": ["/nonexistent_path_for_stderr_test"]
540 }),
541 &Context {
542 dry_run: false,
543 job_id: "test".into(),
544 working_dir: std::env::temp_dir(),
545 },
546 )
547 .expect("Execution failed");
548
549 assert!(!result.success);
550 assert!(result.data["stderr"].as_str().unwrap().contains("No such file"));
551 }
552
553 #[test]
554 fn captures_pid() {
555 let result = ShellExec
556 .execute(
557 &serde_json::json!({ "cmd": "echo hello" }),
558 &Context {
559 dry_run: false,
560 job_id: "test".into(),
561 working_dir: std::env::temp_dir(),
562 },
563 )
564 .expect("Execution failed");
565
566 assert!(result.success);
567 assert!(result.data["pid"].as_u64().is_some());
568 }
569
570 #[test]
571 fn captures_spawned_pids() {
572 let result = ShellExec
573 .execute(
574 &serde_json::json!({ "cmd": "echo hello" }),
575 &Context {
576 dry_run: false,
577 job_id: "test".into(),
578 working_dir: std::env::temp_dir(),
579 },
580 )
581 .expect("Execution failed");
582
583 assert!(result.success);
584 let spawned = result.data["spawned_pids"]
585 .as_array()
586 .expect("spawned_pids should be array");
587 assert!(!spawned.is_empty());
588 }
589
590 #[test]
591 fn enforces_timeout() {
592 let start = Instant::now();
593 let result = ShellExec.execute(
594 &serde_json::json!({ "cmd": "sleep 5", "timeout_secs": 1 }),
595 &Context {
596 dry_run: false,
597 job_id: "test".into(),
598 working_dir: std::env::temp_dir(),
599 },
600 );
601
602 let elapsed = start.elapsed();
603
604 assert!(elapsed.as_secs() < 3);
606 assert!(result.is_err());
607 assert!(result.unwrap_err().to_string().contains("timed out"));
608 }
609
610 #[test]
611 fn validates_empty_cmd() {
612 let cap = ShellExec;
613 let result = cap.validate(&serde_json::json!({ "cmd": "" }));
614 assert!(result.is_err());
615 assert!(result.unwrap_err().to_string().contains("empty"));
616 }
617
618 #[test]
619 fn respects_dry_run() {
620 let result = ShellExec
621 .execute(
622 &serde_json::json!({ "cmd": "rm", "args": ["-rf", "/"] }),
623 &Context {
624 dry_run: true,
625 job_id: "test".into(),
626 working_dir: std::env::temp_dir(),
627 },
628 )
629 .expect("Execution failed");
630
631 assert!(result.success);
632 assert!(result.data["dry_run"].as_bool() == Some(true));
633 assert!(result.data["cmd"].as_str().unwrap() == "rm");
634 }
635
636 #[test]
637 fn prevents_shell_injection() {
638 let result = ShellExec
639 .execute(
640 &serde_json::json!({
641 "cmd": "echo",
642 "args": ["hello; rm -rf /"]
643 }),
644 &Context {
645 dry_run: false,
646 job_id: "test".into(),
647 working_dir: std::env::temp_dir(),
648 },
649 )
650 .expect("Execution failed");
651
652 assert!(result.success);
653 assert!(result.data["stdout"]
654 .as_str()
655 .unwrap()
656 .contains("hello; rm -rf /"));
657 }
658
659 #[test]
660 fn explicit_args_separation() {
661 let result = ShellExec
662 .execute(
663 &serde_json::json!({
664 "cmd": "echo",
665 "args": ["hello", "world"]
666 }),
667 &Context {
668 dry_run: false,
669 job_id: "test".into(),
670 working_dir: std::env::temp_dir(),
671 },
672 )
673 .expect("Execution failed");
674
675 assert!(result.success);
676 assert!(result.data["stdout"]
677 .as_str()
678 .unwrap()
679 .contains("hello world"));
680 }
681
682 #[test]
683 fn test_get_all_descendants_finds_children() {
684 let descendants = get_all_descendants(1);
685 assert!(!descendants.is_empty() || descendants.is_empty());
686 }
687
688 #[test]
689 fn test_get_all_descendants_nonexistent_pid() {
690 let descendants = get_all_descendants(999999);
691 assert!(
692 descendants.is_empty(),
693 "Non-existent PID should have no descendants"
694 );
695 }
696
697 #[test]
700 fn blocks_dangerous_rm_rf_root() {
701 let result = ShellExec.execute(
702 &serde_json::json!({ "cmd": "rm", "args": ["-rf", "/"] }),
703 &Context {
704 dry_run: false,
705 job_id: "test".into(),
706 working_dir: std::env::temp_dir(),
707 },
708 );
709 assert!(result.is_err());
710 assert!(result.unwrap_err().to_string().contains("dangerous"));
711 }
712
713 #[test]
714 fn blocks_dangerous_dd() {
715 let result = ShellExec.execute(
716 &serde_json::json!({ "cmd": "dd", "args": ["if=/dev/zero", "of=/dev/sda"] }),
717 &Context {
718 dry_run: false,
719 job_id: "test".into(),
720 working_dir: std::env::temp_dir(),
721 },
722 );
723 assert!(result.is_err());
724 assert!(result.unwrap_err().to_string().contains("dd"));
725 }
726
727 #[test]
728 fn blocks_dangerous_mkfs() {
729 let result = ShellExec.execute(
730 &serde_json::json!({ "cmd": "mkfs.ext4", "args": ["/dev/sda1"] }),
731 &Context {
732 dry_run: false,
733 job_id: "test".into(),
734 working_dir: std::env::temp_dir(),
735 },
736 );
737 assert!(result.is_err());
738 assert!(result.unwrap_err().to_string().contains("filesystem"));
739 }
740
741 #[test]
742 fn pipes_stdin() {
743 let result = ShellExec.execute(
744 &serde_json::json!({ "cmd": "cat", "stdin": "hello from stdin" }),
745 &Context {
746 dry_run: false,
747 job_id: "test".into(),
748 working_dir: std::env::temp_dir(),
749 },
750 );
751 let output = result.expect("stdin pipe failed");
752 assert!(output.success);
753 assert!(output.data["stdout"]
754 .as_str()
755 .unwrap()
756 .contains("hello from stdin"));
757 }
758
759 #[test]
760 fn rejects_relative_path() {
761 let result = ShellExec.execute(
762 &serde_json::json!({ "cmd": "./malicious_script" }),
763 &Context {
764 dry_run: false,
765 job_id: "test".into(),
766 working_dir: std::env::temp_dir(),
767 },
768 );
769 assert!(result.is_err());
770 assert!(result.unwrap_err().to_string().contains("relative paths"));
771 }
772
773 #[test]
774 fn output_has_truncated_flag() {
775 let result = ShellExec
776 .execute(
777 &serde_json::json!({ "cmd": "echo", "args": ["hello"] }),
778 &Context {
779 dry_run: false,
780 job_id: "test".into(),
781 working_dir: std::env::temp_dir(),
782 },
783 )
784 .expect("Execution failed");
785 assert!(result.data["truncated"].as_bool() == Some(false));
786 }
787
788 #[test]
789 fn kills_descendants_on_timeout() {
790 let start = Instant::now();
792 let result = ShellExec.execute(
793 &serde_json::json!({
794 "cmd": "bash",
795 "args": ["-c", "sleep 30 & sleep 30 & wait"],
796 "timeout_secs": 1
797 }),
798 &Context {
799 dry_run: false,
800 job_id: "test".into(),
801 working_dir: std::env::temp_dir(),
802 },
803 );
804
805 let elapsed = start.elapsed();
806 assert!(elapsed.as_secs() < 3, "should timeout quickly, took {:?}", elapsed);
807 assert!(result.is_err());
808 let err = result.unwrap_err().to_string();
809 assert!(err.contains("timed out"));
810 assert!(err.contains("descendants"), "should report descendant count: {}", err);
811 }
812}