Skip to main content

rust_pipe/transport/
ssh.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::validation;
10
11/// Configuration for a remote SSH worker.
12pub struct SshWorkerConfig {
13    pub host: String,
14    pub user: String,
15    pub port: u16,
16    pub worker_id: String,
17    pub supported_tasks: Vec<String>,
18    pub remote_command: String,
19    pub identity_file: Option<String>,
20    pub connect_timeout_secs: u32,
21}
22
23/// Transport that dispatches tasks to remote machines over SSH.
24pub struct SshTransport {
25    configs: Vec<SshWorkerConfig>,
26    connections: Arc<RwLock<HashMap<String, SshConnection>>>,
27    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
28}
29
30struct SshConnection {
31    stdin_tx: mpsc::UnboundedSender<String>,
32}
33
34impl SshTransport {
35    pub fn new(
36        configs: Vec<SshWorkerConfig>,
37        on_message: impl Fn(String, Message) + Send + Sync + 'static,
38    ) -> Self {
39        Self {
40            configs,
41            connections: Arc::new(RwLock::new(HashMap::new())),
42            on_message: Arc::new(on_message),
43        }
44    }
45}
46
47#[async_trait]
48impl Transport for SshTransport {
49    async fn start(&self) -> Result<(), TransportError> {
50        for config in &self.configs {
51            // Validate inputs to prevent command injection
52            validation::validate_worker_id(&config.worker_id)
53                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
54            validation::validate_hostname(&config.host)
55                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
56            validation::validate_username(&config.user)
57                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
58            validation::validate_no_shell_metacharacters(&config.remote_command, "remote_command")
59                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
60
61            if let Some(ref key) = config.identity_file {
62                validation::validate_file_path(key, "identity_file")
63                    .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
64            }
65
66            let worker_id = config.worker_id.clone();
67            let on_message = self.on_message.clone();
68            let connections = self.connections.clone();
69
70            let mut ssh_args = vec![
71                "-o".to_string(),
72                "StrictHostKeyChecking=yes".to_string(),
73                "-o".to_string(),
74                format!("ConnectTimeout={}", config.connect_timeout_secs),
75                "-o".to_string(),
76                "ServerAliveInterval=5".to_string(),
77                "-o".to_string(),
78                "ServerAliveCountMax=3".to_string(),
79                "-p".to_string(),
80                config.port.to_string(),
81            ];
82
83            if let Some(ref key) = config.identity_file {
84                ssh_args.push("-i".to_string());
85                ssh_args.push(key.clone());
86            }
87
88            ssh_args.push("--".to_string());
89            ssh_args.push(format!("{}@{}", config.user, config.host));
90            ssh_args.push(config.remote_command.clone());
91
92            let mut child = Command::new("ssh")
93                .args(&ssh_args)
94                .stdin(std::process::Stdio::piped())
95                .stdout(std::process::Stdio::piped())
96                .stderr(std::process::Stdio::piped())
97                .spawn()
98                .map_err(|e| {
99                    TransportError::ConnectionFailed(format!(
100                        "SSH to {}@{}:{} failed: {}",
101                        config.user, config.host, config.port, e
102                    ))
103                })?;
104
105            let stdin = child.stdin.take().expect("stdin piped");
106            let stdout = child.stdout.take().expect("stdout piped");
107            let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
108
109            // Register worker
110            let reg_msg = Message::WorkerRegister {
111                registration: super::WorkerRegistration {
112                    worker_id: worker_id.clone(),
113                    supported_tasks: config.supported_tasks.clone(),
114                    max_concurrency: 1,
115                    language: super::WorkerLanguage::Other("ssh".to_string()),
116                },
117            };
118            on_message(worker_id.clone(), reg_msg);
119
120            // Stdin writer
121            let wid = worker_id.clone();
122            tokio::spawn(async move {
123                let mut stdin = stdin;
124                while let Some(line) = stdin_rx.recv().await {
125                    if stdin.write_all(line.as_bytes()).await.is_err() {
126                        tracing::error!(worker_id = %wid, "SSH stdin write failed");
127                        break;
128                    }
129                    if stdin.write_all(b"\n").await.is_err() {
130                        break;
131                    }
132                    let _ = stdin.flush().await;
133                }
134            });
135
136            // Stdout reader
137            let wid = worker_id.clone();
138            tokio::spawn(async move {
139                let reader = BufReader::new(stdout);
140                let mut lines = reader.lines();
141                while let Ok(Some(line)) = lines.next_line().await {
142                    if line.trim().is_empty() {
143                        continue;
144                    }
145                    match serde_json::from_str::<Message>(&line) {
146                        Ok(msg) => on_message(wid.clone(), msg),
147                        Err(e) => {
148                            tracing::debug!(
149                                worker_id = %wid,
150                                error = %e,
151                                "Non-JSON from SSH worker"
152                            );
153                        }
154                    }
155                }
156                tracing::info!(worker_id = %wid, "SSH connection closed");
157            });
158
159            connections
160                .write()
161                .await
162                .insert(worker_id.clone(), SshConnection { stdin_tx });
163
164            tracing::info!(
165                worker_id = %worker_id,
166                host = %config.host,
167                command = %config.remote_command,
168                "SSH worker connected"
169            );
170        }
171
172        Ok(())
173    }
174
175    async fn stop(&self) -> Result<(), TransportError> {
176        let connections = self.connections.read().await;
177        for (worker_id, conn) in connections.iter() {
178            let shutdown =
179                serde_json::to_string(&Message::Shutdown { graceful: true }).unwrap_or_default();
180            let _ = conn.stdin_tx.send(shutdown);
181            tracing::info!(worker_id = %worker_id, "SSH worker shutdown sent");
182        }
183        Ok(())
184    }
185
186    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
187        let connections = self.connections.read().await;
188        let conn = connections
189            .get(worker_id)
190            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
191
192        let json = serde_json::to_string(&message)
193            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
194
195        conn.stdin_tx
196            .send(json)
197            .map_err(|e| TransportError::SendFailed(e.to_string()))
198    }
199
200    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
201        let connections = self.connections.read().await;
202        let json = serde_json::to_string(&message)
203            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
204
205        for (_, conn) in connections.iter() {
206            let _ = conn.stdin_tx.send(json.clone());
207        }
208        Ok(())
209    }
210}