Skip to main content

steer_workspace_client/
lib.rs

1use async_trait::async_trait;
2use std::sync::Arc;
3use std::time::Duration;
4use tokio::sync::RwLock;
5use tonic::transport::Channel;
6
7use steer_proto::remote_workspace::v1::{
8    ApplyEditsRequest as ProtoApplyEditsRequest, AstGrepRequest as ProtoAstGrepRequest,
9    EditOperation as ProtoEditOperation, GetEnvironmentInfoRequest, GetEnvironmentInfoResponse,
10    GlobRequest as ProtoGlobRequest, GrepRequest as ProtoGrepRequest,
11    ListDirectoryRequest as ProtoListDirectoryRequest, ListFilesRequest,
12    ReadFileRequest as ProtoReadFileRequest, WriteFileRequest as ProtoWriteFileRequest,
13    remote_workspace_service_client::RemoteWorkspaceServiceClient,
14};
15use steer_tools::result::{
16    EditResult, FileContentResult, FileEntry, FileListResult, GlobResult, SearchMatch, SearchResult,
17};
18use steer_workspace::{
19    ApplyEditsRequest, AstGrepRequest, EnvironmentInfo, GitCommitSummary, GitHead, GitStatus,
20    GitStatusEntry, GitStatusSummary, GlobRequest, GrepRequest, JjChange, JjChangeType,
21    JjCommitSummary, JjStatus, ListDirectoryRequest, ReadFileRequest, RemoteAuth, Result, VcsInfo,
22    VcsKind, VcsStatus, Workspace, WorkspaceError, WorkspaceMetadata, WorkspaceOpContext,
23    WorkspaceType, WriteFileRequest,
24};
25
26fn convert_search_result(proto_result: steer_proto::common::v1::SearchResult) -> SearchResult {
27    let matches = proto_result
28        .matches
29        .into_iter()
30        .map(|m| SearchMatch {
31            file_path: m.file_path,
32            line_number: m.line_number as usize,
33            line_content: m.line_content,
34            column_range: m
35                .column_range
36                .map(|cr| (cr.start as usize, cr.end as usize)),
37        })
38        .collect();
39
40    SearchResult {
41        matches,
42        total_files_searched: proto_result.total_files_searched as usize,
43        search_completed: proto_result.search_completed,
44    }
45}
46
47fn convert_file_list_result(
48    proto_result: steer_proto::common::v1::FileListResult,
49) -> FileListResult {
50    let entries = proto_result
51        .entries
52        .into_iter()
53        .map(|e| FileEntry {
54            path: e.path,
55            is_directory: e.is_directory,
56            size: e.size,
57            permissions: e.permissions,
58        })
59        .collect();
60
61    FileListResult {
62        entries,
63        base_path: proto_result.base_path,
64    }
65}
66
67fn convert_file_content_result(
68    proto_result: steer_proto::common::v1::FileContentResult,
69) -> FileContentResult {
70    FileContentResult {
71        content: proto_result.content,
72        file_path: proto_result.file_path,
73        line_count: proto_result.line_count as usize,
74        truncated: proto_result.truncated,
75    }
76}
77
78fn convert_edit_result(proto_result: steer_proto::common::v1::EditResult) -> EditResult {
79    EditResult {
80        file_path: proto_result.file_path,
81        changes_made: proto_result.changes_made as usize,
82        file_created: proto_result.file_created,
83        old_content: proto_result.old_content,
84        new_content: proto_result.new_content,
85    }
86}
87
88fn convert_glob_result(proto_result: steer_proto::common::v1::GlobResult) -> GlobResult {
89    GlobResult {
90        matches: proto_result.matches,
91        pattern: proto_result.pattern,
92    }
93}
94
95/// Cached environment information with TTL
96#[derive(Debug, Clone)]
97struct CachedEnvironment {
98    pub info: EnvironmentInfo,
99    pub cached_at: std::time::Instant,
100    pub ttl: Duration,
101}
102
103impl CachedEnvironment {
104    pub fn new(info: EnvironmentInfo, ttl: Duration) -> Self {
105        Self {
106            info,
107            cached_at: std::time::Instant::now(),
108            ttl,
109        }
110    }
111
112    pub fn is_expired(&self) -> bool {
113        self.cached_at.elapsed() > self.ttl
114    }
115}
116
117/// Remote workspace that executes tools and collects environment info via gRPC
118pub struct RemoteWorkspace {
119    client: RemoteWorkspaceServiceClient<Channel>,
120    environment_cache: Arc<RwLock<Option<CachedEnvironment>>>,
121    metadata: WorkspaceMetadata,
122    #[allow(dead_code)]
123    auth: Option<RemoteAuth>,
124}
125
126impl RemoteWorkspace {
127    pub async fn new(address: String, auth: Option<RemoteAuth>) -> Result<Self> {
128        // Create gRPC client
129        let client = RemoteWorkspaceServiceClient::connect(format!("http://{address}"))
130            .await
131            .map_err(|e| WorkspaceError::Transport(format!("Failed to connect: {e}")))?;
132
133        let metadata = WorkspaceMetadata {
134            id: format!("remote:{address}"),
135            workspace_type: WorkspaceType::Remote,
136            location: address.clone(),
137        };
138
139        Ok(Self {
140            client,
141            environment_cache: Arc::new(RwLock::new(None)),
142            metadata,
143            auth,
144        })
145    }
146
147    /// Collect environment information from the remote workspace
148    async fn collect_environment(&self) -> Result<EnvironmentInfo> {
149        let mut client = self.client.clone();
150
151        let request = tonic::Request::new(GetEnvironmentInfoRequest {
152            working_directory: None, // Use remote default
153        });
154
155        let response = client
156            .get_environment_info(request)
157            .await
158            .map_err(|e| WorkspaceError::Status(format!("Failed to get environment info: {e}")))?;
159        let env_response = response.into_inner();
160
161        Self::convert_environment_response(env_response)
162    }
163
164    /// Convert gRPC response to EnvironmentInfo
165    fn convert_environment_response(
166        response: GetEnvironmentInfoResponse,
167    ) -> Result<EnvironmentInfo> {
168        use std::path::PathBuf;
169        use steer_proto::remote_workspace::v1::{
170            GitHeadKind as ProtoGitHeadKind, GitStatusSummary as ProtoGitStatusSummary,
171            JjChangeType as ProtoJjChangeType, VcsKind as ProtoVcsKind, vcs_info,
172        };
173
174        let vcs = response.vcs.and_then(|vcs| {
175            let kind = match ProtoVcsKind::try_from(vcs.kind).ok()? {
176                ProtoVcsKind::Git => VcsKind::Git,
177                ProtoVcsKind::Jj => VcsKind::Jj,
178                ProtoVcsKind::Unspecified => return None,
179            };
180
181            let status = match vcs.status {
182                Some(vcs_info::Status::GitStatus(status)) => {
183                    let head = status.head.and_then(|head| {
184                        let kind = ProtoGitHeadKind::try_from(head.kind).ok()?;
185                        match kind {
186                            ProtoGitHeadKind::Branch => {
187                                Some(GitHead::Branch(head.branch.unwrap_or_default()))
188                            }
189                            ProtoGitHeadKind::Detached => Some(GitHead::Detached),
190                            ProtoGitHeadKind::Unborn => Some(GitHead::Unborn),
191                            ProtoGitHeadKind::Unspecified => None,
192                        }
193                    });
194
195                    let entries = status
196                        .entries
197                        .into_iter()
198                        .filter_map(|entry| {
199                            let summary = match ProtoGitStatusSummary::try_from(entry.summary)
200                                .ok()?
201                            {
202                                ProtoGitStatusSummary::Added => GitStatusSummary::Added,
203                                ProtoGitStatusSummary::Removed => GitStatusSummary::Removed,
204                                ProtoGitStatusSummary::Modified => GitStatusSummary::Modified,
205                                ProtoGitStatusSummary::TypeChange => GitStatusSummary::TypeChange,
206                                ProtoGitStatusSummary::Renamed => GitStatusSummary::Renamed,
207                                ProtoGitStatusSummary::Copied => GitStatusSummary::Copied,
208                                ProtoGitStatusSummary::IntentToAdd => GitStatusSummary::IntentToAdd,
209                                ProtoGitStatusSummary::Conflict => GitStatusSummary::Conflict,
210                                ProtoGitStatusSummary::Unspecified => return None,
211                            };
212                            Some(GitStatusEntry {
213                                summary,
214                                path: entry.path,
215                            })
216                        })
217                        .collect();
218
219                    let recent_commits = status
220                        .recent_commits
221                        .into_iter()
222                        .map(|commit| GitCommitSummary {
223                            id: commit.id,
224                            summary: commit.summary,
225                        })
226                        .collect();
227
228                    VcsStatus::Git(GitStatus {
229                        head,
230                        entries,
231                        recent_commits,
232                        error: status.error,
233                    })
234                }
235                Some(vcs_info::Status::JjStatus(status)) => {
236                    let changes = status
237                        .changes
238                        .into_iter()
239                        .filter_map(|change| {
240                            let change_type =
241                                match ProtoJjChangeType::try_from(change.change_type).ok()? {
242                                    ProtoJjChangeType::Added => JjChangeType::Added,
243                                    ProtoJjChangeType::Removed => JjChangeType::Removed,
244                                    ProtoJjChangeType::Modified => JjChangeType::Modified,
245                                    ProtoJjChangeType::Unspecified => return None,
246                                };
247                            Some(JjChange {
248                                change_type,
249                                path: change.path,
250                            })
251                        })
252                        .collect();
253
254                    let working_copy = status.working_copy.map(|commit| JjCommitSummary {
255                        change_id: commit.change_id,
256                        commit_id: commit.commit_id,
257                        description: commit.description,
258                    });
259
260                    let parents = status
261                        .parents
262                        .into_iter()
263                        .map(|commit| JjCommitSummary {
264                            change_id: commit.change_id,
265                            commit_id: commit.commit_id,
266                            description: commit.description,
267                        })
268                        .collect();
269
270                    VcsStatus::Jj(JjStatus {
271                        changes,
272                        working_copy,
273                        parents,
274                        error: status.error,
275                    })
276                }
277                None => match kind {
278                    VcsKind::Git => {
279                        VcsStatus::Git(GitStatus::unavailable("missing git status".to_string()))
280                    }
281                    VcsKind::Jj => {
282                        VcsStatus::Jj(JjStatus::unavailable("missing jj status".to_string()))
283                    }
284                },
285            };
286
287            Some(VcsInfo {
288                kind,
289                root: PathBuf::from(vcs.root),
290                status,
291            })
292        });
293
294        Ok(EnvironmentInfo {
295            working_directory: PathBuf::from(response.working_directory),
296            vcs,
297            platform: response.platform,
298            date: response.date,
299            directory_structure: response.directory_structure,
300            readme_content: response.readme_content,
301            memory_file_name: response.memory_file_name,
302            memory_file_content: response.memory_file_content,
303        })
304    }
305}
306
307impl std::fmt::Debug for RemoteWorkspace {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        f.debug_struct("RemoteWorkspace")
310            .field("metadata", &self.metadata)
311            .field("auth", &self.auth)
312            .finish_non_exhaustive()
313    }
314}
315
316#[async_trait]
317impl Workspace for RemoteWorkspace {
318    async fn environment(&self) -> Result<EnvironmentInfo> {
319        let mut cache = self.environment_cache.write().await;
320
321        // Check if we have valid cached data
322        if let Some(cached) = cache.as_ref()
323            && !cached.is_expired()
324        {
325            return Ok(cached.info.clone());
326        }
327
328        // Collect fresh environment info from remote
329        let env_info = self.collect_environment().await?;
330
331        // Cache it with 10 minute TTL (longer than local since remote calls are expensive)
332        *cache = Some(CachedEnvironment::new(
333            env_info.clone(),
334            Duration::from_secs(600), // 10 minutes
335        ));
336
337        Ok(env_info)
338    }
339
340    fn metadata(&self) -> WorkspaceMetadata {
341        self.metadata.clone()
342    }
343
344    async fn invalidate_environment_cache(&self) {
345        let mut cache = self.environment_cache.write().await;
346        *cache = None;
347    }
348
349    async fn list_files(
350        &self,
351        query: Option<&str>,
352        max_results: Option<usize>,
353    ) -> Result<Vec<String>> {
354        let mut client = self.client.clone();
355
356        let request = tonic::Request::new(ListFilesRequest {
357            query: query.unwrap_or("").to_string(),
358            max_results: max_results.unwrap_or(0) as u32,
359        });
360
361        let mut stream = client
362            .list_files(request)
363            .await
364            .map_err(|e| WorkspaceError::Status(format!("Failed to list files: {e}")))?
365            .into_inner();
366        let mut all_files = Vec::new();
367
368        // Collect all files from the stream
369        while let Some(response) = stream
370            .message()
371            .await
372            .map_err(|e| WorkspaceError::Status(format!("Stream error: {e}")))?
373        {
374            all_files.extend(response.paths);
375        }
376
377        Ok(all_files)
378    }
379
380    fn working_directory(&self) -> &std::path::Path {
381        // For remote workspaces, we return a placeholder path
382        // The actual working directory is on the remote machine
383        std::path::Path::new("/remote")
384    }
385
386    async fn read_file(
387        &self,
388        request: ReadFileRequest,
389        _ctx: &WorkspaceOpContext,
390    ) -> Result<FileContentResult> {
391        let mut client = self.client.clone();
392        let request = tonic::Request::new(ProtoReadFileRequest {
393            file_path: request.file_path,
394            offset: request.offset,
395            limit: request.limit,
396        });
397        let response = client
398            .read_file(request)
399            .await
400            .map_err(|e| WorkspaceError::Status(format!("Failed to read file: {e}")))?
401            .into_inner();
402        Ok(convert_file_content_result(response))
403    }
404
405    async fn list_directory(
406        &self,
407        request: ListDirectoryRequest,
408        _ctx: &WorkspaceOpContext,
409    ) -> Result<FileListResult> {
410        let mut client = self.client.clone();
411        let request = tonic::Request::new(ProtoListDirectoryRequest {
412            path: request.path,
413            ignore: request.ignore.unwrap_or_default(),
414        });
415        let response = client
416            .list_directory(request)
417            .await
418            .map_err(|e| WorkspaceError::Status(format!("Failed to list directory: {e}")))?
419            .into_inner();
420        Ok(convert_file_list_result(response))
421    }
422
423    async fn glob(&self, request: GlobRequest, _ctx: &WorkspaceOpContext) -> Result<GlobResult> {
424        let mut client = self.client.clone();
425        let request = tonic::Request::new(ProtoGlobRequest {
426            pattern: request.pattern,
427            path: request.path,
428        });
429        let response = client
430            .glob(request)
431            .await
432            .map_err(|e| WorkspaceError::Status(format!("Failed to glob: {e}")))?
433            .into_inner();
434        Ok(convert_glob_result(response))
435    }
436
437    async fn grep(&self, request: GrepRequest, _ctx: &WorkspaceOpContext) -> Result<SearchResult> {
438        let mut client = self.client.clone();
439        let request = tonic::Request::new(ProtoGrepRequest {
440            pattern: request.pattern,
441            include: request.include,
442            path: request.path,
443        });
444        let response = client
445            .grep(request)
446            .await
447            .map_err(|e| WorkspaceError::Status(format!("Failed to grep: {e}")))?
448            .into_inner();
449        Ok(convert_search_result(response))
450    }
451
452    async fn astgrep(
453        &self,
454        request: AstGrepRequest,
455        _ctx: &WorkspaceOpContext,
456    ) -> Result<SearchResult> {
457        let mut client = self.client.clone();
458        let request = tonic::Request::new(ProtoAstGrepRequest {
459            pattern: request.pattern,
460            lang: request.lang,
461            include: request.include,
462            exclude: request.exclude,
463            path: request.path,
464        });
465        let response = client
466            .ast_grep(request)
467            .await
468            .map_err(|e| WorkspaceError::Status(format!("Failed to astgrep: {e}")))?
469            .into_inner();
470        Ok(convert_search_result(response))
471    }
472
473    async fn apply_edits(
474        &self,
475        request: ApplyEditsRequest,
476        _ctx: &WorkspaceOpContext,
477    ) -> Result<EditResult> {
478        let mut client = self.client.clone();
479        let edits = request
480            .edits
481            .into_iter()
482            .map(|edit| ProtoEditOperation {
483                old_string: edit.old_string,
484                new_string: edit.new_string,
485            })
486            .collect();
487        let request = tonic::Request::new(ProtoApplyEditsRequest {
488            file_path: request.file_path,
489            edits,
490        });
491        let response = client
492            .apply_edits(request)
493            .await
494            .map_err(|e| WorkspaceError::Status(format!("Failed to apply edits: {e}")))?
495            .into_inner();
496        Ok(convert_edit_result(response))
497    }
498
499    async fn write_file(
500        &self,
501        request: WriteFileRequest,
502        _ctx: &WorkspaceOpContext,
503    ) -> Result<EditResult> {
504        let mut client = self.client.clone();
505        let request = tonic::Request::new(ProtoWriteFileRequest {
506            file_path: request.file_path,
507            content: request.content,
508        });
509        let response = client
510            .write_file(request)
511            .await
512            .map_err(|e| WorkspaceError::Status(format!("Failed to write file: {e}")))?
513            .into_inner();
514        Ok(convert_edit_result(response))
515    }
516}
517
518#[cfg(test)]
519mod tests {
520    use super::*;
521    use steer_workspace::LlmStatus;
522
523    #[tokio::test]
524    async fn test_remote_workspace_metadata() {
525        let address = "localhost:50051".to_string();
526
527        // This test will fail if no remote backend is running, but we can test metadata creation
528        let metadata = WorkspaceMetadata {
529            id: format!("remote:{address}"),
530            workspace_type: WorkspaceType::Remote,
531            location: address.clone(),
532        };
533
534        assert!(matches!(metadata.workspace_type, WorkspaceType::Remote));
535        assert_eq!(metadata.location, address);
536    }
537
538    #[test]
539    fn test_convert_environment_response() {
540        use std::path::PathBuf;
541
542        let response = GetEnvironmentInfoResponse {
543            working_directory: "/home/user/project".to_string(),
544            vcs: Some(steer_proto::remote_workspace::v1::VcsInfo {
545                kind: steer_proto::remote_workspace::v1::VcsKind::Git as i32,
546                root: "/home/user/project".to_string(),
547                status: Some(
548                    steer_proto::remote_workspace::v1::vcs_info::Status::GitStatus(
549                        steer_proto::remote_workspace::v1::GitStatus {
550                            head: Some(steer_proto::remote_workspace::v1::GitHead {
551                                kind: steer_proto::remote_workspace::v1::GitHeadKind::Branch as i32,
552                                branch: Some("main".to_string()),
553                            }),
554                            entries: Vec::new(),
555                            recent_commits: Vec::new(),
556                            error: None,
557                        },
558                    ),
559                ),
560            }),
561            platform: "linux".to_string(),
562            date: "2025-06-17".to_string(),
563            directory_structure: "project/\nsrc/\nmain.rs\n".to_string(),
564            readme_content: Some("# My Project".to_string()),
565            memory_file_content: None,
566            memory_file_name: None,
567        };
568
569        // Test the static conversion function directly
570        let env_info = RemoteWorkspace::convert_environment_response(response).unwrap();
571
572        assert_eq!(
573            env_info.working_directory,
574            PathBuf::from("/home/user/project")
575        );
576        assert!(matches!(
577            env_info.vcs,
578            Some(VcsInfo {
579                kind: VcsKind::Git,
580                ..
581            })
582        ));
583        assert_eq!(env_info.platform, "linux");
584        assert_eq!(env_info.date, "2025-06-17");
585        assert_eq!(env_info.directory_structure, "project/\nsrc/\nmain.rs\n");
586        assert_eq!(
587            env_info
588                .vcs
589                .as_ref()
590                .map(|vcs| vcs.status.as_llm_string()),
591            Some(
592                "Current branch: main\n\nStatus:\nWorking tree clean\n\nRecent commits:\n<no commits>\n"
593                    .to_string()
594            )
595        );
596        assert_eq!(env_info.readme_content, Some("# My Project".to_string()));
597        assert_eq!(env_info.memory_file_content, None);
598        assert_eq!(env_info.memory_file_name, None);
599    }
600}