rust_pipe/transport/
ssh.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::validation;
10
11pub struct SshWorkerConfig {
13 pub host: String,
14 pub user: String,
15 pub port: u16,
16 pub worker_id: String,
17 pub supported_tasks: Vec<String>,
18 pub remote_command: String,
19 pub identity_file: Option<String>,
20 pub connect_timeout_secs: u32,
21}
22
23pub struct SshTransport {
25 configs: Vec<SshWorkerConfig>,
26 connections: Arc<RwLock<HashMap<String, SshConnection>>>,
27 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
28}
29
30struct SshConnection {
31 stdin_tx: mpsc::UnboundedSender<String>,
32}
33
34impl SshTransport {
35 pub fn new(
36 configs: Vec<SshWorkerConfig>,
37 on_message: impl Fn(String, Message) + Send + Sync + 'static,
38 ) -> Self {
39 Self {
40 configs,
41 connections: Arc::new(RwLock::new(HashMap::new())),
42 on_message: Arc::new(on_message),
43 }
44 }
45}
46
47#[async_trait]
48impl Transport for SshTransport {
49 async fn start(&self) -> Result<(), TransportError> {
50 for config in &self.configs {
51 validation::validate_worker_id(&config.worker_id)
53 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
54 validation::validate_hostname(&config.host)
55 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
56 validation::validate_username(&config.user)
57 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
58 validation::validate_no_shell_metacharacters(&config.remote_command, "remote_command")
59 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
60
61 if let Some(ref key) = config.identity_file {
62 validation::validate_file_path(key, "identity_file")
63 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
64 }
65
66 let worker_id = config.worker_id.clone();
67 let on_message = self.on_message.clone();
68 let connections = self.connections.clone();
69
70 let mut ssh_args = vec![
71 "-o".to_string(),
72 "StrictHostKeyChecking=yes".to_string(),
73 "-o".to_string(),
74 format!("ConnectTimeout={}", config.connect_timeout_secs),
75 "-o".to_string(),
76 "ServerAliveInterval=5".to_string(),
77 "-o".to_string(),
78 "ServerAliveCountMax=3".to_string(),
79 "-p".to_string(),
80 config.port.to_string(),
81 ];
82
83 if let Some(ref key) = config.identity_file {
84 ssh_args.push("-i".to_string());
85 ssh_args.push(key.clone());
86 }
87
88 ssh_args.push("--".to_string());
89 ssh_args.push(format!("{}@{}", config.user, config.host));
90 ssh_args.push(config.remote_command.clone());
91
92 let mut child = Command::new("ssh")
93 .args(&ssh_args)
94 .stdin(std::process::Stdio::piped())
95 .stdout(std::process::Stdio::piped())
96 .stderr(std::process::Stdio::piped())
97 .spawn()
98 .map_err(|e| {
99 TransportError::ConnectionFailed(format!(
100 "SSH to {}@{}:{} failed: {}",
101 config.user, config.host, config.port, e
102 ))
103 })?;
104
105 let stdin = child.stdin.take().expect("stdin piped");
106 let stdout = child.stdout.take().expect("stdout piped");
107 let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
108
109 let reg_msg = Message::WorkerRegister {
111 registration: super::WorkerRegistration {
112 worker_id: worker_id.clone(),
113 supported_tasks: config.supported_tasks.clone(),
114 max_concurrency: 1,
115 language: super::WorkerLanguage::Other("ssh".to_string()),
116 },
117 };
118 on_message(worker_id.clone(), reg_msg);
119
120 let wid = worker_id.clone();
122 tokio::spawn(async move {
123 let mut stdin = stdin;
124 while let Some(line) = stdin_rx.recv().await {
125 if stdin.write_all(line.as_bytes()).await.is_err() {
126 tracing::error!(worker_id = %wid, "SSH stdin write failed");
127 break;
128 }
129 if stdin.write_all(b"\n").await.is_err() {
130 break;
131 }
132 let _ = stdin.flush().await;
133 }
134 });
135
136 let wid = worker_id.clone();
138 tokio::spawn(async move {
139 let reader = BufReader::new(stdout);
140 let mut lines = reader.lines();
141 while let Ok(Some(line)) = lines.next_line().await {
142 if line.trim().is_empty() {
143 continue;
144 }
145 match serde_json::from_str::<Message>(&line) {
146 Ok(msg) => on_message(wid.clone(), msg),
147 Err(e) => {
148 tracing::debug!(
149 worker_id = %wid,
150 error = %e,
151 "Non-JSON from SSH worker"
152 );
153 }
154 }
155 }
156 tracing::info!(worker_id = %wid, "SSH connection closed");
157 });
158
159 connections
160 .write()
161 .await
162 .insert(worker_id.clone(), SshConnection { stdin_tx });
163
164 tracing::info!(
165 worker_id = %worker_id,
166 host = %config.host,
167 command = %config.remote_command,
168 "SSH worker connected"
169 );
170 }
171
172 Ok(())
173 }
174
175 async fn stop(&self) -> Result<(), TransportError> {
176 let connections = self.connections.read().await;
177 for (worker_id, conn) in connections.iter() {
178 let shutdown =
179 serde_json::to_string(&Message::Shutdown { graceful: true }).unwrap_or_default();
180 let _ = conn.stdin_tx.send(shutdown);
181 tracing::info!(worker_id = %worker_id, "SSH worker shutdown sent");
182 }
183 Ok(())
184 }
185
186 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
187 let connections = self.connections.read().await;
188 let conn = connections
189 .get(worker_id)
190 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
191
192 let json = serde_json::to_string(&message)
193 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
194
195 conn.stdin_tx
196 .send(json)
197 .map_err(|e| TransportError::SendFailed(e.to_string()))
198 }
199
200 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
201 let connections = self.connections.read().await;
202 let json = serde_json::to_string(&message)
203 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
204
205 for (_, conn) in connections.iter() {
206 let _ = conn.stdin_tx.send(json.clone());
207 }
208 Ok(())
209 }
210}