Skip to main content

rust_pipe/transport/
docker.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
5use tokio::process::Command;
6use tokio::sync::{mpsc, RwLock};
7
8use super::{Message, Transport, TransportError};
9use crate::validation;
10
11/// Configuration for a Docker container worker.
12pub struct DockerWorkerConfig {
13    pub image: String,
14    pub worker_id: String,
15    pub supported_tasks: Vec<String>,
16    pub env: HashMap<String, String>,
17    pub volumes: Vec<String>,
18    pub network: Option<String>,
19    pub memory_limit: Option<String>,
20    pub cpu_limit: Option<String>,
21}
22
23/// Transport that runs workers inside Docker containers.
24pub struct DockerTransport {
25    configs: Vec<DockerWorkerConfig>,
26    containers: Arc<RwLock<HashMap<String, DockerContainer>>>,
27    on_message: Arc<dyn Fn(String, Message) + Send + Sync>,
28}
29
30struct DockerContainer {
31    container_id: String,
32    stdin_tx: mpsc::UnboundedSender<String>,
33}
34
35impl DockerTransport {
36    pub fn new(
37        configs: Vec<DockerWorkerConfig>,
38        on_message: impl Fn(String, Message) + Send + Sync + 'static,
39    ) -> Self {
40        Self {
41            configs,
42            containers: Arc::new(RwLock::new(HashMap::new())),
43            on_message: Arc::new(on_message),
44        }
45    }
46
47    async fn start_container(config: &DockerWorkerConfig) -> Result<String, TransportError> {
48        // Validate inputs to prevent command injection
49        validation::validate_worker_id(&config.worker_id)
50            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
51        validation::validate_docker_image(&config.image)
52            .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
53
54        for (key, val) in &config.env {
55            validation::validate_no_shell_metacharacters(key, "env key")
56                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
57            validation::validate_no_shell_metacharacters(val, "env value")
58                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
59        }
60
61        for vol in &config.volumes {
62            validation::validate_no_shell_metacharacters(vol, "volume")
63                .map_err(|e| TransportError::ConnectionFailed(e.to_string()))?;
64        }
65
66        let mut args = vec![
67            "run".to_string(),
68            "-d".to_string(),
69            "-i".to_string(),
70            "--rm".to_string(),
71            "--name".to_string(),
72            format!("rust-pipe-{}", config.worker_id),
73        ];
74
75        if let Some(ref mem) = config.memory_limit {
76            args.push("--memory".to_string());
77            args.push(mem.clone());
78        }
79
80        if let Some(ref cpu) = config.cpu_limit {
81            args.push("--cpus".to_string());
82            args.push(cpu.clone());
83        }
84
85        if let Some(ref network) = config.network {
86            args.push("--network".to_string());
87            args.push(network.clone());
88        }
89
90        for vol in &config.volumes {
91            args.push("-v".to_string());
92            args.push(vol.clone());
93        }
94
95        for (key, val) in &config.env {
96            args.push("-e".to_string());
97            args.push(format!("{}={}", key, val));
98        }
99
100        args.push(config.image.clone());
101
102        let output = Command::new("docker")
103            .args(&args)
104            .output()
105            .await
106            .map_err(|e| TransportError::ConnectionFailed(format!("docker run failed: {}", e)))?;
107
108        if !output.status.success() {
109            let stderr = String::from_utf8_lossy(&output.stderr);
110            return Err(TransportError::ConnectionFailed(format!(
111                "docker run failed: {}",
112                stderr
113            )));
114        }
115
116        let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
117        Ok(container_id)
118    }
119}
120
121#[async_trait]
122impl Transport for DockerTransport {
123    async fn start(&self) -> Result<(), TransportError> {
124        for config in &self.configs {
125            let container_id = Self::start_container(config).await?;
126            let worker_id = config.worker_id.clone();
127            let on_message = self.on_message.clone();
128            let containers = self.containers.clone();
129
130            // Attach to container stdin/stdout
131            let mut attach = Command::new("docker")
132                .args(["attach", "--no-stdin=false", &container_id])
133                .stdin(std::process::Stdio::piped())
134                .stdout(std::process::Stdio::piped())
135                .stderr(std::process::Stdio::null())
136                .spawn()
137                .map_err(|e| {
138                    TransportError::ConnectionFailed(format!("docker attach failed: {}", e))
139                })?;
140
141            let stdin = attach.stdin.take().expect("stdin piped");
142            let stdout = attach.stdout.take().expect("stdout piped");
143            let (stdin_tx, mut stdin_rx) = mpsc::unbounded_channel::<String>();
144
145            // Register worker
146            let reg_msg = Message::WorkerRegister {
147                registration: super::WorkerRegistration {
148                    worker_id: worker_id.clone(),
149                    supported_tasks: config.supported_tasks.clone(),
150                    max_concurrency: 1,
151                    language: super::WorkerLanguage::Other("docker".to_string()),
152                    tags: None,
153                },
154            };
155            on_message(worker_id.clone(), reg_msg);
156
157            // Stdin writer
158            let wid = worker_id.clone();
159            tokio::spawn(async move {
160                let mut stdin = stdin;
161                while let Some(line) = stdin_rx.recv().await {
162                    if stdin.write_all(line.as_bytes()).await.is_err() {
163                        tracing::error!(worker_id = %wid, "Docker stdin write failed");
164                        break;
165                    }
166                    if stdin.write_all(b"\n").await.is_err() {
167                        break;
168                    }
169                    let _ = stdin.flush().await;
170                }
171            });
172
173            // Stdout reader
174            let wid = worker_id.clone();
175            tokio::spawn(async move {
176                let reader = BufReader::new(stdout);
177                let mut lines = reader.lines();
178                while let Ok(Some(line)) = lines.next_line().await {
179                    if line.trim().is_empty() {
180                        continue;
181                    }
182                    match serde_json::from_str::<Message>(&line) {
183                        Ok(msg) => on_message(wid.clone(), msg),
184                        Err(e) => {
185                            tracing::debug!(worker_id = %wid, error = %e, "Non-JSON from docker worker");
186                        }
187                    }
188                }
189            });
190
191            containers.write().await.insert(
192                worker_id.clone(),
193                DockerContainer {
194                    container_id,
195                    stdin_tx,
196                },
197            );
198
199            tracing::info!(
200                worker_id = %worker_id,
201                image = %config.image,
202                "Docker worker container started"
203            );
204        }
205
206        Ok(())
207    }
208
209    async fn stop(&self) -> Result<(), TransportError> {
210        let containers = self.containers.read().await;
211        for (worker_id, container) in containers.iter() {
212            let _ = container.stdin_tx.send(
213                serde_json::to_string(&Message::Shutdown { graceful: true }).unwrap_or_default(),
214            );
215
216            // Stop container
217            let _ = Command::new("docker")
218                .args(["stop", "-t", "10", &container.container_id])
219                .output()
220                .await;
221
222            tracing::info!(worker_id = %worker_id, "Docker container stopped");
223        }
224        Ok(())
225    }
226
227    async fn send(&self, worker_id: &str, message: Message) -> Result<(), TransportError> {
228        let containers = self.containers.read().await;
229        let container = containers
230            .get(worker_id)
231            .ok_or_else(|| TransportError::WorkerNotFound(worker_id.to_string()))?;
232
233        let json = serde_json::to_string(&message)
234            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
235
236        container
237            .stdin_tx
238            .send(json)
239            .map_err(|e| TransportError::SendFailed(e.to_string()))
240    }
241
242    async fn broadcast(&self, message: Message) -> Result<(), TransportError> {
243        let containers = self.containers.read().await;
244        let json = serde_json::to_string(&message)
245            .map_err(|e| TransportError::SendFailed(e.to_string()))?;
246
247        for (_, container) in containers.iter() {
248            let _ = container.stdin_tx.send(json.clone());
249        }
250        Ok(())
251    }
252}