Skip to main content

vtcode_core/tools/registry/
progress_facade.rs

1//! Progress callback accessors for ToolRegistry.
2
3use super::{ToolProgressCallback, ToolRegistry};
4
5impl ToolRegistry {
6    /// Replace the callback for streaming tool output and progress, returning the previous callback.
7    pub fn replace_progress_callback(
8        &self,
9        callback: Option<ToolProgressCallback>,
10    ) -> Option<ToolProgressCallback> {
11        let Ok(mut slot) = self.progress_callback.write() else {
12            return None;
13        };
14        std::mem::replace(&mut *slot, callback)
15    }
16
17    /// Set the callback for streaming tool output and progress
18    pub fn set_progress_callback(&self, callback: ToolProgressCallback) {
19        let _ = self.replace_progress_callback(Some(callback));
20    }
21
22    /// Clear the progress callback
23    pub fn clear_progress_callback(&self) {
24        let _ = self.replace_progress_callback(None);
25    }
26
27    /// Get the current progress callback if set
28    pub fn progress_callback(&self) -> Option<ToolProgressCallback> {
29        self.progress_callback.read().ok().and_then(|g| g.clone())
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36    use std::sync::Arc;
37    use std::sync::atomic::{AtomicUsize, Ordering};
38    use tempfile::TempDir;
39
40    #[tokio::test]
41    async fn replace_progress_callback_restores_previous() {
42        let temp_dir = TempDir::new().expect("create temp dir");
43        let registry = ToolRegistry::new(temp_dir.path().to_path_buf()).await;
44
45        let first_hits = Arc::new(AtomicUsize::new(0));
46        let first_hits_clone = Arc::clone(&first_hits);
47        registry.set_progress_callback(Arc::new(move |_, _| {
48            let _ = first_hits_clone.fetch_add(1, Ordering::SeqCst);
49        }));
50
51        let second_hits = Arc::new(AtomicUsize::new(0));
52        let second_hits_clone = Arc::clone(&second_hits);
53        let previous = registry.replace_progress_callback(Some(Arc::new(move |_, _| {
54            let _ = second_hits_clone.fetch_add(1, Ordering::SeqCst);
55        })));
56
57        if let Some(current) = registry.progress_callback() {
58            current("run_pty_cmd", "chunk");
59        }
60        assert_eq!(second_hits.load(Ordering::SeqCst), 1);
61
62        let _ = registry.replace_progress_callback(previous);
63        if let Some(current) = registry.progress_callback() {
64            current("run_pty_cmd", "chunk");
65        }
66        assert_eq!(first_hits.load(Ordering::SeqCst), 1);
67    }
68
69    #[tokio::test]
70    async fn clear_progress_callback_removes_registered_callback() {
71        let temp_dir = TempDir::new().expect("create temp dir");
72        let registry = ToolRegistry::new(temp_dir.path().to_path_buf()).await;
73
74        registry.set_progress_callback(Arc::new(|_, _| {}));
75        assert!(registry.progress_callback().is_some());
76
77        registry.clear_progress_callback();
78        assert!(registry.progress_callback().is_none());
79    }
80}