1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::sync::RwLock;
7use tracing::info;
8
9use crate::error::{Result, WorkspaceError};
10use crate::{CachedEnvironment, EnvironmentInfo, Workspace, WorkspaceMetadata, WorkspaceType};
11use steer_tools::{
12 ExecutionContext as SteerExecutionContext, ToolCall, ToolSchema, result::ToolResult,
13 traits::ExecutableTool,
14};
15
16pub struct LocalWorkspace {
18 path: PathBuf,
19 environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
20 metadata: WorkspaceMetadata,
21 tool_registry: HashMap<String, Box<dyn ExecutableTool>>,
22}
23
24impl LocalWorkspace {
25 pub async fn with_path(path: PathBuf) -> Result<Self> {
26 let metadata = WorkspaceMetadata {
27 id: format!("local:{}", path.display()),
28 workspace_type: WorkspaceType::Local,
29 location: path.display().to_string(),
30 };
31
32 let mut tool_registry = HashMap::new();
34 for tool in steer_tools::tools::workspace_tools() {
35 tool_registry.insert(tool.name().to_string(), tool);
36 }
37
38 Ok(Self {
39 path,
40 environment_cache: Arc::new(RwLock::new(None)),
41 metadata,
42 tool_registry,
43 })
44 }
45
46 async fn collect_environment(&self) -> Result<EnvironmentInfo> {
48 EnvironmentInfo::collect_for_path(&self.path)
49 }
50}
51
52impl std::fmt::Debug for LocalWorkspace {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("LocalWorkspace")
55 .field("path", &self.path)
56 .field("metadata", &self.metadata)
57 .field("tool_count", &self.tool_registry.len())
58 .finish_non_exhaustive()
59 }
60}
61
62#[async_trait]
63impl Workspace for LocalWorkspace {
64 async fn environment(&self) -> Result<EnvironmentInfo> {
65 let mut cache = self.environment_cache.write().await;
66
67 if let Some(cached) = cache.as_ref() {
69 if !cached.is_expired() {
70 return Ok(cached.info.clone());
71 }
72 }
73
74 let env_info = self.collect_environment().await?;
76
77 *cache = Some(CachedEnvironment::new(
79 env_info.clone(),
80 Duration::from_secs(300), ));
82
83 Ok(env_info)
84 }
85
86 fn metadata(&self) -> WorkspaceMetadata {
87 self.metadata.clone()
88 }
89
90 async fn invalidate_environment_cache(&self) {
91 let mut cache = self.environment_cache.write().await;
92 *cache = None;
93 }
94
95 async fn list_files(
96 &self,
97 query: Option<&str>,
98 max_results: Option<usize>,
99 ) -> Result<Vec<String>> {
100 use crate::utils::FileListingUtils;
101
102 info!(target: "workspace.list_files", "Listing files in workspace: {:?}", self.path);
103
104 FileListingUtils::list_files(&self.path, query, max_results).map_err(WorkspaceError::from)
105 }
106
107 fn working_directory(&self) -> &std::path::Path {
108 &self.path
109 }
110
111 async fn execute_tool(
112 &self,
113 tool_call: &ToolCall,
114 context: steer_tools::ExecutionContext,
115 ) -> Result<ToolResult> {
116 let tool = self.tool_registry.get(&tool_call.name).ok_or_else(|| {
118 WorkspaceError::ToolExecution(format!("Unknown tool: {}", tool_call.name))
119 })?;
120
121 let steer_context = SteerExecutionContext::new(tool_call.id.clone())
123 .with_cancellation_token(context.cancellation_token.clone())
124 .with_working_directory(self.path.clone());
125
126 match tool.run(tool_call.parameters.clone(), &steer_context).await {
128 Ok(result) => Ok(result),
129 Err(e) => Ok(ToolResult::Error(e)),
130 }
131 }
132
133 async fn available_tools(&self) -> Vec<ToolSchema> {
134 self.tool_registry
135 .iter()
136 .map(|(name, tool)| ToolSchema {
137 name: name.clone(),
138 description: tool.description().to_string(),
139 input_schema: tool.input_schema().clone(),
140 })
141 .collect()
142 }
143
144 async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
145 self.tool_registry
146 .get(tool_name)
147 .map(|tool| tool.requires_approval())
148 .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use tempfile::tempdir;
156
157 #[tokio::test]
158 async fn test_local_workspace_creation() {
159 let temp_dir = tempdir().unwrap();
160 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
161 .await
162 .unwrap();
163 assert!(matches!(
164 workspace.metadata().workspace_type,
165 WorkspaceType::Local
166 ));
167 }
168
169 #[tokio::test]
170 async fn test_local_workspace_with_path() {
171 let temp_dir = tempdir().unwrap();
172 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
173 .await
174 .unwrap();
175
176 assert!(matches!(
177 workspace.metadata().workspace_type,
178 WorkspaceType::Local
179 ));
180 assert_eq!(
181 workspace.metadata().location,
182 temp_dir.path().display().to_string()
183 );
184 }
185
186 #[tokio::test]
187 async fn test_environment_caching() {
188 let temp_dir = tempdir().unwrap();
189 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
190 .await
191 .unwrap();
192
193 let env1 = workspace.environment().await.unwrap();
195
196 let env2 = workspace.environment().await.unwrap();
198
199 assert_eq!(env1.working_directory, env2.working_directory);
201 assert_eq!(env1.is_git_repo, env2.is_git_repo);
202 assert_eq!(env1.platform, env2.platform);
203 }
204
205 #[tokio::test]
206 async fn test_cache_invalidation() {
207 let temp_dir = tempdir().unwrap();
208 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
209 .await
210 .unwrap();
211
212 let _ = workspace.environment().await.unwrap();
214
215 workspace.invalidate_environment_cache().await;
217
218 let env = workspace.environment().await.unwrap();
220 assert!(!env.working_directory.as_os_str().is_empty());
221 }
222
223 #[tokio::test]
224 async fn test_environment_collection() {
225 let temp_dir = tempdir().unwrap();
226 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
227 .await
228 .unwrap();
229
230 let env = workspace.environment().await.unwrap();
231
232 let expected_path = temp_dir
234 .path()
235 .canonicalize()
236 .unwrap_or_else(|_| temp_dir.path().to_path_buf());
237
238 let actual_canonical = env
240 .working_directory
241 .canonicalize()
242 .unwrap_or(env.working_directory.clone());
243 let expected_canonical = expected_path
244 .canonicalize()
245 .unwrap_or(expected_path.clone());
246
247 assert_eq!(actual_canonical, expected_canonical);
248 }
249
250 #[tokio::test]
251 async fn test_list_files() {
252 let temp_dir = tempdir().unwrap();
253 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
254 .await
255 .unwrap();
256
257 std::fs::write(temp_dir.path().join("test.rs"), "test").unwrap();
259 std::fs::write(temp_dir.path().join("main.rs"), "main").unwrap();
260 std::fs::create_dir(temp_dir.path().join("src")).unwrap();
261 std::fs::write(temp_dir.path().join("src/lib.rs"), "lib").unwrap();
262
263 let files = workspace.list_files(None, None).await.unwrap();
265 assert_eq!(files.len(), 4); assert!(files.contains(&"test.rs".to_string()));
267 assert!(files.contains(&"main.rs".to_string()));
268 assert!(files.contains(&"src/".to_string())); assert!(files.contains(&"src/lib.rs".to_string()));
270
271 let files = workspace.list_files(Some("test"), None).await.unwrap();
273 assert_eq!(files.len(), 1);
274 assert_eq!(files[0], "test.rs");
275
276 let files = workspace.list_files(None, Some(2)).await.unwrap();
278 assert_eq!(files.len(), 2);
279 }
280
281 #[tokio::test]
282 async fn test_list_files_includes_dotfiles() {
283 let temp_dir = tempdir().unwrap();
284 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
285 .await
286 .unwrap();
287
288 std::fs::write(temp_dir.path().join(".gitignore"), "target/").unwrap();
290
291 let files = workspace.list_files(None, None).await.unwrap();
292 assert!(files.contains(&".gitignore".to_string()));
293 }
294
295 #[tokio::test]
296 async fn test_working_directory() {
297 let temp_dir = tempdir().unwrap();
298 let workspace = LocalWorkspace::with_path(temp_dir.path().to_path_buf())
299 .await
300 .unwrap();
301
302 assert_eq!(workspace.working_directory(), temp_dir.path());
303 }
304}