steer_workspace/
local.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tracing::info;
8
9use crate::error::{Result, WorkspaceError};
10use crate::{CachedEnvironment, EnvironmentInfo, Workspace, WorkspaceMetadata, WorkspaceType};
11use steer_tools::{
12    ExecutionContext as SteerExecutionContext, ToolCall, ToolSchema, result::ToolResult,
13    traits::ExecutableTool,
14};
15
16/// Local filesystem workspace
17pub struct LocalWorkspace {
18    path: PathBuf,
19    environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
20    metadata: WorkspaceMetadata,
21    tool_registry: HashMap<String, Box<dyn ExecutableTool>>,
22}
23
24impl LocalWorkspace {
25    pub async fn with_path(path: PathBuf) -> Result<Self> {
26        let metadata = WorkspaceMetadata {
27            id: format!("local:{}", path.display()),
28            workspace_type: WorkspaceType::Local,
29            location: path.display().to_string(),
30        };
31
32        // Create tool registry from workspace tools
33        let mut tool_registry = HashMap::new();
34        for tool in steer_tools::tools::workspace_tools() {
35            tool_registry.insert(tool.name().to_string(), tool);
36        }
37
38        Ok(Self {
39            path,
40            environment_cache: Arc::new(RwLock::new(None)),
41            metadata,
42            tool_registry,
43        })
44    }
45
46    /// Collect environment information for the local workspace
47    async fn collect_environment(&self) -> Result<EnvironmentInfo> {
48        EnvironmentInfo::collect_for_path(&self.path)
49    }
50}
51
52impl std::fmt::Debug for LocalWorkspace {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.debug_struct("LocalWorkspace")
55            .field("path", &self.path)
56            .field("metadata", &self.metadata)
57            .field("tool_count", &self.tool_registry.len())
58            .finish_non_exhaustive()
59    }
60}
61
62#[async_trait]
63impl Workspace for LocalWorkspace {
64    async fn environment(&self) -> Result<EnvironmentInfo> {
65        let mut cache = self.environment_cache.write().await;
66
67        // Check if we have valid cached data
68        if let Some(cached) = cache.as_ref() {
69            if !cached.is_expired() {
70                return Ok(cached.info.clone());
71            }
72        }
73
74        // Collect fresh environment info
75        let env_info = self.collect_environment().await?;
76
77        // Cache it with 5 minute TTL
78        *cache = Some(CachedEnvironment::new(
79            env_info.clone(),
80            Duration::from_secs(300), // 5 minutes
81        ));
82
83        Ok(env_info)
84    }
85
86    fn metadata(&self) -> WorkspaceMetadata {
87        self.metadata.clone()
88    }
89
90    async fn invalidate_environment_cache(&self) {
91        let mut cache = self.environment_cache.write().await;
92        *cache = None;
93    }
94
95    async fn list_files(
96        &self,
97        query: Option<&str>,
98        max_results: Option<usize>,
99    ) -> Result<Vec<String>> {
100        use crate::utils::FileListingUtils;
101
102        info!(target: "workspace.list_files", "Listing files in workspace: {:?}", self.path);
103
104        FileListingUtils::list_files(&self.path, query, max_results).map_err(WorkspaceError::from)
105    }
106
107    fn working_directory(&self) -> &std::path::Path {
108        &self.path
109    }
110
111    async fn execute_tool(
112        &self,
113        tool_call: &ToolCall,
114        context: steer_tools::ExecutionContext,
115    ) -> Result<ToolResult> {
116        // Get the tool from registry
117        let tool = self.tool_registry.get(&tool_call.name).ok_or_else(|| {
118            WorkspaceError::ToolExecution(format!("Unknown tool: {}", tool_call.name))
119        })?;
120
121        // Set working directory
122        let steer_context = SteerExecutionContext::new(tool_call.id.clone())
123            .with_cancellation_token(context.cancellation_token.clone())
124            .with_working_directory(self.path.clone());
125
126        // Execute the tool
127        match tool.run(tool_call.parameters.clone(), &steer_context).await {
128            Ok(result) => Ok(result),
129            Err(e) => Ok(ToolResult::Error(e)),
130        }
131    }
132
133    async fn available_tools(&self) -> Vec<ToolSchema> {
134        self.tool_registry
135            .iter()
136            .map(|(name, tool)| ToolSchema {
137                name: name.clone(),
138                description: tool.description().to_string(),
139                input_schema: tool.input_schema().clone(),
140            })
141            .collect()
142    }
143
144    async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
145        self.tool_registry
146            .get(tool_name)
147            .map(|tool| tool.requires_approval())
148            .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use tempfile::tempdir;
156
157    #[tokio::test]
158    async fn test_local_workspace_creation() {
159        let temp_dir = tempdir().unwrap();
160        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
161            .await
162            .unwrap();
163        assert!(matches!(
164            workspace.metadata().workspace_type,
165            WorkspaceType::Local
166        ));
167    }
168
169    #[tokio::test]
170    async fn test_local_workspace_with_path() {
171        let temp_dir = tempdir().unwrap();
172        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
173            .await
174            .unwrap();
175
176        assert!(matches!(
177            workspace.metadata().workspace_type,
178            WorkspaceType::Local
179        ));
180        assert_eq!(
181            workspace.metadata().location,
182            temp_dir.path().display().to_string()
183        );
184    }
185
186    #[tokio::test]
187    async fn test_environment_caching() {
188        let temp_dir = tempdir().unwrap();
189        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
190            .await
191            .unwrap();
192
193        // First call should collect fresh data
194        let env1 = workspace.environment().await.unwrap();
195
196        // Second call should return cached data
197        let env2 = workspace.environment().await.unwrap();
198
199        // Should be identical
200        assert_eq!(env1.working_directory, env2.working_directory);
201        assert_eq!(env1.is_git_repo, env2.is_git_repo);
202        assert_eq!(env1.platform, env2.platform);
203    }
204
205    #[tokio::test]
206    async fn test_cache_invalidation() {
207        let temp_dir = tempdir().unwrap();
208        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
209            .await
210            .unwrap();
211
212        // Get initial environment
213        let _ = workspace.environment().await.unwrap();
214
215        // Invalidate cache
216        workspace.invalidate_environment_cache().await;
217
218        // Should work fine and fetch fresh data
219        let env = workspace.environment().await.unwrap();
220        assert!(!env.working_directory.as_os_str().is_empty());
221    }
222
223    #[tokio::test]
224    async fn test_environment_collection() {
225        let temp_dir = tempdir().unwrap();
226        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
227            .await
228            .unwrap();
229
230        let env = workspace.environment().await.unwrap();
231
232        // Verify basic environment info
233        let expected_path = temp_dir
234            .path()
235            .canonicalize()
236            .unwrap_or_else(|_| temp_dir.path().to_path_buf());
237
238        // Canonicalize both paths for comparison on macOS
239        let actual_canonical = env
240            .working_directory
241            .canonicalize()
242            .unwrap_or(env.working_directory.clone());
243        let expected_canonical = expected_path
244            .canonicalize()
245            .unwrap_or(expected_path.clone());
246
247        assert_eq!(actual_canonical, expected_canonical);
248    }
249
250    #[tokio::test]
251    async fn test_list_files() {
252        let temp_dir = tempdir().unwrap();
253        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
254            .await
255            .unwrap();
256
257        // Create some test files
258        std::fs::write(temp_dir.path().join("test.rs"), "test").unwrap();
259        std::fs::write(temp_dir.path().join("main.rs"), "main").unwrap();
260        std::fs::create_dir(temp_dir.path().join("src")).unwrap();
261        std::fs::write(temp_dir.path().join("src/lib.rs"), "lib").unwrap();
262
263        // List all files
264        let files = workspace.list_files(None, None).await.unwrap();
265        assert_eq!(files.len(), 4); // 3 files + 1 directory
266        assert!(files.contains(&"test.rs".to_string()));
267        assert!(files.contains(&"main.rs".to_string()));
268        assert!(files.contains(&"src/".to_string())); // Directory with trailing slash
269        assert!(files.contains(&"src/lib.rs".to_string()));
270
271        // Test with query
272        let files = workspace.list_files(Some("test"), None).await.unwrap();
273        assert_eq!(files.len(), 1);
274        assert_eq!(files[0], "test.rs");
275
276        // Test with max_results
277        let files = workspace.list_files(None, Some(2)).await.unwrap();
278        assert_eq!(files.len(), 2);
279    }
280
281    #[tokio::test]
282    async fn test_list_files_includes_dotfiles() {
283        let temp_dir = tempdir().unwrap();
284        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
285            .await
286            .unwrap();
287
288        // Create a dotfile
289        std::fs::write(temp_dir.path().join(".gitignore"), "target/").unwrap();
290
291        let files = workspace.list_files(None, None).await.unwrap();
292        assert!(files.contains(&".gitignore".to_string()));
293    }
294
295    #[tokio::test]
296    async fn test_working_directory() {
297        let temp_dir = tempdir().unwrap();
298        let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
299            .await
300            .unwrap();
301
302        assert_eq!(workspace.working_directory(), temp_dir.path());
303    }
304}