Skip to main content

zeph_tools/
executor.rs

1use std::fmt;
2
3/// Data for rendering file diffs in the TUI.
4#[derive(Debug, Clone)]
5pub struct DiffData {
6    pub file_path: String,
7    pub old_content: String,
8    pub new_content: String,
9}
10
11/// Structured tool invocation from LLM.
12#[derive(Debug, Clone)]
13pub struct ToolCall {
14    pub tool_id: String,
15    pub params: serde_json::Map<String, serde_json::Value>,
16}
17
18/// Cumulative filter statistics for a single tool execution.
19#[derive(Debug, Clone, Default)]
20pub struct FilterStats {
21    pub raw_chars: usize,
22    pub filtered_chars: usize,
23    pub raw_lines: usize,
24    pub filtered_lines: usize,
25    pub confidence: Option<crate::FilterConfidence>,
26    pub command: Option<String>,
27}
28
29impl FilterStats {
30    #[must_use]
31    #[allow(clippy::cast_precision_loss)]
32    pub fn savings_pct(&self) -> f64 {
33        if self.raw_chars == 0 {
34            return 0.0;
35        }
36        (1.0 - self.filtered_chars as f64 / self.raw_chars as f64) * 100.0
37    }
38
39    #[must_use]
40    pub fn estimated_tokens_saved(&self) -> usize {
41        self.raw_chars.saturating_sub(self.filtered_chars) / 4
42    }
43
44    #[must_use]
45    pub fn format_inline(&self, tool_name: &str) -> String {
46        let cmd_label = self
47            .command
48            .as_deref()
49            .map(|c| {
50                let trimmed = c.trim();
51                if trimmed.len() > 60 {
52                    format!(" `{}…`", &trimmed[..57])
53                } else {
54                    format!(" `{trimmed}`")
55                }
56            })
57            .unwrap_or_default();
58        format!(
59            "[{tool_name}]{cmd_label} {} lines \u{2192} {} lines, {:.1}% filtered",
60            self.raw_lines,
61            self.filtered_lines,
62            self.savings_pct()
63        )
64    }
65}
66
67/// Structured result from tool execution.
68#[derive(Debug, Clone)]
69pub struct ToolOutput {
70    pub tool_name: String,
71    pub summary: String,
72    pub blocks_executed: u32,
73    pub filter_stats: Option<FilterStats>,
74    pub diff: Option<DiffData>,
75    /// Whether this tool already streamed its output via `ToolEvent` channel.
76    pub streamed: bool,
77}
78
79impl fmt::Display for ToolOutput {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        f.write_str(&self.summary)
82    }
83}
84
85pub const MAX_TOOL_OUTPUT_CHARS: usize = 30_000;
86
87/// Truncate tool output that exceeds `MAX_TOOL_OUTPUT_CHARS` using head+tail split.
88#[must_use]
89pub fn truncate_tool_output(output: &str) -> String {
90    if output.len() <= MAX_TOOL_OUTPUT_CHARS {
91        return output.to_string();
92    }
93
94    let half = MAX_TOOL_OUTPUT_CHARS / 2;
95    let head_end = output.floor_char_boundary(half);
96    let tail_start = output.ceil_char_boundary(output.len() - half);
97    let head = &output[..head_end];
98    let tail = &output[tail_start..];
99    let truncated = output.len() - head_end - (output.len() - tail_start);
100
101    format!(
102        "{head}\n\n... [truncated {truncated} chars, showing first and last ~{half} chars] ...\n\n{tail}"
103    )
104}
105
106/// Event emitted during tool execution for real-time UI updates.
107#[derive(Debug, Clone)]
108pub enum ToolEvent {
109    Started {
110        tool_name: String,
111        command: String,
112    },
113    OutputChunk {
114        tool_name: String,
115        command: String,
116        chunk: String,
117    },
118    Completed {
119        tool_name: String,
120        command: String,
121        output: String,
122        success: bool,
123        filter_stats: Option<FilterStats>,
124        diff: Option<DiffData>,
125    },
126}
127
128pub type ToolEventTx = tokio::sync::mpsc::UnboundedSender<ToolEvent>;
129
130/// Errors that can occur during tool execution.
131#[derive(Debug, thiserror::Error)]
132pub enum ToolError {
133    #[error("command blocked by policy: {command}")]
134    Blocked { command: String },
135
136    #[error("path not allowed by sandbox: {path}")]
137    SandboxViolation { path: String },
138
139    #[error("command requires confirmation: {command}")]
140    ConfirmationRequired { command: String },
141
142    #[error("command timed out after {timeout_secs}s")]
143    Timeout { timeout_secs: u64 },
144
145    #[error("operation cancelled")]
146    Cancelled,
147
148    #[error("invalid tool parameters: {message}")]
149    InvalidParams { message: String },
150
151    #[error("execution failed: {0}")]
152    Execution(#[from] std::io::Error),
153}
154
155/// Deserialize tool call params from a `serde_json::Map<String, Value>` into a typed struct.
156///
157/// # Errors
158///
159/// Returns `ToolError::InvalidParams` when deserialization fails.
160pub fn deserialize_params<T: serde::de::DeserializeOwned>(
161    params: &serde_json::Map<String, serde_json::Value>,
162) -> Result<T, ToolError> {
163    let obj = serde_json::Value::Object(params.clone());
164    serde_json::from_value(obj).map_err(|e| ToolError::InvalidParams {
165        message: e.to_string(),
166    })
167}
168
169/// Async trait for tool execution backends (shell, future MCP, A2A).
170///
171/// Accepts the full LLM response and returns an optional output.
172/// Returns `None` when no tool invocation is detected in the response.
173pub trait ToolExecutor: Send + Sync {
174    fn execute(
175        &self,
176        response: &str,
177    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send;
178
179    /// Execute bypassing confirmation checks (called after user approves).
180    /// Default: delegates to `execute`.
181    fn execute_confirmed(
182        &self,
183        response: &str,
184    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
185        self.execute(response)
186    }
187
188    /// Return tool definitions this executor can handle.
189    fn tool_definitions(&self) -> Vec<crate::registry::ToolDef> {
190        vec![]
191    }
192
193    /// Execute a structured tool call. Returns `None` if `tool_id` is not handled.
194    fn execute_tool_call(
195        &self,
196        _call: &ToolCall,
197    ) -> impl Future<Output = Result<Option<ToolOutput>, ToolError>> + Send {
198        std::future::ready(Ok(None))
199    }
200
201    /// Inject environment variables for the currently active skill. No-op by default.
202    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
203}
204
205/// Object-safe erased version of [`ToolExecutor`] using boxed futures.
206///
207/// Implemented automatically for all `T: ToolExecutor + 'static`.
208/// Use `Box<dyn ErasedToolExecutor>` when dynamic dispatch is required.
209pub trait ErasedToolExecutor: Send + Sync {
210    fn execute_erased<'a>(
211        &'a self,
212        response: &'a str,
213    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
214
215    fn execute_confirmed_erased<'a>(
216        &'a self,
217        response: &'a str,
218    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
219
220    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef>;
221
222    fn execute_tool_call_erased<'a>(
223        &'a self,
224        call: &'a ToolCall,
225    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>;
226
227    /// Inject environment variables for the currently active skill. No-op by default.
228    fn set_skill_env(&self, _env: Option<std::collections::HashMap<String, String>>) {}
229}
230
231impl<T: ToolExecutor> ErasedToolExecutor for T {
232    fn execute_erased<'a>(
233        &'a self,
234        response: &'a str,
235    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
236    {
237        Box::pin(self.execute(response))
238    }
239
240    fn execute_confirmed_erased<'a>(
241        &'a self,
242        response: &'a str,
243    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
244    {
245        Box::pin(self.execute_confirmed(response))
246    }
247
248    fn tool_definitions_erased(&self) -> Vec<crate::registry::ToolDef> {
249        self.tool_definitions()
250    }
251
252    fn execute_tool_call_erased<'a>(
253        &'a self,
254        call: &'a ToolCall,
255    ) -> std::pin::Pin<Box<dyn Future<Output = Result<Option<ToolOutput>, ToolError>> + Send + 'a>>
256    {
257        Box::pin(self.execute_tool_call(call))
258    }
259
260    fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
261        ToolExecutor::set_skill_env(self, env);
262    }
263}
264
265/// Extract fenced code blocks with the given language marker from text.
266///
267/// Searches for `` ```{lang} `` … `` ``` `` pairs, returning trimmed content.
268#[must_use]
269pub fn extract_fenced_blocks<'a>(text: &'a str, lang: &str) -> Vec<&'a str> {
270    let marker = format!("```{lang}");
271    let marker_len = marker.len();
272    let mut blocks = Vec::new();
273    let mut rest = text;
274
275    while let Some(start) = rest.find(&marker) {
276        let after = &rest[start + marker_len..];
277        if let Some(end) = after.find("```") {
278            blocks.push(after[..end].trim());
279            rest = &after[end + 3..];
280        } else {
281            break;
282        }
283    }
284
285    blocks
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    #[test]
293    fn tool_output_display() {
294        let output = ToolOutput {
295            tool_name: "bash".to_owned(),
296            summary: "$ echo hello\nhello".to_owned(),
297            blocks_executed: 1,
298            filter_stats: None,
299            diff: None,
300            streamed: false,
301        };
302        assert_eq!(output.to_string(), "$ echo hello\nhello");
303    }
304
305    #[test]
306    fn tool_error_blocked_display() {
307        let err = ToolError::Blocked {
308            command: "rm -rf /".to_owned(),
309        };
310        assert_eq!(err.to_string(), "command blocked by policy: rm -rf /");
311    }
312
313    #[test]
314    fn tool_error_sandbox_violation_display() {
315        let err = ToolError::SandboxViolation {
316            path: "/etc/shadow".to_owned(),
317        };
318        assert_eq!(err.to_string(), "path not allowed by sandbox: /etc/shadow");
319    }
320
321    #[test]
322    fn tool_error_confirmation_required_display() {
323        let err = ToolError::ConfirmationRequired {
324            command: "rm -rf /tmp".to_owned(),
325        };
326        assert_eq!(
327            err.to_string(),
328            "command requires confirmation: rm -rf /tmp"
329        );
330    }
331
332    #[test]
333    fn tool_error_timeout_display() {
334        let err = ToolError::Timeout { timeout_secs: 30 };
335        assert_eq!(err.to_string(), "command timed out after 30s");
336    }
337
338    #[test]
339    fn tool_error_invalid_params_display() {
340        let err = ToolError::InvalidParams {
341            message: "missing field `command`".to_owned(),
342        };
343        assert_eq!(
344            err.to_string(),
345            "invalid tool parameters: missing field `command`"
346        );
347    }
348
349    #[test]
350    fn deserialize_params_valid() {
351        #[derive(Debug, serde::Deserialize, PartialEq)]
352        struct P {
353            name: String,
354            count: u32,
355        }
356        let mut map = serde_json::Map::new();
357        map.insert("name".to_owned(), serde_json::json!("test"));
358        map.insert("count".to_owned(), serde_json::json!(42));
359        let p: P = deserialize_params(&map).unwrap();
360        assert_eq!(
361            p,
362            P {
363                name: "test".to_owned(),
364                count: 42
365            }
366        );
367    }
368
369    #[test]
370    fn deserialize_params_missing_required_field() {
371        #[derive(Debug, serde::Deserialize)]
372        struct P {
373            #[allow(dead_code)]
374            name: String,
375        }
376        let map = serde_json::Map::new();
377        let err = deserialize_params::<P>(&map).unwrap_err();
378        assert!(matches!(err, ToolError::InvalidParams { .. }));
379    }
380
381    #[test]
382    fn deserialize_params_wrong_type() {
383        #[derive(Debug, serde::Deserialize)]
384        struct P {
385            #[allow(dead_code)]
386            count: u32,
387        }
388        let mut map = serde_json::Map::new();
389        map.insert("count".to_owned(), serde_json::json!("not a number"));
390        let err = deserialize_params::<P>(&map).unwrap_err();
391        assert!(matches!(err, ToolError::InvalidParams { .. }));
392    }
393
394    #[test]
395    fn deserialize_params_all_optional_empty() {
396        #[derive(Debug, serde::Deserialize, PartialEq)]
397        struct P {
398            name: Option<String>,
399        }
400        let map = serde_json::Map::new();
401        let p: P = deserialize_params(&map).unwrap();
402        assert_eq!(p, P { name: None });
403    }
404
405    #[test]
406    fn deserialize_params_ignores_extra_fields() {
407        #[derive(Debug, serde::Deserialize, PartialEq)]
408        struct P {
409            name: String,
410        }
411        let mut map = serde_json::Map::new();
412        map.insert("name".to_owned(), serde_json::json!("test"));
413        map.insert("extra".to_owned(), serde_json::json!(true));
414        let p: P = deserialize_params(&map).unwrap();
415        assert_eq!(
416            p,
417            P {
418                name: "test".to_owned()
419            }
420        );
421    }
422
423    #[test]
424    fn tool_error_execution_display() {
425        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "bash not found");
426        let err = ToolError::Execution(io_err);
427        assert!(err.to_string().starts_with("execution failed:"));
428        assert!(err.to_string().contains("bash not found"));
429    }
430
431    #[test]
432    fn truncate_tool_output_short_passthrough() {
433        let short = "hello world";
434        assert_eq!(truncate_tool_output(short), short);
435    }
436
437    #[test]
438    fn truncate_tool_output_exact_limit() {
439        let exact = "a".repeat(MAX_TOOL_OUTPUT_CHARS);
440        assert_eq!(truncate_tool_output(&exact), exact);
441    }
442
443    #[test]
444    fn truncate_tool_output_long_split() {
445        let long = "x".repeat(MAX_TOOL_OUTPUT_CHARS + 1000);
446        let result = truncate_tool_output(&long);
447        assert!(result.contains("truncated"));
448        assert!(result.len() < long.len());
449    }
450
451    #[test]
452    fn truncate_tool_output_notice_contains_count() {
453        let long = "y".repeat(MAX_TOOL_OUTPUT_CHARS + 2000);
454        let result = truncate_tool_output(&long);
455        assert!(result.contains("truncated"));
456        assert!(result.contains("chars"));
457    }
458
459    #[derive(Debug)]
460    struct DefaultExecutor;
461    impl ToolExecutor for DefaultExecutor {
462        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
463            Ok(None)
464        }
465    }
466
467    #[tokio::test]
468    async fn execute_tool_call_default_returns_none() {
469        let exec = DefaultExecutor;
470        let call = ToolCall {
471            tool_id: "anything".to_owned(),
472            params: serde_json::Map::new(),
473        };
474        let result = exec.execute_tool_call(&call).await.unwrap();
475        assert!(result.is_none());
476    }
477
478    #[test]
479    fn filter_stats_savings_pct() {
480        let fs = FilterStats {
481            raw_chars: 1000,
482            filtered_chars: 200,
483            ..Default::default()
484        };
485        assert!((fs.savings_pct() - 80.0).abs() < 0.01);
486    }
487
488    #[test]
489    fn filter_stats_savings_pct_zero() {
490        let fs = FilterStats::default();
491        assert!((fs.savings_pct()).abs() < 0.01);
492    }
493
494    #[test]
495    fn filter_stats_estimated_tokens_saved() {
496        let fs = FilterStats {
497            raw_chars: 1000,
498            filtered_chars: 200,
499            ..Default::default()
500        };
501        assert_eq!(fs.estimated_tokens_saved(), 200); // (1000 - 200) / 4
502    }
503
504    #[test]
505    fn filter_stats_format_inline() {
506        let fs = FilterStats {
507            raw_chars: 1000,
508            filtered_chars: 200,
509            raw_lines: 342,
510            filtered_lines: 28,
511            ..Default::default()
512        };
513        let line = fs.format_inline("shell");
514        assert_eq!(line, "[shell] 342 lines \u{2192} 28 lines, 80.0% filtered");
515    }
516
517    #[test]
518    fn filter_stats_format_inline_zero() {
519        let fs = FilterStats::default();
520        let line = fs.format_inline("bash");
521        assert_eq!(line, "[bash] 0 lines \u{2192} 0 lines, 0.0% filtered");
522    }
523}