Skip to main content

vdsl_sync/infra/
shell.rs

1//! Remote shell abstraction for executing commands on different hosts.
2//!
3//! - [`LocalShell`]: runs via `tokio::process::Command` on the local machine
4//! - `PodShell`: runs via RunPod exec API on a GPU pod (downstream crate)
5//! - `SshShell`: runs via SSH (future)
6//!
7//! [`StorageBackend`](super::backend::StorageBackend) implementations compose
8//! a `RemoteShell` to run transfer commands (rclone, rsync, etc.) on the
9//! appropriate host.
10
11use async_trait::async_trait;
12
13use crate::infra::error::InfraError;
14
15/// Output from a shell command execution.
16#[derive(Debug, Clone)]
17pub struct ShellOutput {
18    pub stdout: String,
19    pub stderr: String,
20    pub success: bool,
21    pub exit_code: Option<i32>,
22}
23
24/// Per-file inspection result from batch_inspect.
25#[derive(Debug, Clone)]
26pub struct FileInspection {
27    /// Relative path (same key as input).
28    pub relative_path: String,
29    /// SHA-256 hex hash of the file content.
30    pub sha256: String,
31    /// File size in bytes.
32    pub size: u64,
33}
34
35/// Abstract shell for executing commands on a location's host.
36#[async_trait]
37pub trait RemoteShell: Send + Sync {
38    /// Execute a command on this host.
39    ///
40    /// `args[0]` is the program name, `args[1..]` are arguments.
41    async fn exec(
42        &self,
43        args: &[&str],
44        timeout_secs: Option<u64>,
45    ) -> Result<ShellOutput, InfraError>;
46
47    /// Execute a shell script on this host.
48    ///
49    /// Default: `exec(&["sh", "-c", script])`.
50    /// Remote shells may override to use file-based transfer (SCP)
51    /// to avoid shell escaping issues with SSH.
52    async fn exec_script(
53        &self,
54        script: &str,
55        timeout_secs: Option<u64>,
56    ) -> Result<ShellOutput, InfraError> {
57        self.exec(&["sh", "-c", script], timeout_secs).await
58    }
59
60    /// Batch inspect files: get sha256 + size for ALL paths in one exec call.
61    ///
62    /// Constructs a single shell script that processes every file in the list
63    /// and outputs `<sha256> <size> <relative_path>` per line. Parsed on return.
64    ///
65    /// Timeout scales with file count: base 30s + 2s per file.
66    async fn batch_inspect(
67        &self,
68        root: &str,
69        relative_paths: &[String],
70    ) -> Result<Vec<FileInspection>, InfraError> {
71        if relative_paths.is_empty() {
72            return Ok(Vec::new());
73        }
74
75        // Build heredoc file list embedded in a single sh -c script.
76        // Each file is read line-by-line, sha256sum + stat in one pass.
77        let mut script = format!(
78            "cd '{}' && while IFS= read -r f; do \
79             h=$(sha256sum \"$f\" 2>/dev/null | cut -d' ' -f1); \
80             s=$(stat --format=%s \"$f\" 2>/dev/null || echo 0); \
81             [ -n \"$h\" ] && printf '%s %s %s\\n' \"$h\" \"$s\" \"$f\"; \
82             done <<'__VDSL_FILELIST__'\n",
83            root.replace('\'', "'\\''")
84        );
85        for rel in relative_paths {
86            script.push_str(rel);
87            script.push('\n');
88        }
89        script.push_str("__VDSL_FILELIST__");
90
91        let timeout = 30 + (relative_paths.len() as u64 * 2);
92        let output = self.exec(&["sh", "-c", &script], Some(timeout)).await?;
93
94        if !output.success {
95            return Err(InfraError::Transfer {
96                reason: format!("batch_inspect failed: {}", output.stderr.trim()),
97            });
98        }
99
100        let mut results = Vec::with_capacity(relative_paths.len());
101        for line in output.stdout.lines() {
102            // Format: <sha256_hex> <size> <relative_path>
103            let mut parts = line.splitn(3, ' ');
104            let sha256 = match parts.next() {
105                Some(h) if h.len() == 64 => h.to_string(),
106                _ => continue,
107            };
108            let size = parts
109                .next()
110                .and_then(|s| s.parse::<u64>().ok())
111                .unwrap_or(0);
112            let relative_path = match parts.next() {
113                Some(p) if !p.is_empty() => p.to_string(),
114                _ => continue,
115            };
116            results.push(FileInspection {
117                relative_path,
118                sha256,
119                size,
120            });
121        }
122
123        Ok(results)
124    }
125}
126
127/// Execute commands on the local machine via `tokio::process::Command`.
128pub struct LocalShell;
129
130const LOCAL_DEFAULT_TIMEOUT_SECS: u64 = 600;
131
132#[async_trait]
133impl RemoteShell for LocalShell {
134    async fn exec(
135        &self,
136        args: &[&str],
137        timeout_secs: Option<u64>,
138    ) -> Result<ShellOutput, InfraError> {
139        if args.is_empty() {
140            return Err(InfraError::Transfer {
141                reason: "empty command".into(),
142            });
143        }
144
145        let mut cmd = tokio::process::Command::new(args[0]);
146        if args.len() > 1 {
147            cmd.args(&args[1..]);
148        }
149
150        let timeout =
151            std::time::Duration::from_secs(timeout_secs.unwrap_or(LOCAL_DEFAULT_TIMEOUT_SECS));
152
153        let output = tokio::time::timeout(timeout, cmd.output())
154            .await
155            .map_err(|_| -> InfraError {
156                InfraError::Transfer {
157                    reason: format!(
158                        "command timed out after {}s: {}",
159                        timeout.as_secs(),
160                        args.join(" ")
161                    ),
162                }
163            })?
164            .map_err(|e| -> InfraError {
165                InfraError::Transfer {
166                    reason: format!("exec failed ({}): {e}", args[0]),
167                }
168            })?;
169
170        Ok(ShellOutput {
171            stdout: String::from_utf8_lossy(&output.stdout).to_string(),
172            stderr: String::from_utf8_lossy(&output.stderr).to_string(),
173            success: output.status.success(),
174            exit_code: output.status.code(),
175        })
176    }
177}
178
179/// Mock shell for testing — returns configurable responses.
180#[cfg(any(test, feature = "test-utils"))]
181pub mod mock {
182    use super::*;
183    use std::collections::HashMap;
184    use tokio::sync::Mutex;
185
186    /// Mock file entry with optional hash and size.
187    #[derive(Clone)]
188    pub struct MockFile {
189        pub sha256: String,
190        pub size: u64,
191    }
192
193    impl MockFile {
194        pub fn new(sha256: impl Into<String>, size: u64) -> Self {
195            Self {
196                sha256: sha256.into(),
197                size,
198            }
199        }
200    }
201
202    /// A mock RemoteShell that simulates file operations on a remote host.
203    ///
204    /// Supports:
205    /// - `test -f <path>` — file existence check
206    /// - `sha256sum <path>` — returns configured hash
207    /// - `stat --format=%s <path>` — returns configured size
208    ///
209    /// - `exec_log`: records all commands executed (for assertions)
210    pub struct MockShell {
211        files: Mutex<HashMap<String, MockFile>>,
212        pub exec_log: Mutex<Vec<Vec<String>>>,
213    }
214
215    impl MockShell {
216        /// Create with a set of files (path → MockFile).
217        pub fn with_files(files: impl IntoIterator<Item = (impl Into<String>, MockFile)>) -> Self {
218            Self {
219                files: Mutex::new(files.into_iter().map(|(k, v)| (k.into(), v)).collect()),
220                exec_log: Mutex::new(Vec::new()),
221            }
222        }
223
224        /// Create with paths only (existence checks only, no hash/size).
225        pub fn new(existing: impl IntoIterator<Item = impl Into<String>>) -> Self {
226            Self::with_files(
227                existing
228                    .into_iter()
229                    .map(|p| (p, MockFile::new("0000000000000000", 0))),
230            )
231        }
232    }
233
234    #[async_trait]
235    impl RemoteShell for MockShell {
236        async fn exec(
237            &self,
238            args: &[&str],
239            _timeout_secs: Option<u64>,
240        ) -> Result<ShellOutput, InfraError> {
241            let owned: Vec<String> = args.iter().map(|s| s.to_string()).collect();
242            self.exec_log.lock().await.push(owned);
243
244            // Simulate `test -f <path>`
245            if args.len() >= 3 && args[0] == "test" && args[1] == "-f" {
246                let path = args[2];
247                let exists = self.files.lock().await.contains_key(path);
248                return Ok(ShellOutput {
249                    stdout: String::new(),
250                    stderr: String::new(),
251                    success: exists,
252                    exit_code: Some(if exists { 0 } else { 1 }),
253                });
254            }
255
256            // Simulate `sha256sum <path>`
257            if args.len() >= 2 && args[0] == "sha256sum" {
258                let path = args[1];
259                let files = self.files.lock().await;
260                if let Some(f) = files.get(path) {
261                    return Ok(ShellOutput {
262                        stdout: format!("{}  {}\n", f.sha256, path),
263                        stderr: String::new(),
264                        success: true,
265                        exit_code: Some(0),
266                    });
267                }
268                return Ok(ShellOutput {
269                    stdout: String::new(),
270                    stderr: format!("sha256sum: {path}: No such file or directory\n"),
271                    success: false,
272                    exit_code: Some(1),
273                });
274            }
275
276            // Simulate `stat --format=%s <path>` (GNU) or `stat -f%z <path>` (BSD)
277            if args.len() >= 3 && args[0] == "stat" {
278                let path = args.last().expect("args is non-empty");
279                let files = self.files.lock().await;
280                if let Some(f) = files.get(*path) {
281                    return Ok(ShellOutput {
282                        stdout: format!("{}\n", f.size),
283                        stderr: String::new(),
284                        success: true,
285                        exit_code: Some(0),
286                    });
287                }
288                return Ok(ShellOutput {
289                    stdout: String::new(),
290                    stderr: format!("stat: cannot stat '{path}': No such file or directory\n"),
291                    success: false,
292                    exit_code: Some(1),
293                });
294            }
295
296            // Default: success with empty output
297            Ok(ShellOutput {
298                stdout: String::new(),
299                stderr: String::new(),
300                success: true,
301                exit_code: Some(0),
302            })
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[tokio::test]
312    async fn local_shell_echo() {
313        let shell = LocalShell;
314        let output = shell.exec(&["echo", "hello"], None).await.unwrap();
315        assert!(output.success);
316        assert_eq!(output.stdout.trim(), "hello");
317        assert_eq!(output.exit_code, Some(0));
318    }
319
320    #[tokio::test]
321    async fn local_shell_empty_args() {
322        let shell = LocalShell;
323        let result = shell.exec(&[], None).await;
324        assert!(result.is_err());
325    }
326
327    #[tokio::test]
328    async fn local_shell_nonexistent_command() {
329        let shell = LocalShell;
330        let result = shell.exec(&["__nonexistent_command_12345__"], None).await;
331        assert!(result.is_err());
332    }
333
334    #[tokio::test]
335    async fn local_shell_exit_code() {
336        let shell = LocalShell;
337        let output = shell.exec(&["false"], None).await.unwrap();
338        assert!(!output.success);
339        assert_ne!(output.exit_code, Some(0));
340    }
341}