steer_workspace_client/
lib.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::RwLock;
6use tonic::transport::Channel;
7
8use steer_proto::remote_workspace::v1::{
9    ExecuteToolRequest, ExecuteToolResponse, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10    GetToolApprovalRequirementsRequest, GetToolSchemasRequest, ListFilesRequest,
11    remote_workspace_service_client::RemoteWorkspaceServiceClient,
12};
13use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
14use steer_workspace::{
15    EnvironmentInfo, RemoteAuth, Result, Workspace, WorkspaceError, WorkspaceMetadata,
16    WorkspaceType,
17};
18
19/// Serializable version of ExecutionContext for remote transmission
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct SerializableExecutionContext {
22    pub tool_call_id: String,
23    pub working_directory: std::path::PathBuf,
24}
25
26impl From<&steer_tools::ExecutionContext> for SerializableExecutionContext {
27    fn from(context: &steer_tools::ExecutionContext) -> Self {
28        Self {
29            tool_call_id: context.tool_call_id.clone(),
30            working_directory: context.working_directory.clone(),
31        }
32    }
33}
34
35/// Convert gRPC tool response to ToolResult
36fn convert_tool_response(response: ExecuteToolResponse) -> Result<ToolResult> {
37    use steer_proto::remote_workspace::v1::execute_tool_response::Result as ProtoResult;
38    use steer_tools::result::{
39        BashResult, EditResult, ExternalResult, FileContentResult, FileEntry, FileListResult,
40        GlobResult, SearchMatch, SearchResult, TodoItem, TodoListResult, TodoWriteResult,
41    };
42
43    match response.result {
44        Some(ProtoResult::StringResult(s)) => {
45            // Legacy string result - treat as External
46            Ok(ToolResult::External(ExternalResult {
47                tool_name: "remote".to_string(),
48                payload: s,
49            }))
50        }
51        Some(ProtoResult::SearchResult(proto_result)) => {
52            let matches = proto_result
53                .matches
54                .into_iter()
55                .map(|m| SearchMatch {
56                    file_path: m.file_path,
57                    line_number: m.line_number as usize,
58                    line_content: m.line_content,
59                    column_range: m
60                        .column_range
61                        .map(|cr| (cr.start as usize, cr.end as usize)),
62                })
63                .collect();
64
65            Ok(ToolResult::Search(SearchResult {
66                matches,
67                total_files_searched: proto_result.total_files_searched as usize,
68                search_completed: proto_result.search_completed,
69            }))
70        }
71        Some(ProtoResult::FileListResult(proto_result)) => {
72            let entries = proto_result
73                .entries
74                .into_iter()
75                .map(|e| FileEntry {
76                    path: e.path,
77                    is_directory: e.is_directory,
78                    size: e.size,
79                    permissions: e.permissions,
80                })
81                .collect();
82
83            Ok(ToolResult::FileList(FileListResult {
84                entries,
85                base_path: proto_result.base_path,
86            }))
87        }
88        Some(ProtoResult::FileContentResult(proto_result)) => {
89            Ok(ToolResult::FileContent(FileContentResult {
90                content: proto_result.content,
91                file_path: proto_result.file_path,
92                line_count: proto_result.line_count as usize,
93                truncated: proto_result.truncated,
94            }))
95        }
96        Some(ProtoResult::EditResult(proto_result)) => Ok(ToolResult::Edit(EditResult {
97            file_path: proto_result.file_path,
98            changes_made: proto_result.changes_made as usize,
99            file_created: proto_result.file_created,
100            old_content: proto_result.old_content,
101            new_content: proto_result.new_content,
102        })),
103        Some(ProtoResult::BashResult(proto_result)) => Ok(ToolResult::Bash(BashResult {
104            stdout: proto_result.stdout,
105            stderr: proto_result.stderr,
106            exit_code: proto_result.exit_code,
107            command: proto_result.command,
108        })),
109        Some(ProtoResult::GlobResult(proto_result)) => Ok(ToolResult::Glob(GlobResult {
110            matches: proto_result.matches,
111            pattern: proto_result.pattern,
112        })),
113        Some(ProtoResult::TodoListResult(proto_result)) => {
114            let todos = proto_result
115                .todos
116                .into_iter()
117                .map(|t| TodoItem {
118                    id: t.id,
119                    content: t.content,
120                    status: t.status,
121                    priority: t.priority,
122                })
123                .collect();
124
125            Ok(ToolResult::TodoRead(TodoListResult { todos }))
126        }
127        Some(ProtoResult::TodoWriteResult(proto_result)) => {
128            let todos = proto_result
129                .todos
130                .into_iter()
131                .map(|t| TodoItem {
132                    id: t.id,
133                    content: t.content,
134                    status: t.status,
135                    priority: t.priority,
136                })
137                .collect();
138
139            Ok(ToolResult::TodoWrite(TodoWriteResult {
140                todos,
141                operation: proto_result.operation,
142            }))
143        }
144        _ => Err(WorkspaceError::ToolExecution(
145            "No result returned from remote execution".to_string(),
146        )),
147    }
148}
149
150/// Cached environment information with TTL
151#[derive(Debug, Clone)]
152struct CachedEnvironment {
153    pub info: EnvironmentInfo,
154    pub cached_at: std::time::Instant,
155    pub ttl: Duration,
156}
157
158impl CachedEnvironment {
159    pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
160        Self {
161            info,
162            cached_at: std::time::Instant::now(),
163            ttl,
164        }
165    }
166
167    pub fn is_expired(&self) -> bool {
168        self.cached_at.elapsed() > self.ttl
169    }
170}
171
172/// Remote workspace that executes tools and collects environment info via gRPC
173pub struct RemoteWorkspace {
174    client: RemoteWorkspaceServiceClient<Channel>,
175    environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
176    metadata: WorkspaceMetadata,
177    #[allow(dead_code)]
178    auth: Option<RemoteAuth>,
179}
180
181impl RemoteWorkspace {
182    pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
183        // Create gRPC client
184        let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
185            .await
186            .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?;
187
188        let metadata = WorkspaceMetadata {
189            id: format!("remote:{address}"),
190            workspace_type: WorkspaceType::Remote,
191            location: address.clone(),
192        };
193
194        Ok(Self {
195            client,
196            environment_cache: Arc::new(RwLock::new(None)),
197            metadata,
198            auth,
199        })
200    }
201
202    /// Collect environment information from the remote workspace
203    async fn collect_environment(&self) -> Result<EnvironmentInfo> {
204        let mut client = self.client.clone();
205
206        let request = tonic::Request::new(GetEnvironmentInfoRequest {
207            working_directory: None, // Use remote default
208        });
209
210        let response = client
211            .get_environment_info(request)
212            .await
213            .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
214        let env_response = response.into_inner();
215
216        Self::convert_environment_response(env_response)
217    }
218
219    /// Convert gRPC response to EnvironmentInfo
220    fn convert_environment_response(
221        response: GetEnvironmentInfoResponse,
222    ) -> Result<EnvironmentInfo> {
223        use std::path::PathBuf;
224
225        Ok(EnvironmentInfo {
226            working_directory: PathBuf::from(response.working_directory),
227            is_git_repo: response.is_git_repo,
228            platform: response.platform,
229            date: response.date,
230            directory_structure: response.directory_structure,
231            git_status: response.git_status,
232            readme_content: response.readme_content,
233            claude_md_content: response.claude_md_content,
234        })
235    }
236}
237
238impl std::fmt::Debug for RemoteWorkspace {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        f.debug_struct("RemoteWorkspace")
241            .field("metadata", &self.metadata)
242            .field("auth", &self.auth)
243            .finish_non_exhaustive()
244    }
245}
246
247#[async_trait]
248impl Workspace for RemoteWorkspace {
249    async fn environment(&self) -> Result<EnvironmentInfo> {
250        let mut cache = self.environment_cache.write().await;
251
252        // Check if we have valid cached data
253        if let Some(cached) = cache.as_ref() {
254            if !cached.is_expired() {
255                return Ok(cached.info.clone());
256            }
257        }
258
259        // Collect fresh environment info from remote
260        let env_info = self.collect_environment().await?;
261
262        // Cache it with 10 minute TTL (longer than local since remote calls are expensive)
263        *cache = Some(CachedEnvironment::new(
264            env_info.clone(),
265            Duration::from_secs(600), // 10 minutes
266        ));
267
268        Ok(env_info)
269    }
270
271    fn metadata(&self) -> WorkspaceMetadata {
272        self.metadata.clone()
273    }
274
275    async fn invalidate_environment_cache(&self) {
276        let mut cache = self.environment_cache.write().await;
277        *cache = None;
278    }
279
280    async fn list_files(
281        &self,
282        query: Option<&str>,
283        max_results: Option<usize>,
284    ) -> Result<Vec<String>> {
285        let mut client = self.client.clone();
286
287        let request = tonic::Request::new(ListFilesRequest {
288            query: query.unwrap_or("").to_string(),
289            max_results: max_results.unwrap_or(0) as u32,
290        });
291
292        let mut stream = client
293            .list_files(request)
294            .await
295            .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
296            .into_inner();
297        let mut all_files = Vec::new();
298
299        // Collect all files from the stream
300        while let Some(response) = stream
301            .message()
302            .await
303            .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
304        {
305            all_files.extend(response.paths);
306        }
307
308        Ok(all_files)
309    }
310
311    fn working_directory(&self) -> &std::path::Path {
312        // For remote workspaces, we return a placeholder path
313        // The actual working directory is on the remote machine
314        std::path::Path::new("/remote")
315    }
316
317    async fn execute_tool(
318        &self,
319        tool_call: &ToolCall,
320        context: steer_tools::ExecutionContext,
321    ) -> Result<ToolResult> {
322        let mut client = self.client.clone();
323
324        // Serialize the execution context
325        let context_json = serde_json::to_string(&SerializableExecutionContext::from(&context))
326            .map_err(|e| {
327                WorkspaceError::ToolExecution(format!("Failed to serialize context: {e}"))
328            })?;
329
330        // Serialize tool parameters
331        let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
332            WorkspaceError::ToolExecution(format!("Failed to serialize parameters: {e}"))
333        })?;
334
335        let request = tonic::Request::new(ExecuteToolRequest {
336            tool_call_id: tool_call.id.clone(),
337            tool_name: tool_call.name.clone(),
338            parameters_json,
339            context_json,
340            timeout_ms: Some(30000), // 30 second default
341        });
342
343        let response = client
344            .execute_tool(request)
345            .await
346            .map_err(|e| WorkspaceError::ToolExecution(format!("Failed to execute tool: {e}")))?
347            .into_inner();
348
349        if !response.success {
350            return Err(WorkspaceError::ToolExecution(format!(
351                "Tool execution failed: {}",
352                response.error
353            )));
354        }
355
356        // Convert the response to ToolResult
357        convert_tool_response(response)
358    }
359
360    async fn available_tools(&self) -> Vec<ToolSchema> {
361        let mut client = self.client.clone();
362
363        let request = tonic::Request::new(GetToolSchemasRequest {});
364
365        match client.get_tool_schemas(request).await {
366            Ok(response) => {
367                response
368                    .into_inner()
369                    .tools
370                    .into_iter()
371                    .map(|schema| {
372                        // Parse the JSON input schema
373                        let input_schema = serde_json::from_str(&schema.input_schema_json)
374                            .unwrap_or_else(|_| steer_tools::InputSchema {
375                                properties: serde_json::Map::new(),
376                                required: Vec::new(),
377                                schema_type: "object".to_string(),
378                            });
379
380                        ToolSchema {
381                            name: schema.name,
382                            description: schema.description,
383                            input_schema,
384                        }
385                    })
386                    .collect()
387            }
388            Err(_) => Vec::new(),
389        }
390    }
391
392    async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
393        let mut client = self.client.clone();
394
395        let request = tonic::Request::new(GetToolApprovalRequirementsRequest {
396            tool_names: vec![tool_name.to_string()],
397        });
398
399        let response = client
400            .get_tool_approval_requirements(request)
401            .await
402            .map_err(|e| {
403                WorkspaceError::ToolExecution(format!("Failed to get approval requirements: {e}"))
404            })?
405            .into_inner();
406
407        response
408            .approval_requirements
409            .get(tool_name)
410            .copied()
411            .ok_or_else(|| WorkspaceError::ToolExecution(format!("Unknown tool: {tool_name}")))
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[tokio::test]
420    async fn test_remote_workspace_metadata() {
421        let address = "localhost:50051".to_string();
422
423        // This test will fail if no remote backend is running, but we can test metadata creation
424        let metadata = WorkspaceMetadata {
425            id: format!("remote:{address}"),
426            workspace_type: WorkspaceType::Remote,
427            location: address.clone(),
428        };
429
430        assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
431        assert_eq!(metadata.location, address);
432    }
433
434    #[test]
435    fn test_convert_environment_response() {
436        use std::path::PathBuf;
437
438        let response = GetEnvironmentInfoResponse {
439            working_directory: "/home/user/project".to_string(),
440            is_git_repo: true,
441            platform: "linux".to_string(),
442            date: "2025-06-17".to_string(),
443            directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
444            git_status: Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string()),
445            readme_content: Some("# My Project".to_string()),
446            claude_md_content: None,
447        };
448
449        // Test the static conversion function directly
450        let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
451
452        assert_eq!(
453            env_info.working_directory,
454            PathBuf::from("/home/user/project")
455        );
456        assert!(env_info.is_git_repo);
457        assert_eq!(env_info.platform, "linux");
458        assert_eq!(env_info.date, "2025-06-17");
459        assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
460        assert_eq!(
461            env_info.git_status,
462            Some("Current branch: main\n\nStatus:\nWorking tree clean\n".to_string())
463        );
464        assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
465        assert_eq!(env_info.claude_md_content, None);
466    }
467}