Skip to main content

zeph_tools/
executor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5
6/// Data for rendering file diffs in the TUI.
7#[derive(Debug, Clone)]
8pub struct DiffData {
9    pub file_path: String,
10    pub old_content: String,
11    pub new_content: String,
12}
13
14/// Structured tool invocation from LLM.
15#[derive(Debug, Clone)]
16pub struct ToolCall {
17    pub tool_id: String,
18    pub params: serde_json::Map<String, serde_json::Value>,
19}
20
21/// Cumulative filter statistics for a single tool execution.
22#[derive(Debug, Clone, Default)]
23pub struct FilterStats {
24    pub raw_chars: usize,
25    pub filtered_chars: usize,
26    pub raw_lines: usize,
27    pub filtered_lines: usize,
28    pub confidence: Option<crate::FilterConfidence>,
29    pub command: Option<String>,
30    pub kept_lines: Vec<usize>,
31}
32
33impl FilterStats {
34    #[must_use]
35    #[allow(clippy::cast_precision_loss)]
36    pub fn savings_pct(&self) -> f64 {
37        if self.raw_chars == 0 {
38            return 0.0;
39        }
40        (1.0 - self.filtered_chars as f64 / self.raw_chars as f64) * 100.0
41    }
42
43    #[must_use]
44    pub fn estimated_tokens_saved(&self) -> usize {
45        self.raw_chars.saturating_sub(self.filtered_chars) / 4
46    }
47
48    #[must_use]
49    pub fn format_inline(&self, tool_name: &str) -> String {
50        let cmd_label = self
51            .command
52            .as_deref()
53            .map(|c| {
54                let trimmed = c.trim();
55                if trimmed.len() > 60 {
56                    format!(" `{}…`", &trimmed[..57])
57                } else {
58                    format!(" `{trimmed}`")
59                }
60            })
61            .unwrap_or_default();
62        format!(
63            "[{tool_name}]{cmd_label} {} lines \u{2192} {} lines, {:.1}% filtered",
64            self.raw_lines,
65            self.filtered_lines,
66            self.savings_pct()
67        )
68    }
69}
70
71/// Structured result from tool execution.
72#[derive(Debug, Clone)]
73pub struct ToolOutput {
74    pub tool_name: String,
75    pub summary: String,
76    pub blocks_executed: u32,
77    pub filter_stats: Option<FilterStats>,
78    pub diff: Option<DiffData>,
79    /// Whether this tool already streamed its output via `ToolEvent` channel.
80    pub streamed: bool,
81}
82
83impl fmt::Display for ToolOutput {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.write_str(&self.summary)
86    }
87}
88
89pub const MAX_TOOL_OUTPUT_CHARS: usize = 30_000;
90
91/// Truncate tool output that exceeds `MAX_TOOL_OUTPUT_CHARS` using head+tail split.
92#[must_use]
93pub fn truncate_tool_output(output: &str) -> String {
94    if output.len() <= MAX_TOOL_OUTPUT_CHARS {
95        return output.to_string();
96    }
97
98    let half = MAX_TOOL_OUTPUT_CHARS / 2;
99    let head_end = output.floor_char_boundary(half);
100    let tail_start = output.ceil_char_boundary(output.len() - half);
101    let head = &output[..head_end];
102    let tail = &output[tail_start..];
103    let truncated = output.len() - head_end - (output.len() - tail_start);
104
105    format!(
106        "{head}\n\n... [truncated {truncated} chars, showing first and last ~{half} chars] ...\n\n{tail}"
107    )
108}
109
110/// Event emitted during tool execution for real-time UI updates.
111#[derive(Debug, Clone)]
112pub enum ToolEvent {
113    Started {
114        tool_name: String,
115        command: String,
116    },
117    OutputChunk {
118        tool_name: String,
119        command: String,
120        chunk: String,
121    },
122    Completed {
123        tool_name: String,
124        command: String,
125        output: String,
126        success: bool,
127        filter_stats: Option<FilterStats>,
128        diff: Option<DiffData>,
129    },
130}
131
132pub type ToolEventTx = tokio::sync::mpsc::UnboundedSender<ToolEvent>;
133
134/// Errors that can occur during tool execution.
135#[derive(Debug, thiserror::Error)]
136pub enum ToolError {
137    #[error("command blocked by policy: {command}")]
138    Blocked { command: String },
139
140    #[error("path not allowed by sandbox: {path}")]
141    SandboxViolation { path: String },
142
143    #[error("command requires confirmation: {command}")]
144    ConfirmationRequired { command: String },
145
146    #[error("command timed out after {timeout_secs}s")]
147    Timeout { timeout_secs: u64 },
148
149    #[error("operation cancelled")]
150    Cancelled,
151
152    #[error("invalid tool parameters: {message}")]
153    InvalidParams { message: String },
154
155    #[error("execution failed: {0}")]
156    Execution(#[from] std::io::Error),
157}
158
159/// Deserialize tool call params from a `serde_json::Map<String, Value>` into a typed struct.
160///
161/// # Errors
162///
163/// Returns `ToolError::InvalidParams` when deserialization fails.
164pub fn deserialize_params<T: serde::de::DeserializeOwned>(
165    params: &serde_json::Map<String, serde_json::Value>,
166) -> Result<T, ToolError> {
167    let obj = serde_json::Value::Object(params.clone());
168    serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams {
169        message: e.to_string(),
170    })
171}
172
173/// Async trait for tool execution backends (shell, future MCP, A2A).
174///
175/// Accepts the full LLM response and returns an optional output.
176/// Returns `None` when no tool invocation is detected in the response.
177pub trait ToolExecutor: Send + Sync {
178    fn execute(
179        &self,
180        response: &str,
181    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send;
182
183    /// Execute bypassing confirmation checks (called after user approves).
184    /// Default: delegates to `execute`.
185    fn execute_confirmed(
186        &self,
187        response: &str,
188    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
189        self.execute(response)
190    }
191
192    /// Return tool definitions this executor can handle.
193    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
194        vec![]
195    }
196
197    /// Execute a structured tool call. Returns `None` if `tool_id` is not handled.
198    fn execute_tool_call(
199        &self,
200        _call: &ToolCall,
201    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
202        std::future::ready(Ok(None))
203    }
204
205    /// Inject environment variables for the currently active skill. No-op by default.
206    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
207}
208
209/// Object-safe erased version of [`ToolExecutor`] using boxed futures.
210///
211/// Implemented automatically for all `T: ToolExecutor + 'static`.
212/// Use `Box<dyn ErasedToolExecutor>` when dynamic dispatch is required.
213pub trait ErasedToolExecutor: Send + Sync {
214    fn execute_erased<'a>(
215        &'a self,
216        response: &'a str,
217    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
218
219    fn execute_confirmed_erased<'a>(
220        &'a self,
221        response: &'a str,
222    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
223
224    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef>;
225
226    fn execute_tool_call_erased<'a>(
227        &'a self,
228        call: &'a ToolCall,
229    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
230
231    /// Inject environment variables for the currently active skill. No-op by default.
232    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
233}
234
235impl<T: ToolExecutor> ErasedToolExecutor for T {
236    fn execute_erased<'a>(
237        &'a self,
238        response: &'a str,
239    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
240    {
241        Box::pin(self.execute(response))
242    }
243
244    fn execute_confirmed_erased<'a>(
245        &'a self,
246        response: &'a str,
247    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
248    {
249        Box::pin(self.execute_confirmed(response))
250    }
251
252    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef> {
253        self.tool_definitions()
254    }
255
256    fn execute_tool_call_erased<'a>(
257        &'a self,
258        call: &'a ToolCall,
259    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
260    {
261        Box::pin(self.execute_tool_call(call))
262    }
263
264    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
265        ToolExecutor::set_skill_env(self, env);
266    }
267}
268
269/// Wraps `Arc<dyn ErasedToolExecutor>` so it can be used as a concrete `ToolExecutor`.
270///
271/// Enables dynamic composition of tool executors at runtime without static type chains.
272pub struct DynExecutor(pub std::sync::Arc<dyn ErasedToolExecutor>);
273
274impl ToolExecutor for DynExecutor {
275    fn execute(
276        &self,
277        response: &str,
278    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
279        // Clone data to satisfy the 'static-ish bound: erased futures must not borrow self.
280        let inner = std::sync::Arc::clone(&self.0);
281        let response = response.to_owned();
282        async move { inner.execute_erased(&response).await }
283    }
284
285    fn execute_confirmed(
286        &self,
287        response: &str,
288    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
289        let inner = std::sync::Arc::clone(&self.0);
290        let response = response.to_owned();
291        async move { inner.execute_confirmed_erased(&response).await }
292    }
293
294    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
295        self.0.tool_definitions_erased()
296    }
297
298    fn execute_tool_call(
299        &self,
300        call: &ToolCall,
301    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
302        let inner = std::sync::Arc::clone(&self.0);
303        let call = call.clone();
304        async move { inner.execute_tool_call_erased(&call).await }
305    }
306
307    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
308        ErasedToolExecutor::set_skill_env(self.0.as_ref(), env);
309    }
310}
311
312/// Extract fenced code blocks with the given language marker from text.
313///
314/// Searches for `` ```{lang} `` … `` ``` `` pairs, returning trimmed content.
315#[must_use]
316pub fn extract_fenced_blocks<'a>(text: &'a str, lang: &str) -> Vec<&'a str> {
317    let marker = format!("```{lang}");
318    let marker_len = marker.len();
319    let mut blocks = Vec::new();
320    let mut rest = text;
321
322    while let Some(start) = rest.find(&marker) {
323        let after = &rest[start + marker_len..];
324        if let Some(end) = after.find("```") {
325            blocks.push(after[..end].trim());
326            rest = &after[end + 3..];
327        } else {
328            break;
329        }
330    }
331
332    blocks
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn tool_output_display() {
341        let output = ToolOutput {
342            tool_name: "bash".to_owned(),
343            summary: "$ echo hello\nhello".to_owned(),
344            blocks_executed: 1,
345            filter_stats: None,
346            diff: None,
347            streamed: false,
348        };
349        assert_eq!(output.to_string(), "$ echo hello\nhello");
350    }
351
352    #[test]
353    fn tool_error_blocked_display() {
354        let err = ToolError::Blocked {
355            command: "rm -rf /".to_owned(),
356        };
357        assert_eq!(err.to_string(), "command blocked by policy: rm -rf /");
358    }
359
360    #[test]
361    fn tool_error_sandbox_violation_display() {
362        let err = ToolError::SandboxViolation {
363            path: "/etc/shadow".to_owned(),
364        };
365        assert_eq!(err.to_string(), "path not allowed by sandbox: /etc/shadow");
366    }
367
368    #[test]
369    fn tool_error_confirmation_required_display() {
370        let err = ToolError::ConfirmationRequired {
371            command: "rm -rf /tmp".to_owned(),
372        };
373        assert_eq!(
374            err.to_string(),
375            "command requires confirmation: rm -rf /tmp"
376        );
377    }
378
379    #[test]
380    fn tool_error_timeout_display() {
381        let err = ToolError::Timeout { timeout_secs: 30 };
382        assert_eq!(err.to_string(), "command timed out after 30s");
383    }
384
385    #[test]
386    fn tool_error_invalid_params_display() {
387        let err = ToolError::InvalidParams {
388            message: "missing field `command`".to_owned(),
389        };
390        assert_eq!(
391            err.to_string(),
392            "invalid tool parameters: missing field `command`"
393        );
394    }
395
396    #[test]
397    fn deserialize_params_valid() {
398        #[derive(Debug, serde::Deserialize, PartialEq)]
399        struct P {
400            name: String,
401            count: u32,
402        }
403        let mut map = serde_json::Map::new();
404        map.insert("name".to_owned(), serde_json::json!("test"));
405        map.insert("count".to_owned(), serde_json::json!(42));
406        let p: P = deserialize_params(&map).unwrap();
407        assert_eq!(
408            p,
409            P {
410                name: "test".to_owned(),
411                count: 42
412            }
413        );
414    }
415
416    #[test]
417    fn deserialize_params_missing_required_field() {
418        #[derive(Debug, serde::Deserialize)]
419        struct P {
420            #[allow(dead_code)]
421            name: String,
422        }
423        let map = serde_json::Map::new();
424        let err = deserialize_params::<P>(&map).unwrap_err();
425        assert!(matches!(err, ToolError::InvalidParams { .. }));
426    }
427
428    #[test]
429    fn deserialize_params_wrong_type() {
430        #[derive(Debug, serde::Deserialize)]
431        struct P {
432            #[allow(dead_code)]
433            count: u32,
434        }
435        let mut map = serde_json::Map::new();
436        map.insert("count".to_owned(), serde_json::json!("not a number"));
437        let err = deserialize_params::<P>(&map).unwrap_err();
438        assert!(matches!(err, ToolError::InvalidParams { .. }));
439    }
440
441    #[test]
442    fn deserialize_params_all_optional_empty() {
443        #[derive(Debug, serde::Deserialize, PartialEq)]
444        struct P {
445            name: Option<String>,
446        }
447        let map = serde_json::Map::new();
448        let p: P = deserialize_params(&map).unwrap();
449        assert_eq!(p, P { name: None });
450    }
451
452    #[test]
453    fn deserialize_params_ignores_extra_fields() {
454        #[derive(Debug, serde::Deserialize, PartialEq)]
455        struct P {
456            name: String,
457        }
458        let mut map = serde_json::Map::new();
459        map.insert("name".to_owned(), serde_json::json!("test"));
460        map.insert("extra".to_owned(), serde_json::json!(true));
461        let p: P = deserialize_params(&map).unwrap();
462        assert_eq!(
463            p,
464            P {
465                name: "test".to_owned()
466            }
467        );
468    }
469
470    #[test]
471    fn tool_error_execution_display() {
472        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found");
473        let err = ToolError::Execution(io_err);
474        assert!(err.to_string().starts_with("execution failed:"));
475        assert!(err.to_string().contains("bash not found"));
476    }
477
478    #[test]
479    fn truncate_tool_output_short_passthrough() {
480        let short = "hello world";
481        assert_eq!(truncate_tool_output(short), short);
482    }
483
484    #[test]
485    fn truncate_tool_output_exact_limit() {
486        let exact = "a".repeat(MAX_TOOL_OUTPUT_CHARS);
487        assert_eq!(truncate_tool_output(&exact), exact);
488    }
489
490    #[test]
491    fn truncate_tool_output_long_split() {
492        let long = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
493        let result = truncate_tool_output(&long);
494        assert!(result.contains("truncated"));
495        assert!(result.len() < long.len());
496    }
497
498    #[test]
499    fn truncate_tool_output_notice_contains_count() {
500        let long = "y".repeat(MAX_TOOL_OUTPUT_CHARS + 2000);
501        let result = truncate_tool_output(&long);
502        assert!(result.contains("truncated"));
503        assert!(result.contains("chars"));
504    }
505
506    #[derive(Debug)]
507    struct DefaultExecutor;
508    impl ToolExecutor for DefaultExecutor {
509        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
510            Ok(None)
511        }
512    }
513
514    #[tokio::test]
515    async fn execute_tool_call_default_returns_none() {
516        let exec = DefaultExecutor;
517        let call = ToolCall {
518            tool_id: "anything".to_owned(),
519            params: serde_json::Map::new(),
520        };
521        let result = exec.execute_tool_call(&call).await.unwrap();
522        assert!(result.is_none());
523    }
524
525    #[test]
526    fn filter_stats_savings_pct() {
527        let fs = FilterStats {
528            raw_chars: 1000,
529            filtered_chars: 200,
530            ..Default::default()
531        };
532        assert!((fs.savings_pct() - 80.0).abs() < 0.01);
533    }
534
535    #[test]
536    fn filter_stats_savings_pct_zero() {
537        let fs = FilterStats::default();
538        assert!((fs.savings_pct()).abs() < 0.01);
539    }
540
541    #[test]
542    fn filter_stats_estimated_tokens_saved() {
543        let fs = FilterStats {
544            raw_chars: 1000,
545            filtered_chars: 200,
546            ..Default::default()
547        };
548        assert_eq!(fs.estimated_tokens_saved(), 200); // (1000 - 200) / 4
549    }
550
551    #[test]
552    fn filter_stats_format_inline() {
553        let fs = FilterStats {
554            raw_chars: 1000,
555            filtered_chars: 200,
556            raw_lines: 342,
557            filtered_lines: 28,
558            ..Default::default()
559        };
560        let line = fs.format_inline("shell");
561        assert_eq!(line, "[shell] 342 lines \u{2192} 28 lines, 80.0% filtered");
562    }
563
564    #[test]
565    fn filter_stats_format_inline_zero() {
566        let fs = FilterStats::default();
567        let line = fs.format_inline("bash");
568        assert_eq!(line, "[bash] 0 lines \u{2192} 0 lines, 0.0% filtered");
569    }
570
571    // DynExecutor tests
572
573    struct FixedExecutor {
574        tool_id: &'static str,
575        output: &'static str,
576    }
577
578    impl ToolExecutor for FixedExecutor {
579        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
580            Ok(Some(ToolOutput {
581                tool_name: self.tool_id.to_owned(),
582                summary: self.output.to_owned(),
583                blocks_executed: 1,
584                filter_stats: None,
585                diff: None,
586                streamed: false,
587            }))
588        }
589
590        fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
591            vec![]
592        }
593
594        async fn execute_tool_call(
595            &self,
596            _call: &ToolCall,
597        ) -> Result<Option<ToolOutput>, ToolError> {
598            Ok(Some(ToolOutput {
599                tool_name: self.tool_id.to_owned(),
600                summary: self.output.to_owned(),
601                blocks_executed: 1,
602                filter_stats: None,
603                diff: None,
604                streamed: false,
605            }))
606        }
607    }
608
609    #[tokio::test]
610    async fn dyn_executor_execute_delegates() {
611        let inner = std::sync::Arc::new(FixedExecutor {
612            tool_id: "bash",
613            output: "hello",
614        });
615        let exec = DynExecutor(inner);
616        let result = exec.execute("```bash\necho hello\n```").await.unwrap();
617        assert!(result.is_some());
618        assert_eq!(result.unwrap().summary, "hello");
619    }
620
621    #[tokio::test]
622    async fn dyn_executor_execute_confirmed_delegates() {
623        let inner = std::sync::Arc::new(FixedExecutor {
624            tool_id: "bash",
625            output: "confirmed",
626        });
627        let exec = DynExecutor(inner);
628        let result = exec.execute_confirmed("...").await.unwrap();
629        assert!(result.is_some());
630        assert_eq!(result.unwrap().summary, "confirmed");
631    }
632
633    #[test]
634    fn dyn_executor_tool_definitions_delegates() {
635        let inner = std::sync::Arc::new(FixedExecutor {
636            tool_id: "my_tool",
637            output: "",
638        });
639        let exec = DynExecutor(inner);
640        // FixedExecutor returns empty definitions; verify delegation occurs without panic.
641        let defs = exec.tool_definitions();
642        assert!(defs.is_empty());
643    }
644
645    #[tokio::test]
646    async fn dyn_executor_execute_tool_call_delegates() {
647        let inner = std::sync::Arc::new(FixedExecutor {
648            tool_id: "bash",
649            output: "tool_call_result",
650        });
651        let exec = DynExecutor(inner);
652        let call = ToolCall {
653            tool_id: "bash".to_owned(),
654            params: serde_json::Map::new(),
655        };
656        let result = exec.execute_tool_call(&call).await.unwrap();
657        assert!(result.is_some());
658        assert_eq!(result.unwrap().summary, "tool_call_result");
659    }
660}