rust_pipe/transport/
docker.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::validation;
10
11pub struct DockerWorkerConfig {
13 pub image: String,
14 pub worker_id: String,
15 pub supported_tasks: Vec<String>,
16 pub env: HashMap<String, String>,
17 pub volumes: Vec<String>,
18 pub network: Option<String>,
19 pub memory_limit: Option<String>,
20 pub cpu_limit: Option<String>,
21}
22
23pub struct DockerTransport {
25 configs: Vec<DockerWorkerConfig>,
26 containers: Arc<RwLock<HashMap<String, DockerContainer>>>,
27 on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
28}
29
30struct DockerContainer {
31 container_id: String,
32 stdin_tx: mpsc::UnboundedSender<String>,
33}
34
35impl DockerTransport {
36 pub fn new(
37 configs: Vec<DockerWorkerConfig>,
38 on_message: impl Fn(String, Message) + Send + Sync + 'static,
39 ) -> Self {
40 Self {
41 configs,
42 containers: Arc::new(RwLock::new(HashMap::new())),
43 on_message: Arc::new(on_message),
44 }
45 }
46
47 async fn start_container(config: &DockerWorkerConfig) -> Result<String, TransportError> {
48 validation::validate_worker_id(&config.worker_id)
50 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
51 validation::validate_docker_image(&config.image)
52 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
53
54 for (key, val) in &config.env {
55 validation::validate_no_shell_metacharacters(key, "env key")
56 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
57 validation::validate_no_shell_metacharacters(val, "env value")
58 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
59 }
60
61 for vol in &config.volumes {
62 validation::validate_no_shell_metacharacters(vol, "volume")
63 .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
64 }
65
66 let mut args = vec![
67 "run".to_string(),
68 "-d".to_string(),
69 "-i".to_string(),
70 "--rm".to_string(),
71 "--name".to_string(),
72 format!("rust-pipe-{}", config.worker_id),
73 ];
74
75 if let Some(ref mem) = config.memory_limit {
76 args.push("--memory".to_string());
77 args.push(mem.clone());
78 }
79
80 if let Some(ref cpu) = config.cpu_limit {
81 args.push("--cpus".to_string());
82 args.push(cpu.clone());
83 }
84
85 if let Some(ref network) = config.network {
86 args.push("--network".to_string());
87 args.push(network.clone());
88 }
89
90 for vol in &config.volumes {
91 args.push("-v".to_string());
92 args.push(vol.clone());
93 }
94
95 for (key, val) in &config.env {
96 args.push("-e".to_string());
97 args.push(format!("{}={}", key, val));
98 }
99
100 args.push(config.image.clone());
101
102 let output = Command::new("docker")
103 .args(&args)
104 .output()
105 .await
106 .map_err(|e| TransportError::ConnectionFailed(format!("docker run failed: {}", e)))?;
107
108 if !output.status.success() {
109 let stderr = String::from_utf8_lossy(&output.stderr);
110 return Err(TransportError::ConnectionFailed(format!(
111 "docker run failed: {}",
112 stderr
113 )));
114 }
115
116 let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
117 Ok(container_id)
118 }
119}
120
121#[async_trait]
122impl Transport for DockerTransport {
123 async fn start(&self) -> Result<(), TransportError> {
124 for config in &self.configs {
125 let container_id = Self::start_container(config).await?;
126 let worker_id = config.worker_id.clone();
127 let on_message = self.on_message.clone();
128 let containers = self.containers.clone();
129
130 let mut attach = Command::new("docker")
132 .args(["attach", "--no-stdin=false", &container_id])
133 .stdin(std::process::Stdio::piped())
134 .stdout(std::process::Stdio::piped())
135 .stderr(std::process::Stdio::null())
136 .spawn()
137 .map_err(|e| {
138 TransportError::ConnectionFailed(format!("docker attach failed: {}", e))
139 })?;
140
141 let stdin = attach.stdin.take().expect("stdin piped");
142 let stdout = attach.stdout.take().expect("stdout piped");
143 let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
144
145 let reg_msg = Message::WorkerRegister {
147 registration: super::WorkerRegistration {
148 worker_id: worker_id.clone(),
149 supported_tasks: config.supported_tasks.clone(),
150 max_concurrency: 1,
151 language: super::WorkerLanguage::Other("docker".to_string()),
152 tags: None,
153 },
154 };
155 on_message(worker_id.clone(), reg_msg);
156
157 let wid = worker_id.clone();
159 tokio::spawn(async move {
160 let mut stdin = stdin;
161 while let Some(line) = stdin_rx.recv().await {
162 if stdin.write_all(line.as_bytes()).await.is_err() {
163 tracing::error!(worker_id = %wid, "Docker stdin write failed");
164 break;
165 }
166 if stdin.write_all(b"\n").await.is_err() {
167 break;
168 }
169 let _ = stdin.flush().await;
170 }
171 });
172
173 let wid = worker_id.clone();
175 tokio::spawn(async move {
176 let reader = BufReader::new(stdout);
177 let mut lines = reader.lines();
178 while let Ok(Some(line)) = lines.next_line().await {
179 if line.trim().is_empty() {
180 continue;
181 }
182 match serde_json::from_str::<Message>(&line) {
183 Ok(msg) => on_message(wid.clone(), msg),
184 Err(e) => {
185 tracing::debug!(worker_id = %wid, error = %e, "Non-JSON from docker worker");
186 }
187 }
188 }
189 });
190
191 containers.write().await.insert(
192 worker_id.clone(),
193 DockerContainer {
194 container_id,
195 stdin_tx,
196 },
197 );
198
199 tracing::info!(
200 worker_id = %worker_id,
201 image = %config.image,
202 "Docker worker container started"
203 );
204 }
205
206 Ok(())
207 }
208
209 async fn stop(&self) -> Result<(), TransportError> {
210 let containers = self.containers.read().await;
211 for (worker_id, container) in containers.iter() {
212 let _ = container.stdin_tx.send(
213 serde_json::to_string(&Message::Shutdown { graceful: true }).unwrap_or_default(),
214 );
215
216 let _ = Command::new("docker")
218 .args(["stop", "-t", "10", &container.container_id])
219 .output()
220 .await;
221
222 tracing::info!(worker_id = %worker_id, "Docker container stopped");
223 }
224 Ok(())
225 }
226
227 async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
228 let containers = self.containers.read().await;
229 let container = containers
230 .get(worker_id)
231 .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
232
233 let json = serde_json::to_string(&message)
234 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
235
236 container
237 .stdin_tx
238 .send(json)
239 .map_err(|e| TransportError::SendFailed(e.to_string()))
240 }
241
242 async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
243 let containers = self.containers.read().await;
244 let json = serde_json::to_string(&message)
245 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
246
247 for (_, container) in containers.iter() {
248 let _ = container.stdin_tx.send(json.clone());
249 }
250 Ok(())
251 }
252}