1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tracing::info;
8
9use crate::error::{Result, WorkspaceError};
10use crate::{CachedEnvironment, EnvironmentInfo, Workspace, WorkspaceMetadata, WorkspaceType};
11use steer_tools::{
12 ExecutionContext as SteerExecutionContext, ToolCall, ToolSchema, result::ToolResult,
13 traits::ExecutableTool,
14};
15
16pub struct LocalWorkspace {
18 path: PathBuf,
19 environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
20 metadata: WorkspaceMetadata,
21 tool_registry: HashMap<String, Box<dyn ExecutableTool>>,
22}
23
24impl LocalWorkspace {
25 pub async fn with_path(path: PathBuf) -> Result<Self> {
26 let metadata = WorkspaceMetadata {
27 id: format!("local:{}", path.display()),
28 workspace_type: WorkspaceType::Local,
29 location: path.display().to_string(),
30 };
31
32 let mut tool_registry = HashMap::new();
34 for tool in steer_tools::tools::workspace_tools() {
35 tool_registry.insert(tool.name().to_string(), tool);
36 }
37
38 Ok(Self {
39 path,
40 environment_cache: Arc::new(RwLock::new(None)),
41 metadata,
42 tool_registry,
43 })
44 }
45
46 async fn collect_environment(&self) -> Result<EnvironmentInfo> {
48 EnvironmentInfo::collect_for_path(&self.path)
49 }
50}
51
52impl std::fmt::Debug for LocalWorkspace {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("LocalWorkspace")
55 .field("path", &self.path)
56 .field("metadata", &self.metadata)
57 .field("tool_count", &self.tool_registry.len())
58 .finish_non_exhaustive()
59 }
60}
61
62#[async_trait]
63impl Workspace for LocalWorkspace {
64 async fn environment(&self) -> Result<EnvironmentInfo> {
65 let mut cache = self.environment_cache.write().await;
66
67 if let Some(cached) = cache.as_ref() {
69 if !cached.is_expired() {
70 return Ok(cached.info.clone());
71 }
72 }
73
74 let env_info = self.collect_environment().await?;
76
77 *cache = Some(CachedEnvironment::new(
79 env_info.clone(),
80 Duration::from_secs(300), ));
82
83 Ok(env_info)
84 }
85
86 fn metadata(&self) -> WorkspaceMetadata {
87 self.metadata.clone()
88 }
89
90 async fn invalidate_environment_cache(&self) {
91 let mut cache = self.environment_cache.write().await;
92 *cache = None;
93 }
94
95 async fn list_files(
96 &self,
97 query: Option<&str>,
98 max_results: Option<usize>,
99 ) -> Result<Vec<String>> {
100 use crate::utils::FileListingUtils;
101
102 info!(target: "workspace.list_files", "Listing files in workspace: {:?}", self.path);
103
104 FileListingUtils::list_files(&self.path, query, max_results).map_err(WorkspaceError::from)
105 }
106
107 fn working_directory(&self) -> &std::path::Path {
108 &self.path
109 }
110
111 async fn execute_tool(
112 &self,
113 tool_call: &ToolCall,
114 context: steer_tools::ExecutionContext,
115 ) -> Result<ToolResult> {
116 let tool = self.tool_registry.get(&tool_call.name).ok_or_else(|| {
118 WorkspaceError::ToolExecution(format!("Unknown tool: {}", tool_call.name))
119 })?;
120
121 let steer_context = SteerExecutionContext::new(tool_call.id.clone())
123 .with_cancellation_token(context.cancellation_token.clone())
124 .with_working_directory(self.path.clone());
125
126 tool.run(tool_call.parameters.clone(), &steer_context)
128 .await
129 .map_err(|e| WorkspaceError::ToolExecution(e.to_string()))
130 }
131
132 async fn available_tools(&self) -> Vec<ToolSchema> {
133 self.tool_registry
134 .iter()
135 .map(|(name, tool)| ToolSchema {
136 name: name.clone(),
137 description: tool.description().to_string(),
138 input_schema: tool.input_schema().clone(),
139 })
140 .collect()
141 }
142
143 async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
144 self.tool_registry
145 .get(tool_name)
146 .map(|tool| tool.requires_approval())
147 .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use tempfile::tempdir;
155
156 #[tokio::test]
157 async fn test_local_workspace_creation() {
158 let temp_dir = tempdir().unwrap();
159 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
160 .await
161 .unwrap();
162 assert!(matches!(
163 workspace.metadata().workspace_type,
164 WorkspaceType::Local
165 ));
166 }
167
168 #[tokio::test]
169 async fn test_local_workspace_with_path() {
170 let temp_dir = tempdir().unwrap();
171 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
172 .await
173 .unwrap();
174
175 assert!(matches!(
176 workspace.metadata().workspace_type,
177 WorkspaceType::Local
178 ));
179 assert_eq!(
180 workspace.metadata().location,
181 temp_dir.path().display().to_string()
182 );
183 }
184
185 #[tokio::test]
186 async fn test_environment_caching() {
187 let temp_dir = tempdir().unwrap();
188 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
189 .await
190 .unwrap();
191
192 let env1 = workspace.environment().await.unwrap();
194
195 let env2 = workspace.environment().await.unwrap();
197
198 assert_eq!(env1.working_directory, env2.working_directory);
200 assert_eq!(env1.is_git_repo, env2.is_git_repo);
201 assert_eq!(env1.platform, env2.platform);
202 }
203
204 #[tokio::test]
205 async fn test_cache_invalidation() {
206 let temp_dir = tempdir().unwrap();
207 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
208 .await
209 .unwrap();
210
211 let _ = workspace.environment().await.unwrap();
213
214 workspace.invalidate_environment_cache().await;
216
217 let env = workspace.environment().await.unwrap();
219 assert!(!env.working_directory.as_os_str().is_empty());
220 }
221
222 #[tokio::test]
223 async fn test_environment_collection() {
224 let temp_dir = tempdir().unwrap();
225 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
226 .await
227 .unwrap();
228
229 let env = workspace.environment().await.unwrap();
230
231 let expected_path = temp_dir
233 .path()
234 .canonicalize()
235 .unwrap_or_else(|_| temp_dir.path().to_path_buf());
236
237 let actual_canonical = env
239 .working_directory
240 .canonicalize()
241 .unwrap_or(env.working_directory.clone());
242 let expected_canonical = expected_path
243 .canonicalize()
244 .unwrap_or(expected_path.clone());
245
246 assert_eq!(actual_canonical, expected_canonical);
247 }
248
249 #[tokio::test]
250 async fn test_list_files() {
251 let temp_dir = tempdir().unwrap();
252 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
253 .await
254 .unwrap();
255
256 std::fs::write(temp_dir.path().join("test.rs"), "test").unwrap();
258 std::fs::write(temp_dir.path().join("main.rs"), "main").unwrap();
259 std::fs::create_dir(temp_dir.path().join("src")).unwrap();
260 std::fs::write(temp_dir.path().join("src/lib.rs"), "lib").unwrap();
261
262 let files = workspace.list_files(None, None).await.unwrap();
264 assert_eq!(files.len(), 4); assert!(files.contains(&"test.rs".to_string()));
266 assert!(files.contains(&"main.rs".to_string()));
267 assert!(files.contains(&"src/".to_string())); assert!(files.contains(&"src/lib.rs".to_string()));
269
270 let files = workspace.list_files(Some("test"), None).await.unwrap();
272 assert_eq!(files.len(), 1);
273 assert_eq!(files[0], "test.rs");
274
275 let files = workspace.list_files(None, Some(2)).await.unwrap();
277 assert_eq!(files.len(), 2);
278 }
279
280 #[tokio::test]
281 async fn test_list_files_includes_dotfiles() {
282 let temp_dir = tempdir().unwrap();
283 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
284 .await
285 .unwrap();
286
287 std::fs::write(temp_dir.path().join(".gitignore"), "target/").unwrap();
289
290 let files = workspace.list_files(None, None).await.unwrap();
291 assert!(files.contains(&".gitignore".to_string()));
292 }
293
294 #[tokio::test]
295 async fn test_working_directory() {
296 let temp_dir = tempdir().unwrap();
297 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
298 .await
299 .unwrap();
300
301 assert_eq!(workspace.working_directory(), temp_dir.path());
302 }
303}