steer_workspace/
local.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tracing::info;
8
9use crate::error::{Result, WorkspaceError};
10use crate::{CachedEnvironment, EnvironmentInfo, Workspace, WorkspaceMetadata, WorkspaceType};
11use steer_tools::{
12    ExecutionContext as SteerExecutionContext, ToolCall, ToolSchema, result::ToolResult,
13    traits::ExecutableTool,
14};
15
16/// Local filesystem workspace
17pub struct LocalWorkspace {
18    path: PathBuf,
19    environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
20    metadata: WorkspaceMetadata,
21    tool_registry: HashMap<String, Box<dyn ExecutableTool>>,
22}
23
24impl LocalWorkspace {
25    pub async fn with_path(path: PathBuf) -> Result<Self> {
26        let metadata = WorkspaceMetadata {
27            id: format!("local:{}", path.display()),
28            workspace_type: WorkspaceType::Local,
29            location: path.display().to_string(),
30        };
31
32        // Create tool registry from workspace tools
33        let mut tool_registry = HashMap::new();
34        for tool in steer_tools::tools::workspace_tools() {
35            tool_registry.insert(tool.name().to_string(), tool);
36        }
37
38        Ok(Self {
39            path,
40            environment_cache: Arc::new(RwLock::new(None)),
41            metadata,
42            tool_registry,
43        })
44    }
45
46    /// Collect environment information for the local workspace
47    async fn collect_environment(&self) -> Result<EnvironmentInfo> {
48        EnvironmentInfo::collect_for_path(&self.path)
49    }
50}
51
52impl std::fmt::Debug for LocalWorkspace {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("LocalWorkspace")
55            .field("path", &self.path)
56            .field("metadata", &self.metadata)
57            .field("tool_count", &self.tool_registry.len())
58            .finish_non_exhaustive()
59    }
60}
61
62#[async_trait]
63impl Workspace for LocalWorkspace {
64    async fn environment(&self) -> Result<EnvironmentInfo> {
65        let mut cache = self.environment_cache.write().await;
66
67        // Check if we have valid cached data
68        if let Some(cached) = cache.as_ref() {
69            if !cached.is_expired() {
70                return Ok(cached.info.clone());
71            }
72        }
73
74        // Collect fresh environment info
75        let env_info = self.collect_environment().await?;
76
77        // Cache it with 5 minute TTL
78        *cache = Some(CachedEnvironment::new(
79            env_info.clone(),
80            Duration::from_secs(300), // 5 minutes
81        ));
82
83        Ok(env_info)
84    }
85
86    fn metadata(&self) -> WorkspaceMetadata {
87        self.metadata.clone()
88    }
89
90    async fn invalidate_environment_cache(&self) {
91        let mut cache = self.environment_cache.write().await;
92        *cache = None;
93    }
94
95    async fn list_files(
96        &self,
97        query: Option<&str>,
98        max_results: Option<usize>,
99    ) -> Result<Vec<String>> {
100        use crate::utils::FileListingUtils;
101
102        info!(target: "workspace.list_files", "Listing files in workspace: {:?}", self.path);
103
104        FileListingUtils::list_files(&self.path, query, max_results).map_err(WorkspaceError::from)
105    }
106
107    fn working_directory(&self) -> &std::path::Path {
108        &self.path
109    }
110
111    async fn execute_tool(
112        &self,
113        tool_call: &ToolCall,
114        context: steer_tools::ExecutionContext,
115    ) -> Result<ToolResult> {
116        // Get the tool from registry
117        let tool = self.tool_registry.get(&tool_call.name).ok_or_else(|| {
118            WorkspaceError::ToolExecution(format!("Unknown tool: {}", tool_call.name))
119        })?;
120
121        // Set working directory
122        let steer_context = SteerExecutionContext::new(tool_call.id.clone())
123            .with_cancellation_token(context.cancellation_token.clone())
124            .with_working_directory(self.path.clone());
125
126        // Execute the tool
127        tool.run(tool_call.parameters.clone(), &steer_context)
128            .await
129            .map_err(|e| WorkspaceError::ToolExecution(e.to_string()))
130    }
131
132    async fn available_tools(&self) -> Vec<ToolSchema> {
133        self.tool_registry
134            .iter()
135            .map(|(name, tool)| ToolSchema {
136                name: name.clone(),
137                description: tool.description().to_string(),
138                input_schema: tool.input_schema().clone(),
139            })
140            .collect()
141    }
142
143    async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
144        self.tool_registry
145            .get(tool_name)
146            .map(|tool| tool.requires_approval())
147            .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use tempfile::tempdir;
155
156    #[tokio::test]
157    async fn test_local_workspace_creation() {
158        let temp_dir = tempdir().unwrap();
159        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
160            .await
161            .unwrap();
162        assert!(matches!(
163            workspace.metadata().workspace_type,
164            WorkspaceType::Local
165        ));
166    }
167
168    #[tokio::test]
169    async fn test_local_workspace_with_path() {
170        let temp_dir = tempdir().unwrap();
171        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
172            .await
173            .unwrap();
174
175        assert!(matches!(
176            workspace.metadata().workspace_type,
177            WorkspaceType::Local
178        ));
179        assert_eq!(
180            workspace.metadata().location,
181            temp_dir.path().display().to_string()
182        );
183    }
184
185    #[tokio::test]
186    async fn test_environment_caching() {
187        let temp_dir = tempdir().unwrap();
188        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
189            .await
190            .unwrap();
191
192        // First call should collect fresh data
193        let env1 = workspace.environment().await.unwrap();
194
195        // Second call should return cached data
196        let env2 = workspace.environment().await.unwrap();
197
198        // Should be identical
199        assert_eq!(env1.working_directory, env2.working_directory);
200        assert_eq!(env1.is_git_repo, env2.is_git_repo);
201        assert_eq!(env1.platform, env2.platform);
202    }
203
204    #[tokio::test]
205    async fn test_cache_invalidation() {
206        let temp_dir = tempdir().unwrap();
207        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
208            .await
209            .unwrap();
210
211        // Get initial environment
212        let _ = workspace.environment().await.unwrap();
213
214        // Invalidate cache
215        workspace.invalidate_environment_cache().await;
216
217        // Should work fine and fetch fresh data
218        let env = workspace.environment().await.unwrap();
219        assert!(!env.working_directory.as_os_str().is_empty());
220    }
221
222    #[tokio::test]
223    async fn test_environment_collection() {
224        let temp_dir = tempdir().unwrap();
225        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
226            .await
227            .unwrap();
228
229        let env = workspace.environment().await.unwrap();
230
231        // Verify basic environment info
232        let expected_path = temp_dir
233            .path()
234            .canonicalize()
235            .unwrap_or_else(|_| temp_dir.path().to_path_buf());
236
237        // Canonicalize both paths for comparison on macOS
238        let actual_canonical = env
239            .working_directory
240            .canonicalize()
241            .unwrap_or(env.working_directory.clone());
242        let expected_canonical = expected_path
243            .canonicalize()
244            .unwrap_or(expected_path.clone());
245
246        assert_eq!(actual_canonical, expected_canonical);
247    }
248
249    #[tokio::test]
250    async fn test_list_files() {
251        let temp_dir = tempdir().unwrap();
252        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
253            .await
254            .unwrap();
255
256        // Create some test files
257        std::fs::write(temp_dir.path().join("test.rs"), "test").unwrap();
258        std::fs::write(temp_dir.path().join("main.rs"), "main").unwrap();
259        std::fs::create_dir(temp_dir.path().join("src")).unwrap();
260        std::fs::write(temp_dir.path().join("src/lib.rs"), "lib").unwrap();
261
262        // List all files
263        let files = workspace.list_files(None, None).await.unwrap();
264        assert_eq!(files.len(), 4); // 3 files + 1 directory
265        assert!(files.contains(&"test.rs".to_string()));
266        assert!(files.contains(&"main.rs".to_string()));
267        assert!(files.contains(&"src/".to_string())); // Directory with trailing slash
268        assert!(files.contains(&"src/lib.rs".to_string()));
269
270        // Test with query
271        let files = workspace.list_files(Some("test"), None).await.unwrap();
272        assert_eq!(files.len(), 1);
273        assert_eq!(files[0], "test.rs");
274
275        // Test with max_results
276        let files = workspace.list_files(None, Some(2)).await.unwrap();
277        assert_eq!(files.len(), 2);
278    }
279
280    #[tokio::test]
281    async fn test_list_files_includes_dotfiles() {
282        let temp_dir = tempdir().unwrap();
283        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
284            .await
285            .unwrap();
286
287        // Create a dotfile
288        std::fs::write(temp_dir.path().join(".gitignore"), "target/").unwrap();
289
290        let files = workspace.list_files(None, None).await.unwrap();
291        assert!(files.contains(&".gitignore".to_string()));
292    }
293
294    #[tokio::test]
295    async fn test_working_directory() {
296        let temp_dir = tempdir().unwrap();
297        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
298            .await
299            .unwrap();
300
301        assert_eq!(workspace.working_directory(), temp_dir.path());
302    }
303}