Skip to main content

zeph_tools/shell/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::path::PathBuf;
5use std::time::{Duration, Instant};
6
7use tokio::process::Command;
8use tokio_util::sync::CancellationToken;
9
10use schemars::JsonSchema;
11use serde::Deserialize;
12
13use std::sync::Arc;
14
15use parking_lot::RwLock;
16
17use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
18use crate::config::ShellConfig;
19use crate::executor::{
20    ClaimSource, FilterStats, ToolCall, ToolError, ToolEvent, ToolEventTx, ToolExecutor, ToolOutput,
21};
22use crate::filter::{OutputFilterRegistry, sanitize_output};
23use crate::permissions::{PermissionAction, PermissionPolicy};
24
25mod transaction;
26use transaction::{TransactionSnapshot, affected_paths, build_scope_matchers, is_write_command};
27
28const DEFAULT_BLOCKED: &[&str] = &[
29    "rm -rf /", "sudo", "mkfs", "dd if=", "curl", "wget", "nc ", "ncat", "netcat", "shutdown",
30    "reboot", "halt",
31];
32
33/// The default list of blocked command patterns used by [`ShellExecutor`].
34///
35/// Exposed so other executors (e.g. `AcpShellExecutor`) can reuse the same
36/// blocklist without duplicating it.
37pub const DEFAULT_BLOCKED_COMMANDS: &[&str] = DEFAULT_BLOCKED;
38
39/// Shell interpreters that may execute arbitrary code via `-c` or positional args.
40pub const SHELL_INTERPRETERS: &[&str] =
41    &["bash", "sh", "zsh", "fish", "dash", "ksh", "csh", "tcsh"];
42
43/// Subshell metacharacters that could embed a blocked command inside a benign wrapper.
44/// Commands containing these sequences are rejected outright because safe static
45/// analysis of nested shell evaluation is not feasible.
46const SUBSHELL_METACHARS: &[&str] = &["$(", "`", "<(", ">("];
47
48/// Check if `command` matches any pattern in `blocklist`.
49///
50/// Returns the matched pattern string if the command is blocked, `None` otherwise.
51/// The check is case-insensitive and handles common shell escape sequences.
52///
53/// Commands containing subshell metacharacters (`$(` or `` ` ``) are always
54/// blocked because nested evaluation cannot be safely analysed statically.
55#[must_use]
56pub fn check_blocklist(command: &str, blocklist: &[String]) -> Option<String> {
57    let lower = command.to_lowercase();
58    // Reject commands that embed subshell constructs to prevent blocklist bypass.
59    for meta in SUBSHELL_METACHARS {
60        if lower.contains(meta) {
61            return Some((*meta).to_owned());
62        }
63    }
64    let cleaned = strip_shell_escapes(&lower);
65    let commands = tokenize_commands(&cleaned);
66    for blocked in blocklist {
67        for cmd_tokens in &commands {
68            if tokens_match_pattern(cmd_tokens, blocked) {
69                return Some(blocked.clone());
70            }
71        }
72    }
73    None
74}
75
76/// Build the effective command string for blocklist evaluation when the binary is a
77/// shell interpreter (bash, sh, zsh, etc.) and args contains a `-c` script.
78///
79/// Returns `None` if the args do not follow the `-c <script>` pattern.
80#[must_use]
81pub fn effective_shell_command<'a>(binary: &str, args: &'a [String]) -> Option<&'a str> {
82    let base = binary.rsplit('/').next().unwrap_or(binary);
83    if !SHELL_INTERPRETERS.contains(&base) {
84        return None;
85    }
86    // Find "-c" and return the next element as the script to check.
87    let pos = args.iter().position(|a| a == "-c")?;
88    args.get(pos + 1).map(String::as_str)
89}
90
91const NETWORK_COMMANDS: &[&str] = &["curl", "wget", "nc ", "ncat", "netcat"];
92
93#[derive(Deserialize, JsonSchema)]
94pub(crate) struct BashParams {
95    /// The bash command to execute
96    command: String,
97}
98
99/// Bash block extraction and execution via `tokio::process::Command`.
100#[derive(Debug)]
101pub struct ShellExecutor {
102    timeout: Duration,
103    blocked_commands: Vec<String>,
104    allowed_paths: Vec<PathBuf>,
105    confirm_patterns: Vec<String>,
106    env_blocklist: Vec<String>,
107    audit_logger: Option<Arc<AuditLogger>>,
108    tool_event_tx: Option<ToolEventTx>,
109    permission_policy: Option<PermissionPolicy>,
110    output_filter_registry: Option<OutputFilterRegistry>,
111    cancel_token: Option<CancellationToken>,
112    skill_env: RwLock<Option<std::collections::HashMap<String, String>>>,
113    transactional: bool,
114    auto_rollback: bool,
115    auto_rollback_exit_codes: Vec<i32>,
116    snapshot_required: bool,
117    max_snapshot_bytes: u64,
118    transaction_scope_matchers: Vec<globset::GlobMatcher>,
119}
120
121impl ShellExecutor {
122    #[must_use]
123    pub fn new(config: &ShellConfig) -> Self {
124        let allowed: Vec<String> = config
125            .allowed_commands
126            .iter()
127            .map(|s| s.to_lowercase())
128            .collect();
129
130        let mut blocked: Vec<String> = DEFAULT_BLOCKED
131            .iter()
132            .filter(|s| !allowed.contains(&s.to_lowercase()))
133            .map(|s| (*s).to_owned())
134            .collect();
135        blocked.extend(config.blocked_commands.iter().map(|s| s.to_lowercase()));
136
137        if !config.allow_network {
138            for cmd in NETWORK_COMMANDS {
139                let lower = cmd.to_lowercase();
140                if !blocked.contains(&lower) {
141                    blocked.push(lower);
142                }
143            }
144        }
145
146        blocked.sort();
147        blocked.dedup();
148
149        let allowed_paths = if config.allowed_paths.is_empty() {
150            vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
151        } else {
152            config.allowed_paths.iter().map(PathBuf::from).collect()
153        };
154
155        Self {
156            timeout: Duration::from_secs(config.timeout),
157            blocked_commands: blocked,
158            allowed_paths,
159            confirm_patterns: config.confirm_patterns.clone(),
160            env_blocklist: config.env_blocklist.clone(),
161            audit_logger: None,
162            tool_event_tx: None,
163            permission_policy: None,
164            output_filter_registry: None,
165            cancel_token: None,
166            skill_env: RwLock::new(None),
167            transactional: config.transactional,
168            auto_rollback: config.auto_rollback,
169            auto_rollback_exit_codes: config.auto_rollback_exit_codes.clone(),
170            snapshot_required: config.snapshot_required,
171            max_snapshot_bytes: config.max_snapshot_bytes,
172            transaction_scope_matchers: build_scope_matchers(&config.transaction_scope),
173        }
174    }
175
176    /// Set environment variables to inject when executing the active skill's bash blocks.
177    pub fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
178        *self.skill_env.write() = env;
179    }
180
181    #[must_use]
182    pub fn with_audit(mut self, logger: Arc<AuditLogger>) -> Self {
183        self.audit_logger = Some(logger);
184        self
185    }
186
187    #[must_use]
188    pub fn with_tool_event_tx(mut self, tx: ToolEventTx) -> Self {
189        self.tool_event_tx = Some(tx);
190        self
191    }
192
193    #[must_use]
194    pub fn with_permissions(mut self, policy: PermissionPolicy) -> Self {
195        self.permission_policy = Some(policy);
196        self
197    }
198
199    #[must_use]
200    pub fn with_cancel_token(mut self, token: CancellationToken) -> Self {
201        self.cancel_token = Some(token);
202        self
203    }
204
205    #[must_use]
206    pub fn with_output_filters(mut self, registry: OutputFilterRegistry) -> Self {
207        self.output_filter_registry = Some(registry);
208        self
209    }
210
211    /// Execute a bash block bypassing the confirmation check (called after user confirms).
212    ///
213    /// # Errors
214    ///
215    /// Returns `ToolError` on blocked commands, sandbox violations, or execution failures.
216    pub async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
217        self.execute_inner(response, true).await
218    }
219
220    async fn execute_inner(
221        &self,
222        response: &str,
223        skip_confirm: bool,
224    ) -> Result<Option<ToolOutput>, ToolError> {
225        let blocks = extract_bash_blocks(response);
226        if blocks.is_empty() {
227            return Ok(None);
228        }
229
230        let mut outputs = Vec::with_capacity(blocks.len());
231        let mut cumulative_filter_stats: Option<FilterStats> = None;
232        let mut last_envelope: Option<ShellOutputEnvelope> = None;
233        #[allow(clippy::cast_possible_truncation)]
234        let blocks_executed = blocks.len() as u32;
235
236        for block in &blocks {
237            let (output_line, per_block_stats, envelope) =
238                self.execute_block(block, skip_confirm).await?;
239            if let Some(fs) = per_block_stats {
240                let stats = cumulative_filter_stats.get_or_insert_with(FilterStats::default);
241                stats.raw_chars += fs.raw_chars;
242                stats.filtered_chars += fs.filtered_chars;
243                stats.raw_lines += fs.raw_lines;
244                stats.filtered_lines += fs.filtered_lines;
245                stats.confidence = Some(match (stats.confidence, fs.confidence) {
246                    (Some(prev), Some(cur)) => crate::filter::worse_confidence(prev, cur),
247                    (Some(prev), None) => prev,
248                    (None, Some(cur)) => cur,
249                    (None, None) => unreachable!(),
250                });
251                if stats.command.is_none() {
252                    stats.command = fs.command;
253                }
254                if stats.kept_lines.is_empty() && !fs.kept_lines.is_empty() {
255                    stats.kept_lines = fs.kept_lines;
256                }
257            }
258            last_envelope = Some(envelope);
259            outputs.push(output_line);
260        }
261
262        let raw_response = last_envelope
263            .as_ref()
264            .and_then(|e| serde_json::to_value(e).ok());
265
266        Ok(Some(ToolOutput {
267            tool_name: "bash".to_owned(),
268            summary: outputs.join("\n\n"),
269            blocks_executed,
270            filter_stats: cumulative_filter_stats,
271            diff: None,
272            streamed: self.tool_event_tx.is_some(),
273            terminal_id: None,
274            locations: None,
275            raw_response,
276            claim_source: Some(ClaimSource::Shell),
277        }))
278    }
279
280    #[allow(clippy::too_many_lines)]
281    async fn execute_block(
282        &self,
283        block: &str,
284        skip_confirm: bool,
285    ) -> Result<(String, Option<FilterStats>, ShellOutputEnvelope), ToolError> {
286        self.check_permissions(block, skip_confirm).await?;
287        self.validate_sandbox(block)?;
288
289        // Take a transactional snapshot before executing write commands.
290        let mut snapshot_warning: Option<String> = None;
291        let snapshot = if self.transactional && is_write_command(block) {
292            let paths = affected_paths(block, &self.transaction_scope_matchers);
293            if paths.is_empty() {
294                None
295            } else {
296                match TransactionSnapshot::capture(&paths, self.max_snapshot_bytes) {
297                    Ok(snap) => {
298                        tracing::debug!(
299                            files = snap.file_count(),
300                            bytes = snap.total_bytes(),
301                            "transaction snapshot captured"
302                        );
303                        Some(snap)
304                    }
305                    Err(e) if self.snapshot_required => {
306                        return Err(ToolError::SnapshotFailed {
307                            reason: e.to_string(),
308                        });
309                    }
310                    Err(e) => {
311                        tracing::warn!(err = %e, "transaction snapshot failed, proceeding without rollback");
312                        snapshot_warning =
313                            Some(format!("[warn] snapshot failed: {e}; rollback unavailable"));
314                        None
315                    }
316                }
317            }
318        } else {
319            None
320        };
321
322        if let Some(ref tx) = self.tool_event_tx {
323            let _ = tx.send(ToolEvent::Started {
324                tool_name: "bash".to_owned(),
325                command: block.to_owned(),
326            });
327        }
328
329        let start = Instant::now();
330        let skill_env_snapshot: Option<std::collections::HashMap<String, String>> =
331            self.skill_env.read().clone();
332        let (mut envelope, out) = execute_bash(
333            block,
334            self.timeout,
335            self.tool_event_tx.as_ref(),
336            self.cancel_token.as_ref(),
337            skill_env_snapshot.as_ref(),
338            &self.env_blocklist,
339        )
340        .await;
341        let exit_code = envelope.exit_code;
342        if exit_code == 130
343            && self
344                .cancel_token
345                .as_ref()
346                .is_some_and(CancellationToken::is_cancelled)
347        {
348            return Err(ToolError::Cancelled);
349        }
350        #[allow(clippy::cast_possible_truncation)]
351        let duration_ms = start.elapsed().as_millis() as u64;
352
353        // Perform auto-rollback if configured and the exit code qualifies.
354        if let Some(snap) = snapshot {
355            let should_rollback = self.auto_rollback
356                && if self.auto_rollback_exit_codes.is_empty() {
357                    exit_code >= 2
358                } else {
359                    self.auto_rollback_exit_codes.contains(&exit_code)
360                };
361            if should_rollback {
362                match snap.rollback() {
363                    Ok(report) => {
364                        tracing::info!(
365                            restored = report.restored_count,
366                            deleted = report.deleted_count,
367                            "transaction rollback completed"
368                        );
369                        self.log_audit(
370                            block,
371                            AuditResult::Rollback {
372                                restored: report.restored_count,
373                                deleted: report.deleted_count,
374                            },
375                            duration_ms,
376                            None,
377                            Some(exit_code),
378                            false,
379                        )
380                        .await;
381                        if let Some(ref tx) = self.tool_event_tx {
382                            let _ = tx.send(ToolEvent::Rollback {
383                                tool_name: "bash".to_owned(),
384                                command: block.to_owned(),
385                                restored_count: report.restored_count,
386                                deleted_count: report.deleted_count,
387                            });
388                        }
389                    }
390                    Err(e) => {
391                        tracing::error!(err = %e, "transaction rollback failed");
392                    }
393                }
394            }
395            // On success (no rollback): snapshot dropped here; TempDir auto-cleans.
396        }
397
398        let is_timeout = out.contains("[error] command timed out");
399        let audit_result = if is_timeout {
400            AuditResult::Timeout
401        } else if out.contains("[error]") || out.contains("[stderr]") {
402            AuditResult::Error {
403                message: out.clone(),
404            }
405        } else {
406            AuditResult::Success
407        };
408        if is_timeout {
409            self.log_audit(
410                block,
411                audit_result,
412                duration_ms,
413                None,
414                Some(exit_code),
415                false,
416            )
417            .await;
418            self.emit_completed(block, &out, false, None);
419            return Err(ToolError::Timeout {
420                timeout_secs: self.timeout.as_secs(),
421            });
422        }
423
424        if let Some(category) = classify_shell_exit(exit_code, &out) {
425            self.emit_completed(block, &out, false, None);
426            return Err(ToolError::Shell {
427                exit_code,
428                category,
429                message: out.lines().take(3).collect::<Vec<_>>().join("; "),
430            });
431        }
432
433        let sanitized = sanitize_output(&out);
434        let mut per_block_stats: Option<FilterStats> = None;
435        let filtered = if let Some(ref registry) = self.output_filter_registry {
436            match registry.apply(block, &sanitized, exit_code) {
437                Some(fr) => {
438                    tracing::debug!(
439                        command = block,
440                        raw = fr.raw_chars,
441                        filtered = fr.filtered_chars,
442                        savings_pct = fr.savings_pct(),
443                        "output filter applied"
444                    );
445                    per_block_stats = Some(FilterStats {
446                        raw_chars: fr.raw_chars,
447                        filtered_chars: fr.filtered_chars,
448                        raw_lines: fr.raw_lines,
449                        filtered_lines: fr.filtered_lines,
450                        confidence: Some(fr.confidence),
451                        command: Some(block.to_owned()),
452                        kept_lines: fr.kept_lines.clone(),
453                    });
454                    fr.output
455                }
456                None => sanitized,
457            }
458        } else {
459            sanitized
460        };
461
462        self.emit_completed(
463            block,
464            &out,
465            !out.contains("[error]"),
466            per_block_stats.clone(),
467        );
468
469        // Mark truncated if output was shortened during filtering.
470        envelope.truncated = filtered.len() < out.len();
471
472        self.log_audit(
473            block,
474            audit_result,
475            duration_ms,
476            None,
477            Some(exit_code),
478            envelope.truncated,
479        )
480        .await;
481
482        let output_line = if let Some(warn) = snapshot_warning {
483            format!("{warn}\n$ {block}\n{filtered}")
484        } else {
485            format!("$ {block}\n{filtered}")
486        };
487        Ok((output_line, per_block_stats, envelope))
488    }
489
490    fn emit_completed(
491        &self,
492        command: &str,
493        output: &str,
494        success: bool,
495        filter_stats: Option<FilterStats>,
496    ) {
497        if let Some(ref tx) = self.tool_event_tx {
498            let _ = tx.send(ToolEvent::Completed {
499                tool_name: "bash".to_owned(),
500                command: command.to_owned(),
501                output: output.to_owned(),
502                success,
503                filter_stats,
504                diff: None,
505            });
506        }
507    }
508
509    /// Check blocklist, permission policy, and confirmation requirements for `block`.
510    async fn check_permissions(&self, block: &str, skip_confirm: bool) -> Result<(), ToolError> {
511        // Always check the blocklist first — it is a hard security boundary
512        // that must not be bypassed by the PermissionPolicy layer.
513        if let Some(blocked) = self.find_blocked_command(block) {
514            let err = ToolError::Blocked {
515                command: blocked.to_owned(),
516            };
517            self.log_audit(
518                block,
519                AuditResult::Blocked {
520                    reason: format!("blocked command: {blocked}"),
521                },
522                0,
523                Some(&err),
524                None,
525                false,
526            )
527            .await;
528            return Err(err);
529        }
530
531        if let Some(ref policy) = self.permission_policy {
532            match policy.check("bash", block) {
533                PermissionAction::Deny => {
534                    let err = ToolError::Blocked {
535                        command: block.to_owned(),
536                    };
537                    self.log_audit(
538                        block,
539                        AuditResult::Blocked {
540                            reason: "denied by permission policy".to_owned(),
541                        },
542                        0,
543                        Some(&err),
544                        None,
545                        false,
546                    )
547                    .await;
548                    return Err(err);
549                }
550                PermissionAction::Ask if !skip_confirm => {
551                    return Err(ToolError::ConfirmationRequired {
552                        command: block.to_owned(),
553                    });
554                }
555                _ => {}
556            }
557        } else if !skip_confirm && let Some(pattern) = self.find_confirm_command(block) {
558            return Err(ToolError::ConfirmationRequired {
559                command: pattern.to_owned(),
560            });
561        }
562
563        Ok(())
564    }
565
566    fn validate_sandbox(&self, code: &str) -> Result<(), ToolError> {
567        let cwd = std::env::current_dir().unwrap_or_default();
568
569        for token in extract_paths(code) {
570            if has_traversal(&token) {
571                return Err(ToolError::SandboxViolation { path: token });
572            }
573
574            let path = if token.starts_with('/') {
575                PathBuf::from(&token)
576            } else {
577                cwd.join(&token)
578            };
579            let canonical = path
580                .canonicalize()
581                .or_else(|_| std::path::absolute(&path))
582                .unwrap_or(path);
583            if !self
584                .allowed_paths
585                .iter()
586                .any(|allowed| canonical.starts_with(allowed))
587            {
588                return Err(ToolError::SandboxViolation {
589                    path: canonical.display().to_string(),
590                });
591            }
592        }
593        Ok(())
594    }
595
596    /// Scan `code` for commands that match the configured blocklist.
597    ///
598    /// The function normalizes input via [`strip_shell_escapes`] (decoding `$'\xNN'`,
599    /// `$'\NNN'`, backslash escapes, and quote-splitting) and then splits on shell
600    /// metacharacters (`||`, `&&`, `;`, `|`, `\n`) via [`tokenize_commands`].  Each
601    /// resulting token sequence is tested against every entry in `blocked_commands`
602    /// through [`tokens_match_pattern`], which handles transparent prefixes (`env`,
603    /// `command`, `exec`, etc.), absolute paths, and dot-suffixed variants.
604    ///
605    /// # Known limitations
606    ///
607    /// The following constructs are **not** detected by this function:
608    ///
609    /// - **Here-strings** `<<<` with a shell interpreter: the outer command is the
610    ///   shell (`bash`, `sh`), which is not blocked by default; the payload string is
611    ///   opaque to this filter.
612    ///   Example: `bash <<< 'sudo rm -rf /'` — inner payload is not parsed.
613    ///
614    /// - **`eval` and `bash -c` / `sh -c`**: the string argument is not parsed; any
615    ///   blocked command embedded as a string argument passes through undetected.
616    ///   Example: `eval 'sudo rm -rf /'`.
617    ///
618    /// - **Variable expansion**: `strip_shell_escapes` does not resolve variable
619    ///   references, so `cmd=sudo; $cmd rm` bypasses the blocklist.
620    ///
621    /// `$(...)`, backtick, `<(...)`, and `>(...)` substitutions are detected by
622    /// [`extract_subshell_contents`], which extracts the inner command string and
623    /// checks it against the blocklist separately.  The default `confirm_patterns`
624    /// in [`ShellConfig`] additionally include `"$("`, `` "`" ``, `"<("`, `">("`,
625    /// `"<<<"`, and `"eval "`, so those constructs also trigger a confirmation
626    /// request via [`find_confirm_command`] before execution.
627    ///
628    /// For high-security deployments, complement this filter with OS-level sandboxing
629    /// (Linux namespaces, seccomp, or similar) to enforce hard execution boundaries.
630    fn find_blocked_command(&self, code: &str) -> Option<&str> {
631        let cleaned = strip_shell_escapes(&code.to_lowercase());
632        let commands = tokenize_commands(&cleaned);
633        for blocked in &self.blocked_commands {
634            for cmd_tokens in &commands {
635                if tokens_match_pattern(cmd_tokens, blocked) {
636                    return Some(blocked.as_str());
637                }
638            }
639        }
640        // Also check commands embedded inside subshell constructs.
641        for inner in extract_subshell_contents(&cleaned) {
642            let inner_commands = tokenize_commands(&inner);
643            for blocked in &self.blocked_commands {
644                for cmd_tokens in &inner_commands {
645                    if tokens_match_pattern(cmd_tokens, blocked) {
646                        return Some(blocked.as_str());
647                    }
648                }
649            }
650        }
651        None
652    }
653
654    fn find_confirm_command(&self, code: &str) -> Option<&str> {
655        let normalized = code.to_lowercase();
656        for pattern in &self.confirm_patterns {
657            if normalized.contains(pattern.as_str()) {
658                return Some(pattern.as_str());
659            }
660        }
661        None
662    }
663
664    async fn log_audit(
665        &self,
666        command: &str,
667        result: AuditResult,
668        duration_ms: u64,
669        error: Option<&ToolError>,
670        exit_code: Option<i32>,
671        truncated: bool,
672    ) {
673        if let Some(ref logger) = self.audit_logger {
674            let (error_category, error_domain, error_phase) =
675                error.map_or((None, None, None), |e| {
676                    let cat = e.category();
677                    (
678                        Some(cat.label().to_owned()),
679                        Some(cat.domain().label().to_owned()),
680                        Some(cat.phase().label().to_owned()),
681                    )
682                });
683            let entry = AuditEntry {
684                timestamp: chrono_now(),
685                tool: "shell".into(),
686                command: command.into(),
687                result,
688                duration_ms,
689                error_category,
690                error_domain,
691                error_phase,
692                claim_source: Some(ClaimSource::Shell),
693                mcp_server_id: None,
694                injection_flagged: false,
695                embedding_anomalous: false,
696                cross_boundary_mcp_to_acp: false,
697                adversarial_policy_decision: None,
698                exit_code,
699                truncated,
700                caller_id: None,
701                policy_match: None,
702            };
703            logger.log(&entry).await;
704        }
705    }
706}
707
708impl ToolExecutor for ShellExecutor {
709    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
710        self.execute_inner(response, false).await
711    }
712
713    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
714        use crate::registry::{InvocationHint, ToolDef};
715        vec![ToolDef {
716            id: "bash".into(),
717            description: "Execute a shell command and return stdout/stderr.\n\nParameters: command (string, required) - shell command to run\nReturns: stdout and stderr combined, prefixed with exit code\nErrors: Blocked if command matches security policy; Timeout after configured seconds; SandboxViolation if path outside allowed dirs\nExample: {\"command\": \"ls -la /tmp\"}".into(),
718            schema: schemars::schema_for!(BashParams),
719            invocation: InvocationHint::FencedBlock("bash"),
720        }]
721    }
722
723    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
724        if call.tool_id != "bash" {
725            return Ok(None);
726        }
727        let params: BashParams = crate::executor::deserialize_params(&call.params)?;
728        if params.command.is_empty() {
729            return Ok(None);
730        }
731        let command = &params.command;
732        // Wrap as a fenced block so execute_inner can extract and run it
733        let synthetic = format!("```bash\n{command}\n```");
734        self.execute_inner(&synthetic, false).await
735    }
736
737    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
738        ShellExecutor::set_skill_env(self, env);
739    }
740}
741
742/// Strip shell escape sequences that could bypass command detection.
743/// Handles: backslash insertion (`su\do` -> `sudo`), `$'\xNN'` hex and `$'\NNN'` octal
744/// escapes, adjacent quoted segments (`"su""do"` -> `sudo`), backslash-newline continuations.
745pub(crate) fn strip_shell_escapes(input: &str) -> String {
746    let mut out = String::with_capacity(input.len());
747    let bytes = input.as_bytes();
748    let mut i = 0;
749    while i < bytes.len() {
750        // $'...' ANSI-C quoting: decode \xNN hex and \NNN octal escapes
751        if i + 1 < bytes.len() && bytes[i] == b'$' && bytes[i + 1] == b'\'' {
752            let mut j = i + 2; // points after $'
753            let mut decoded = String::new();
754            let mut valid = false;
755            while j < bytes.len() && bytes[j] != b'\'' {
756                if bytes[j] == b'\\' && j + 1 < bytes.len() {
757                    let next = bytes[j + 1];
758                    if next == b'x' && j + 3 < bytes.len() {
759                        // \xNN hex escape
760                        let hi = (bytes[j + 2] as char).to_digit(16);
761                        let lo = (bytes[j + 3] as char).to_digit(16);
762                        if let (Some(h), Some(l)) = (hi, lo) {
763                            #[allow(clippy::cast_possible_truncation)]
764                            let byte = ((h << 4) | l) as u8;
765                            decoded.push(byte as char);
766                            j += 4;
767                            valid = true;
768                            continue;
769                        }
770                    } else if next.is_ascii_digit() {
771                        // \NNN octal escape (up to 3 digits)
772                        let mut val = u32::from(next - b'0');
773                        let mut len = 2; // consumed \N so far
774                        if j + 2 < bytes.len() && bytes[j + 2].is_ascii_digit() {
775                            val = val * 8 + u32::from(bytes[j + 2] - b'0');
776                            len = 3;
777                            if j + 3 < bytes.len() && bytes[j + 3].is_ascii_digit() {
778                                val = val * 8 + u32::from(bytes[j + 3] - b'0');
779                                len = 4;
780                            }
781                        }
782                        #[allow(clippy::cast_possible_truncation)]
783                        decoded.push((val & 0xFF) as u8 as char);
784                        j += len;
785                        valid = true;
786                        continue;
787                    }
788                    // other \X escape: emit X literally
789                    decoded.push(next as char);
790                    j += 2;
791                } else {
792                    decoded.push(bytes[j] as char);
793                    j += 1;
794                }
795            }
796            if j < bytes.len() && bytes[j] == b'\'' && valid {
797                out.push_str(&decoded);
798                i = j + 1;
799                continue;
800            }
801            // not a decodable $'...' sequence — fall through to handle as regular chars
802        }
803        // backslash-newline continuation: remove both
804        if bytes[i] == b'\\' && i + 1 < bytes.len() && bytes[i + 1] == b'\n' {
805            i += 2;
806            continue;
807        }
808        // intra-word backslash: skip the backslash, keep next char (e.g. su\do -> sudo)
809        if bytes[i] == b'\\' && i + 1 < bytes.len() && bytes[i + 1] != b'\n' {
810            i += 1;
811            out.push(bytes[i] as char);
812            i += 1;
813            continue;
814        }
815        // quoted segment stripping: collapse adjacent quoted segments
816        if bytes[i] == b'"' || bytes[i] == b'\'' {
817            let quote = bytes[i];
818            i += 1;
819            while i < bytes.len() && bytes[i] != quote {
820                out.push(bytes[i] as char);
821                i += 1;
822            }
823            if i < bytes.len() {
824                i += 1; // skip closing quote
825            }
826            continue;
827        }
828        out.push(bytes[i] as char);
829        i += 1;
830    }
831    out
832}
833
834/// Extract inner command strings from subshell constructs in `s`.
835///
836/// Recognises:
837/// - Backtick: `` `cmd` `` → `cmd`
838/// - Dollar-paren: `$(cmd)` → `cmd`
839/// - Process substitution (lt): `<(cmd)` → `cmd`
840/// - Process substitution (gt): `>(cmd)` → `cmd`
841///
842/// Depth counting handles nested parentheses correctly.
843pub(crate) fn extract_subshell_contents(s: &str) -> Vec<String> {
844    let mut results = Vec::new();
845    let chars: Vec<char> = s.chars().collect();
846    let len = chars.len();
847    let mut i = 0;
848
849    while i < len {
850        // Backtick substitution: `...`
851        if chars[i] == '`' {
852            let start = i + 1;
853            let mut j = start;
854            while j < len && chars[j] != '`' {
855                j += 1;
856            }
857            if j < len {
858                results.push(chars[start..j].iter().collect());
859            }
860            i = j + 1;
861            continue;
862        }
863
864        // $(...), <(...), >(...)
865        let next_is_open_paren = i + 1 < len && chars[i + 1] == '(';
866        let is_paren_subshell = next_is_open_paren && matches!(chars[i], '$' | '<' | '>');
867
868        if is_paren_subshell {
869            let start = i + 2;
870            let mut depth: usize = 1;
871            let mut j = start;
872            while j < len && depth > 0 {
873                match chars[j] {
874                    '(' => depth += 1,
875                    ')' => depth -= 1,
876                    _ => {}
877                }
878                if depth > 0 {
879                    j += 1;
880                } else {
881                    break;
882                }
883            }
884            if depth == 0 {
885                results.push(chars[start..j].iter().collect());
886            }
887            i = j + 1;
888            continue;
889        }
890
891        i += 1;
892    }
893
894    results
895}
896
897/// Split normalized shell code into sub-commands on `|`, `||`, `&&`, `;`, `\n`.
898/// Returns list of sub-commands, each as `Vec<String>` of tokens.
899pub(crate) fn tokenize_commands(normalized: &str) -> Vec<Vec<String>> {
900    // Replace two-char operators with a single separator, then split on single-char separators
901    let replaced = normalized.replace("||", "\n").replace("&&", "\n");
902    replaced
903        .split([';', '|', '\n'])
904        .map(|seg| {
905            seg.split_whitespace()
906                .map(str::to_owned)
907                .collect::<Vec<String>>()
908        })
909        .filter(|tokens| !tokens.is_empty())
910        .collect()
911}
912
913/// Transparent prefix commands that invoke the next argument as a command.
914/// Skipped when determining the "real" command name being invoked.
915const TRANSPARENT_PREFIXES: &[&str] = &["env", "command", "exec", "nice", "nohup", "time", "xargs"];
916
917/// Return the basename of a token (last path component after '/').
918fn cmd_basename(tok: &str) -> &str {
919    tok.rsplit('/').next().unwrap_or(tok)
920}
921
922/// Check if the first tokens of a sub-command match a blocked pattern.
923/// Handles:
924/// - Transparent prefix commands (`env sudo rm` -> checks `sudo`)
925/// - Absolute paths (`/usr/bin/sudo rm` -> basename `sudo` is checked)
926/// - Dot-suffixed variants (`mkfs` matches `mkfs.ext4`)
927/// - Multi-word patterns (`rm -rf /` joined prefix check)
928pub(crate) fn tokens_match_pattern(tokens: &[String], pattern: &str) -> bool {
929    if tokens.is_empty() || pattern.is_empty() {
930        return false;
931    }
932    let pattern = pattern.trim();
933    let pattern_tokens: Vec<&str> = pattern.split_whitespace().collect();
934    if pattern_tokens.is_empty() {
935        return false;
936    }
937
938    // Skip transparent prefix tokens to reach the real command
939    let start = tokens
940        .iter()
941        .position(|t| !TRANSPARENT_PREFIXES.contains(&cmd_basename(t)))
942        .unwrap_or(0);
943    let effective = &tokens[start..];
944    if effective.is_empty() {
945        return false;
946    }
947
948    if pattern_tokens.len() == 1 {
949        let pat = pattern_tokens[0];
950        let base = cmd_basename(&effective[0]);
951        // Exact match OR dot-suffixed variant (e.g. "mkfs" matches "mkfs.ext4")
952        base == pat || base.starts_with(&format!("{pat}."))
953    } else {
954        // Multi-word: join first N tokens (using basename for first) and check prefix
955        let n = pattern_tokens.len().min(effective.len());
956        let mut parts: Vec<&str> = vec![cmd_basename(&effective[0])];
957        parts.extend(effective[1..n].iter().map(String::as_str));
958        let joined = parts.join(" ");
959        if joined.starts_with(pattern) {
960            return true;
961        }
962        if effective.len() > n {
963            let mut parts2: Vec<&str> = vec![cmd_basename(&effective[0])];
964            parts2.extend(effective[1..=n].iter().map(String::as_str));
965            parts2.join(" ").starts_with(pattern)
966        } else {
967            false
968        }
969    }
970}
971
972fn extract_paths(code: &str) -> Vec<String> {
973    let mut result = Vec::new();
974
975    // Tokenize respecting single/double quotes
976    let mut tokens: Vec<String> = Vec::new();
977    let mut current = String::new();
978    let mut chars = code.chars().peekable();
979    while let Some(c) = chars.next() {
980        match c {
981            '"' | '\'' => {
982                let quote = c;
983                while let Some(&nc) = chars.peek() {
984                    if nc == quote {
985                        chars.next();
986                        break;
987                    }
988                    current.push(chars.next().unwrap());
989                }
990            }
991            c if c.is_whitespace() || matches!(c, ';' | '|' | '&') => {
992                if !current.is_empty() {
993                    tokens.push(std::mem::take(&mut current));
994                }
995            }
996            _ => current.push(c),
997        }
998    }
999    if !current.is_empty() {
1000        tokens.push(current);
1001    }
1002
1003    for token in tokens {
1004        let trimmed = token.trim_end_matches([';', '&', '|']).to_owned();
1005        if trimmed.is_empty() {
1006            continue;
1007        }
1008        if trimmed.starts_with('/')
1009            || trimmed.starts_with("./")
1010            || trimmed.starts_with("../")
1011            || trimmed == ".."
1012            || (trimmed.starts_with('.') && trimmed.contains('/'))
1013            || is_relative_path_token(&trimmed)
1014        {
1015            result.push(trimmed);
1016        }
1017    }
1018    result
1019}
1020
1021/// Returns `true` if `token` looks like a relative path of the form `word/more`
1022/// (contains `/` but does not start with `/` or `.`).
1023///
1024/// Excluded:
1025/// - URL schemes (`scheme://`)
1026/// - Shell variable assignments (`KEY=value`)
1027fn is_relative_path_token(token: &str) -> bool {
1028    // Must contain a slash but not start with `/` (absolute) or `.` (handled above).
1029    if !token.contains('/') || token.starts_with('/') || token.starts_with('.') {
1030        return false;
1031    }
1032    // Reject URLs: anything with `://`
1033    if token.contains("://") {
1034        return false;
1035    }
1036    // Reject shell variable assignments: `IDENTIFIER=...`
1037    if let Some(eq_pos) = token.find('=') {
1038        let key = &token[..eq_pos];
1039        if key.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
1040            return false;
1041        }
1042    }
1043    // First character must be an identifier-start (letter, digit, or `_`).
1044    token
1045        .chars()
1046        .next()
1047        .is_some_and(|c| c.is_ascii_alphanumeric() || c == '_')
1048}
1049
1050/// Classify shell exit codes and stderr patterns into `ToolErrorCategory`.
1051///
1052/// Returns `Some(category)` only for well-known failure modes that benefit from
1053/// structured feedback (exit 126/127, recognisable stderr patterns). All other
1054/// non-zero exits are left as `Ok` output so they surface verbatim to the LLM.
1055fn classify_shell_exit(
1056    exit_code: i32,
1057    output: &str,
1058) -> Option<crate::error_taxonomy::ToolErrorCategory> {
1059    use crate::error_taxonomy::ToolErrorCategory;
1060    match exit_code {
1061        // exit 126: command found but not executable (OS-level permission/policy)
1062        126 => Some(ToolErrorCategory::PolicyBlocked),
1063        // exit 127: command not found in PATH
1064        127 => Some(ToolErrorCategory::PermanentFailure),
1065        _ => {
1066            let lower = output.to_lowercase();
1067            if lower.contains("permission denied") {
1068                Some(ToolErrorCategory::PolicyBlocked)
1069            } else if lower.contains("no such file or directory") {
1070                Some(ToolErrorCategory::PermanentFailure)
1071            } else {
1072                None
1073            }
1074        }
1075    }
1076}
1077
1078fn has_traversal(path: &str) -> bool {
1079    path.split('/').any(|seg| seg == "..")
1080}
1081
1082fn extract_bash_blocks(text: &str) -> Vec<&str> {
1083    crate::executor::extract_fenced_blocks(text, "bash")
1084}
1085
1086/// Kill a child process and its descendants.
1087/// On unix, sends SIGKILL to child processes via `pkill -KILL -P <pid>` before
1088/// killing the parent, preventing zombie subprocesses.
1089async fn kill_process_tree(child: &mut tokio::process::Child) {
1090    #[cfg(unix)]
1091    if let Some(pid) = child.id() {
1092        let _ = Command::new("pkill")
1093            .args(["-KILL", "-P", &pid.to_string()])
1094            .status()
1095            .await;
1096    }
1097    let _ = child.kill().await;
1098}
1099
1100/// Structured output from a shell command execution.
1101#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
1102pub struct ShellOutputEnvelope {
1103    pub stdout: String,
1104    pub stderr: String,
1105    pub exit_code: i32,
1106    pub truncated: bool,
1107}
1108
1109#[allow(clippy::too_many_lines)]
1110async fn execute_bash(
1111    code: &str,
1112    timeout: Duration,
1113    event_tx: Option<&ToolEventTx>,
1114    cancel_token: Option<&CancellationToken>,
1115    extra_env: Option<&std::collections::HashMap<String, String>>,
1116    env_blocklist: &[String],
1117) -> (ShellOutputEnvelope, String) {
1118    use std::process::Stdio;
1119    use tokio::io::{AsyncBufReadExt, BufReader};
1120
1121    let timeout_secs = timeout.as_secs();
1122
1123    let mut cmd = Command::new("bash");
1124    cmd.arg("-c")
1125        .arg(code)
1126        .stdout(Stdio::piped())
1127        .stderr(Stdio::piped());
1128
1129    for (key, _) in std::env::vars() {
1130        if env_blocklist
1131            .iter()
1132            .any(|prefix| key.starts_with(prefix.as_str()))
1133        {
1134            cmd.env_remove(&key);
1135        }
1136    }
1137
1138    if let Some(env) = extra_env {
1139        cmd.envs(env);
1140    }
1141    let child_result = cmd.spawn();
1142
1143    let mut child = match child_result {
1144        Ok(c) => c,
1145        Err(e) => {
1146            let msg = format!("[error] {e}");
1147            return (
1148                ShellOutputEnvelope {
1149                    stdout: String::new(),
1150                    stderr: msg.clone(),
1151                    exit_code: 1,
1152                    truncated: false,
1153                },
1154                msg,
1155            );
1156        }
1157    };
1158
1159    let stdout = child.stdout.take().expect("stdout piped");
1160    let stderr = child.stderr.take().expect("stderr piped");
1161
1162    // Channel carries (is_stderr, line) so we can accumulate separate buffers
1163    // while still building a combined interleaved string for streaming and LLM context.
1164    let (line_tx, mut line_rx) = tokio::sync::mpsc::channel::<(bool, String)>(64);
1165
1166    let stdout_tx = line_tx.clone();
1167    tokio::spawn(async move {
1168        let mut reader = BufReader::new(stdout);
1169        let mut buf = String::new();
1170        while reader.read_line(&mut buf).await.unwrap_or(0) > 0 {
1171            let _ = stdout_tx.send((false, buf.clone())).await;
1172            buf.clear();
1173        }
1174    });
1175
1176    tokio::spawn(async move {
1177        let mut reader = BufReader::new(stderr);
1178        let mut buf = String::new();
1179        while reader.read_line(&mut buf).await.unwrap_or(0) > 0 {
1180            let _ = line_tx.send((true, buf.clone())).await;
1181            buf.clear();
1182        }
1183    });
1184
1185    let mut combined = String::new();
1186    let mut stdout_buf = String::new();
1187    let mut stderr_buf = String::new();
1188    let deadline = tokio::time::Instant::now() + timeout;
1189
1190    loop {
1191        tokio::select! {
1192            line = line_rx.recv() => {
1193                match line {
1194                    Some((is_stderr, chunk)) => {
1195                        let interleaved = if is_stderr {
1196                            format!("[stderr] {chunk}")
1197                        } else {
1198                            chunk.clone()
1199                        };
1200                        if let Some(tx) = event_tx {
1201                            let _ = tx.send(ToolEvent::OutputChunk {
1202                                tool_name: "bash".to_owned(),
1203                                command: code.to_owned(),
1204                                chunk: interleaved.clone(),
1205                            });
1206                        }
1207                        combined.push_str(&interleaved);
1208                        if is_stderr {
1209                            stderr_buf.push_str(&chunk);
1210                        } else {
1211                            stdout_buf.push_str(&chunk);
1212                        }
1213                    }
1214                    None => break,
1215                }
1216            }
1217            () = tokio::time::sleep_until(deadline) => {
1218                kill_process_tree(&mut child).await;
1219                let msg = format!("[error] command timed out after {timeout_secs}s");
1220                return (
1221                    ShellOutputEnvelope {
1222                        stdout: stdout_buf,
1223                        stderr: format!("{stderr_buf}command timed out after {timeout_secs}s"),
1224                        exit_code: 1,
1225                        truncated: false,
1226                    },
1227                    msg,
1228                );
1229            }
1230            () = async {
1231                match cancel_token {
1232                    Some(t) => t.cancelled().await,
1233                    None => std::future::pending().await,
1234                }
1235            } => {
1236                kill_process_tree(&mut child).await;
1237                return (
1238                    ShellOutputEnvelope {
1239                        stdout: stdout_buf,
1240                        stderr: format!("{stderr_buf}operation aborted"),
1241                        exit_code: 130,
1242                        truncated: false,
1243                    },
1244                    "[cancelled] operation aborted".to_string(),
1245                );
1246            }
1247        }
1248    }
1249
1250    let status = child.wait().await;
1251    let exit_code = status.ok().and_then(|s| s.code()).unwrap_or(1);
1252
1253    let (envelope, combined) = if combined.is_empty() {
1254        (
1255            ShellOutputEnvelope {
1256                stdout: String::new(),
1257                stderr: String::new(),
1258                exit_code,
1259                truncated: false,
1260            },
1261            "(no output)".to_string(),
1262        )
1263    } else {
1264        (
1265            ShellOutputEnvelope {
1266                stdout: stdout_buf.trim_end().to_owned(),
1267                stderr: stderr_buf.trim_end().to_owned(),
1268                exit_code,
1269                truncated: false,
1270            },
1271            combined,
1272        )
1273    };
1274    (envelope, combined)
1275}
1276
1277#[cfg(test)]
1278mod tests;