1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::RwLock;
6use tonic::transport::Channel;
7
8use steer_proto::remote_workspace::v1::{
9 ExecuteToolRequest, ExecuteToolResponse, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10 GetToolApprovalRequirementsRequest, GetToolSchemasRequest, ListFilesRequest, ToolErrorDetail,
11 remote_workspace_service_client::RemoteWorkspaceServiceClient, tool_error_detail,
12};
13use steer_tools::{
14 ToolCall, ToolSchema,
15 result::ToolResult,
16 tools::todo::{TodoItem, TodoPriority, TodoStatus, TodoWriteFileOperation},
17};
18use steer_workspace::{
19 EnvironmentInfo, RemoteAuth, Result, Workspace, WorkspaceError, WorkspaceMetadata,
20 WorkspaceType,
21};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SerializableExecutionContext {
26 pub tool_call_id: String,
27 pub working_directory: std::path::PathBuf,
28}
29
30impl From<&steer_tools::ExecutionContext> for SerializableExecutionContext {
31 fn from(context: &steer_tools::ExecutionContext) -> Self {
32 Self {
33 tool_call_id: context.tool_call_id.clone(),
34 working_directory: context.working_directory.clone(),
35 }
36 }
37}
38
39fn convert_tool_response(response: ExecuteToolResponse) -> Result<ToolResult> {
41 use steer_proto::remote_workspace::v1::execute_tool_response::Result as ProtoResult;
42 use steer_tools::result::{
43 BashResult, EditResult, ExternalResult, FileContentResult, FileEntry, FileListResult,
44 GlobResult, SearchMatch, SearchResult, TodoListResult, TodoWriteResult,
45 };
46
47 match response.result {
48 Some(ProtoResult::StringResult(s)) => {
49 Ok(ToolResult::External(ExternalResult {
51 tool_name: "remote".to_string(),
52 payload: s,
53 }))
54 }
55 Some(ProtoResult::SearchResult(proto_result)) => {
56 let matches = proto_result
57 .matches
58 .into_iter()
59 .map(|m| SearchMatch {
60 file_path: m.file_path,
61 line_number: m.line_number as usize,
62 line_content: m.line_content,
63 column_range: m
64 .column_range
65 .map(|cr| (cr.start as usize, cr.end as usize)),
66 })
67 .collect();
68
69 Ok(ToolResult::Search(SearchResult {
70 matches,
71 total_files_searched: proto_result.total_files_searched as usize,
72 search_completed: proto_result.search_completed,
73 }))
74 }
75 Some(ProtoResult::FileListResult(proto_result)) => {
76 let entries = proto_result
77 .entries
78 .into_iter()
79 .map(|e| FileEntry {
80 path: e.path,
81 is_directory: e.is_directory,
82 size: e.size,
83 permissions: e.permissions,
84 })
85 .collect();
86
87 Ok(ToolResult::FileList(FileListResult {
88 entries,
89 base_path: proto_result.base_path,
90 }))
91 }
92 Some(ProtoResult::FileContentResult(proto_result)) => {
93 Ok(ToolResult::FileContent(FileContentResult {
94 content: proto_result.content,
95 file_path: proto_result.file_path,
96 line_count: proto_result.line_count as usize,
97 truncated: proto_result.truncated,
98 }))
99 }
100 Some(ProtoResult::EditResult(proto_result)) => Ok(ToolResult::Edit(EditResult {
101 file_path: proto_result.file_path,
102 changes_made: proto_result.changes_made as usize,
103 file_created: proto_result.file_created,
104 old_content: proto_result.old_content,
105 new_content: proto_result.new_content,
106 })),
107 Some(ProtoResult::BashResult(proto_result)) => Ok(ToolResult::Bash(BashResult {
108 stdout: proto_result.stdout,
109 stderr: proto_result.stderr,
110 exit_code: proto_result.exit_code,
111 command: proto_result.command,
112 })),
113 Some(ProtoResult::GlobResult(proto_result)) => Ok(ToolResult::Glob(GlobResult {
114 matches: proto_result.matches,
115 pattern: proto_result.pattern,
116 })),
117 Some(ProtoResult::TodoListResult(proto_result)) => {
118 let todos = proto_result
119 .todos
120 .into_iter()
121 .map(convert_proto_to_todo_item)
122 .collect();
123
124 Ok(ToolResult::TodoRead(TodoListResult { todos }))
125 }
126 Some(ProtoResult::TodoWriteResult(proto_result)) => {
127 let todos = proto_result
128 .todos
129 .into_iter()
130 .map(convert_proto_to_todo_item)
131 .collect();
132
133 Ok(ToolResult::TodoWrite(TodoWriteResult {
134 todos,
135 operation: convert_proto_to_todo_write_file_operation(
136 steer_proto::common::v1::TodoWriteFileOperation::try_from(
137 proto_result.operation,
138 )
139 .unwrap_or(steer_proto::common::v1::TodoWriteFileOperation::OperationUnset),
140 ),
141 }))
142 }
143 _ => Err(WorkspaceError::ToolExecution(
144 "No result returned from remote execution".to_string(),
145 )),
146 }
147}
148
149#[derive(Debug, Clone)]
151struct CachedEnvironment {
152 pub info: EnvironmentInfo,
153 pub cached_at: std::time::Instant,
154 pub ttl: Duration,
155}
156
157impl CachedEnvironment {
158 pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
159 Self {
160 info,
161 cached_at: std::time::Instant::now(),
162 ttl,
163 }
164 }
165
166 pub fn is_expired(&self) -> bool {
167 self.cached_at.elapsed() > self.ttl
168 }
169}
170
171pub struct RemoteWorkspace {
173 client: RemoteWorkspaceServiceClient<Channel>,
174 environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
175 metadata: WorkspaceMetadata,
176 #[allow(dead_code)]
177 auth: Option<RemoteAuth>,
178}
179
180impl RemoteWorkspace {
181 pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
182 let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
184 .await
185 .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?;
186
187 let metadata = WorkspaceMetadata {
188 id: format!("remote:{address}"),
189 workspace_type: WorkspaceType::Remote,
190 location: address.clone(),
191 };
192
193 Ok(Self {
194 client,
195 environment_cache: Arc::new(RwLock::new(None)),
196 metadata,
197 auth,
198 })
199 }
200
201 async fn collect_environment(&self) -> Result<EnvironmentInfo> {
203 let mut client = self.client.clone();
204
205 let request = tonic::Request::new(GetEnvironmentInfoRequest {
206 working_directory: None, });
208
209 let response = client
210 .get_environment_info(request)
211 .await
212 .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
213 let env_response = response.into_inner();
214
215 Self::convert_environment_response(env_response)
216 }
217
218 fn convert_environment_response(
220 response: GetEnvironmentInfoResponse,
221 ) -> Result<EnvironmentInfo> {
222 use std::path::PathBuf;
223
224 Ok(EnvironmentInfo {
225 working_directory: PathBuf::from(response.working_directory),
226 is_git_repo: response.is_git_repo,
227 platform: response.platform,
228 date: response.date,
229 directory_structure: response.directory_structure,
230 git_status: response.git_status,
231 readme_content: response.readme_content,
232 claude_md_content: response.claude_md_content,
233 })
234 }
235}
236
237impl std::fmt::Debug for RemoteWorkspace {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("RemoteWorkspace")
240 .field("metadata", &self.metadata)
241 .field("auth", &self.auth)
242 .finish_non_exhaustive()
243 }
244}
245
246#[async_trait]
247impl Workspace for RemoteWorkspace {
248 async fn environment(&self) -> Result<EnvironmentInfo> {
249 let mut cache = self.environment_cache.write().await;
250
251 if let Some(cached) = cache.as_ref() {
253 if !cached.is_expired() {
254 return Ok(cached.info.clone());
255 }
256 }
257
258 let env_info = self.collect_environment().await?;
260
261 *cache = Some(CachedEnvironment::new(
263 env_info.clone(),
264 Duration::from_secs(600), ));
266
267 Ok(env_info)
268 }
269
270 fn metadata(&self) -> WorkspaceMetadata {
271 self.metadata.clone()
272 }
273
274 async fn invalidate_environment_cache(&self) {
275 let mut cache = self.environment_cache.write().await;
276 *cache = None;
277 }
278
279 async fn list_files(
280 &self,
281 query: Option<&str>,
282 max_results: Option<usize>,
283 ) -> Result<Vec<String>> {
284 let mut client = self.client.clone();
285
286 let request = tonic::Request::new(ListFilesRequest {
287 query: query.unwrap_or("").to_string(),
288 max_results: max_results.unwrap_or(0) as u32,
289 });
290
291 let mut stream = client
292 .list_files(request)
293 .await
294 .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
295 .into_inner();
296 let mut all_files = Vec::new();
297
298 while let Some(response) = stream
300 .message()
301 .await
302 .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
303 {
304 all_files.extend(response.paths);
305 }
306
307 Ok(all_files)
308 }
309
310 fn working_directory(&self) -> &std::path::Path {
311 std::path::Path::new("/remote")
314 }
315
316 async fn execute_tool(
317 &self,
318 tool_call: &ToolCall,
319 context: steer_tools::ExecutionContext,
320 ) -> Result<ToolResult> {
321 let mut client = self.client.clone();
322
323 let context_json = serde_json::to_string(&SerializableExecutionContext::from(&context))
325 .map_err(|e| {
326 WorkspaceError::ToolExecution(format!("Failed to serialize context: {e}"))
327 })?;
328
329 let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
331 WorkspaceError::ToolExecution(format!("Failed to serialize parameters: {e}"))
332 })?;
333
334 let request = tonic::Request::new(ExecuteToolRequest {
335 tool_call_id: tool_call.id.clone(),
336 tool_name: tool_call.name.clone(),
337 parameters_json,
338 context_json,
339 timeout_ms: Some(30000), });
341
342 let response = client
343 .execute_tool(request)
344 .await
345 .map_err(|e| WorkspaceError::Status(format!("Failed to execute tool: {e}")))?
346 .into_inner();
347
348 if !response.success {
349 let tool_error = if let Some(detail) = response.error_detail {
351 convert_error_detail_to_tool_error(detail)
352 } else {
353 steer_tools::ToolError::Execution {
355 tool_name: tool_call.name.clone(),
356 message: response.error,
357 }
358 };
359 return Ok(ToolResult::Error(tool_error));
360 }
361
362 convert_tool_response(response)
363 }
364
365 async fn available_tools(&self) -> Vec<ToolSchema> {
366 let mut client = self.client.clone();
367
368 let request = tonic::Request::new(GetToolSchemasRequest {});
369
370 match client.get_tool_schemas(request).await {
371 Ok(response) => {
372 response
373 .into_inner()
374 .tools
375 .into_iter()
376 .map(|schema| {
377 let input_schema = serde_json::from_str(&schema.input_schema_json)
379 .unwrap_or_else(|_| steer_tools::InputSchema {
380 properties: serde_json::Map::new(),
381 required: Vec::new(),
382 schema_type: "object".to_string(),
383 });
384
385 ToolSchema {
386 name: schema.name,
387 description: schema.description,
388 input_schema,
389 }
390 })
391 .collect()
392 }
393 Err(_) => Vec::new(),
394 }
395 }
396
397 async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
398 let mut client = self.client.clone();
399
400 let request = tonic::Request::new(GetToolApprovalRequirementsRequest {
401 tool_names: vec![tool_name.to_string()],
402 });
403
404 let response = client
405 .get_tool_approval_requirements(request)
406 .await
407 .map_err(|e| {
408 WorkspaceError::ToolExecution(format!("Failed to get approval requirements: {e}"))
409 })?
410 .into_inner();
411
412 response
413 .approval_requirements
414 .get(tool_name)
415 .copied()
416 .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
417 }
418}
419
420fn convert_proto_to_todo_item(item: steer_proto::common::v1::TodoItem) -> TodoItem {
421 TodoItem {
422 id: item.id.clone(),
423 content: item.content.clone(),
424 status: match steer_proto::common::v1::TodoStatus::try_from(item.status) {
425 Ok(steer_proto::common::v1::TodoStatus::Pending) => TodoStatus::Pending,
426 Ok(steer_proto::common::v1::TodoStatus::InProgress) => TodoStatus::InProgress,
427 Ok(steer_proto::common::v1::TodoStatus::Completed) => TodoStatus::Completed,
428 Ok(steer_proto::common::v1::TodoStatus::StatusUnset) => TodoStatus::Pending,
429 Err(_) => TodoStatus::Pending,
430 },
431 priority: match steer_proto::common::v1::TodoPriority::try_from(item.priority) {
432 Ok(steer_proto::common::v1::TodoPriority::High) => TodoPriority::High,
433 Ok(steer_proto::common::v1::TodoPriority::Medium) => TodoPriority::Medium,
434 Ok(steer_proto::common::v1::TodoPriority::Low) => TodoPriority::Low,
435 Ok(steer_proto::common::v1::TodoPriority::PriorityUnset) => TodoPriority::Low,
436 Err(_) => TodoPriority::Low,
437 },
438 }
439}
440
441fn convert_proto_to_todo_write_file_operation(
442 operation: steer_proto::common::v1::TodoWriteFileOperation,
443) -> TodoWriteFileOperation {
444 match operation {
445 steer_proto::common::v1::TodoWriteFileOperation::Created => TodoWriteFileOperation::Created,
446 steer_proto::common::v1::TodoWriteFileOperation::Modified => {
447 TodoWriteFileOperation::Modified
448 }
449 steer_proto::common::v1::TodoWriteFileOperation::OperationUnset => {
450 TodoWriteFileOperation::Created
451 }
452 }
453}
454
455fn convert_error_detail_to_tool_error(detail: ToolErrorDetail) -> steer_tools::ToolError {
457 use tool_error_detail::Kind;
458
459 let kind = Kind::try_from(detail.kind).unwrap_or(Kind::Internal);
460
461 match kind {
462 Kind::Execution => steer_tools::ToolError::Execution {
463 tool_name: detail.tool_name,
464 message: detail.message,
465 },
466 Kind::Io => steer_tools::ToolError::Io {
467 tool_name: detail.tool_name,
468 message: detail.message,
469 },
470 Kind::InvalidParams => {
471 steer_tools::ToolError::InvalidParams(detail.tool_name, detail.message)
472 }
473 Kind::Cancelled => steer_tools::ToolError::Cancelled(detail.tool_name),
474 Kind::Timeout => steer_tools::ToolError::Timeout(detail.tool_name),
475 Kind::UnknownTool => steer_tools::ToolError::UnknownTool(detail.tool_name),
476 Kind::DeniedByUser => steer_tools::ToolError::DeniedByUser(detail.tool_name),
477 Kind::Internal | Kind::Unset => steer_tools::ToolError::InternalError(detail.message),
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[tokio::test]
486 async fn test_remote_workspace_metadata() {
487 let address = "localhost:50051".to_string();
488
489 let metadata = WorkspaceMetadata {
491 id: format!("remote:{address}"),
492 workspace_type: WorkspaceType::Remote,
493 location: address.clone(),
494 };
495
496 assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
497 assert_eq!(metadata.location, address);
498 }
499
500 #[test]
501 fn test_convert_environment_response() {
502 use std::path::PathBuf;
503
504 let response = GetEnvironmentInfoResponse {
505 working_directory: "/home/user/project".to_string(),
506 is_git_repo: true,
507 platform: "linux".to_string(),
508 date: "2025-06-17".to_string(),
509 directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
510 git_status: Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string()),
511 readme_content: Some("# My Project".to_string()),
512 claude_md_content: None,
513 };
514
515 let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
517
518 assert_eq!(
519 env_info.working_directory,
520 PathBuf::from("/home/user/project")
521 );
522 assert!(env_info.is_git_repo);
523 assert_eq!(env_info.platform, "linux");
524 assert_eq!(env_info.date, "2025-06-17");
525 assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
526 assert_eq!(
527 env_info.git_status,
528 Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string())
529 );
530 assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
531 assert_eq!(env_info.claude_md_content, None);
532 }
533}