steer_workspace_client/
lib.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::RwLock;
6use tonic::transport::Channel;
7
8use steer_proto::remote_workspace::v1::{
9    ExecuteToolRequest, ExecuteToolResponse, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10    GetToolApprovalRequirementsRequest, GetToolSchemasRequest, ListFilesRequest,
11    remote_workspace_service_client::RemoteWorkspaceServiceClient,
12};
13use steer_tools::{
14    ToolCall, ToolSchema,
15    result::ToolResult,
16    tools::todo::{TodoItem, TodoPriority, TodoStatus, TodoWriteFileOperation},
17};
18use steer_workspace::{
19    EnvironmentInfo, RemoteAuth, Result, Workspace, WorkspaceError, WorkspaceMetadata,
20    WorkspaceType,
21};
22
23/// Serializable version of ExecutionContext for remote transmission
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SerializableExecutionContext {
26    pub tool_call_id: String,
27    pub working_directory: std::path::PathBuf,
28}
29
30impl From<&steer_tools::ExecutionContext> for SerializableExecutionContext {
31    fn from(context: &steer_tools::ExecutionContext) -> Self {
32        Self {
33            tool_call_id: context.tool_call_id.clone(),
34            working_directory: context.working_directory.clone(),
35        }
36    }
37}
38
39/// Convert gRPC tool response to ToolResult
40fn convert_tool_response(response: ExecuteToolResponse) -> Result<ToolResult> {
41    use steer_proto::remote_workspace::v1::execute_tool_response::Result as ProtoResult;
42    use steer_tools::result::{
43        BashResult, EditResult, ExternalResult, FileContentResult, FileEntry, FileListResult,
44        GlobResult, SearchMatch, SearchResult, TodoListResult, TodoWriteResult,
45    };
46
47    match response.result {
48        Some(ProtoResult::StringResult(s)) => {
49            // Legacy string result - treat as External
50            Ok(ToolResult::External(ExternalResult {
51                tool_name: "remote".to_string(),
52                payload: s,
53            }))
54        }
55        Some(ProtoResult::SearchResult(proto_result)) => {
56            let matches = proto_result
57                .matches
58                .into_iter()
59                .map(|m| SearchMatch {
60                    file_path: m.file_path,
61                    line_number: m.line_number as usize,
62                    line_content: m.line_content,
63                    column_range: m
64                        .column_range
65                        .map(|cr| (cr.start as usize, cr.end as usize)),
66                })
67                .collect();
68
69            Ok(ToolResult::Search(SearchResult {
70                matches,
71                total_files_searched: proto_result.total_files_searched as usize,
72                search_completed: proto_result.search_completed,
73            }))
74        }
75        Some(ProtoResult::FileListResult(proto_result)) => {
76            let entries = proto_result
77                .entries
78                .into_iter()
79                .map(|e| FileEntry {
80                    path: e.path,
81                    is_directory: e.is_directory,
82                    size: e.size,
83                    permissions: e.permissions,
84                })
85                .collect();
86
87            Ok(ToolResult::FileList(FileListResult {
88                entries,
89                base_path: proto_result.base_path,
90            }))
91        }
92        Some(ProtoResult::FileContentResult(proto_result)) => {
93            Ok(ToolResult::FileContent(FileContentResult {
94                content: proto_result.content,
95                file_path: proto_result.file_path,
96                line_count: proto_result.line_count as usize,
97                truncated: proto_result.truncated,
98            }))
99        }
100        Some(ProtoResult::EditResult(proto_result)) => Ok(ToolResult::Edit(EditResult {
101            file_path: proto_result.file_path,
102            changes_made: proto_result.changes_made as usize,
103            file_created: proto_result.file_created,
104            old_content: proto_result.old_content,
105            new_content: proto_result.new_content,
106        })),
107        Some(ProtoResult::BashResult(proto_result)) => Ok(ToolResult::Bash(BashResult {
108            stdout: proto_result.stdout,
109            stderr: proto_result.stderr,
110            exit_code: proto_result.exit_code,
111            command: proto_result.command,
112        })),
113        Some(ProtoResult::GlobResult(proto_result)) => Ok(ToolResult::Glob(GlobResult {
114            matches: proto_result.matches,
115            pattern: proto_result.pattern,
116        })),
117        Some(ProtoResult::TodoListResult(proto_result)) => {
118            let todos = proto_result
119                .todos
120                .into_iter()
121                .map(convert_proto_to_todo_item)
122                .collect();
123
124            Ok(ToolResult::TodoRead(TodoListResult { todos }))
125        }
126        Some(ProtoResult::TodoWriteResult(proto_result)) => {
127            let todos = proto_result
128                .todos
129                .into_iter()
130                .map(convert_proto_to_todo_item)
131                .collect();
132
133            Ok(ToolResult::TodoWrite(TodoWriteResult {
134                todos,
135                operation: convert_proto_to_todo_write_file_operation(
136                    steer_proto::common::v1::TodoWriteFileOperation::try_from(
137                        proto_result.operation,
138                    )
139                    .unwrap_or(steer_proto::common::v1::TodoWriteFileOperation::OperationUnset),
140                ),
141            }))
142        }
143        _ => Err(WorkspaceError::ToolExecution(
144            "No result returned from remote execution".to_string(),
145        )),
146    }
147}
148
149/// Cached environment information with TTL
150#[derive(Debug, Clone)]
151struct CachedEnvironment {
152    pub info: EnvironmentInfo,
153    pub cached_at: std::time::Instant,
154    pub ttl: Duration,
155}
156
157impl CachedEnvironment {
158    pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
159        Self {
160            info,
161            cached_at: std::time::Instant::now(),
162            ttl,
163        }
164    }
165
166    pub fn is_expired(&self) -> bool {
167        self.cached_at.elapsed() > self.ttl
168    }
169}
170
171/// Remote workspace that executes tools and collects environment info via gRPC
172pub struct RemoteWorkspace {
173    client: RemoteWorkspaceServiceClient<Channel>,
174    environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
175    metadata: WorkspaceMetadata,
176    #[allow(dead_code)]
177    auth: Option<RemoteAuth>,
178}
179
180impl RemoteWorkspace {
181    pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
182        // Create gRPC client
183        let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
184            .await
185            .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?;
186
187        let metadata = WorkspaceMetadata {
188            id: format!("remote:{address}"),
189            workspace_type: WorkspaceType::Remote,
190            location: address.clone(),
191        };
192
193        Ok(Self {
194            client,
195            environment_cache: Arc::new(RwLock::new(None)),
196            metadata,
197            auth,
198        })
199    }
200
201    /// Collect environment information from the remote workspace
202    async fn collect_environment(&self) -> Result<EnvironmentInfo> {
203        let mut client = self.client.clone();
204
205        let request = tonic::Request::new(GetEnvironmentInfoRequest {
206            working_directory: None, // Use remote default
207        });
208
209        let response = client
210            .get_environment_info(request)
211            .await
212            .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
213        let env_response = response.into_inner();
214
215        Self::convert_environment_response(env_response)
216    }
217
218    /// Convert gRPC response to EnvironmentInfo
219    fn convert_environment_response(
220        response: GetEnvironmentInfoResponse,
221    ) -> Result<EnvironmentInfo> {
222        use std::path::PathBuf;
223
224        Ok(EnvironmentInfo {
225            working_directory: PathBuf::from(response.working_directory),
226            is_git_repo: response.is_git_repo,
227            platform: response.platform,
228            date: response.date,
229            directory_structure: response.directory_structure,
230            git_status: response.git_status,
231            readme_content: response.readme_content,
232            claude_md_content: response.claude_md_content,
233        })
234    }
235}
236
237impl std::fmt::Debug for RemoteWorkspace {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        f.debug_struct("RemoteWorkspace")
240            .field("metadata", &self.metadata)
241            .field("auth", &self.auth)
242            .finish_non_exhaustive()
243    }
244}
245
246#[async_trait]
247impl Workspace for RemoteWorkspace {
248    async fn environment(&self) -> Result<EnvironmentInfo> {
249        let mut cache = self.environment_cache.write().await;
250
251        // Check if we have valid cached data
252        if let Some(cached) = cache.as_ref() {
253            if !cached.is_expired() {
254                return Ok(cached.info.clone());
255            }
256        }
257
258        // Collect fresh environment info from remote
259        let env_info = self.collect_environment().await?;
260
261        // Cache it with 10 minute TTL (longer than local since remote calls are expensive)
262        *cache = Some(CachedEnvironment::new(
263            env_info.clone(),
264            Duration::from_secs(600), // 10 minutes
265        ));
266
267        Ok(env_info)
268    }
269
270    fn metadata(&self) -> WorkspaceMetadata {
271        self.metadata.clone()
272    }
273
274    async fn invalidate_environment_cache(&self) {
275        let mut cache = self.environment_cache.write().await;
276        *cache = None;
277    }
278
279    async fn list_files(
280        &self,
281        query: Option<&str>,
282        max_results: Option<usize>,
283    ) -> Result<Vec<String>> {
284        let mut client = self.client.clone();
285
286        let request = tonic::Request::new(ListFilesRequest {
287            query: query.unwrap_or("").to_string(),
288            max_results: max_results.unwrap_or(0) as u32,
289        });
290
291        let mut stream = client
292            .list_files(request)
293            .await
294            .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
295            .into_inner();
296        let mut all_files = Vec::new();
297
298        // Collect all files from the stream
299        while let Some(response) = stream
300            .message()
301            .await
302            .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
303        {
304            all_files.extend(response.paths);
305        }
306
307        Ok(all_files)
308    }
309
310    fn working_directory(&self) -> &std::path::Path {
311        // For remote workspaces, we return a placeholder path
312        // The actual working directory is on the remote machine
313        std::path::Path::new("/remote")
314    }
315
316    async fn execute_tool(
317        &self,
318        tool_call: &ToolCall,
319        context: steer_tools::ExecutionContext,
320    ) -> Result<ToolResult> {
321        let mut client = self.client.clone();
322
323        // Serialize the execution context
324        let context_json = serde_json::to_string(&SerializableExecutionContext::from(&context))
325            .map_err(|e| {
326                WorkspaceError::ToolExecution(format!("Failed to serialize context: {e}"))
327            })?;
328
329        // Serialize tool parameters
330        let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
331            WorkspaceError::ToolExecution(format!("Failed to serialize parameters: {e}"))
332        })?;
333
334        let request = tonic::Request::new(ExecuteToolRequest {
335            tool_call_id: tool_call.id.clone(),
336            tool_name: tool_call.name.clone(),
337            parameters_json,
338            context_json,
339            timeout_ms: Some(30000), // 30 second default
340        });
341
342        let response = client
343            .execute_tool(request)
344            .await
345            .map_err(|e| WorkspaceError::ToolExecution(format!("Failed to execute tool: {e}")))?
346            .into_inner();
347
348        if !response.success {
349            return Err(WorkspaceError::ToolExecution(format!(
350                "Tool execution failed: {}",
351                response.error
352            )));
353        }
354
355        convert_tool_response(response)
356    }
357
358    async fn available_tools(&self) -> Vec<ToolSchema> {
359        let mut client = self.client.clone();
360
361        let request = tonic::Request::new(GetToolSchemasRequest {});
362
363        match client.get_tool_schemas(request).await {
364            Ok(response) => {
365                response
366                    .into_inner()
367                    .tools
368                    .into_iter()
369                    .map(|schema| {
370                        // Parse the JSON input schema
371                        let input_schema = serde_json::from_str(&schema.input_schema_json)
372                            .unwrap_or_else(|_| steer_tools::InputSchema {
373                                properties: serde_json::Map::new(),
374                                required: Vec::new(),
375                                schema_type: "object".to_string(),
376                            });
377
378                        ToolSchema {
379                            name: schema.name,
380                            description: schema.description,
381                            input_schema,
382                        }
383                    })
384                    .collect()
385            }
386            Err(_) => Vec::new(),
387        }
388    }
389
390    async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
391        let mut client = self.client.clone();
392
393        let request = tonic::Request::new(GetToolApprovalRequirementsRequest {
394            tool_names: vec![tool_name.to_string()],
395        });
396
397        let response = client
398            .get_tool_approval_requirements(request)
399            .await
400            .map_err(|e| {
401                WorkspaceError::ToolExecution(format!("Failed to get approval requirements: {e}"))
402            })?
403            .into_inner();
404
405        response
406            .approval_requirements
407            .get(tool_name)
408            .copied()
409            .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
410    }
411}
412
413fn convert_proto_to_todo_item(item: steer_proto::common::v1::TodoItem) -> TodoItem {
414    TodoItem {
415        id: item.id.clone(),
416        content: item.content.clone(),
417        status: match steer_proto::common::v1::TodoStatus::try_from(item.status) {
418            Ok(steer_proto::common::v1::TodoStatus::Pending) => TodoStatus::Pending,
419            Ok(steer_proto::common::v1::TodoStatus::InProgress) => TodoStatus::InProgress,
420            Ok(steer_proto::common::v1::TodoStatus::Completed) => TodoStatus::Completed,
421            Ok(steer_proto::common::v1::TodoStatus::StatusUnset) => TodoStatus::Pending,
422            Err(_) => TodoStatus::Pending,
423        },
424        priority: match steer_proto::common::v1::TodoPriority::try_from(item.priority) {
425            Ok(steer_proto::common::v1::TodoPriority::High) => TodoPriority::High,
426            Ok(steer_proto::common::v1::TodoPriority::Medium) => TodoPriority::Medium,
427            Ok(steer_proto::common::v1::TodoPriority::Low) => TodoPriority::Low,
428            Ok(steer_proto::common::v1::TodoPriority::PriorityUnset) => TodoPriority::Low,
429            Err(_) => TodoPriority::Low,
430        },
431    }
432}
433
434fn convert_proto_to_todo_write_file_operation(
435    operation: steer_proto::common::v1::TodoWriteFileOperation,
436) -> TodoWriteFileOperation {
437    match operation {
438        steer_proto::common::v1::TodoWriteFileOperation::Created => TodoWriteFileOperation::Created,
439        steer_proto::common::v1::TodoWriteFileOperation::Modified => {
440            TodoWriteFileOperation::Modified
441        }
442        steer_proto::common::v1::TodoWriteFileOperation::OperationUnset => {
443            TodoWriteFileOperation::Created
444        }
445    }
446}
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    #[tokio::test]
453    async fn test_remote_workspace_metadata() {
454        let address = "localhost:50051".to_string();
455
456        // This test will fail if no remote backend is running, but we can test metadata creation
457        let metadata = WorkspaceMetadata {
458            id: format!("remote:{address}"),
459            workspace_type: WorkspaceType::Remote,
460            location: address.clone(),
461        };
462
463        assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
464        assert_eq!(metadata.location, address);
465    }
466
467    #[test]
468    fn test_convert_environment_response() {
469        use std::path::PathBuf;
470
471        let response = GetEnvironmentInfoResponse {
472            working_directory: "/home/user/project".to_string(),
473            is_git_repo: true,
474            platform: "linux".to_string(),
475            date: "2025-06-17".to_string(),
476            directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
477            git_status: Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string()),
478            readme_content: Some("# My Project".to_string()),
479            claude_md_content: None,
480        };
481
482        // Test the static conversion function directly
483        let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
484
485        assert_eq!(
486            env_info.working_directory,
487            PathBuf::from("/home/user/project")
488        );
489        assert!(env_info.is_git_repo);
490        assert_eq!(env_info.platform, "linux");
491        assert_eq!(env_info.date, "2025-06-17");
492        assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
493        assert_eq!(
494            env_info.git_status,
495            Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string())
496        );
497        assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
498        assert_eq!(env_info.claude_md_content, None);
499    }
500}