1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::RwLock;
6use tonic::transport::Channel;
7
8use steer_proto::remote_workspace::v1::{
9 ExecuteToolRequest, ExecuteToolResponse, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10 GetToolApprovalRequirementsRequest, GetToolSchemasRequest, ListFilesRequest,
11 remote_workspace_service_client::RemoteWorkspaceServiceClient,
12};
13use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
14use steer_workspace::{
15 EnvironmentInfo, RemoteAuth, Result, Workspace, WorkspaceError, WorkspaceMetadata,
16 WorkspaceType,
17};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SerializableExecutionContext {
22 pub tool_call_id: String,
23 pub working_directory: std::path::PathBuf,
24}
25
26impl From<&steer_tools::ExecutionContext> for SerializableExecutionContext {
27 fn from(context: &steer_tools::ExecutionContext) -> Self {
28 Self {
29 tool_call_id: context.tool_call_id.clone(),
30 working_directory: context.working_directory.clone(),
31 }
32 }
33}
34
35fn convert_tool_response(response: ExecuteToolResponse) -> Result<ToolResult> {
37 use steer_proto::remote_workspace::v1::execute_tool_response::Result as ProtoResult;
38 use steer_tools::result::{
39 BashResult, EditResult, ExternalResult, FileContentResult, FileEntry, FileListResult,
40 GlobResult, SearchMatch, SearchResult, TodoItem, TodoListResult, TodoWriteResult,
41 };
42
43 match response.result {
44 Some(ProtoResult::StringResult(s)) => {
45 Ok(ToolResult::External(ExternalResult {
47 tool_name: "remote".to_string(),
48 payload: s,
49 }))
50 }
51 Some(ProtoResult::SearchResult(proto_result)) => {
52 let matches = proto_result
53 .matches
54 .into_iter()
55 .map(|m| SearchMatch {
56 file_path: m.file_path,
57 line_number: m.line_number as usize,
58 line_content: m.line_content,
59 column_range: m
60 .column_range
61 .map(|cr| (cr.start as usize, cr.end as usize)),
62 })
63 .collect();
64
65 Ok(ToolResult::Search(SearchResult {
66 matches,
67 total_files_searched: proto_result.total_files_searched as usize,
68 search_completed: proto_result.search_completed,
69 }))
70 }
71 Some(ProtoResult::FileListResult(proto_result)) => {
72 let entries = proto_result
73 .entries
74 .into_iter()
75 .map(|e| FileEntry {
76 path: e.path,
77 is_directory: e.is_directory,
78 size: e.size,
79 permissions: e.permissions,
80 })
81 .collect();
82
83 Ok(ToolResult::FileList(FileListResult {
84 entries,
85 base_path: proto_result.base_path,
86 }))
87 }
88 Some(ProtoResult::FileContentResult(proto_result)) => {
89 Ok(ToolResult::FileContent(FileContentResult {
90 content: proto_result.content,
91 file_path: proto_result.file_path,
92 line_count: proto_result.line_count as usize,
93 truncated: proto_result.truncated,
94 }))
95 }
96 Some(ProtoResult::EditResult(proto_result)) => Ok(ToolResult::Edit(EditResult {
97 file_path: proto_result.file_path,
98 changes_made: proto_result.changes_made as usize,
99 file_created: proto_result.file_created,
100 old_content: proto_result.old_content,
101 new_content: proto_result.new_content,
102 })),
103 Some(ProtoResult::BashResult(proto_result)) => Ok(ToolResult::Bash(BashResult {
104 stdout: proto_result.stdout,
105 stderr: proto_result.stderr,
106 exit_code: proto_result.exit_code,
107 command: proto_result.command,
108 })),
109 Some(ProtoResult::GlobResult(proto_result)) => Ok(ToolResult::Glob(GlobResult {
110 matches: proto_result.matches,
111 pattern: proto_result.pattern,
112 })),
113 Some(ProtoResult::TodoListResult(proto_result)) => {
114 let todos = proto_result
115 .todos
116 .into_iter()
117 .map(|t| TodoItem {
118 id: t.id,
119 content: t.content,
120 status: t.status,
121 priority: t.priority,
122 })
123 .collect();
124
125 Ok(ToolResult::TodoRead(TodoListResult { todos }))
126 }
127 Some(ProtoResult::TodoWriteResult(proto_result)) => {
128 let todos = proto_result
129 .todos
130 .into_iter()
131 .map(|t| TodoItem {
132 id: t.id,
133 content: t.content,
134 status: t.status,
135 priority: t.priority,
136 })
137 .collect();
138
139 Ok(ToolResult::TodoWrite(TodoWriteResult {
140 todos,
141 operation: proto_result.operation,
142 }))
143 }
144 _ => Err(WorkspaceError::ToolExecution(
145 "No result returned from remote execution".to_string(),
146 )),
147 }
148}
149
150#[derive(Debug, Clone)]
152struct CachedEnvironment {
153 pub info: EnvironmentInfo,
154 pub cached_at: std::time::Instant,
155 pub ttl: Duration,
156}
157
158impl CachedEnvironment {
159 pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
160 Self {
161 info,
162 cached_at: std::time::Instant::now(),
163 ttl,
164 }
165 }
166
167 pub fn is_expired(&self) -> bool {
168 self.cached_at.elapsed() > self.ttl
169 }
170}
171
172pub struct RemoteWorkspace {
174 client: RemoteWorkspaceServiceClient<Channel>,
175 environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
176 metadata: WorkspaceMetadata,
177 #[allow(dead_code)]
178 auth: Option<RemoteAuth>,
179}
180
181impl RemoteWorkspace {
182 pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
183 let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
185 .await
186 .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?;
187
188 let metadata = WorkspaceMetadata {
189 id: format!("remote:{address}"),
190 workspace_type: WorkspaceType::Remote,
191 location: address.clone(),
192 };
193
194 Ok(Self {
195 client,
196 environment_cache: Arc::new(RwLock::new(None)),
197 metadata,
198 auth,
199 })
200 }
201
202 async fn collect_environment(&self) -> Result<EnvironmentInfo> {
204 let mut client = self.client.clone();
205
206 let request = tonic::Request::new(GetEnvironmentInfoRequest {
207 working_directory: None, });
209
210 let response = client
211 .get_environment_info(request)
212 .await
213 .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
214 let env_response = response.into_inner();
215
216 Self::convert_environment_response(env_response)
217 }
218
219 fn convert_environment_response(
221 response: GetEnvironmentInfoResponse,
222 ) -> Result<EnvironmentInfo> {
223 use std::path::PathBuf;
224
225 Ok(EnvironmentInfo {
226 working_directory: PathBuf::from(response.working_directory),
227 is_git_repo: response.is_git_repo,
228 platform: response.platform,
229 date: response.date,
230 directory_structure: response.directory_structure,
231 git_status: response.git_status,
232 readme_content: response.readme_content,
233 claude_md_content: response.claude_md_content,
234 })
235 }
236}
237
238impl std::fmt::Debug for RemoteWorkspace {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 f.debug_struct("RemoteWorkspace")
241 .field("metadata", &self.metadata)
242 .field("auth", &self.auth)
243 .finish_non_exhaustive()
244 }
245}
246
247#[async_trait]
248impl Workspace for RemoteWorkspace {
249 async fn environment(&self) -> Result<EnvironmentInfo> {
250 let mut cache = self.environment_cache.write().await;
251
252 if let Some(cached) = cache.as_ref() {
254 if !cached.is_expired() {
255 return Ok(cached.info.clone());
256 }
257 }
258
259 let env_info = self.collect_environment().await?;
261
262 *cache = Some(CachedEnvironment::new(
264 env_info.clone(),
265 Duration::from_secs(600), ));
267
268 Ok(env_info)
269 }
270
271 fn metadata(&self) -> WorkspaceMetadata {
272 self.metadata.clone()
273 }
274
275 async fn invalidate_environment_cache(&self) {
276 let mut cache = self.environment_cache.write().await;
277 *cache = None;
278 }
279
280 async fn list_files(
281 &self,
282 query: Option<&str>,
283 max_results: Option<usize>,
284 ) -> Result<Vec<String>> {
285 let mut client = self.client.clone();
286
287 let request = tonic::Request::new(ListFilesRequest {
288 query: query.unwrap_or("").to_string(),
289 max_results: max_results.unwrap_or(0) as u32,
290 });
291
292 let mut stream = client
293 .list_files(request)
294 .await
295 .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
296 .into_inner();
297 let mut all_files = Vec::new();
298
299 while let Some(response) = stream
301 .message()
302 .await
303 .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
304 {
305 all_files.extend(response.paths);
306 }
307
308 Ok(all_files)
309 }
310
311 fn working_directory(&self) -> &std::path::Path {
312 std::path::Path::new("/remote")
315 }
316
317 async fn execute_tool(
318 &self,
319 tool_call: &ToolCall,
320 context: steer_tools::ExecutionContext,
321 ) -> Result<ToolResult> {
322 let mut client = self.client.clone();
323
324 let context_json = serde_json::to_string(&SerializableExecutionContext::from(&context))
326 .map_err(|e| {
327 WorkspaceError::ToolExecution(format!("Failed to serialize context: {e}"))
328 })?;
329
330 let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
332 WorkspaceError::ToolExecution(format!("Failed to serialize parameters: {e}"))
333 })?;
334
335 let request = tonic::Request::new(ExecuteToolRequest {
336 tool_call_id: tool_call.id.clone(),
337 tool_name: tool_call.name.clone(),
338 parameters_json,
339 context_json,
340 timeout_ms: Some(30000), });
342
343 let response = client
344 .execute_tool(request)
345 .await
346 .map_err(|e| WorkspaceError::ToolExecution(format!("Failed to execute tool: {e}")))?
347 .into_inner();
348
349 if !response.success {
350 return Err(WorkspaceError::ToolExecution(format!(
351 "Tool execution failed: {}",
352 response.error
353 )));
354 }
355
356 convert_tool_response(response)
358 }
359
360 async fn available_tools(&self) -> Vec<ToolSchema> {
361 let mut client = self.client.clone();
362
363 let request = tonic::Request::new(GetToolSchemasRequest {});
364
365 match client.get_tool_schemas(request).await {
366 Ok(response) => {
367 response
368 .into_inner()
369 .tools
370 .into_iter()
371 .map(|schema| {
372 let input_schema = serde_json::from_str(&schema.input_schema_json)
374 .unwrap_or_else(|_| steer_tools::InputSchema {
375 properties: serde_json::Map::new(),
376 required: Vec::new(),
377 schema_type: "object".to_string(),
378 });
379
380 ToolSchema {
381 name: schema.name,
382 description: schema.description,
383 input_schema,
384 }
385 })
386 .collect()
387 }
388 Err(_) => Vec::new(),
389 }
390 }
391
392 async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
393 let mut client = self.client.clone();
394
395 let request = tonic::Request::new(GetToolApprovalRequirementsRequest {
396 tool_names: vec![tool_name.to_string()],
397 });
398
399 let response = client
400 .get_tool_approval_requirements(request)
401 .await
402 .map_err(|e| {
403 WorkspaceError::ToolExecution(format!("Failed to get approval requirements: {e}"))
404 })?
405 .into_inner();
406
407 response
408 .approval_requirements
409 .get(tool_name)
410 .copied()
411 .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[tokio::test]
420 async fn test_remote_workspace_metadata() {
421 let address = "localhost:50051".to_string();
422
423 let metadata = WorkspaceMetadata {
425 id: format!("remote:{address}"),
426 workspace_type: WorkspaceType::Remote,
427 location: address.clone(),
428 };
429
430 assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
431 assert_eq!(metadata.location, address);
432 }
433
434 #[test]
435 fn test_convert_environment_response() {
436 use std::path::PathBuf;
437
438 let response = GetEnvironmentInfoResponse {
439 working_directory: "/home/user/project".to_string(),
440 is_git_repo: true,
441 platform: "linux".to_string(),
442 date: "2025-06-17".to_string(),
443 directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
444 git_status: Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string()),
445 readme_content: Some("# My Project".to_string()),
446 claude_md_content: None,
447 };
448
449 let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
451
452 assert_eq!(
453 env_info.working_directory,
454 PathBuf::from("/home/user/project")
455 );
456 assert!(env_info.is_git_repo);
457 assert_eq!(env_info.platform, "linux");
458 assert_eq!(env_info.date, "2025-06-17");
459 assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
460 assert_eq!(
461 env_info.git_status,
462 Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string())
463 );
464 assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
465 assert_eq!(env_info.claude_md_content, None);
466 }
467}