1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::RwLock;
6use tonic::transport::Channel;
7
8use steer_proto::remote_workspace::v1::{
9 ExecuteToolRequest, ExecuteToolResponse, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10 GetToolApprovalRequirementsRequest, GetToolSchemasRequest, ListFilesRequest,
11 remote_workspace_service_client::RemoteWorkspaceServiceClient,
12};
13use steer_tools::{
14 ToolCall, ToolSchema,
15 result::ToolResult,
16 tools::todo::{TodoItem, TodoPriority, TodoStatus, TodoWriteFileOperation},
17};
18use steer_workspace::{
19 EnvironmentInfo, RemoteAuth, Result, Workspace, WorkspaceError, WorkspaceMetadata,
20 WorkspaceType,
21};
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SerializableExecutionContext {
26 pub tool_call_id: String,
27 pub working_directory: std::path::PathBuf,
28}
29
30impl From<&steer_tools::ExecutionContext> for SerializableExecutionContext {
31 fn from(context: &steer_tools::ExecutionContext) -> Self {
32 Self {
33 tool_call_id: context.tool_call_id.clone(),
34 working_directory: context.working_directory.clone(),
35 }
36 }
37}
38
39fn convert_tool_response(response: ExecuteToolResponse) -> Result<ToolResult> {
41 use steer_proto::remote_workspace::v1::execute_tool_response::Result as ProtoResult;
42 use steer_tools::result::{
43 BashResult, EditResult, ExternalResult, FileContentResult, FileEntry, FileListResult,
44 GlobResult, SearchMatch, SearchResult, TodoListResult, TodoWriteResult,
45 };
46
47 match response.result {
48 Some(ProtoResult::StringResult(s)) => {
49 Ok(ToolResult::External(ExternalResult {
51 tool_name: "remote".to_string(),
52 payload: s,
53 }))
54 }
55 Some(ProtoResult::SearchResult(proto_result)) => {
56 let matches = proto_result
57 .matches
58 .into_iter()
59 .map(|m| SearchMatch {
60 file_path: m.file_path,
61 line_number: m.line_number as usize,
62 line_content: m.line_content,
63 column_range: m
64 .column_range
65 .map(|cr| (cr.start as usize, cr.end as usize)),
66 })
67 .collect();
68
69 Ok(ToolResult::Search(SearchResult {
70 matches,
71 total_files_searched: proto_result.total_files_searched as usize,
72 search_completed: proto_result.search_completed,
73 }))
74 }
75 Some(ProtoResult::FileListResult(proto_result)) => {
76 let entries = proto_result
77 .entries
78 .into_iter()
79 .map(|e| FileEntry {
80 path: e.path,
81 is_directory: e.is_directory,
82 size: e.size,
83 permissions: e.permissions,
84 })
85 .collect();
86
87 Ok(ToolResult::FileList(FileListResult {
88 entries,
89 base_path: proto_result.base_path,
90 }))
91 }
92 Some(ProtoResult::FileContentResult(proto_result)) => {
93 Ok(ToolResult::FileContent(FileContentResult {
94 content: proto_result.content,
95 file_path: proto_result.file_path,
96 line_count: proto_result.line_count as usize,
97 truncated: proto_result.truncated,
98 }))
99 }
100 Some(ProtoResult::EditResult(proto_result)) => Ok(ToolResult::Edit(EditResult {
101 file_path: proto_result.file_path,
102 changes_made: proto_result.changes_made as usize,
103 file_created: proto_result.file_created,
104 old_content: proto_result.old_content,
105 new_content: proto_result.new_content,
106 })),
107 Some(ProtoResult::BashResult(proto_result)) => Ok(ToolResult::Bash(BashResult {
108 stdout: proto_result.stdout,
109 stderr: proto_result.stderr,
110 exit_code: proto_result.exit_code,
111 command: proto_result.command,
112 })),
113 Some(ProtoResult::GlobResult(proto_result)) => Ok(ToolResult::Glob(GlobResult {
114 matches: proto_result.matches,
115 pattern: proto_result.pattern,
116 })),
117 Some(ProtoResult::TodoListResult(proto_result)) => {
118 let todos = proto_result
119 .todos
120 .into_iter()
121 .map(convert_proto_to_todo_item)
122 .collect();
123
124 Ok(ToolResult::TodoRead(TodoListResult { todos }))
125 }
126 Some(ProtoResult::TodoWriteResult(proto_result)) => {
127 let todos = proto_result
128 .todos
129 .into_iter()
130 .map(convert_proto_to_todo_item)
131 .collect();
132
133 Ok(ToolResult::TodoWrite(TodoWriteResult {
134 todos,
135 operation: convert_proto_to_todo_write_file_operation(
136 steer_proto::common::v1::TodoWriteFileOperation::try_from(
137 proto_result.operation,
138 )
139 .unwrap_or(steer_proto::common::v1::TodoWriteFileOperation::OperationUnset),
140 ),
141 }))
142 }
143 _ => Err(WorkspaceError::ToolExecution(
144 "No result returned from remote execution".to_string(),
145 )),
146 }
147}
148
149#[derive(Debug, Clone)]
151struct CachedEnvironment {
152 pub info: EnvironmentInfo,
153 pub cached_at: std::time::Instant,
154 pub ttl: Duration,
155}
156
157impl CachedEnvironment {
158 pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
159 Self {
160 info,
161 cached_at: std::time::Instant::now(),
162 ttl,
163 }
164 }
165
166 pub fn is_expired(&self) -> bool {
167 self.cached_at.elapsed() > self.ttl
168 }
169}
170
171pub struct RemoteWorkspace {
173 client: RemoteWorkspaceServiceClient<Channel>,
174 environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
175 metadata: WorkspaceMetadata,
176 #[allow(dead_code)]
177 auth: Option<RemoteAuth>,
178}
179
180impl RemoteWorkspace {
181 pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
182 let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
184 .await
185 .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?;
186
187 let metadata = WorkspaceMetadata {
188 id: format!("remote:{address}"),
189 workspace_type: WorkspaceType::Remote,
190 location: address.clone(),
191 };
192
193 Ok(Self {
194 client,
195 environment_cache: Arc::new(RwLock::new(None)),
196 metadata,
197 auth,
198 })
199 }
200
201 async fn collect_environment(&self) -> Result<EnvironmentInfo> {
203 let mut client = self.client.clone();
204
205 let request = tonic::Request::new(GetEnvironmentInfoRequest {
206 working_directory: None, });
208
209 let response = client
210 .get_environment_info(request)
211 .await
212 .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
213 let env_response = response.into_inner();
214
215 Self::convert_environment_response(env_response)
216 }
217
218 fn convert_environment_response(
220 response: GetEnvironmentInfoResponse,
221 ) -> Result<EnvironmentInfo> {
222 use std::path::PathBuf;
223
224 Ok(EnvironmentInfo {
225 working_directory: PathBuf::from(response.working_directory),
226 is_git_repo: response.is_git_repo,
227 platform: response.platform,
228 date: response.date,
229 directory_structure: response.directory_structure,
230 git_status: response.git_status,
231 readme_content: response.readme_content,
232 claude_md_content: response.claude_md_content,
233 })
234 }
235}
236
237impl std::fmt::Debug for RemoteWorkspace {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 f.debug_struct("RemoteWorkspace")
240 .field("metadata", &self.metadata)
241 .field("auth", &self.auth)
242 .finish_non_exhaustive()
243 }
244}
245
246#[async_trait]
247impl Workspace for RemoteWorkspace {
248 async fn environment(&self) -> Result<EnvironmentInfo> {
249 let mut cache = self.environment_cache.write().await;
250
251 if let Some(cached) = cache.as_ref() {
253 if !cached.is_expired() {
254 return Ok(cached.info.clone());
255 }
256 }
257
258 let env_info = self.collect_environment().await?;
260
261 *cache = Some(CachedEnvironment::new(
263 env_info.clone(),
264 Duration::from_secs(600), ));
266
267 Ok(env_info)
268 }
269
270 fn metadata(&self) -> WorkspaceMetadata {
271 self.metadata.clone()
272 }
273
274 async fn invalidate_environment_cache(&self) {
275 let mut cache = self.environment_cache.write().await;
276 *cache = None;
277 }
278
279 async fn list_files(
280 &self,
281 query: Option<&str>,
282 max_results: Option<usize>,
283 ) -> Result<Vec<String>> {
284 let mut client = self.client.clone();
285
286 let request = tonic::Request::new(ListFilesRequest {
287 query: query.unwrap_or("").to_string(),
288 max_results: max_results.unwrap_or(0) as u32,
289 });
290
291 let mut stream = client
292 .list_files(request)
293 .await
294 .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
295 .into_inner();
296 let mut all_files = Vec::new();
297
298 while let Some(response) = stream
300 .message()
301 .await
302 .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
303 {
304 all_files.extend(response.paths);
305 }
306
307 Ok(all_files)
308 }
309
310 fn working_directory(&self) -> &std::path::Path {
311 std::path::Path::new("/remote")
314 }
315
316 async fn execute_tool(
317 &self,
318 tool_call: &ToolCall,
319 context: steer_tools::ExecutionContext,
320 ) -> Result<ToolResult> {
321 let mut client = self.client.clone();
322
323 let context_json = serde_json::to_string(&SerializableExecutionContext::from(&context))
325 .map_err(|e| {
326 WorkspaceError::ToolExecution(format!("Failed to serialize context: {e}"))
327 })?;
328
329 let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
331 WorkspaceError::ToolExecution(format!("Failed to serialize parameters: {e}"))
332 })?;
333
334 let request = tonic::Request::new(ExecuteToolRequest {
335 tool_call_id: tool_call.id.clone(),
336 tool_name: tool_call.name.clone(),
337 parameters_json,
338 context_json,
339 timeout_ms: Some(30000), });
341
342 let response = client
343 .execute_tool(request)
344 .await
345 .map_err(|e| WorkspaceError::ToolExecution(format!("Failed to execute tool: {e}")))?
346 .into_inner();
347
348 if !response.success {
349 return Err(WorkspaceError::ToolExecution(format!(
350 "Tool execution failed: {}",
351 response.error
352 )));
353 }
354
355 convert_tool_response(response)
356 }
357
358 async fn available_tools(&self) -> Vec<ToolSchema> {
359 let mut client = self.client.clone();
360
361 let request = tonic::Request::new(GetToolSchemasRequest {});
362
363 match client.get_tool_schemas(request).await {
364 Ok(response) => {
365 response
366 .into_inner()
367 .tools
368 .into_iter()
369 .map(|schema| {
370 let input_schema = serde_json::from_str(&schema.input_schema_json)
372 .unwrap_or_else(|_| steer_tools::InputSchema {
373 properties: serde_json::Map::new(),
374 required: Vec::new(),
375 schema_type: "object".to_string(),
376 });
377
378 ToolSchema {
379 name: schema.name,
380 description: schema.description,
381 input_schema,
382 }
383 })
384 .collect()
385 }
386 Err(_) => Vec::new(),
387 }
388 }
389
390 async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
391 let mut client = self.client.clone();
392
393 let request = tonic::Request::new(GetToolApprovalRequirementsRequest {
394 tool_names: vec![tool_name.to_string()],
395 });
396
397 let response = client
398 .get_tool_approval_requirements(request)
399 .await
400 .map_err(|e| {
401 WorkspaceError::ToolExecution(format!("Failed to get approval requirements: {e}"))
402 })?
403 .into_inner();
404
405 response
406 .approval_requirements
407 .get(tool_name)
408 .copied()
409 .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
410 }
411}
412
413fn convert_proto_to_todo_item(item: steer_proto::common::v1::TodoItem) -> TodoItem {
414 TodoItem {
415 id: item.id.clone(),
416 content: item.content.clone(),
417 status: match steer_proto::common::v1::TodoStatus::try_from(item.status) {
418 Ok(steer_proto::common::v1::TodoStatus::Pending) => TodoStatus::Pending,
419 Ok(steer_proto::common::v1::TodoStatus::InProgress) => TodoStatus::InProgress,
420 Ok(steer_proto::common::v1::TodoStatus::Completed) => TodoStatus::Completed,
421 Ok(steer_proto::common::v1::TodoStatus::StatusUnset) => TodoStatus::Pending,
422 Err(_) => TodoStatus::Pending,
423 },
424 priority: match steer_proto::common::v1::TodoPriority::try_from(item.priority) {
425 Ok(steer_proto::common::v1::TodoPriority::High) => TodoPriority::High,
426 Ok(steer_proto::common::v1::TodoPriority::Medium) => TodoPriority::Medium,
427 Ok(steer_proto::common::v1::TodoPriority::Low) => TodoPriority::Low,
428 Ok(steer_proto::common::v1::TodoPriority::PriorityUnset) => TodoPriority::Low,
429 Err(_) => TodoPriority::Low,
430 },
431 }
432}
433
434fn convert_proto_to_todo_write_file_operation(
435 operation: steer_proto::common::v1::TodoWriteFileOperation,
436) -> TodoWriteFileOperation {
437 match operation {
438 steer_proto::common::v1::TodoWriteFileOperation::Created => TodoWriteFileOperation::Created,
439 steer_proto::common::v1::TodoWriteFileOperation::Modified => {
440 TodoWriteFileOperation::Modified
441 }
442 steer_proto::common::v1::TodoWriteFileOperation::OperationUnset => {
443 TodoWriteFileOperation::Created
444 }
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451
452 #[tokio::test]
453 async fn test_remote_workspace_metadata() {
454 let address = "localhost:50051".to_string();
455
456 let metadata = WorkspaceMetadata {
458 id: format!("remote:{address}"),
459 workspace_type: WorkspaceType::Remote,
460 location: address.clone(),
461 };
462
463 assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
464 assert_eq!(metadata.location, address);
465 }
466
467 #[test]
468 fn test_convert_environment_response() {
469 use std::path::PathBuf;
470
471 let response = GetEnvironmentInfoResponse {
472 working_directory: "/home/user/project".to_string(),
473 is_git_repo: true,
474 platform: "linux".to_string(),
475 date: "2025-06-17".to_string(),
476 directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
477 git_status: Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string()),
478 readme_content: Some("# My Project".to_string()),
479 claude_md_content: None,
480 };
481
482 let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
484
485 assert_eq!(
486 env_info.working_directory,
487 PathBuf::from("/home/user/project")
488 );
489 assert!(env_info.is_git_repo);
490 assert_eq!(env_info.platform, "linux");
491 assert_eq!(env_info.date, "2025-06-17");
492 assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
493 assert_eq!(
494 env_info.git_status,
495 Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string())
496 );
497 assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
498 assert_eq!(env_info.claude_md_content, None);
499 }
500}