Skip to main content

rust_pipe/transport/
ssh.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::validation;
10
11/// Configuration for a remote SSH worker.
12pub struct SshWorkerConfig {
13    pub host: String,
14    pub user: String,
15    pub port: u16,
16    pub worker_id: String,
17    pub supported_tasks: Vec<String>,
18    pub remote_command: String,
19    pub identity_file: Option<String>,
20    pub connect_timeout_secs: u32,
21}
22
23/// Transport that dispatches tasks to remote machines over SSH.
24pub struct SshTransport {
25    configs: Vec<SshWorkerConfig>,
26    connections: Arc<RwLock<HashMap<String, SshConnection>>>,
27    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
28}
29
30struct SshConnection {
31    stdin_tx: mpsc::UnboundedSender<String>,
32}
33
34impl SshTransport {
35    pub fn new(
36        configs: Vec<SshWorkerConfig>,
37        on_message: impl Fn(String, Message) + Send + Sync + 'static,
38    ) -> Self {
39        Self {
40            configs,
41            connections: Arc::new(RwLock::new(HashMap::new())),
42            on_message: Arc::new(on_message),
43        }
44    }
45}
46
47#[async_trait]
48impl Transport for SshTransport {
49    async fn start(&self) -> Result<(), TransportError> {
50        for config in &self.configs {
51            // Validate inputs to prevent command injection
52            validation::validate_worker_id(&config.worker_id)
53                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
54            validation::validate_hostname(&config.host)
55                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
56            validation::validate_username(&config.user)
57                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
58            validation::validate_no_shell_metacharacters(&config.remote_command, "remote_command")
59                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
60
61            if let Some(ref key) = config.identity_file {
62                validation::validate_file_path(key, "identity_file")
63                    .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
64            }
65
66            let worker_id = config.worker_id.clone();
67            let on_message = self.on_message.clone();
68            let connections = self.connections.clone();
69
70            let mut ssh_args = vec![
71                "-o".to_string(),
72                "StrictHostKeyChecking=yes".to_string(),
73                "-o".to_string(),
74                format!("ConnectTimeout={}", config.connect_timeout_secs),
75                "-o".to_string(),
76                "ServerAliveInterval=5".to_string(),
77                "-o".to_string(),
78                "ServerAliveCountMax=3".to_string(),
79                "-p".to_string(),
80                config.port.to_string(),
81            ];
82
83            if let Some(ref key) = config.identity_file {
84                ssh_args.push("-i".to_string());
85                ssh_args.push(key.clone());
86            }
87
88            ssh_args.push("--".to_string());
89            ssh_args.push(format!("{}@{}", config.user, config.host));
90            ssh_args.push(config.remote_command.clone());
91
92            let mut child = Command::new("ssh")
93                .args(&ssh_args)
94                .stdin(std::process::Stdio::piped())
95                .stdout(std::process::Stdio::piped())
96                .stderr(std::process::Stdio::piped())
97                .spawn()
98                .map_err(|e| {
99                    TransportError::ConnectionFailed(format!(
100                        "SSH to {}@{}:{} failed: {}",
101                        config.user, config.host, config.port, e
102                    ))
103                })?;
104
105            let stdin = child.stdin.take().expect("stdin piped");
106            let stdout = child.stdout.take().expect("stdout piped");
107            let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
108
109            // Register worker
110            let reg_msg = Message::WorkerRegister {
111                registration: super::WorkerRegistration {
112                    worker_id: worker_id.clone(),
113                    supported_tasks: config.supported_tasks.clone(),
114                    max_concurrency: 1,
115                    language: super::WorkerLanguage::Other("ssh".to_string()),
116                    tags: None,
117                },
118            };
119            on_message(worker_id.clone(), reg_msg);
120
121            // Stdin writer
122            let wid = worker_id.clone();
123            tokio::spawn(async move {
124                let mut stdin = stdin;
125                while let Some(line) = stdin_rx.recv().await {
126                    if stdin.write_all(line.as_bytes()).await.is_err() {
127                        tracing::error!(worker_id = %wid, "SSH stdin write failed");
128                        break;
129                    }
130                    if stdin.write_all(b"\n").await.is_err() {
131                        break;
132                    }
133                    let _ = stdin.flush().await;
134                }
135            });
136
137            // Stdout reader
138            let wid = worker_id.clone();
139            tokio::spawn(async move {
140                let reader = BufReader::new(stdout);
141                let mut lines = reader.lines();
142                while let Ok(Some(line)) = lines.next_line().await {
143                    if line.trim().is_empty() {
144                        continue;
145                    }
146                    match serde_json::from_str::<Message>(&line) {
147                        Ok(msg) => on_message(wid.clone(), msg),
148                        Err(e) => {
149                            tracing::debug!(
150                                worker_id = %wid,
151                                error = %e,
152                                "Non-JSON from SSH worker"
153                            );
154                        }
155                    }
156                }
157                tracing::info!(worker_id = %wid, "SSH connection closed");
158            });
159
160            connections
161                .write()
162                .await
163                .insert(worker_id.clone(), SshConnection { stdin_tx });
164
165            tracing::info!(
166                worker_id = %worker_id,
167                host = %config.host,
168                command = %config.remote_command,
169                "SSH worker connected"
170            );
171        }
172
173        Ok(())
174    }
175
176    async fn stop(&self) -> Result<(), TransportError> {
177        let connections = self.connections.read().await;
178        for (worker_id, conn) in connections.iter() {
179            let shutdown =
180                serde_json::to_string(&Message::Shutdown { graceful: true }).unwrap_or_default();
181            let _ = conn.stdin_tx.send(shutdown);
182            tracing::info!(worker_id = %worker_id, "SSH worker shutdown sent");
183        }
184        Ok(())
185    }
186
187    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
188        let connections = self.connections.read().await;
189        let conn = connections
190            .get(worker_id)
191            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
192
193        let json = serde_json::to_string(&message)
194            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
195
196        conn.stdin_tx
197            .send(json)
198            .map_err(|e| TransportError::SendFailed(e.to_string()))
199    }
200
201    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
202        let connections = self.connections.read().await;
203        let json = serde_json::to_string(&message)
204            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
205
206        for (_, conn) in connections.iter() {
207            let _ = conn.stdin_tx.send(json.clone());
208        }
209        Ok(())
210    }
211}