rust_pipe/transport/
ssh.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::validation;
10
11pub struct SshWorkerConfig {
13 pub host: String,
14 pub user: String,
15 pub port: u16,
16 pub worker_id: String,
17 pub supported_tasks: Vec<String>,
18 pub remote_command: String,
19 pub identity_file: Option<String>,
20 pub connect_timeout_secs: u32,
21}
22
23pub struct SshTransport {
25 configs: Vec<SshWorkerConfig>,
26 connections: Arc<RwLock<HashMap<String, SshConnection>>>,
27 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
28}
29
30struct SshConnection {
31 stdin_tx: mpsc::UnboundedSender<String>,
32}
33
34impl SshTransport {
35 pub fn new(
36 configs: Vec<SshWorkerConfig>,
37 on_message: impl Fn(String, Message) + Send + Sync + 'static,
38 ) -> Self {
39 Self {
40 configs,
41 connections: Arc::new(RwLock::new(HashMap::new())),
42 on_message: Arc::new(on_message),
43 }
44 }
45}
46
47#[async_trait]
48impl Transport for SshTransport {
49 async fn start(&self) -> Result<(), TransportError> {
50 for config in &self.configs {
51 validation::validate_worker_id(&config.worker_id)
53 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
54 validation::validate_hostname(&config.host)
55 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
56 validation::validate_username(&config.user)
57 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
58 validation::validate_no_shell_metacharacters(&config.remote_command, "remote_command")
59 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
60
61 if let Some(ref key) = config.identity_file {
62 validation::validate_file_path(key, "identity_file")
63 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
64 }
65
66 let worker_id = config.worker_id.clone();
67 let on_message = self.on_message.clone();
68 let connections = self.connections.clone();
69
70 let mut ssh_args = vec![
71 "-o".to_string(),
72 "StrictHostKeyChecking=yes".to_string(),
73 "-o".to_string(),
74 format!("ConnectTimeout={}", config.connect_timeout_secs),
75 "-o".to_string(),
76 "ServerAliveInterval=5".to_string(),
77 "-o".to_string(),
78 "ServerAliveCountMax=3".to_string(),
79 "-p".to_string(),
80 config.port.to_string(),
81 ];
82
83 if let Some(ref key) = config.identity_file {
84 ssh_args.push("-i".to_string());
85 ssh_args.push(key.clone());
86 }
87
88 ssh_args.push("--".to_string());
89 ssh_args.push(format!("{}@{}", config.user, config.host));
90 ssh_args.push(config.remote_command.clone());
91
92 let mut child = Command::new("ssh")
93 .args(&ssh_args)
94 .stdin(std::process::Stdio::piped())
95 .stdout(std::process::Stdio::piped())
96 .stderr(std::process::Stdio::piped())
97 .spawn()
98 .map_err(|e| {
99 TransportError::ConnectionFailed(format!(
100 "SSH to {}@{}:{} failed: {}",
101 config.user, config.host, config.port, e
102 ))
103 })?;
104
105 let stdin = child.stdin.take().expect("stdin piped");
106 let stdout = child.stdout.take().expect("stdout piped");
107 let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
108
109 let reg_msg = Message::WorkerRegister {
111 registration: super::WorkerRegistration {
112 worker_id: worker_id.clone(),
113 supported_tasks: config.supported_tasks.clone(),
114 max_concurrency: 1,
115 language: super::WorkerLanguage::Other("ssh".to_string()),
116 tags: None,
117 },
118 };
119 on_message(worker_id.clone(), reg_msg);
120
121 let wid = worker_id.clone();
123 tokio::spawn(async move {
124 let mut stdin = stdin;
125 while let Some(line) = stdin_rx.recv().await {
126 if stdin.write_all(line.as_bytes()).await.is_err() {
127 tracing::error!(worker_id = %wid, "SSH stdin write failed");
128 break;
129 }
130 if stdin.write_all(b"\n").await.is_err() {
131 break;
132 }
133 let _ = stdin.flush().await;
134 }
135 });
136
137 let wid = worker_id.clone();
139 tokio::spawn(async move {
140 let reader = BufReader::new(stdout);
141 let mut lines = reader.lines();
142 while let Ok(Some(line)) = lines.next_line().await {
143 if line.trim().is_empty() {
144 continue;
145 }
146 match serde_json::from_str::<Message>(&line) {
147 Ok(msg) => on_message(wid.clone(), msg),
148 Err(e) => {
149 tracing::debug!(
150 worker_id = %wid,
151 error = %e,
152 "Non-JSON from SSH worker"
153 );
154 }
155 }
156 }
157 tracing::info!(worker_id = %wid, "SSH connection closed");
158 });
159
160 connections
161 .write()
162 .await
163 .insert(worker_id.clone(), SshConnection { stdin_tx });
164
165 tracing::info!(
166 worker_id = %worker_id,
167 host = %config.host,
168 command = %config.remote_command,
169 "SSH worker connected"
170 );
171 }
172
173 Ok(())
174 }
175
176 async fn stop(&self) -> Result<(), TransportError> {
177 let connections = self.connections.read().await;
178 for (worker_id, conn) in connections.iter() {
179 let shutdown =
180 serde_json::to_string(&Message::Shutdown { graceful: true }).unwrap_or_default();
181 let _ = conn.stdin_tx.send(shutdown);
182 tracing::info!(worker_id = %worker_id, "SSH worker shutdown sent");
183 }
184 Ok(())
185 }
186
187 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
188 let connections = self.connections.read().await;
189 let conn = connections
190 .get(worker_id)
191 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
192
193 let json = serde_json::to_string(&message)
194 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
195
196 conn.stdin_tx
197 .send(json)
198 .map_err(|e| TransportError::SendFailed(e.to_string()))
199 }
200
201 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
202 let connections = self.connections.read().await;
203 let json = serde_json::to_string(&message)
204 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
205
206 for (_, conn) in connections.iter() {
207 let _ = conn.stdin_tx.send(json.clone());
208 }
209 Ok(())
210 }
211}