syncable_cli/agent/ui/
hooks.rs1use crate::agent::ui::Spinner;
6use rig::agent::CancelSignal;
7use rig::completion::CompletionModel;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10
11#[derive(Clone)]
13pub struct ToolDisplayHook {
14 sender: mpsc::Sender<ToolEvent>,
15}
16
17#[derive(Debug, Clone)]
19pub enum ToolEvent {
20 ToolStart { name: String, args: String },
21 ToolComplete { name: String, result: String },
22}
23
24impl ToolDisplayHook {
25 pub fn new() -> (Self, mpsc::Receiver<ToolEvent>) {
27 let (sender, receiver) = mpsc::channel(32);
28 (Self { sender }, receiver)
29 }
30
31 pub fn from_sender(sender: mpsc::Sender<ToolEvent>) -> Self {
33 Self { sender }
34 }
35}
36
37impl Default for ToolDisplayHook {
38 fn default() -> Self {
39 let (hook, _) = Self::new();
40 hook
41 }
42}
43
44impl<M> rig::agent::PromptHook<M> for ToolDisplayHook
45where
46 M: CompletionModel,
47{
48 fn on_tool_call(
49 &self,
50 tool_name: &str,
51 args: &str,
52 _cancel: CancelSignal,
53 ) -> impl std::future::Future<Output = ()> + Send {
54 let sender = self.sender.clone();
55 let name = tool_name.to_string();
56 let args_str = args.to_string();
57
58 async move {
59 let _ = sender
60 .send(ToolEvent::ToolStart {
61 name,
62 args: args_str,
63 })
64 .await;
65 }
66 }
67
68 fn on_tool_result(
69 &self,
70 tool_name: &str,
71 _args: &str,
72 result: &str,
73 _cancel: CancelSignal,
74 ) -> impl std::future::Future<Output = ()> + Send {
75 let sender = self.sender.clone();
76 let name = tool_name.to_string();
77 let result_str = result.to_string();
78
79 async move {
80 let _ = sender
81 .send(ToolEvent::ToolComplete {
82 name,
83 result: result_str,
84 })
85 .await;
86 }
87 }
88}
89
90pub fn spawn_tool_display_handler(
92 mut receiver: mpsc::Receiver<ToolEvent>,
93 spinner: Arc<Spinner>,
94) -> tokio::task::JoinHandle<()> {
95 tokio::spawn(async move {
96 while let Some(event) = receiver.recv().await {
97 match event {
98 ToolEvent::ToolStart { name, args } => {
99 let description = format_tool_description(&name, &args);
101 spinner.tool_executing(&name, &description).await;
102 }
103 ToolEvent::ToolComplete { name, .. } => {
104 spinner.tool_complete(&name).await;
105 }
106 }
107 }
108 })
109}
110
111fn format_tool_description(name: &str, args: &str) -> String {
113 match name {
114 "analyze_project" => "Analyzing project structure...".to_string(),
115 "security_scan" => "Running security scan...".to_string(),
116 "check_vulnerabilities" => "Checking for vulnerabilities...".to_string(),
117 "read_file" => {
118 if let Ok(args_value) = serde_json::from_str::<serde_json::Value>(args) {
120 if let Some(path) = args_value.get("path").and_then(|p| p.as_str()) {
121 return format!("Reading {}", truncate_path(path));
122 }
123 }
124 "Reading file...".to_string()
125 }
126 "list_directory" => {
127 if let Ok(args_value) = serde_json::from_str::<serde_json::Value>(args) {
128 if let Some(path) = args_value.get("path").and_then(|p| p.as_str()) {
129 return format!("Listing {}", truncate_path(path));
130 }
131 }
132 "Listing directory...".to_string()
133 }
134 "search_code" => {
135 if let Ok(args_value) = serde_json::from_str::<serde_json::Value>(args) {
136 if let Some(pattern) = args_value.get("pattern").and_then(|p| p.as_str()) {
137 return format!("Searching for '{}'...", truncate_text(pattern, 30));
138 }
139 }
140 "Searching code...".to_string()
141 }
142 "find_files" => "Finding files...".to_string(),
143 "generate_iac" => "Generating infrastructure config...".to_string(),
144 "discover_services" => "Discovering services...".to_string(),
145 _ => format!("Executing {}...", name),
146 }
147}
148
149fn truncate_path(path: &str) -> String {
151 if path.len() <= 40 {
152 path.to_string()
153 } else {
154 format!("...{}", &path[path.len() - 37..])
156 }
157}
158
159fn truncate_text(text: &str, max_len: usize) -> String {
161 if text.len() <= max_len {
162 text.to_string()
163 } else {
164 format!("{}...", &text[..max_len - 3])
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171
172 #[test]
173 fn test_truncate_path() {
174 assert_eq!(truncate_path("short.txt"), "short.txt");
175 let long_path = "/very/long/path/that/exceeds/forty/characters/file.rs";
176 assert!(truncate_path(long_path).len() <= 40);
177 assert!(truncate_path(long_path).starts_with("..."));
178 }
179}