syncable_cli/agent/ui/
hooks.rs

1//! Rig PromptHook implementations for UI updates
2//!
3//! Provides hooks that update the UI when tools are called during agent execution.
4
5use crate::agent::ui::Spinner;
6use rig::agent::CancelSignal;
7use rig::completion::CompletionModel;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10
11/// A hook that updates the spinner when tools are executed
12#[derive(Clone)]
13pub struct ToolDisplayHook {
14    sender: mpsc::Sender<ToolEvent>,
15}
16
17/// Events sent from the hook to the UI
18#[derive(Debug, Clone)]
19pub enum ToolEvent {
20    ToolStart { name: String, args: String },
21    ToolComplete { name: String, result: String },
22}
23
24impl ToolDisplayHook {
25    /// Create a new hook with a channel to send tool events
26    pub fn new() -> (Self, mpsc::Receiver<ToolEvent>) {
27        let (sender, receiver) = mpsc::channel(32);
28        (Self { sender }, receiver)
29    }
30
31    /// Create a hook from an existing sender
32    pub fn from_sender(sender: mpsc::Sender<ToolEvent>) -> Self {
33        Self { sender }
34    }
35}
36
37impl Default for ToolDisplayHook {
38    fn default() -> Self {
39        let (hook, _) = Self::new();
40        hook
41    }
42}
43
44impl<M> rig::agent::PromptHook<M> for ToolDisplayHook
45where
46    M: CompletionModel,
47{
48    fn on_tool_call(
49        &self,
50        tool_name: &str,
51        args: &str,
52        _cancel: CancelSignal,
53    ) -> impl std::future::Future<Output = ()> + Send {
54        let sender = self.sender.clone();
55        let name = tool_name.to_string();
56        let args_str = args.to_string();
57
58        async move {
59            let _ = sender
60                .send(ToolEvent::ToolStart {
61                    name,
62                    args: args_str,
63                })
64                .await;
65        }
66    }
67
68    fn on_tool_result(
69        &self,
70        tool_name: &str,
71        _args: &str,
72        result: &str,
73        _cancel: CancelSignal,
74    ) -> impl std::future::Future<Output = ()> + Send {
75        let sender = self.sender.clone();
76        let name = tool_name.to_string();
77        let result_str = result.to_string();
78
79        async move {
80            let _ = sender
81                .send(ToolEvent::ToolComplete {
82                    name,
83                    result: result_str,
84                })
85                .await;
86        }
87    }
88}
89
90/// Spawns a task that listens for tool events and updates the spinner
91pub fn spawn_tool_display_handler(
92    mut receiver: mpsc::Receiver<ToolEvent>,
93    spinner: Arc<Spinner>,
94) -> tokio::task::JoinHandle<()> {
95    tokio::spawn(async move {
96        while let Some(event) = receiver.recv().await {
97            match event {
98                ToolEvent::ToolStart { name, args } => {
99                    // Format a nice description from the tool name
100                    let description = format_tool_description(&name, &args);
101                    spinner.tool_executing(&name, &description).await;
102                }
103                ToolEvent::ToolComplete { name, .. } => {
104                    spinner.tool_complete(&name).await;
105                }
106            }
107        }
108    })
109}
110
111/// Format a user-friendly description for a tool based on its name and args
112fn format_tool_description(name: &str, args: &str) -> String {
113    match name {
114        "analyze_project" => "Analyzing project structure...".to_string(),
115        "security_scan" => "Running security scan...".to_string(),
116        "check_vulnerabilities" => "Checking for vulnerabilities...".to_string(),
117        "read_file" => {
118            // Try to extract the file path from args
119            if let Ok(args_value) = serde_json::from_str::<serde_json::Value>(args) {
120                if let Some(path) = args_value.get("path").and_then(|p| p.as_str()) {
121                    return format!("Reading {}", truncate_path(path));
122                }
123            }
124            "Reading file...".to_string()
125        }
126        "list_directory" => {
127            if let Ok(args_value) = serde_json::from_str::<serde_json::Value>(args) {
128                if let Some(path) = args_value.get("path").and_then(|p| p.as_str()) {
129                    return format!("Listing {}", truncate_path(path));
130                }
131            }
132            "Listing directory...".to_string()
133        }
134        "search_code" => {
135            if let Ok(args_value) = serde_json::from_str::<serde_json::Value>(args) {
136                if let Some(pattern) = args_value.get("pattern").and_then(|p| p.as_str()) {
137                    return format!("Searching for '{}'...", truncate_text(pattern, 30));
138                }
139            }
140            "Searching code...".to_string()
141        }
142        "find_files" => "Finding files...".to_string(),
143        "generate_iac" => "Generating infrastructure config...".to_string(),
144        "discover_services" => "Discovering services...".to_string(),
145        _ => format!("Executing {}...", name),
146    }
147}
148
149/// Truncate a path for display
150fn truncate_path(path: &str) -> String {
151    if path.len() <= 40 {
152        path.to_string()
153    } else {
154        // Show last 40 chars with ...
155        format!("...{}", &path[path.len() - 37..])
156    }
157}
158
159/// Truncate text for display
160fn truncate_text(text: &str, max_len: usize) -> String {
161    if text.len() <= max_len {
162        text.to_string()
163    } else {
164        format!("{}...", &text[..max_len - 3])
165    }
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    #[test]
173    fn test_truncate_path() {
174        assert_eq!(truncate_path("short.txt"), "short.txt");
175        let long_path = "/very/long/path/that/exceeds/forty/characters/file.rs";
176        assert!(truncate_path(long_path).len() <= 40);
177        assert!(truncate_path(long_path).starts_with("..."));
178    }
179}