1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use tokio::sync::mpsc;
5use tokio_stream::wrappers::ReceiverStream;
6use tonic::{Request, Response, Status};
7
8use steer_tools::tools::workspace_tools;
9use steer_tools::traits::ExecutableTool;
10use steer_tools::{ExecutionContext, ToolError};
11use steer_workspace::utils::{
12 DirectoryStructureUtils, EnvironmentUtils, FileListingUtils, GitStatusUtils,
13};
14
15use crate::proto::{
16 ExecuteToolRequest, ExecuteToolResponse, GetAgentInfoRequest, GetAgentInfoResponse,
17 GetToolApprovalRequirementsRequest, GetToolApprovalRequirementsResponse, GetToolSchemasRequest,
18 GetToolSchemasResponse, HealthRequest, HealthResponse, HealthStatus, ListFilesRequest,
19 ListFilesResponse, ToolSchema as GrpcToolSchema, execute_tool_response::Result as ProtoResult,
20 remote_workspace_service_server::RemoteWorkspaceService as RemoteWorkspaceServiceServer,
21};
22use steer_proto::common::v1::{
23 BashResult as ProtoBashResult, ColumnRange as ProtoColumnRange, EditResult as ProtoEditResult,
24 FileContentResult as ProtoFileContentResult, FileEntry as ProtoFileEntry,
25 FileListResult as ProtoFileListResult, GlobResult as ProtoGlobResult,
26 SearchMatch as ProtoSearchMatch, SearchResult as ProtoSearchResult,
27 TodoListResult as ProtoTodoListResult, TodoWriteResult as ProtoTodoWriteResult,
28};
29
30use steer_grpc::grpc::conversions::{
31 convert_todo_item_to_proto, convert_todo_write_file_operation_to_proto,
32};
33
34pub struct RemoteWorkspaceService {
40 working_dir: PathBuf,
41 tools: Arc<HashMap<String, Box<dyn ExecutableTool>>>,
42 version: String,
43}
44
45impl RemoteWorkspaceService {
46 pub fn new(working_dir: PathBuf) -> Result<Self, ToolError> {
48 Self::with_tools(workspace_tools(), working_dir)
49 }
50
51 pub fn with_tools(
53 tools_list: Vec<Box<dyn ExecutableTool>>,
54 working_dir: PathBuf,
55 ) -> Result<Self, ToolError> {
56 let mut tools: HashMap<String, Box<dyn ExecutableTool>> = HashMap::new();
57
58 for tool in tools_list {
60 tools.insert(tool.name().to_string(), tool);
61 }
62
63 Ok(Self {
64 working_dir,
65 tools: Arc::new(tools),
66 version: env!("CARGO_PKG_VERSION").to_string(),
67 })
68 }
69
70 pub fn get_supported_tools(&self) -> Vec<String> {
72 self.tools.keys().cloned().collect()
73 }
74
75 fn tool_error_to_status(error: ToolError) -> Status {
77 match error {
78 ToolError::Cancelled(_) => Status::cancelled("Tool execution was cancelled"),
79 ToolError::UnknownTool(tool_name) => {
80 Status::not_found(format!("Unknown tool: {tool_name}"))
81 }
82 ToolError::InvalidParams(tool_name, message) => Status::invalid_argument(format!(
83 "Invalid parameters for tool {tool_name}: {message}"
84 )),
85 ToolError::Execution { tool_name, message } => {
86 Status::internal(format!("Tool {tool_name} execution failed: {message}"))
87 }
88 ToolError::Io { tool_name, message } => {
89 Status::internal(format!("IO error in tool {tool_name}: {message}"))
90 }
91 ToolError::DeniedByUser(tool_name) => {
92 Status::permission_denied(format!("Tool execution denied: {tool_name}"))
93 }
94 ToolError::Timeout(tool_name) => {
95 Status::deadline_exceeded(format!("Tool {tool_name} execution timed out"))
96 }
97 ToolError::InternalError(message) => {
98 Status::internal(format!("Internal error: {message}"))
99 }
100 }
101 }
102
103 fn tool_result_to_proto_result(
105 result: &steer_tools::result::ToolResult,
106 ) -> Option<ProtoResult> {
107 match result {
109 steer_tools::result::ToolResult::Search(search_result) => {
110 let proto_matches = search_result
111 .matches
112 .iter()
113 .map(|m| ProtoSearchMatch {
114 file_path: m.file_path.clone(),
115 line_number: m.line_number as u64,
116 line_content: m.line_content.clone(),
117 column_range: m.column_range.map(|(start, end)| ProtoColumnRange {
118 start: start as u64,
119 end: end as u64,
120 }),
121 })
122 .collect();
123
124 Some(ProtoResult::SearchResult(ProtoSearchResult {
125 matches: proto_matches,
126 total_files_searched: search_result.total_files_searched as u64,
127 search_completed: search_result.search_completed,
128 }))
129 }
130
131 steer_tools::result::ToolResult::FileList(file_list) => {
132 let proto_entries = file_list
133 .entries
134 .iter()
135 .map(|e| ProtoFileEntry {
136 path: e.path.clone(),
137 is_directory: e.is_directory,
138 size: e.size,
139 permissions: e.permissions.clone(),
140 })
141 .collect();
142
143 Some(ProtoResult::FileListResult(ProtoFileListResult {
144 entries: proto_entries,
145 base_path: file_list.base_path.clone(),
146 }))
147 }
148
149 steer_tools::result::ToolResult::FileContent(file_content) => {
150 Some(ProtoResult::FileContentResult(ProtoFileContentResult {
151 content: file_content.content.clone(),
152 file_path: file_content.file_path.clone(),
153 line_count: file_content.line_count as u64,
154 truncated: file_content.truncated,
155 }))
156 }
157
158 steer_tools::result::ToolResult::Edit(edit_result) => {
159 Some(ProtoResult::EditResult(ProtoEditResult {
160 file_path: edit_result.file_path.clone(),
161 changes_made: edit_result.changes_made as u64,
162 file_created: edit_result.file_created,
163 old_content: edit_result.old_content.clone(),
164 new_content: edit_result.new_content.clone(),
165 }))
166 }
167
168 steer_tools::result::ToolResult::Bash(bash_result) => {
169 Some(ProtoResult::BashResult(ProtoBashResult {
170 stdout: bash_result.stdout.clone(),
171 stderr: bash_result.stderr.clone(),
172 exit_code: bash_result.exit_code,
173 command: bash_result.command.clone(),
174 }))
175 }
176
177 steer_tools::result::ToolResult::Glob(glob_result) => {
178 Some(ProtoResult::GlobResult(ProtoGlobResult {
179 matches: glob_result.matches.clone(),
180 pattern: glob_result.pattern.clone(),
181 }))
182 }
183
184 steer_tools::result::ToolResult::TodoRead(todo_list) => {
185 let proto_todos = todo_list
186 .todos
187 .iter()
188 .map(convert_todo_item_to_proto)
189 .collect();
190
191 Some(ProtoResult::TodoListResult(ProtoTodoListResult {
192 todos: proto_todos,
193 }))
194 }
195
196 steer_tools::result::ToolResult::TodoWrite(todo_write_result) => {
197 let proto_todos = todo_write_result
198 .todos
199 .iter()
200 .map(convert_todo_item_to_proto)
201 .collect();
202
203 Some(ProtoResult::TodoWriteResult(ProtoTodoWriteResult {
204 todos: proto_todos,
205 operation: convert_todo_write_file_operation_to_proto(
206 &todo_write_result.operation,
207 ) as i32,
208 }))
209 }
210
211 steer_tools::result::ToolResult::Fetch(_) => {
212 None
214 }
215
216 steer_tools::result::ToolResult::Agent(_) => {
217 None
219 }
220
221 steer_tools::result::ToolResult::External(_) => {
222 None
224 }
225
226 steer_tools::result::ToolResult::Error(_) => {
227 None
229 }
230 }
231 }
232
233 fn get_directory_structure(&self) -> Result<String, std::io::Error> {
235 DirectoryStructureUtils::get_directory_structure(&self.working_dir, 3)
236 }
237
238 async fn get_git_status(&self) -> Result<String, std::io::Error> {
240 GitStatusUtils::get_git_status(&self.working_dir)
241 }
242}
243
244#[tonic::async_trait]
245impl RemoteWorkspaceServiceServer for RemoteWorkspaceService {
246 type ListFilesStream = ReceiverStream<Result<ListFilesResponse, Status>>;
247 async fn get_tool_schemas(
249 &self,
250 _request: Request<GetToolSchemasRequest>,
251 ) -> Result<Response<GetToolSchemasResponse>, Status> {
252 let mut schemas = Vec::new();
253
254 for (name, tool) in self.tools.iter() {
255 let input_schema = tool.input_schema();
256 let input_schema_json = serde_json::to_string(&input_schema)
257 .map_err(|e| Status::internal(format!("Failed to serialize schema: {e}")))?;
258
259 schemas.push(GrpcToolSchema {
260 name: name.clone(),
261 description: tool.description(),
262 input_schema_json,
263 });
264 }
265
266 Ok(Response::new(GetToolSchemasResponse { tools: schemas }))
267 }
268
269 async fn execute_tool(
271 &self,
272 request: Request<ExecuteToolRequest>,
273 ) -> Result<Response<ExecuteToolResponse>, Status> {
274 let start_time = std::time::Instant::now();
275 let req = request.into_inner();
276
277 let parameters: serde_json::Value =
279 serde_json::from_str(&req.parameters_json).map_err(|e| {
280 Status::invalid_argument(format!("Failed to parse tool parameters: {e}"))
281 })?;
282
283 let tool = self
285 .tools
286 .get(&req.tool_name)
287 .ok_or_else(|| Status::not_found(format!("Unknown tool: {}", req.tool_name)))?;
288
289 let cancellation_token = tokio_util::sync::CancellationToken::new();
293 let _guard = cancellation_token.clone().drop_guard();
294
295 let context = ExecutionContext::new(req.tool_call_id.clone())
297 .with_cancellation_token(cancellation_token);
298
299 let result = tool.run(parameters, &context).await;
300
301 let end_time = std::time::Instant::now();
302 let duration = end_time - start_time;
303
304 let response = match result {
306 Ok(tool_result) => {
307 let proto_result = Self::tool_result_to_proto_result(&tool_result);
309
310 ExecuteToolResponse {
311 success: true,
312 result: proto_result.or_else(|| {
313 Some(ProtoResult::StringResult(tool_result.llm_format()))
315 }),
316 error: String::new(),
317 started_at: Some(prost_types::Timestamp {
318 seconds: start_time.elapsed().as_secs() as i64,
319 nanos: 0,
320 }),
321 completed_at: Some(prost_types::Timestamp {
322 seconds: duration.as_secs() as i64,
323 nanos: duration.subsec_nanos() as i32,
324 }),
325 metadata: std::collections::HashMap::new(),
326 }
327 }
328 Err(error) => {
329 match &error {
332 ToolError::Cancelled(_) => {
333 return Err(Status::cancelled("Tool execution was cancelled"));
334 }
335 ToolError::UnknownTool(_) => {
336 return Err(Self::tool_error_to_status(error));
337 }
338 _ => ExecuteToolResponse {
339 success: false,
340 result: None,
341 error: error.to_string(),
342 started_at: Some(prost_types::Timestamp {
343 seconds: start_time.elapsed().as_secs() as i64,
344 nanos: 0,
345 }),
346 completed_at: Some(prost_types::Timestamp {
347 seconds: duration.as_secs() as i64,
348 nanos: duration.subsec_nanos() as i32,
349 }),
350 metadata: std::collections::HashMap::new(),
351 },
352 }
353 }
354 };
355
356 Ok(Response::new(response))
357 }
358
359 async fn get_agent_info(
361 &self,
362 _request: Request<GetAgentInfoRequest>,
363 ) -> Result<Response<GetAgentInfoResponse>, Status> {
364 let supported_tools = self.get_supported_tools();
365
366 let info = GetAgentInfoResponse {
367 version: self.version.clone(),
368 supported_tools,
369 metadata: std::collections::HashMap::from([
370 (
371 "hostname".to_string(),
372 gethostname::gethostname().to_string_lossy().to_string(),
373 ),
374 (
375 "working_directory".to_string(),
376 self.working_dir.to_string_lossy().to_string(),
377 ),
378 ]),
379 };
380
381 Ok(Response::new(info))
382 }
383
384 async fn health(
386 &self,
387 _request: Request<HealthRequest>,
388 ) -> Result<Response<HealthResponse>, Status> {
389 let response = HealthResponse {
391 status: HealthStatus::Serving as i32,
392 message: "Agent is healthy and ready to execute tools".to_string(),
393 details: std::collections::HashMap::from([(
394 "tool_count".to_string(),
395 self.get_supported_tools().len().to_string(),
396 )]),
397 };
398
399 Ok(Response::new(response))
400 }
401
402 async fn get_tool_approval_requirements(
404 &self,
405 request: Request<GetToolApprovalRequirementsRequest>,
406 ) -> Result<Response<GetToolApprovalRequirementsResponse>, Status> {
407 let req = request.into_inner();
408 let mut approval_requirements = std::collections::HashMap::new();
409
410 for tool_name in req.tool_names {
411 if let Some(tool) = self.tools.get(&tool_name) {
412 approval_requirements.insert(tool_name, tool.requires_approval());
413 } else {
414 }
417 }
418
419 Ok(Response::new(GetToolApprovalRequirementsResponse {
420 approval_requirements,
421 }))
422 }
423
424 async fn get_environment_info(
426 &self,
427 request: Request<crate::proto::GetEnvironmentInfoRequest>,
428 ) -> Result<Response<crate::proto::GetEnvironmentInfoResponse>, Status> {
429 let req = request.into_inner();
430
431 let working_directory = if let Some(dir) = req.working_directory {
433 dir
434 } else {
435 self.working_dir.to_string_lossy().to_string()
436 };
437
438 let is_git_repo = EnvironmentUtils::is_git_repo(Path::new(&working_directory));
440
441 let platform = EnvironmentUtils::get_platform().to_string();
443
444 let date = EnvironmentUtils::get_current_date();
446
447 let directory_structure = self.get_directory_structure().unwrap_or_else(|_| {
449 format!("Failed to read directory structure from {working_directory}")
450 });
451
452 let git_status = if is_git_repo {
454 self.get_git_status().await.ok()
455 } else {
456 None
457 };
458
459 let readme_content = EnvironmentUtils::read_readme(Path::new(&working_directory));
461
462 let claude_md_content = EnvironmentUtils::read_claude_md(Path::new(&working_directory));
464
465 let response = crate::proto::GetEnvironmentInfoResponse {
466 working_directory,
467 is_git_repo,
468 platform,
469 date,
470 directory_structure,
471 git_status,
472 readme_content,
473 claude_md_content,
474 };
475
476 Ok(Response::new(response))
477 }
478
479 async fn list_files(
481 &self,
482 request: Request<ListFilesRequest>,
483 ) -> Result<Response<Self::ListFilesStream>, Status> {
484 let req = request.into_inner();
485
486 let (tx, rx) = mpsc::channel(100);
488
489 tokio::spawn(async move {
491 let query = if req.query.is_empty() {
493 None
494 } else {
495 Some(req.query.as_str())
496 };
497 let max_results = if req.max_results > 0 {
498 Some(req.max_results as usize)
499 } else {
500 None
501 };
502
503 let files = match FileListingUtils::list_files(Path::new("."), query, max_results) {
504 Ok(files) => files,
505 Err(e) => {
506 tracing::error!("Error listing files: {}", e);
507 return;
508 }
509 };
510
511 for chunk in files.chunks(1000) {
513 let response = ListFilesResponse {
514 paths: chunk.to_vec(),
515 };
516
517 if let Err(e) = tx.send(Ok(response)).await {
519 tracing::debug!("Client cancelled file list stream: {}", e);
520 break;
521 }
522 }
523 });
524
525 Ok(Response::new(ReceiverStream::new(rx)))
530 }
531}