Skip to main content

rust_pipe/transport/
docker.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::validation;
10
11/// Configuration for a Docker container worker.
12pub struct DockerWorkerConfig {
13    pub image: String,
14    pub worker_id: String,
15    pub supported_tasks: Vec<String>,
16    pub env: HashMap<String, String>,
17    pub volumes: Vec<String>,
18    pub network: Option<String>,
19    pub memory_limit: Option<String>,
20    pub cpu_limit: Option<String>,
21}
22
23/// Transport that runs workers inside Docker containers.
24pub struct DockerTransport {
25    configs: Vec<DockerWorkerConfig>,
26    containers: Arc<RwLock<HashMap<String, DockerContainer>>>,
27    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
28}
29
30struct DockerContainer {
31    container_id: String,
32    stdin_tx: mpsc::UnboundedSender<String>,
33}
34
35impl DockerTransport {
36    pub fn new(
37        configs: Vec<DockerWorkerConfig>,
38        on_message: impl Fn(String, Message) + Send + Sync + 'static,
39    ) -> Self {
40        Self {
41            configs,
42            containers: Arc::new(RwLock::new(HashMap::new())),
43            on_message: Arc::new(on_message),
44        }
45    }
46
47    async fn start_container(config: &DockerWorkerConfig) -> Result<String, TransportError> {
48        // Validate inputs to prevent command injection
49        validation::validate_worker_id(&config.worker_id)
50            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
51        validation::validate_docker_image(&config.image)
52            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
53
54        for (key, val) in &config.env {
55            validation::validate_no_shell_metacharacters(key, "env key")
56                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
57            validation::validate_no_shell_metacharacters(val, "env value")
58                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
59        }
60
61        for vol in &config.volumes {
62            validation::validate_no_shell_metacharacters(vol, "volume")
63                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
64        }
65
66        let mut args = vec![
67            "run".to_string(),
68            "-d".to_string(),
69            "-i".to_string(),
70            "--rm".to_string(),
71            "--name".to_string(),
72            format!("rust-pipe-{}", config.worker_id),
73        ];
74
75        if let Some(ref mem) = config.memory_limit {
76            args.push("--memory".to_string());
77            args.push(mem.clone());
78        }
79
80        if let Some(ref cpu) = config.cpu_limit {
81            args.push("--cpus".to_string());
82            args.push(cpu.clone());
83        }
84
85        if let Some(ref network) = config.network {
86            args.push("--network".to_string());
87            args.push(network.clone());
88        }
89
90        for vol in &config.volumes {
91            args.push("-v".to_string());
92            args.push(vol.clone());
93        }
94
95        for (key, val) in &config.env {
96            args.push("-e".to_string());
97            args.push(format!("{}={}", key, val));
98        }
99
100        args.push(config.image.clone());
101
102        let output = Command::new("docker")
103            .args(&args)
104            .output()
105            .await
106            .map_err(|e| TransportError::ConnectionFailed(format!("docker run failed: {}", e)))?;
107
108        if !output.status.success() {
109            let stderr = String::from_utf8_lossy(&output.stderr);
110            return Err(TransportError::ConnectionFailed(format!(
111                "docker run failed: {}",
112                stderr
113            )));
114        }
115
116        let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
117        Ok(container_id)
118    }
119}
120
121#[async_trait]
122impl Transport for DockerTransport {
123    async fn start(&self) -> Result<(), TransportError> {
124        for config in &self.configs {
125            let container_id = Self::start_container(config).await?;
126            let worker_id = config.worker_id.clone();
127            let on_message = self.on_message.clone();
128            let containers = self.containers.clone();
129
130            // Attach to container stdin/stdout
131            let mut attach = Command::new("docker")
132                .args(["attach", "--no-stdin=false", &container_id])
133                .stdin(std::process::Stdio::piped())
134                .stdout(std::process::Stdio::piped())
135                .stderr(std::process::Stdio::null())
136                .spawn()
137                .map_err(|e| {
138                    TransportError::ConnectionFailed(format!("docker attach failed: {}", e))
139                })?;
140
141            let stdin = attach.stdin.take().expect("stdin piped");
142            let stdout = attach.stdout.take().expect("stdout piped");
143            let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
144
145            // Register worker
146            let reg_msg = Message::WorkerRegister {
147                registration: super::WorkerRegistration {
148                    worker_id: worker_id.clone(),
149                    supported_tasks: config.supported_tasks.clone(),
150                    max_concurrency: 1,
151                    language: super::WorkerLanguage::Other("docker".to_string()),
152                },
153            };
154            on_message(worker_id.clone(), reg_msg);
155
156            // Stdin writer
157            let wid = worker_id.clone();
158            tokio::spawn(async move {
159                let mut stdin = stdin;
160                while let Some(line) = stdin_rx.recv().await {
161                    if stdin.write_all(line.as_bytes()).await.is_err() {
162                        tracing::error!(worker_id = %wid, "Docker stdin write failed");
163                        break;
164                    }
165                    if stdin.write_all(b"\n").await.is_err() {
166                        break;
167                    }
168                    let _ = stdin.flush().await;
169                }
170            });
171
172            // Stdout reader
173            let wid = worker_id.clone();
174            tokio::spawn(async move {
175                let reader = BufReader::new(stdout);
176                let mut lines = reader.lines();
177                while let Ok(Some(line)) = lines.next_line().await {
178                    if line.trim().is_empty() {
179                        continue;
180                    }
181                    match serde_json::from_str::<Message>(&line) {
182                        Ok(msg) => on_message(wid.clone(), msg),
183                        Err(e) => {
184                            tracing::debug!(worker_id = %wid, error = %e, "Non-JSON from docker worker");
185                        }
186                    }
187                }
188            });
189
190            containers.write().await.insert(
191                worker_id.clone(),
192                DockerContainer {
193                    container_id,
194                    stdin_tx,
195                },
196            );
197
198            tracing::info!(
199                worker_id = %worker_id,
200                image = %config.image,
201                "Docker worker container started"
202            );
203        }
204
205        Ok(())
206    }
207
208    async fn stop(&self) -> Result<(), TransportError> {
209        let containers = self.containers.read().await;
210        for (worker_id, container) in containers.iter() {
211            let _ = container.stdin_tx.send(
212                serde_json::to_string(&Message::Shutdown { graceful: true }).unwrap_or_default(),
213            );
214
215            // Stop container
216            let _ = Command::new("docker")
217                .args(["stop", "-t", "10", &container.container_id])
218                .output()
219                .await;
220
221            tracing::info!(worker_id = %worker_id, "Docker container stopped");
222        }
223        Ok(())
224    }
225
226    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
227        let containers = self.containers.read().await;
228        let container = containers
229            .get(worker_id)
230            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
231
232        let json = serde_json::to_string(&message)
233            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
234
235        container
236            .stdin_tx
237            .send(json)
238            .map_err(|e| TransportError::SendFailed(e.to_string()))
239    }
240
241    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
242        let containers = self.containers.read().await;
243        let json = serde_json::to_string(&message)
244            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
245
246        for (_, container) in containers.iter() {
247            let _ = container.stdin_tx.send(json.clone());
248        }
249        Ok(())
250    }
251}