steer_remote_workspace/
remote_workspace_service.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use tokio::sync::mpsc;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::{Request, Response, Status};
7
8use steer_tools::tools::workspace_tools;
9use steer_tools::traits::ExecutableTool;
10use steer_tools::{ExecutionContext, ToolError};
11use steer_workspace::utils::{
12    DirectoryStructureUtils, EnvironmentUtils, FileListingUtils, GitStatusUtils,
13};
14
15use crate::proto::{
16    ExecuteToolRequest, ExecuteToolResponse, GetAgentInfoRequest, GetAgentInfoResponse,
17    GetToolApprovalRequirementsRequest, GetToolApprovalRequirementsResponse, GetToolSchemasRequest,
18    GetToolSchemasResponse, HealthRequest, HealthResponse, HealthStatus, ListFilesRequest,
19    ListFilesResponse, ToolSchema as GrpcToolSchema, execute_tool_response::Result as ProtoResult,
20    remote_workspace_service_server::RemoteWorkspaceService as RemoteWorkspaceServiceServer,
21};
22use steer_proto::common::v1::{
23    BashResult as ProtoBashResult, ColumnRange as ProtoColumnRange, EditResult as ProtoEditResult,
24    FileContentResult as ProtoFileContentResult, FileEntry as ProtoFileEntry,
25    FileListResult as ProtoFileListResult, GlobResult as ProtoGlobResult,
26    SearchMatch as ProtoSearchMatch, SearchResult as ProtoSearchResult,
27    TodoListResult as ProtoTodoListResult, TodoWriteResult as ProtoTodoWriteResult,
28};
29
30use steer_grpc::grpc::conversions::{
31    convert_todo_item_to_proto, convert_todo_write_file_operation_to_proto,
32};
33
34/// Agent service implementation that executes tools locally
35///
36/// This service receives tool execution requests via gRPC and executes them
37/// using the standard tool executor. It's designed to run on remote machines,
38/// VMs, or containers to provide remote tool execution capabilities.
39pub struct RemoteWorkspaceService {
40    working_dir: PathBuf,
41    tools: Arc<HashMap<String, Box<dyn ExecutableTool>>>,
42    version: String,
43}
44
45impl RemoteWorkspaceService {
46    /// Create a new RemoteWorkspaceService with the standard tool set
47    pub fn new(working_dir: PathBuf) -> Result<Self, ToolError> {
48        Self::with_tools(workspace_tools(), working_dir)
49    }
50
51    /// Create a new RemoteWorkspaceService with a custom set of tools
52    pub fn with_tools(
53        tools_list: Vec<Box<dyn ExecutableTool>>,
54        working_dir: PathBuf,
55    ) -> Result<Self, ToolError> {
56        let mut tools: HashMap<String, Box<dyn ExecutableTool>> = HashMap::new();
57
58        // Register the provided tools
59        for tool in tools_list {
60            tools.insert(tool.name().to_string(), tool);
61        }
62
63        Ok(Self {
64            working_dir,
65            tools: Arc::new(tools),
66            version: env!("CARGO_PKG_VERSION").to_string(),
67        })
68    }
69
70    /// Get the supported tools from the tool executor
71    pub fn get_supported_tools(&self) -> Vec<String> {
72        self.tools.keys().cloned().collect()
73    }
74
75    /// Convert a ToolError to a gRPC Status
76    fn tool_error_to_status(error: ToolError) -> Status {
77        match error {
78            ToolError::Cancelled(_) => Status::cancelled("Tool execution was cancelled"),
79            ToolError::UnknownTool(tool_name) => {
80                Status::not_found(format!("Unknown tool: {tool_name}"))
81            }
82            ToolError::InvalidParams(tool_name, message) => Status::invalid_argument(format!(
83                "Invalid parameters for tool {tool_name}: {message}"
84            )),
85            ToolError::Execution { tool_name, message } => {
86                Status::internal(format!("Tool {tool_name} execution failed: {message}"))
87            }
88            ToolError::Io { tool_name, message } => {
89                Status::internal(format!("IO error in tool {tool_name}: {message}"))
90            }
91            ToolError::DeniedByUser(tool_name) => {
92                Status::permission_denied(format!("Tool execution denied: {tool_name}"))
93            }
94            ToolError::Timeout(tool_name) => {
95                Status::deadline_exceeded(format!("Tool {tool_name} execution timed out"))
96            }
97            ToolError::InternalError(message) => {
98                Status::internal(format!("Internal error: {message}"))
99            }
100        }
101    }
102
103    /// Convert a typed tool output to a proto result
104    fn tool_result_to_proto_result(
105        result: &steer_tools::result::ToolResult,
106    ) -> Option<ProtoResult> {
107        // Match on the ToolResult enum variants
108        match result {
109            steer_tools::result::ToolResult::Search(search_result) => {
110                let proto_matches = search_result
111                    .matches
112                    .iter()
113                    .map(|m| ProtoSearchMatch {
114                        file_path: m.file_path.clone(),
115                        line_number: m.line_number as u64,
116                        line_content: m.line_content.clone(),
117                        column_range: m.column_range.map(|(start, end)| ProtoColumnRange {
118                            start: start as u64,
119                            end: end as u64,
120                        }),
121                    })
122                    .collect();
123
124                Some(ProtoResult::SearchResult(ProtoSearchResult {
125                    matches: proto_matches,
126                    total_files_searched: search_result.total_files_searched as u64,
127                    search_completed: search_result.search_completed,
128                }))
129            }
130
131            steer_tools::result::ToolResult::FileList(file_list) => {
132                let proto_entries = file_list
133                    .entries
134                    .iter()
135                    .map(|e| ProtoFileEntry {
136                        path: e.path.clone(),
137                        is_directory: e.is_directory,
138                        size: e.size,
139                        permissions: e.permissions.clone(),
140                    })
141                    .collect();
142
143                Some(ProtoResult::FileListResult(ProtoFileListResult {
144                    entries: proto_entries,
145                    base_path: file_list.base_path.clone(),
146                }))
147            }
148
149            steer_tools::result::ToolResult::FileContent(file_content) => {
150                Some(ProtoResult::FileContentResult(ProtoFileContentResult {
151                    content: file_content.content.clone(),
152                    file_path: file_content.file_path.clone(),
153                    line_count: file_content.line_count as u64,
154                    truncated: file_content.truncated,
155                }))
156            }
157
158            steer_tools::result::ToolResult::Edit(edit_result) => {
159                Some(ProtoResult::EditResult(ProtoEditResult {
160                    file_path: edit_result.file_path.clone(),
161                    changes_made: edit_result.changes_made as u64,
162                    file_created: edit_result.file_created,
163                    old_content: edit_result.old_content.clone(),
164                    new_content: edit_result.new_content.clone(),
165                }))
166            }
167
168            steer_tools::result::ToolResult::Bash(bash_result) => {
169                Some(ProtoResult::BashResult(ProtoBashResult {
170                    stdout: bash_result.stdout.clone(),
171                    stderr: bash_result.stderr.clone(),
172                    exit_code: bash_result.exit_code,
173                    command: bash_result.command.clone(),
174                }))
175            }
176
177            steer_tools::result::ToolResult::Glob(glob_result) => {
178                Some(ProtoResult::GlobResult(ProtoGlobResult {
179                    matches: glob_result.matches.clone(),
180                    pattern: glob_result.pattern.clone(),
181                }))
182            }
183
184            steer_tools::result::ToolResult::TodoRead(todo_list) => {
185                let proto_todos = todo_list
186                    .todos
187                    .iter()
188                    .map(convert_todo_item_to_proto)
189                    .collect();
190
191                Some(ProtoResult::TodoListResult(ProtoTodoListResult {
192                    todos: proto_todos,
193                }))
194            }
195
196            steer_tools::result::ToolResult::TodoWrite(todo_write_result) => {
197                let proto_todos = todo_write_result
198                    .todos
199                    .iter()
200                    .map(convert_todo_item_to_proto)
201                    .collect();
202
203                Some(ProtoResult::TodoWriteResult(ProtoTodoWriteResult {
204                    todos: proto_todos,
205                    operation: convert_todo_write_file_operation_to_proto(
206                        &todo_write_result.operation,
207                    ) as i32,
208                }))
209            }
210
211            steer_tools::result::ToolResult::Fetch(_) => {
212                // Fetch results are not handled in the remote workspace
213                None
214            }
215
216            steer_tools::result::ToolResult::Agent(_) => {
217                // Agent results are not handled in the remote workspace
218                None
219            }
220
221            steer_tools::result::ToolResult::External(_) => {
222                // External results are not handled in the remote workspace
223                None
224            }
225
226            steer_tools::result::ToolResult::Error(_) => {
227                // Errors are handled differently
228                None
229            }
230        }
231    }
232
233    /// Get directory structure for environment info
234    fn get_directory_structure(&self) -> Result<String, std::io::Error> {
235        DirectoryStructureUtils::get_directory_structure(&self.working_dir, 3)
236    }
237
238    /// Get git status information
239    async fn get_git_status(&self) -> Result<String, std::io::Error> {
240        GitStatusUtils::get_git_status(&self.working_dir)
241    }
242}
243
244#[tonic::async_trait]
245impl RemoteWorkspaceServiceServer for RemoteWorkspaceService {
246    type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
247    /// Get tool schemas
248    async fn get_tool_schemas(
249        &self,
250        _request: Request<GetToolSchemasRequest>,
251    ) -> Result<Response<GetToolSchemasResponse>, Status> {
252        let mut schemas = Vec::new();
253
254        for (name, tool) in self.tools.iter() {
255            let input_schema = tool.input_schema();
256            let input_schema_json = serde_json::to_string(&input_schema)
257                .map_err(|e| Status::internal(format!("Failed to serialize schema: {e}")))?;
258
259            schemas.push(GrpcToolSchema {
260                name: name.clone(),
261                description: tool.description(),
262                input_schema_json,
263            });
264        }
265
266        Ok(Response::new(GetToolSchemasResponse { tools: schemas }))
267    }
268
269    /// Execute a tool call on the agent
270    async fn execute_tool(
271        &self,
272        request: Request<ExecuteToolRequest>,
273    ) -> Result<Response<ExecuteToolResponse>, Status> {
274        let start_time = std::time::Instant::now();
275        let req = request.into_inner();
276
277        // Parse the tool parameters
278        let parameters: serde_json::Value =
279            serde_json::from_str(&req.parameters_json).map_err(|e| {
280                Status::invalid_argument(format!("Failed to parse tool parameters: {e}"))
281            })?;
282
283        // Look up the tool
284        let tool = self
285            .tools
286            .get(&req.tool_name)
287            .ok_or_else(|| Status::not_found(format!("Unknown tool: {}", req.tool_name)))?;
288
289        // Create a cancellation token and a drop guard. When the gRPC request is cancelled,
290        // this async function will be dropped, which triggers the drop guard to cancel the token.
291        // This ensures that long-running tools (like bash commands) are properly cancelled.
292        let cancellation_token = tokio_util::sync::CancellationToken::new();
293        let _guard = cancellation_token.clone().drop_guard();
294
295        // Create execution context
296        let context = ExecutionContext::new(req.tool_call_id.clone())
297            .with_cancellation_token(cancellation_token);
298
299        let result = tool.run(parameters, &context).await;
300
301        let end_time = std::time::Instant::now();
302        let duration = end_time - start_time;
303
304        // Convert result to response
305        let response = match result {
306            Ok(tool_result) => {
307                // Convert to a typed result
308                let proto_result = Self::tool_result_to_proto_result(&tool_result);
309
310                ExecuteToolResponse {
311                    success: true,
312                    result: proto_result.or_else(|| {
313                        // Fallback to string result
314                        Some(ProtoResult::StringResult(tool_result.llm_format()))
315                    }),
316                    error: String::new(),
317                    started_at: Some(prost_types::Timestamp {
318                        seconds: start_time.elapsed().as_secs() as i64,
319                        nanos: 0,
320                    }),
321                    completed_at: Some(prost_types::Timestamp {
322                        seconds: duration.as_secs() as i64,
323                        nanos: duration.subsec_nanos() as i32,
324                    }),
325                    metadata: std::collections::HashMap::new(),
326                }
327            }
328            Err(error) => {
329                // For some errors, we want to return them as successful responses
330                // with the error in the error field, rather than failing the gRPC call
331                match &error {
332                    ToolError::Cancelled(_) => {
333                        return Err(Status::cancelled("Tool execution was cancelled"));
334                    }
335                    ToolError::UnknownTool(_) => {
336                        return Err(Self::tool_error_to_status(error));
337                    }
338                    _ => ExecuteToolResponse {
339                        success: false,
340                        result: None,
341                        error: error.to_string(),
342                        started_at: Some(prost_types::Timestamp {
343                            seconds: start_time.elapsed().as_secs() as i64,
344                            nanos: 0,
345                        }),
346                        completed_at: Some(prost_types::Timestamp {
347                            seconds: duration.as_secs() as i64,
348                            nanos: duration.subsec_nanos() as i32,
349                        }),
350                        metadata: std::collections::HashMap::new(),
351                    },
352                }
353            }
354        };
355
356        Ok(Response::new(response))
357    }
358
359    /// Get information about the agent and available tools
360    async fn get_agent_info(
361        &self,
362        _request: Request<GetAgentInfoRequest>,
363    ) -> Result<Response<GetAgentInfoResponse>, Status> {
364        let supported_tools = self.get_supported_tools();
365
366        let info = GetAgentInfoResponse {
367            version: self.version.clone(),
368            supported_tools,
369            metadata: std::collections::HashMap::from([
370                (
371                    "hostname".to_string(),
372                    gethostname::gethostname().to_string_lossy().to_string(),
373                ),
374                (
375                    "working_directory".to_string(),
376                    self.working_dir.to_string_lossy().to_string(),
377                ),
378            ]),
379        };
380
381        Ok(Response::new(info))
382    }
383
384    /// Health check
385    async fn health(
386        &self,
387        _request: Request<HealthRequest>,
388    ) -> Result<Response<HealthResponse>, Status> {
389        // Simple health check - we could add more sophisticated checks here
390        let response = HealthResponse {
391            status: HealthStatus::Serving as i32,
392            message: "Agent is healthy and ready to execute tools".to_string(),
393            details: std::collections::HashMap::from([(
394                "tool_count".to_string(),
395                self.get_supported_tools().len().to_string(),
396            )]),
397        };
398
399        Ok(Response::new(response))
400    }
401
402    /// Get tool approval requirements
403    async fn get_tool_approval_requirements(
404        &self,
405        request: Request<GetToolApprovalRequirementsRequest>,
406    ) -> Result<Response<GetToolApprovalRequirementsResponse>, Status> {
407        let req = request.into_inner();
408        let mut approval_requirements = std::collections::HashMap::new();
409
410        for tool_name in req.tool_names {
411            if let Some(tool) = self.tools.get(&tool_name) {
412                approval_requirements.insert(tool_name, tool.requires_approval());
413            } else {
414                // Unknown tools are not included in the response
415                // This matches the behavior of the local backend which returns UnknownTool error
416            }
417        }
418
419        Ok(Response::new(GetToolApprovalRequirementsResponse {
420            approval_requirements,
421        }))
422    }
423
424    /// Get environment information for the remote workspace
425    async fn get_environment_info(
426        &self,
427        request: Request<crate::proto::GetEnvironmentInfoRequest>,
428    ) -> Result<Response<crate::proto::GetEnvironmentInfoResponse>, Status> {
429        let req = request.into_inner();
430
431        // Use the provided working directory or current directory
432        let working_directory = if let Some(dir) = req.working_directory {
433            dir
434        } else {
435            self.working_dir.to_string_lossy().to_string()
436        };
437
438        // Check if it's a git repo
439        let is_git_repo = EnvironmentUtils::is_git_repo(Path::new(&working_directory));
440
441        // Get platform information
442        let platform = EnvironmentUtils::get_platform().to_string();
443
444        // Get current date
445        let date = EnvironmentUtils::get_current_date();
446
447        // Get directory structure (simplified for now)
448        let directory_structure = self.get_directory_structure().unwrap_or_else(|_| {
449            format!("Failed to read directory structure from {working_directory}")
450        });
451
452        // Get git status if it's a git repo
453        let git_status = if is_git_repo {
454            self.get_git_status().await.ok()
455        } else {
456            None
457        };
458
459        // Read README.md if it exists
460        let readme_content = EnvironmentUtils::read_readme(Path::new(&working_directory));
461
462        // Read CLAUDE.md if it exists
463        let claude_md_content = EnvironmentUtils::read_claude_md(Path::new(&working_directory));
464
465        let response = crate::proto::GetEnvironmentInfoResponse {
466            working_directory,
467            is_git_repo,
468            platform,
469            date,
470            directory_structure,
471            git_status,
472            readme_content,
473            claude_md_content,
474        };
475
476        Ok(Response::new(response))
477    }
478
479    /// List files in the workspace for fuzzy finding
480    async fn list_files(
481        &self,
482        request: Request<ListFilesRequest>,
483    ) -> Result<Response<Self::ListFilesStream>, Status> {
484        let req = request.into_inner();
485
486        // Create the response stream
487        let (tx, rx) = mpsc::channel(100);
488
489        // Spawn task to stream the files
490        tokio::spawn(async move {
491            // Use the shared file listing utility
492            let query = if req.query.is_empty() {
493                None
494            } else {
495                Some(req.query.as_str())
496            };
497            let max_results = if req.max_results > 0 {
498                Some(req.max_results as usize)
499            } else {
500                None
501            };
502
503            let files = match FileListingUtils::list_files(Path::new("."), query, max_results) {
504                Ok(files) => files,
505                Err(e) => {
506                    tracing::error!("Error listing files: {}", e);
507                    return;
508                }
509            };
510
511            // Stream files in chunks of 1000
512            for chunk in files.chunks(1000) {
513                let response = ListFilesResponse {
514                    paths: chunk.to_vec(),
515                };
516
517                // If the receiver is dropped (client cancelled), this will fail and exit the loop
518                if let Err(e) = tx.send(Ok(response)).await {
519                    tracing::debug!("Client cancelled file list stream: {}", e);
520                    break;
521                }
522            }
523        });
524
525        // Note: The task handle is stored but not awaited here because the gRPC
526        // streaming response will consume the receiver. The task will complete
527        // when either all files are sent or the receiver is dropped (client cancellation).
528
529        Ok(Response::new(ReceiverStream::new(rx)))
530    }
531}