Skip to main content

zeph_tools/
executor.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5
6/// Data for rendering file diffs in the TUI.
7#[derive(Debug, Clone)]
8pub struct DiffData {
9    pub file_path: String,
10    pub old_content: String,
11    pub new_content: String,
12}
13
14/// Structured tool invocation from LLM.
15#[derive(Debug, Clone)]
16pub struct ToolCall {
17    pub tool_id: String,
18    pub params: serde_json::Map<String, serde_json::Value>,
19}
20
21/// Cumulative filter statistics for a single tool execution.
22#[derive(Debug, Clone, Default)]
23pub struct FilterStats {
24    pub raw_chars: usize,
25    pub filtered_chars: usize,
26    pub raw_lines: usize,
27    pub filtered_lines: usize,
28    pub confidence: Option<crate::FilterConfidence>,
29    pub command: Option<String>,
30    pub kept_lines: Vec<usize>,
31}
32
33impl FilterStats {
34    #[must_use]
35    #[allow(clippy::cast_precision_loss)]
36    pub fn savings_pct(&self) -> f64 {
37        if self.raw_chars == 0 {
38            return 0.0;
39        }
40        (1.0 - self.filtered_chars as f64 / self.raw_chars as f64) * 100.0
41    }
42
43    #[must_use]
44    pub fn estimated_tokens_saved(&self) -> usize {
45        self.raw_chars.saturating_sub(self.filtered_chars) / 4
46    }
47
48    #[must_use]
49    pub fn format_inline(&self, tool_name: &str) -> String {
50        let cmd_label = self
51            .command
52            .as_deref()
53            .map(|c| {
54                let trimmed = c.trim();
55                if trimmed.len() > 60 {
56                    format!(" `{}…`", &trimmed[..57])
57                } else {
58                    format!(" `{trimmed}`")
59                }
60            })
61            .unwrap_or_default();
62        format!(
63            "[{tool_name}]{cmd_label} {} lines \u{2192} {} lines, {:.1}% filtered",
64            self.raw_lines,
65            self.filtered_lines,
66            self.savings_pct()
67        )
68    }
69}
70
71/// Structured result from tool execution.
72#[derive(Debug, Clone)]
73pub struct ToolOutput {
74    pub tool_name: String,
75    pub summary: String,
76    pub blocks_executed: u32,
77    pub filter_stats: Option<FilterStats>,
78    pub diff: Option<DiffData>,
79    /// Whether this tool already streamed its output via `ToolEvent` channel.
80    pub streamed: bool,
81    /// Terminal ID when the tool was executed via IDE terminal (ACP terminal/* protocol).
82    pub terminal_id: Option<String>,
83}
84
85impl fmt::Display for ToolOutput {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.write_str(&self.summary)
88    }
89}
90
91pub const MAX_TOOL_OUTPUT_CHARS: usize = 30_000;
92
93/// Truncate tool output that exceeds `MAX_TOOL_OUTPUT_CHARS` using head+tail split.
94#[must_use]
95pub fn truncate_tool_output(output: &str) -> String {
96    if output.len() <= MAX_TOOL_OUTPUT_CHARS {
97        return output.to_string();
98    }
99
100    let half = MAX_TOOL_OUTPUT_CHARS / 2;
101    let head_end = output.floor_char_boundary(half);
102    let tail_start = output.ceil_char_boundary(output.len() - half);
103    let head = &output[..head_end];
104    let tail = &output[tail_start..];
105    let truncated = output.len() - head_end - (output.len() - tail_start);
106
107    format!(
108        "{head}\n\n... [truncated {truncated} chars, showing first and last ~{half} chars] ...\n\n{tail}"
109    )
110}
111
112/// Event emitted during tool execution for real-time UI updates.
113#[derive(Debug, Clone)]
114pub enum ToolEvent {
115    Started {
116        tool_name: String,
117        command: String,
118    },
119    OutputChunk {
120        tool_name: String,
121        command: String,
122        chunk: String,
123    },
124    Completed {
125        tool_name: String,
126        command: String,
127        output: String,
128        success: bool,
129        filter_stats: Option<FilterStats>,
130        diff: Option<DiffData>,
131    },
132}
133
134pub type ToolEventTx = tokio::sync::mpsc::UnboundedSender<ToolEvent>;
135
136/// Errors that can occur during tool execution.
137#[derive(Debug, thiserror::Error)]
138pub enum ToolError {
139    #[error("command blocked by policy: {command}")]
140    Blocked { command: String },
141
142    #[error("path not allowed by sandbox: {path}")]
143    SandboxViolation { path: String },
144
145    #[error("command requires confirmation: {command}")]
146    ConfirmationRequired { command: String },
147
148    #[error("command timed out after {timeout_secs}s")]
149    Timeout { timeout_secs: u64 },
150
151    #[error("operation cancelled")]
152    Cancelled,
153
154    #[error("invalid tool parameters: {message}")]
155    InvalidParams { message: String },
156
157    #[error("execution failed: {0}")]
158    Execution(#[from] std::io::Error),
159}
160
161/// Deserialize tool call params from a `serde_json::Map<String, Value>` into a typed struct.
162///
163/// # Errors
164///
165/// Returns `ToolError::InvalidParams` when deserialization fails.
166pub fn deserialize_params<T: serde::de::DeserializeOwned>(
167    params: &serde_json::Map<String, serde_json::Value>,
168) -> Result<T, ToolError> {
169    let obj = serde_json::Value::Object(params.clone());
170    serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams {
171        message: e.to_string(),
172    })
173}
174
175/// Async trait for tool execution backends (shell, future MCP, A2A).
176///
177/// Accepts the full LLM response and returns an optional output.
178/// Returns `None` when no tool invocation is detected in the response.
179pub trait ToolExecutor: Send + Sync {
180    fn execute(
181        &self,
182        response: &str,
183    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send;
184
185    /// Execute bypassing confirmation checks (called after user approves).
186    /// Default: delegates to `execute`.
187    fn execute_confirmed(
188        &self,
189        response: &str,
190    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
191        self.execute(response)
192    }
193
194    /// Return tool definitions this executor can handle.
195    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
196        vec![]
197    }
198
199    /// Execute a structured tool call. Returns `None` if `tool_id` is not handled.
200    fn execute_tool_call(
201        &self,
202        _call: &ToolCall,
203    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
204        std::future::ready(Ok(None))
205    }
206
207    /// Inject environment variables for the currently active skill. No-op by default.
208    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
209}
210
211/// Object-safe erased version of [`ToolExecutor`] using boxed futures.
212///
213/// Implemented automatically for all `T: ToolExecutor + 'static`.
214/// Use `Box<dyn ErasedToolExecutor>` when dynamic dispatch is required.
215pub trait ErasedToolExecutor: Send + Sync {
216    fn execute_erased<'a>(
217        &'a self,
218        response: &'a str,
219    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
220
221    fn execute_confirmed_erased<'a>(
222        &'a self,
223        response: &'a str,
224    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
225
226    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef>;
227
228    fn execute_tool_call_erased<'a>(
229        &'a self,
230        call: &'a ToolCall,
231    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
232
233    /// Inject environment variables for the currently active skill. No-op by default.
234    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
235}
236
237impl<T: ToolExecutor> ErasedToolExecutor for T {
238    fn execute_erased<'a>(
239        &'a self,
240        response: &'a str,
241    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
242    {
243        Box::pin(self.execute(response))
244    }
245
246    fn execute_confirmed_erased<'a>(
247        &'a self,
248        response: &'a str,
249    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
250    {
251        Box::pin(self.execute_confirmed(response))
252    }
253
254    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef> {
255        self.tool_definitions()
256    }
257
258    fn execute_tool_call_erased<'a>(
259        &'a self,
260        call: &'a ToolCall,
261    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
262    {
263        Box::pin(self.execute_tool_call(call))
264    }
265
266    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
267        ToolExecutor::set_skill_env(self, env);
268    }
269}
270
271/// Wraps `Arc<dyn ErasedToolExecutor>` so it can be used as a concrete `ToolExecutor`.
272///
273/// Enables dynamic composition of tool executors at runtime without static type chains.
274pub struct DynExecutor(pub std::sync::Arc<dyn ErasedToolExecutor>);
275
276impl ToolExecutor for DynExecutor {
277    fn execute(
278        &self,
279        response: &str,
280    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
281        // Clone data to satisfy the 'static-ish bound: erased futures must not borrow self.
282        let inner = std::sync::Arc::clone(&self.0);
283        let response = response.to_owned();
284        async move { inner.execute_erased(&response).await }
285    }
286
287    fn execute_confirmed(
288        &self,
289        response: &str,
290    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
291        let inner = std::sync::Arc::clone(&self.0);
292        let response = response.to_owned();
293        async move { inner.execute_confirmed_erased(&response).await }
294    }
295
296    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
297        self.0.tool_definitions_erased()
298    }
299
300    fn execute_tool_call(
301        &self,
302        call: &ToolCall,
303    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
304        let inner = std::sync::Arc::clone(&self.0);
305        let call = call.clone();
306        async move { inner.execute_tool_call_erased(&call).await }
307    }
308
309    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
310        ErasedToolExecutor::set_skill_env(self.0.as_ref(), env);
311    }
312}
313
314/// Extract fenced code blocks with the given language marker from text.
315///
316/// Searches for `` ```{lang} `` … `` ``` `` pairs, returning trimmed content.
317#[must_use]
318pub fn extract_fenced_blocks<'a>(text: &'a str, lang: &str) -> Vec<&'a str> {
319    let marker = format!("```{lang}");
320    let marker_len = marker.len();
321    let mut blocks = Vec::new();
322    let mut rest = text;
323
324    while let Some(start) = rest.find(&marker) {
325        let after = &rest[start + marker_len..];
326        if let Some(end) = after.find("```") {
327            blocks.push(after[..end].trim());
328            rest = &after[end + 3..];
329        } else {
330            break;
331        }
332    }
333
334    blocks
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn tool_output_display() {
343        let output = ToolOutput {
344            tool_name: "bash".to_owned(),
345            summary: "$ echo hello\nhello".to_owned(),
346            blocks_executed: 1,
347            filter_stats: None,
348            diff: None,
349            streamed: false,
350            terminal_id: None,
351        };
352        assert_eq!(output.to_string(), "$ echo hello\nhello");
353    }
354
355    #[test]
356    fn tool_error_blocked_display() {
357        let err = ToolError::Blocked {
358            command: "rm -rf /".to_owned(),
359        };
360        assert_eq!(err.to_string(), "command blocked by policy: rm -rf /");
361    }
362
363    #[test]
364    fn tool_error_sandbox_violation_display() {
365        let err = ToolError::SandboxViolation {
366            path: "/etc/shadow".to_owned(),
367        };
368        assert_eq!(err.to_string(), "path not allowed by sandbox: /etc/shadow");
369    }
370
371    #[test]
372    fn tool_error_confirmation_required_display() {
373        let err = ToolError::ConfirmationRequired {
374            command: "rm -rf /tmp".to_owned(),
375        };
376        assert_eq!(
377            err.to_string(),
378            "command requires confirmation: rm -rf /tmp"
379        );
380    }
381
382    #[test]
383    fn tool_error_timeout_display() {
384        let err = ToolError::Timeout { timeout_secs: 30 };
385        assert_eq!(err.to_string(), "command timed out after 30s");
386    }
387
388    #[test]
389    fn tool_error_invalid_params_display() {
390        let err = ToolError::InvalidParams {
391            message: "missing field `command`".to_owned(),
392        };
393        assert_eq!(
394            err.to_string(),
395            "invalid tool parameters: missing field `command`"
396        );
397    }
398
399    #[test]
400    fn deserialize_params_valid() {
401        #[derive(Debug, serde::Deserialize, PartialEq)]
402        struct P {
403            name: String,
404            count: u32,
405        }
406        let mut map = serde_json::Map::new();
407        map.insert("name".to_owned(), serde_json::json!("test"));
408        map.insert("count".to_owned(), serde_json::json!(42));
409        let p: P = deserialize_params(&map).unwrap();
410        assert_eq!(
411            p,
412            P {
413                name: "test".to_owned(),
414                count: 42
415            }
416        );
417    }
418
419    #[test]
420    fn deserialize_params_missing_required_field() {
421        #[derive(Debug, serde::Deserialize)]
422        #[allow(dead_code)]
423        struct P {
424            name: String,
425        }
426        let map = serde_json::Map::new();
427        let err = deserialize_params::<P>(&map).unwrap_err();
428        assert!(matches!(err, ToolError::InvalidParams { .. }));
429    }
430
431    #[test]
432    fn deserialize_params_wrong_type() {
433        #[derive(Debug, serde::Deserialize)]
434        #[allow(dead_code)]
435        struct P {
436            count: u32,
437        }
438        let mut map = serde_json::Map::new();
439        map.insert("count".to_owned(), serde_json::json!("not a number"));
440        let err = deserialize_params::<P>(&map).unwrap_err();
441        assert!(matches!(err, ToolError::InvalidParams { .. }));
442    }
443
444    #[test]
445    fn deserialize_params_all_optional_empty() {
446        #[derive(Debug, serde::Deserialize, PartialEq)]
447        struct P {
448            name: Option<String>,
449        }
450        let map = serde_json::Map::new();
451        let p: P = deserialize_params(&map).unwrap();
452        assert_eq!(p, P { name: None });
453    }
454
455    #[test]
456    fn deserialize_params_ignores_extra_fields() {
457        #[derive(Debug, serde::Deserialize, PartialEq)]
458        struct P {
459            name: String,
460        }
461        let mut map = serde_json::Map::new();
462        map.insert("name".to_owned(), serde_json::json!("test"));
463        map.insert("extra".to_owned(), serde_json::json!(true));
464        let p: P = deserialize_params(&map).unwrap();
465        assert_eq!(
466            p,
467            P {
468                name: "test".to_owned()
469            }
470        );
471    }
472
473    #[test]
474    fn tool_error_execution_display() {
475        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found");
476        let err = ToolError::Execution(io_err);
477        assert!(err.to_string().starts_with("execution failed:"));
478        assert!(err.to_string().contains("bash not found"));
479    }
480
481    #[test]
482    fn truncate_tool_output_short_passthrough() {
483        let short = "hello world";
484        assert_eq!(truncate_tool_output(short), short);
485    }
486
487    #[test]
488    fn truncate_tool_output_exact_limit() {
489        let exact = "a".repeat(MAX_TOOL_OUTPUT_CHARS);
490        assert_eq!(truncate_tool_output(&exact), exact);
491    }
492
493    #[test]
494    fn truncate_tool_output_long_split() {
495        let long = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
496        let result = truncate_tool_output(&long);
497        assert!(result.contains("truncated"));
498        assert!(result.len() < long.len());
499    }
500
501    #[test]
502    fn truncate_tool_output_notice_contains_count() {
503        let long = "y".repeat(MAX_TOOL_OUTPUT_CHARS + 2000);
504        let result = truncate_tool_output(&long);
505        assert!(result.contains("truncated"));
506        assert!(result.contains("chars"));
507    }
508
509    #[derive(Debug)]
510    struct DefaultExecutor;
511    impl ToolExecutor for DefaultExecutor {
512        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
513            Ok(None)
514        }
515    }
516
517    #[tokio::test]
518    async fn execute_tool_call_default_returns_none() {
519        let exec = DefaultExecutor;
520        let call = ToolCall {
521            tool_id: "anything".to_owned(),
522            params: serde_json::Map::new(),
523        };
524        let result = exec.execute_tool_call(&call).await.unwrap();
525        assert!(result.is_none());
526    }
527
528    #[test]
529    fn filter_stats_savings_pct() {
530        let fs = FilterStats {
531            raw_chars: 1000,
532            filtered_chars: 200,
533            ..Default::default()
534        };
535        assert!((fs.savings_pct() - 80.0).abs() < 0.01);
536    }
537
538    #[test]
539    fn filter_stats_savings_pct_zero() {
540        let fs = FilterStats::default();
541        assert!((fs.savings_pct()).abs() < 0.01);
542    }
543
544    #[test]
545    fn filter_stats_estimated_tokens_saved() {
546        let fs = FilterStats {
547            raw_chars: 1000,
548            filtered_chars: 200,
549            ..Default::default()
550        };
551        assert_eq!(fs.estimated_tokens_saved(), 200); // (1000 - 200) / 4
552    }
553
554    #[test]
555    fn filter_stats_format_inline() {
556        let fs = FilterStats {
557            raw_chars: 1000,
558            filtered_chars: 200,
559            raw_lines: 342,
560            filtered_lines: 28,
561            ..Default::default()
562        };
563        let line = fs.format_inline("shell");
564        assert_eq!(line, "[shell] 342 lines \u{2192} 28 lines, 80.0% filtered");
565    }
566
567    #[test]
568    fn filter_stats_format_inline_zero() {
569        let fs = FilterStats::default();
570        let line = fs.format_inline("bash");
571        assert_eq!(line, "[bash] 0 lines \u{2192} 0 lines, 0.0% filtered");
572    }
573
574    // DynExecutor tests
575
576    struct FixedExecutor {
577        tool_id: &'static str,
578        output: &'static str,
579    }
580
581    impl ToolExecutor for FixedExecutor {
582        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
583            Ok(Some(ToolOutput {
584                tool_name: self.tool_id.to_owned(),
585                summary: self.output.to_owned(),
586                blocks_executed: 1,
587                filter_stats: None,
588                diff: None,
589                streamed: false,
590                terminal_id: None,
591            }))
592        }
593
594        fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
595            vec![]
596        }
597
598        async fn execute_tool_call(
599            &self,
600            _call: &ToolCall,
601        ) -> Result<Option<ToolOutput>, ToolError> {
602            Ok(Some(ToolOutput {
603                tool_name: self.tool_id.to_owned(),
604                summary: self.output.to_owned(),
605                blocks_executed: 1,
606                filter_stats: None,
607                diff: None,
608                streamed: false,
609                terminal_id: None,
610            }))
611        }
612    }
613
614    #[tokio::test]
615    async fn dyn_executor_execute_delegates() {
616        let inner = std::sync::Arc::new(FixedExecutor {
617            tool_id: "bash",
618            output: "hello",
619        });
620        let exec = DynExecutor(inner);
621        let result = exec.execute("```bash\necho hello\n```").await.unwrap();
622        assert!(result.is_some());
623        assert_eq!(result.unwrap().summary, "hello");
624    }
625
626    #[tokio::test]
627    async fn dyn_executor_execute_confirmed_delegates() {
628        let inner = std::sync::Arc::new(FixedExecutor {
629            tool_id: "bash",
630            output: "confirmed",
631        });
632        let exec = DynExecutor(inner);
633        let result = exec.execute_confirmed("...").await.unwrap();
634        assert!(result.is_some());
635        assert_eq!(result.unwrap().summary, "confirmed");
636    }
637
638    #[test]
639    fn dyn_executor_tool_definitions_delegates() {
640        let inner = std::sync::Arc::new(FixedExecutor {
641            tool_id: "my_tool",
642            output: "",
643        });
644        let exec = DynExecutor(inner);
645        // FixedExecutor returns empty definitions; verify delegation occurs without panic.
646        let defs = exec.tool_definitions();
647        assert!(defs.is_empty());
648    }
649
650    #[tokio::test]
651    async fn dyn_executor_execute_tool_call_delegates() {
652        let inner = std::sync::Arc::new(FixedExecutor {
653            tool_id: "bash",
654            output: "tool_call_result",
655        });
656        let exec = DynExecutor(inner);
657        let call = ToolCall {
658            tool_id: "bash".to_owned(),
659            params: serde_json::Map::new(),
660        };
661        let result = exec.execute_tool_call(&call).await.unwrap();
662        assert!(result.is_some());
663        assert_eq!(result.unwrap().summary, "tool_call_result");
664    }
665}