Skip to main content

steer_workspace_client/
lib.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::sync::RwLock;
5use tonic::transport::Channel;
6
7use steer_proto::remote_workspace::v1::{
8    ApplyEditsRequest as ProtoApplyEditsRequest, AstGrepRequest as ProtoAstGrepRequest,
9    EditOperation as ProtoEditOperation, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10    GlobRequest as ProtoGlobRequest, GrepRequest as ProtoGrepRequest,
11    ListDirectoryRequest as ProtoListDirectoryRequest, ListFilesRequest,
12    ReadFileRequest as ProtoReadFileRequest, WriteFileRequest as ProtoWriteFileRequest,
13    remote_workspace_service_client::RemoteWorkspaceServiceClient,
14};
15use steer_tools::result::{
16    EditResult, FileContentResult, FileEntry, FileListResult, GlobResult, SearchMatch, SearchResult,
17};
18use steer_workspace::{
19    ApplyEditsRequest, AstGrepRequest, EnvironmentInfo, GitCommitSummary, GitHead, GitStatus,
20    GitStatusEntry, GitStatusSummary, GlobRequest, GrepRequest, JjChange, JjChangeType,
21    JjCommitSummary, JjStatus, ListDirectoryRequest, ReadFileRequest, RemoteAuth, Result, VcsInfo,
22    VcsKind, VcsStatus, Workspace, WorkspaceError, WorkspaceMetadata, WorkspaceOpContext,
23    WorkspaceType, WriteFileRequest,
24};
25
26const GRPC_MAX_MESSAGE_SIZE_BYTES: usize = 32 * 1024 * 1024;
27
28fn convert_search_result(proto_result: steer_proto::common::v1::SearchResult) -> SearchResult {
29    let matches = proto_result
30        .matches
31        .into_iter()
32        .map(|m| SearchMatch {
33            file_path: m.file_path,
34            line_number: m.line_number as usize,
35            line_content: m.line_content,
36            column_range: m
37                .column_range
38                .map(|cr| (cr.start as usize, cr.end as usize)),
39        })
40        .collect();
41
42    SearchResult {
43        matches,
44        total_files_searched: proto_result.total_files_searched as usize,
45        search_completed: proto_result.search_completed,
46    }
47}
48
49fn convert_file_list_result(
50    proto_result: steer_proto::common::v1::FileListResult,
51) -> FileListResult {
52    let entries = proto_result
53        .entries
54        .into_iter()
55        .map(|e| FileEntry {
56            path: e.path,
57            is_directory: e.is_directory,
58            size: e.size,
59            permissions: e.permissions,
60        })
61        .collect();
62
63    FileListResult {
64        entries,
65        base_path: proto_result.base_path,
66    }
67}
68
69fn convert_file_content_result(
70    proto_result: steer_proto::common::v1::FileContentResult,
71) -> FileContentResult {
72    FileContentResult {
73        content: proto_result.content,
74        file_path: proto_result.file_path,
75        line_count: proto_result.line_count as usize,
76        truncated: proto_result.truncated,
77    }
78}
79
80fn convert_edit_result(proto_result: steer_proto::common::v1::EditResult) -> EditResult {
81    EditResult {
82        file_path: proto_result.file_path,
83        changes_made: proto_result.changes_made as usize,
84        file_created: proto_result.file_created,
85        old_content: proto_result.old_content,
86        new_content: proto_result.new_content,
87    }
88}
89
90fn convert_glob_result(proto_result: steer_proto::common::v1::GlobResult) -> GlobResult {
91    GlobResult {
92        matches: proto_result.matches,
93        pattern: proto_result.pattern,
94    }
95}
96
97/// Cached environment information with TTL
98#[derive(Debug, Clone)]
99struct CachedEnvironment {
100    pub info: EnvironmentInfo,
101    pub cached_at: std::time::Instant,
102    pub ttl: Duration,
103}
104
105impl CachedEnvironment {
106    pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
107        Self {
108            info,
109            cached_at: std::time::Instant::now(),
110            ttl,
111        }
112    }
113
114    pub fn is_expired(&self) -> bool {
115        self.cached_at.elapsed() > self.ttl
116    }
117}
118
119/// Remote workspace that executes tools and collects environment info via gRPC
120pub struct RemoteWorkspace {
121    client: RemoteWorkspaceServiceClient<Channel>,
122    environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
123    metadata: WorkspaceMetadata,
124    #[allow(dead_code)]
125    auth: Option<RemoteAuth>,
126}
127
128impl RemoteWorkspace {
129    pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
130        // Create gRPC client
131        let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
132            .await
133            .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?
134            .max_decoding_message_size(GRPC_MAX_MESSAGE_SIZE_BYTES)
135            .max_encoding_message_size(GRPC_MAX_MESSAGE_SIZE_BYTES);
136
137        let metadata = WorkspaceMetadata {
138            id: format!("remote:{address}"),
139            workspace_type: WorkspaceType::Remote,
140            location: address.clone(),
141        };
142
143        Ok(Self {
144            client,
145            environment_cache: Arc::new(RwLock::new(None)),
146            metadata,
147            auth,
148        })
149    }
150
151    /// Collect environment information from the remote workspace
152    async fn collect_environment(&self) -> Result<EnvironmentInfo> {
153        let mut client = self.client.clone();
154
155        let request = tonic::Request::new(GetEnvironmentInfoRequest {
156            working_directory: None, // Use remote default
157        });
158
159        let response = client
160            .get_environment_info(request)
161            .await
162            .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
163        let env_response = response.into_inner();
164
165        Self::convert_environment_response(env_response)
166    }
167
168    /// Convert gRPC response to EnvironmentInfo
169    fn convert_environment_response(
170        response: GetEnvironmentInfoResponse,
171    ) -> Result<EnvironmentInfo> {
172        use std::path::PathBuf;
173        use steer_proto::remote_workspace::v1::{
174            GitHeadKind as ProtoGitHeadKind, GitStatusSummary as ProtoGitStatusSummary,
175            JjChangeType as ProtoJjChangeType, VcsKind as ProtoVcsKind, vcs_info,
176        };
177
178        let vcs = response.vcs.and_then(|vcs| {
179            let kind = match ProtoVcsKind::try_from(vcs.kind).ok()? {
180                ProtoVcsKind::Git => VcsKind::Git,
181                ProtoVcsKind::Jj => VcsKind::Jj,
182                ProtoVcsKind::Unspecified => return None,
183            };
184
185            let status = match vcs.status {
186                Some(vcs_info::Status::GitStatus(status)) => {
187                    let head = status.head.and_then(|head| {
188                        let kind = ProtoGitHeadKind::try_from(head.kind).ok()?;
189                        match kind {
190                            ProtoGitHeadKind::Branch => {
191                                Some(GitHead::Branch(head.branch.unwrap_or_default()))
192                            }
193                            ProtoGitHeadKind::Detached => Some(GitHead::Detached),
194                            ProtoGitHeadKind::Unborn => Some(GitHead::Unborn),
195                            ProtoGitHeadKind::Unspecified => None,
196                        }
197                    });
198
199                    let entries = status
200                        .entries
201                        .into_iter()
202                        .filter_map(|entry| {
203                            let summary = match ProtoGitStatusSummary::try_from(entry.summary)
204                                .ok()?
205                            {
206                                ProtoGitStatusSummary::Added => GitStatusSummary::Added,
207                                ProtoGitStatusSummary::Removed => GitStatusSummary::Removed,
208                                ProtoGitStatusSummary::Modified => GitStatusSummary::Modified,
209                                ProtoGitStatusSummary::TypeChange => GitStatusSummary::TypeChange,
210                                ProtoGitStatusSummary::Renamed => GitStatusSummary::Renamed,
211                                ProtoGitStatusSummary::Copied => GitStatusSummary::Copied,
212                                ProtoGitStatusSummary::IntentToAdd => GitStatusSummary::IntentToAdd,
213                                ProtoGitStatusSummary::Conflict => GitStatusSummary::Conflict,
214                                ProtoGitStatusSummary::Unspecified => return None,
215                            };
216                            Some(GitStatusEntry {
217                                summary,
218                                path: entry.path,
219                            })
220                        })
221                        .collect();
222
223                    let recent_commits = status
224                        .recent_commits
225                        .into_iter()
226                        .map(|commit| GitCommitSummary {
227                            id: commit.id,
228                            summary: commit.summary,
229                        })
230                        .collect();
231
232                    VcsStatus::Git(GitStatus {
233                        head,
234                        entries,
235                        recent_commits,
236                        error: status.error,
237                    })
238                }
239                Some(vcs_info::Status::JjStatus(status)) => {
240                    let changes = status
241                        .changes
242                        .into_iter()
243                        .filter_map(|change| {
244                            let change_type =
245                                match ProtoJjChangeType::try_from(change.change_type).ok()? {
246                                    ProtoJjChangeType::Added => JjChangeType::Added,
247                                    ProtoJjChangeType::Removed => JjChangeType::Removed,
248                                    ProtoJjChangeType::Modified => JjChangeType::Modified,
249                                    ProtoJjChangeType::Unspecified => return None,
250                                };
251                            Some(JjChange {
252                                change_type,
253                                path: change.path,
254                            })
255                        })
256                        .collect();
257
258                    let working_copy = status.working_copy.map(|commit| JjCommitSummary {
259                        change_id: commit.change_id,
260                        commit_id: commit.commit_id,
261                        description: commit.description,
262                    });
263
264                    let parents = status
265                        .parents
266                        .into_iter()
267                        .map(|commit| JjCommitSummary {
268                            change_id: commit.change_id,
269                            commit_id: commit.commit_id,
270                            description: commit.description,
271                        })
272                        .collect();
273
274                    VcsStatus::Jj(JjStatus {
275                        changes,
276                        working_copy,
277                        parents,
278                        error: status.error,
279                    })
280                }
281                None => match kind {
282                    VcsKind::Git => {
283                        VcsStatus::Git(GitStatus::unavailable("missing git status".to_string()))
284                    }
285                    VcsKind::Jj => {
286                        VcsStatus::Jj(JjStatus::unavailable("missing jj status".to_string()))
287                    }
288                },
289            };
290
291            Some(VcsInfo {
292                kind,
293                root: PathBuf::from(vcs.root),
294                status,
295            })
296        });
297
298        Ok(EnvironmentInfo {
299            working_directory: PathBuf::from(response.working_directory),
300            vcs,
301            platform: response.platform,
302            date: response.date,
303            directory_structure: response.directory_structure,
304            readme_content: response.readme_content,
305            memory_file_name: response.memory_file_name,
306            memory_file_content: response.memory_file_content,
307        })
308    }
309}
310
311impl std::fmt::Debug for RemoteWorkspace {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.debug_struct("RemoteWorkspace")
314            .field("metadata", &self.metadata)
315            .field("auth", &self.auth)
316            .finish_non_exhaustive()
317    }
318}
319
320#[async_trait]
321impl Workspace for RemoteWorkspace {
322    async fn environment(&self) -> Result<EnvironmentInfo> {
323        let mut cache = self.environment_cache.write().await;
324
325        // Check if we have valid cached data
326        if let Some(cached) = cache.as_ref()
327            && !cached.is_expired()
328        {
329            return Ok(cached.info.clone());
330        }
331
332        // Collect fresh environment info from remote
333        let env_info = self.collect_environment().await?;
334
335        // Cache it with 10 minute TTL (longer than local since remote calls are expensive)
336        *cache = Some(CachedEnvironment::new(
337            env_info.clone(),
338            Duration::from_secs(600), // 10 minutes
339        ));
340
341        Ok(env_info)
342    }
343
344    fn metadata(&self) -> WorkspaceMetadata {
345        self.metadata.clone()
346    }
347
348    async fn invalidate_environment_cache(&self) {
349        let mut cache = self.environment_cache.write().await;
350        *cache = None;
351    }
352
353    async fn list_files(
354        &self,
355        query: Option<&str>,
356        max_results: Option<usize>,
357    ) -> Result<Vec<String>> {
358        let mut client = self.client.clone();
359
360        let request = tonic::Request::new(ListFilesRequest {
361            query: query.unwrap_or("").to_string(),
362            max_results: max_results.unwrap_or(0) as u32,
363        });
364
365        let mut stream = client
366            .list_files(request)
367            .await
368            .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
369            .into_inner();
370        let mut all_files = Vec::new();
371
372        // Collect all files from the stream
373        while let Some(response) = stream
374            .message()
375            .await
376            .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
377        {
378            all_files.extend(response.paths);
379        }
380
381        Ok(all_files)
382    }
383
384    fn working_directory(&self) -> &std::path::Path {
385        // For remote workspaces, we return a placeholder path
386        // The actual working directory is on the remote machine
387        std::path::Path::new("/remote")
388    }
389
390    async fn read_file(
391        &self,
392        request: ReadFileRequest,
393        _ctx: &WorkspaceOpContext,
394    ) -> Result<FileContentResult> {
395        let mut client = self.client.clone();
396        let request = tonic::Request::new(ProtoReadFileRequest {
397            file_path: request.file_path,
398            offset: request.offset,
399            limit: request.limit,
400            raw: request.raw,
401        });
402        let response = client
403            .read_file(request)
404            .await
405            .map_err(|e| WorkspaceError::Status(format!("Failed to read file: {e}")))?
406            .into_inner();
407        Ok(convert_file_content_result(response))
408    }
409
410    async fn list_directory(
411        &self,
412        request: ListDirectoryRequest,
413        _ctx: &WorkspaceOpContext,
414    ) -> Result<FileListResult> {
415        let mut client = self.client.clone();
416        let request = tonic::Request::new(ProtoListDirectoryRequest {
417            path: request.path,
418            ignore: request.ignore.unwrap_or_default(),
419        });
420        let response = client
421            .list_directory(request)
422            .await
423            .map_err(|e| WorkspaceError::Status(format!("Failed to list directory: {e}")))?
424            .into_inner();
425        Ok(convert_file_list_result(response))
426    }
427
428    async fn glob(&self, request: GlobRequest, _ctx: &WorkspaceOpContext) -> Result<GlobResult> {
429        let mut client = self.client.clone();
430        let request = tonic::Request::new(ProtoGlobRequest {
431            pattern: request.pattern,
432            path: request.path,
433        });
434        let response = client
435            .glob(request)
436            .await
437            .map_err(|e| WorkspaceError::Status(format!("Failed to glob: {e}")))?
438            .into_inner();
439        Ok(convert_glob_result(response))
440    }
441
442    async fn grep(&self, request: GrepRequest, _ctx: &WorkspaceOpContext) -> Result<SearchResult> {
443        let mut client = self.client.clone();
444        let request = tonic::Request::new(ProtoGrepRequest {
445            pattern: request.pattern,
446            include: request.include,
447            path: request.path,
448        });
449        let response = client
450            .grep(request)
451            .await
452            .map_err(|e| WorkspaceError::Status(format!("Failed to grep: {e}")))?
453            .into_inner();
454        Ok(convert_search_result(response))
455    }
456
457    async fn astgrep(
458        &self,
459        request: AstGrepRequest,
460        _ctx: &WorkspaceOpContext,
461    ) -> Result<SearchResult> {
462        let mut client = self.client.clone();
463        let request = tonic::Request::new(ProtoAstGrepRequest {
464            pattern: request.pattern,
465            lang: request.lang,
466            include: request.include,
467            exclude: request.exclude,
468            path: request.path,
469        });
470        let response = client
471            .ast_grep(request)
472            .await
473            .map_err(|e| WorkspaceError::Status(format!("Failed to astgrep: {e}")))?
474            .into_inner();
475        Ok(convert_search_result(response))
476    }
477
478    async fn apply_edits(
479        &self,
480        request: ApplyEditsRequest,
481        _ctx: &WorkspaceOpContext,
482    ) -> Result<EditResult> {
483        let mut client = self.client.clone();
484        let edits = request
485            .edits
486            .into_iter()
487            .map(|edit| ProtoEditOperation {
488                old_string: edit.old_string,
489                new_string: edit.new_string,
490            })
491            .collect();
492        let request = tonic::Request::new(ProtoApplyEditsRequest {
493            file_path: request.file_path,
494            edits,
495        });
496        let response = client
497            .apply_edits(request)
498            .await
499            .map_err(|e| WorkspaceError::Status(format!("Failed to apply edits: {e}")))?
500            .into_inner();
501        Ok(convert_edit_result(response))
502    }
503
504    async fn write_file(
505        &self,
506        request: WriteFileRequest,
507        _ctx: &WorkspaceOpContext,
508    ) -> Result<EditResult> {
509        let mut client = self.client.clone();
510        let request = tonic::Request::new(ProtoWriteFileRequest {
511            file_path: request.file_path,
512            content: request.content,
513        });
514        let response = client
515            .write_file(request)
516            .await
517            .map_err(|e| WorkspaceError::Status(format!("Failed to write file: {e}")))?
518            .into_inner();
519        Ok(convert_edit_result(response))
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use steer_workspace::LlmStatus;
527
528    #[tokio::test]
529    async fn test_remote_workspace_metadata() {
530        let address = "localhost:50051".to_string();
531
532        // This test will fail if no remote backend is running, but we can test metadata creation
533        let metadata = WorkspaceMetadata {
534            id: format!("remote:{address}"),
535            workspace_type: WorkspaceType::Remote,
536            location: address.clone(),
537        };
538
539        assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
540        assert_eq!(metadata.location, address);
541    }
542
543    #[test]
544    fn test_convert_environment_response() {
545        use std::path::PathBuf;
546
547        let response = GetEnvironmentInfoResponse {
548            working_directory: "/home/user/project".to_string(),
549            vcs: Some(steer_proto::remote_workspace::v1::VcsInfo {
550                kind: steer_proto::remote_workspace::v1::VcsKind::Git as i32,
551                root: "/home/user/project".to_string(),
552                status: Some(
553                    steer_proto::remote_workspace::v1::vcs_info::Status::GitStatus(
554                        steer_proto::remote_workspace::v1::GitStatus {
555                            head: Some(steer_proto::remote_workspace::v1::GitHead {
556                                kind: steer_proto::remote_workspace::v1::GitHeadKind::Branch as i32,
557                                branch: Some("main".to_string()),
558                            }),
559                            entries: Vec::new(),
560                            recent_commits: Vec::new(),
561                            error: None,
562                        },
563                    ),
564                ),
565            }),
566            platform: "linux".to_string(),
567            date: "2025-06-17".to_string(),
568            directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
569            readme_content: Some("# My Project".to_string()),
570            memory_file_content: None,
571            memory_file_name: None,
572        };
573
574        // Test the static conversion function directly
575        let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
576
577        assert_eq!(
578            env_info.working_directory,
579            PathBuf::from("/home/user/project")
580        );
581        assert!(matches!(
582            env_info.vcs,
583            Some(VcsInfo {
584                kind: VcsKind::Git,
585                ..
586            })
587        ));
588        assert_eq!(env_info.platform, "linux");
589        assert_eq!(env_info.date, "2025-06-17");
590        assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
591        assert_eq!(
592            env_info
593                .vcs
594                .as_ref()
595                .map(|vcs| vcs.status.as_llm_string()),
596            Some(
597                "Current branch: main\n\nStatus:\nWorking tree clean\n\nRecent commits:\n<no commits>\n"
598                    .to_string()
599            )
600        );
601        assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
602        assert_eq!(env_info.memory_file_content, None);
603        assert_eq!(env_info.memory_file_name, None);
604    }
605}