steer_workspace_client/
lib.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::RwLock;
6use tonic::transport::Channel;
7
8use steer_proto::remote_workspace::v1::{
9    ExecuteToolRequest, ExecuteToolResponse, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10    GetToolApprovalRequirementsRequest, GetToolSchemasRequest, ListFilesRequest, ToolErrorDetail,
11    remote_workspace_service_client::RemoteWorkspaceServiceClient, tool_error_detail,
12};
13use steer_tools::{
14    ToolCall, ToolSchema,
15    result::ToolResult,
16    tools::todo::{TodoItem, TodoPriority, TodoStatus, TodoWriteFileOperation},
17};
18use steer_workspace::{
19    EnvironmentInfo, RemoteAuth, Result, Workspace, WorkspaceError, WorkspaceMetadata,
20    WorkspaceType,
21};
22
23/// Serializable version of ExecutionContext for remote transmission
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SerializableExecutionContext {
26    pub tool_call_id: String,
27    pub working_directory: std::path::PathBuf,
28}
29
30impl From<&steer_tools::ExecutionContext> for SerializableExecutionContext {
31    fn from(context: &steer_tools::ExecutionContext) -> Self {
32        Self {
33            tool_call_id: context.tool_call_id.clone(),
34            working_directory: context.working_directory.clone(),
35        }
36    }
37}
38
39/// Convert gRPC tool response to ToolResult
40fn convert_tool_response(response: ExecuteToolResponse) -> Result<ToolResult> {
41    use steer_proto::remote_workspace::v1::execute_tool_response::Result as ProtoResult;
42    use steer_tools::result::{
43        BashResult, EditResult, ExternalResult, FileContentResult, FileEntry, FileListResult,
44        GlobResult, SearchMatch, SearchResult, TodoListResult, TodoWriteResult,
45    };
46
47    match response.result {
48        Some(ProtoResult::StringResult(s)) => {
49            // Legacy string result - treat as External
50            Ok(ToolResult::External(ExternalResult {
51                tool_name: "remote".to_string(),
52                payload: s,
53            }))
54        }
55        Some(ProtoResult::SearchResult(proto_result)) => {
56            let matches = proto_result
57                .matches
58                .into_iter()
59                .map(|m| SearchMatch {
60                    file_path: m.file_path,
61                    line_number: m.line_number as usize,
62                    line_content: m.line_content,
63                    column_range: m
64                        .column_range
65                        .map(|cr| (cr.start as usize, cr.end as usize)),
66                })
67                .collect();
68
69            Ok(ToolResult::Search(SearchResult {
70                matches,
71                total_files_searched: proto_result.total_files_searched as usize,
72                search_completed: proto_result.search_completed,
73            }))
74        }
75        Some(ProtoResult::FileListResult(proto_result)) => {
76            let entries = proto_result
77                .entries
78                .into_iter()
79                .map(|e| FileEntry {
80                    path: e.path,
81                    is_directory: e.is_directory,
82                    size: e.size,
83                    permissions: e.permissions,
84                })
85                .collect();
86
87            Ok(ToolResult::FileList(FileListResult {
88                entries,
89                base_path: proto_result.base_path,
90            }))
91        }
92        Some(ProtoResult::FileContentResult(proto_result)) => {
93            Ok(ToolResult::FileContent(FileContentResult {
94                content: proto_result.content,
95                file_path: proto_result.file_path,
96                line_count: proto_result.line_count as usize,
97                truncated: proto_result.truncated,
98            }))
99        }
100        Some(ProtoResult::EditResult(proto_result)) => Ok(ToolResult::Edit(EditResult {
101            file_path: proto_result.file_path,
102            changes_made: proto_result.changes_made as usize,
103            file_created: proto_result.file_created,
104            old_content: proto_result.old_content,
105            new_content: proto_result.new_content,
106        })),
107        Some(ProtoResult::BashResult(proto_result)) => Ok(ToolResult::Bash(BashResult {
108            stdout: proto_result.stdout,
109            stderr: proto_result.stderr,
110            exit_code: proto_result.exit_code,
111            command: proto_result.command,
112        })),
113        Some(ProtoResult::GlobResult(proto_result)) => Ok(ToolResult::Glob(GlobResult {
114            matches: proto_result.matches,
115            pattern: proto_result.pattern,
116        })),
117        Some(ProtoResult::TodoListResult(proto_result)) => {
118            let todos = proto_result
119                .todos
120                .into_iter()
121                .map(convert_proto_to_todo_item)
122                .collect();
123
124            Ok(ToolResult::TodoRead(TodoListResult { todos }))
125        }
126        Some(ProtoResult::TodoWriteResult(proto_result)) => {
127            let todos = proto_result
128                .todos
129                .into_iter()
130                .map(convert_proto_to_todo_item)
131                .collect();
132
133            Ok(ToolResult::TodoWrite(TodoWriteResult {
134                todos,
135                operation: convert_proto_to_todo_write_file_operation(
136                    steer_proto::common::v1::TodoWriteFileOperation::try_from(
137                        proto_result.operation,
138                    )
139                    .unwrap_or(steer_proto::common::v1::TodoWriteFileOperation::OperationUnset),
140                ),
141            }))
142        }
143        _ => Err(WorkspaceError::ToolExecution(
144            "No result returned from remote execution".to_string(),
145        )),
146    }
147}
148
149/// Cached environment information with TTL
150#[derive(Debug, Clone)]
151struct CachedEnvironment {
152    pub info: EnvironmentInfo,
153    pub cached_at: std::time::Instant,
154    pub ttl: Duration,
155}
156
157impl CachedEnvironment {
158    pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
159        Self {
160            info,
161            cached_at: std::time::Instant::now(),
162            ttl,
163        }
164    }
165
166    pub fn is_expired(&self) -> bool {
167        self.cached_at.elapsed() > self.ttl
168    }
169}
170
171/// Remote workspace that executes tools and collects environment info via gRPC
172pub struct RemoteWorkspace {
173    client: RemoteWorkspaceServiceClient<Channel>,
174    environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
175    metadata: WorkspaceMetadata,
176    #[allow(dead_code)]
177    auth: Option<RemoteAuth>,
178}
179
180impl RemoteWorkspace {
181    pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
182        // Create gRPC client
183        let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
184            .await
185            .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?;
186
187        let metadata = WorkspaceMetadata {
188            id: format!("remote:{address}"),
189            workspace_type: WorkspaceType::Remote,
190            location: address.clone(),
191        };
192
193        Ok(Self {
194            client,
195            environment_cache: Arc::new(RwLock::new(None)),
196            metadata,
197            auth,
198        })
199    }
200
201    /// Collect environment information from the remote workspace
202    async fn collect_environment(&self) -> Result<EnvironmentInfo> {
203        let mut client = self.client.clone();
204
205        let request = tonic::Request::new(GetEnvironmentInfoRequest {
206            working_directory: None, // Use remote default
207        });
208
209        let response = client
210            .get_environment_info(request)
211            .await
212            .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
213        let env_response = response.into_inner();
214
215        Self::convert_environment_response(env_response)
216    }
217
218    /// Convert gRPC response to EnvironmentInfo
219    fn convert_environment_response(
220        response: GetEnvironmentInfoResponse,
221    ) -> Result<EnvironmentInfo> {
222        use std::path::PathBuf;
223
224        Ok(EnvironmentInfo {
225            working_directory: PathBuf::from(response.working_directory),
226            is_git_repo: response.is_git_repo,
227            platform: response.platform,
228            date: response.date,
229            directory_structure: response.directory_structure,
230            git_status: response.git_status,
231            readme_content: response.readme_content,
232            claude_md_content: response.claude_md_content,
233        })
234    }
235}
236
237impl std::fmt::Debug for RemoteWorkspace {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        f.debug_struct("RemoteWorkspace")
240            .field("metadata", &self.metadata)
241            .field("auth", &self.auth)
242            .finish_non_exhaustive()
243    }
244}
245
246#[async_trait]
247impl Workspace for RemoteWorkspace {
248    async fn environment(&self) -> Result<EnvironmentInfo> {
249        let mut cache = self.environment_cache.write().await;
250
251        // Check if we have valid cached data
252        if let Some(cached) = cache.as_ref() {
253            if !cached.is_expired() {
254                return Ok(cached.info.clone());
255            }
256        }
257
258        // Collect fresh environment info from remote
259        let env_info = self.collect_environment().await?;
260
261        // Cache it with 10 minute TTL (longer than local since remote calls are expensive)
262        *cache = Some(CachedEnvironment::new(
263            env_info.clone(),
264            Duration::from_secs(600), // 10 minutes
265        ));
266
267        Ok(env_info)
268    }
269
270    fn metadata(&self) -> WorkspaceMetadata {
271        self.metadata.clone()
272    }
273
274    async fn invalidate_environment_cache(&self) {
275        let mut cache = self.environment_cache.write().await;
276        *cache = None;
277    }
278
279    async fn list_files(
280        &self,
281        query: Option<&str>,
282        max_results: Option<usize>,
283    ) -> Result<Vec<String>> {
284        let mut client = self.client.clone();
285
286        let request = tonic::Request::new(ListFilesRequest {
287            query: query.unwrap_or("").to_string(),
288            max_results: max_results.unwrap_or(0) as u32,
289        });
290
291        let mut stream = client
292            .list_files(request)
293            .await
294            .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
295            .into_inner();
296        let mut all_files = Vec::new();
297
298        // Collect all files from the stream
299        while let Some(response) = stream
300            .message()
301            .await
302            .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
303        {
304            all_files.extend(response.paths);
305        }
306
307        Ok(all_files)
308    }
309
310    fn working_directory(&self) -> &std::path::Path {
311        // For remote workspaces, we return a placeholder path
312        // The actual working directory is on the remote machine
313        std::path::Path::new("/remote")
314    }
315
316    async fn execute_tool(
317        &self,
318        tool_call: &ToolCall,
319        context: steer_tools::ExecutionContext,
320    ) -> Result<ToolResult> {
321        let mut client = self.client.clone();
322
323        // Serialize the execution context
324        let context_json = serde_json::to_string(&SerializableExecutionContext::from(&context))
325            .map_err(|e| {
326                WorkspaceError::ToolExecution(format!("Failed to serialize context: {e}"))
327            })?;
328
329        // Serialize tool parameters
330        let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
331            WorkspaceError::ToolExecution(format!("Failed to serialize parameters: {e}"))
332        })?;
333
334        let request = tonic::Request::new(ExecuteToolRequest {
335            tool_call_id: tool_call.id.clone(),
336            tool_name: tool_call.name.clone(),
337            parameters_json,
338            context_json,
339            timeout_ms: Some(30000), // 30 second default
340        });
341
342        let response = client
343            .execute_tool(request)
344            .await
345            .map_err(|e| WorkspaceError::Status(format!("Failed to execute tool: {e}")))?
346            .into_inner();
347
348        if !response.success {
349            // Use typed error_detail if available, otherwise fall back to string
350            let tool_error = if let Some(detail) = response.error_detail {
351                convert_error_detail_to_tool_error(detail)
352            } else {
353                // Fallback for older servers
354                steer_tools::ToolError::Execution {
355                    tool_name: tool_call.name.clone(),
356                    message: response.error,
357                }
358            };
359            return Ok(ToolResult::Error(tool_error));
360        }
361
362        convert_tool_response(response)
363    }
364
365    async fn available_tools(&self) -> Vec<ToolSchema> {
366        let mut client = self.client.clone();
367
368        let request = tonic::Request::new(GetToolSchemasRequest {});
369
370        match client.get_tool_schemas(request).await {
371            Ok(response) => {
372                response
373                    .into_inner()
374                    .tools
375                    .into_iter()
376                    .map(|schema| {
377                        // Parse the JSON input schema
378                        let input_schema = serde_json::from_str(&schema.input_schema_json)
379                            .unwrap_or_else(|_| steer_tools::InputSchema {
380                                properties: serde_json::Map::new(),
381                                required: Vec::new(),
382                                schema_type: "object".to_string(),
383                            });
384
385                        ToolSchema {
386                            name: schema.name,
387                            description: schema.description,
388                            input_schema,
389                        }
390                    })
391                    .collect()
392            }
393            Err(_) => Vec::new(),
394        }
395    }
396
397    async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
398        let mut client = self.client.clone();
399
400        let request = tonic::Request::new(GetToolApprovalRequirementsRequest {
401            tool_names: vec![tool_name.to_string()],
402        });
403
404        let response = client
405            .get_tool_approval_requirements(request)
406            .await
407            .map_err(|e| {
408                WorkspaceError::ToolExecution(format!("Failed to get approval requirements: {e}"))
409            })?
410            .into_inner();
411
412        response
413            .approval_requirements
414            .get(tool_name)
415            .copied()
416            .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
417    }
418}
419
420fn convert_proto_to_todo_item(item: steer_proto::common::v1::TodoItem) -> TodoItem {
421    TodoItem {
422        id: item.id.clone(),
423        content: item.content.clone(),
424        status: match steer_proto::common::v1::TodoStatus::try_from(item.status) {
425            Ok(steer_proto::common::v1::TodoStatus::Pending) => TodoStatus::Pending,
426            Ok(steer_proto::common::v1::TodoStatus::InProgress) => TodoStatus::InProgress,
427            Ok(steer_proto::common::v1::TodoStatus::Completed) => TodoStatus::Completed,
428            Ok(steer_proto::common::v1::TodoStatus::StatusUnset) => TodoStatus::Pending,
429            Err(_) => TodoStatus::Pending,
430        },
431        priority: match steer_proto::common::v1::TodoPriority::try_from(item.priority) {
432            Ok(steer_proto::common::v1::TodoPriority::High) => TodoPriority::High,
433            Ok(steer_proto::common::v1::TodoPriority::Medium) => TodoPriority::Medium,
434            Ok(steer_proto::common::v1::TodoPriority::Low) => TodoPriority::Low,
435            Ok(steer_proto::common::v1::TodoPriority::PriorityUnset) => TodoPriority::Low,
436            Err(_) => TodoPriority::Low,
437        },
438    }
439}
440
441fn convert_proto_to_todo_write_file_operation(
442    operation: steer_proto::common::v1::TodoWriteFileOperation,
443) -> TodoWriteFileOperation {
444    match operation {
445        steer_proto::common::v1::TodoWriteFileOperation::Created => TodoWriteFileOperation::Created,
446        steer_proto::common::v1::TodoWriteFileOperation::Modified => {
447            TodoWriteFileOperation::Modified
448        }
449        steer_proto::common::v1::TodoWriteFileOperation::OperationUnset => {
450            TodoWriteFileOperation::Created
451        }
452    }
453}
454
455/// Convert proto ToolErrorDetail to ToolError
456fn convert_error_detail_to_tool_error(detail: ToolErrorDetail) -> steer_tools::ToolError {
457    use tool_error_detail::Kind;
458
459    let kind = Kind::try_from(detail.kind).unwrap_or(Kind::Internal);
460
461    match kind {
462        Kind::Execution => steer_tools::ToolError::Execution {
463            tool_name: detail.tool_name,
464            message: detail.message,
465        },
466        Kind::Io => steer_tools::ToolError::Io {
467            tool_name: detail.tool_name,
468            message: detail.message,
469        },
470        Kind::InvalidParams => {
471            steer_tools::ToolError::InvalidParams(detail.tool_name, detail.message)
472        }
473        Kind::Cancelled => steer_tools::ToolError::Cancelled(detail.tool_name),
474        Kind::Timeout => steer_tools::ToolError::Timeout(detail.tool_name),
475        Kind::UnknownTool => steer_tools::ToolError::UnknownTool(detail.tool_name),
476        Kind::DeniedByUser => steer_tools::ToolError::DeniedByUser(detail.tool_name),
477        Kind::Internal | Kind::Unset => steer_tools::ToolError::InternalError(detail.message),
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[tokio::test]
486    async fn test_remote_workspace_metadata() {
487        let address = "localhost:50051".to_string();
488
489        // This test will fail if no remote backend is running, but we can test metadata creation
490        let metadata = WorkspaceMetadata {
491            id: format!("remote:{address}"),
492            workspace_type: WorkspaceType::Remote,
493            location: address.clone(),
494        };
495
496        assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
497        assert_eq!(metadata.location, address);
498    }
499
500    #[test]
501    fn test_convert_environment_response() {
502        use std::path::PathBuf;
503
504        let response = GetEnvironmentInfoResponse {
505            working_directory: "/home/user/project".to_string(),
506            is_git_repo: true,
507            platform: "linux".to_string(),
508            date: "2025-06-17".to_string(),
509            directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
510            git_status: Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string()),
511            readme_content: Some("# My Project".to_string()),
512            claude_md_content: None,
513        };
514
515        // Test the static conversion function directly
516        let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
517
518        assert_eq!(
519            env_info.working_directory,
520            PathBuf::from("/home/user/project")
521        );
522        assert!(env_info.is_git_repo);
523        assert_eq!(env_info.platform, "linux");
524        assert_eq!(env_info.date, "2025-06-17");
525        assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
526        assert_eq!(
527            env_info.git_status,
528            Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string())
529        );
530        assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
531        assert_eq!(env_info.claude_md_content, None);
532    }
533}