Skip to main content

stdiobus_backend_docker/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2026-present Raman Marozau <raman@worktif.com>
3// Copyright (c) 2026-present stdiobus contributors
4
5#![cfg_attr(docsrs, feature(doc_cfg))]
6
7//! Docker backend for stdio_bus
8//!
9//! Runs stdio_bus in a Docker container and communicates via TCP.
10
11use async_trait::async_trait;
12use std::path::Path;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use stdiobus_core::{Backend, BusMessage, BusState, BusStats, DockerOptions, Error, Result};
16use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
17use tokio::net::tcp::OwnedWriteHalf;
18use tokio::net::TcpStream;
19use tokio::process::Command;
20use tokio::sync::{mpsc, Mutex, RwLock};
21
22/// Docker backend implementation
23pub struct DockerBackend {
24    config_path: String,
25    options: DockerOptions,
26    state: RwLock<BusState>,
27    container_id: RwLock<Option<String>>,
28    writer: RwLock<Option<OwnedWriteHalf>>,
29    message_tx: mpsc::Sender<BusMessage>,
30    message_rx: Mutex<Option<mpsc::Receiver<BusMessage>>>,
31    stats: Arc<Stats>,
32}
33
34struct Stats {
35    messages_in: AtomicU64,
36    messages_out: AtomicU64,
37    bytes_in: AtomicU64,
38    bytes_out: AtomicU64,
39}
40
41impl DockerBackend {
42    /// Create a new Docker backend
43    pub fn new(config_path: &str, options: DockerOptions) -> Result<Self> {
44        // Verify config file exists
45        if !Path::new(config_path).exists() {
46            return Err(Error::InvalidArgument {
47                message: format!("Config file not found: {}", config_path),
48            });
49        }
50
51        let (tx, rx) = mpsc::channel(1000);
52
53        Ok(Self {
54            config_path: config_path.to_string(),
55            options,
56            state: RwLock::new(BusState::Created),
57            container_id: RwLock::new(None),
58            writer: RwLock::new(None),
59            message_tx: tx,
60            message_rx: Mutex::new(Some(rx)),
61            stats: Arc::new(Stats {
62                messages_in: AtomicU64::new(0),
63                messages_out: AtomicU64::new(0),
64                bytes_in: AtomicU64::new(0),
65                bytes_out: AtomicU64::new(0),
66            }),
67        })
68    }
69
70    /// Pull Docker image if needed
71    async fn pull_image(&self) -> Result<()> {
72        if self.options.pull_policy == "never" {
73            return Ok(());
74        }
75
76        if self.options.pull_policy == "if-missing" {
77            let output = Command::new(&self.options.engine_path)
78                .args(["image", "inspect", &self.options.image])
79                .output()
80                .await?;
81
82            if output.status.success() {
83                return Ok(());
84            }
85        }
86
87        tracing::info!("Pulling Docker image: {}", self.options.image);
88
89        let status = Command::new(&self.options.engine_path)
90            .args(["pull", &self.options.image])
91            .status()
92            .await?;
93
94        if !status.success() {
95            return Err(Error::TransportError {
96                message: format!("Failed to pull image: {}", self.options.image),
97            });
98        }
99
100        Ok(())
101    }
102
103    /// Start the Docker container
104    async fn start_container(&self) -> Result<String> {
105        let container_name = format!(
106            "{}-{}",
107            self.options.container_name_prefix,
108            uuid::Uuid::new_v4()
109        );
110
111        let config_path = std::fs::canonicalize(&self.config_path)?;
112        let config_mount = format!("{}:/config.json:ro", config_path.display());
113
114        // Bind to port 0 to let the OS assign an available port, then read it back
115        let listener = std::net::TcpListener::bind("127.0.0.1:0").map_err(|e| {
116            Error::TransportError {
117                message: format!("Failed to find available port: {}", e),
118            }
119        })?;
120        let port = listener.local_addr().map_err(|e| {
121            Error::TransportError {
122                message: format!("Failed to get local address: {}", e),
123            }
124        })?.port();
125        // Drop the listener to free the port before Docker binds to it
126        drop(listener);
127
128        let mut args = vec![
129            "run".to_string(),
130            "-d".to_string(),
131            "--rm".to_string(),
132            "--name".to_string(),
133            container_name.clone(),
134            "-v".to_string(),
135            config_mount,
136            "-p".to_string(),
137            format!("127.0.0.1:{}:8080", port),
138        ];
139
140        for (key, value) in &self.options.env {
141            args.push("-e".to_string());
142            args.push(format!("{}={}", key, value));
143        }
144
145        args.extend(self.options.extra_args.clone());
146        args.push(self.options.image.clone());
147        args.push("--config".to_string());
148        args.push("/config.json".to_string());
149        args.push("--tcp".to_string());
150        args.push("0.0.0.0:8080".to_string());
151
152        tracing::info!("Starting container: {}", container_name);
153
154        let output = Command::new(&self.options.engine_path)
155            .args(&args)
156            .output()
157            .await?;
158
159        if !output.status.success() {
160            let stderr = String::from_utf8_lossy(&output.stderr);
161            return Err(Error::TransportError {
162                message: format!("Failed to start container: {}", stderr),
163            });
164        }
165
166        let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
167        tracing::info!("Container started: {}", container_id);
168
169        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
170
171        let addr = format!("127.0.0.1:{}", port);
172        let stream = tokio::time::timeout(
173            self.options.startup_timeout,
174            Self::connect_with_retry(&addr),
175        )
176        .await
177        .map_err(|_| Error::Timeout {
178            timeout_ms: self.options.startup_timeout.as_millis() as u64,
179        })??;
180
181        let (reader, writer) = stream.into_split();
182        *self.writer.write().await = Some(writer);
183
184        let tx = self.message_tx.clone();
185        let stats = self.stats.clone();
186        tokio::spawn(async move {
187            let mut reader = BufReader::new(reader);
188            let mut line = String::new();
189
190            loop {
191                line.clear();
192                match reader.read_line(&mut line).await {
193                    Ok(0) => break,
194                    Ok(n) => {
195                        stats.messages_out.fetch_add(1, Ordering::Relaxed);
196                        stats.bytes_out.fetch_add(n as u64, Ordering::Relaxed);
197
198                        let json = line.trim().to_string();
199                        if !json.is_empty() {
200                            let _ = tx.send(BusMessage { json }).await;
201                        }
202                    }
203                    Err(e) => {
204                        tracing::error!("Read error: {}", e);
205                        break;
206                    }
207                }
208            }
209        });
210
211        Ok(container_id)
212    }
213
214    async fn connect_with_retry(addr: &str) -> Result<TcpStream> {
215        let mut attempts = 0;
216        loop {
217            match TcpStream::connect(addr).await {
218                Ok(stream) => return Ok(stream),
219                Err(_e) if attempts < 30 => {
220                    attempts += 1;
221                    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
222                }
223                Err(e) => {
224                    return Err(Error::TransportError {
225                        message: format!("Failed to connect to {}: {}", addr, e),
226                    });
227                }
228            }
229        }
230    }
231
232    async fn stop_container(&self, timeout_secs: u32) -> Result<()> {
233        let container_id = self.container_id.read().await.clone();
234
235        if let Some(id) = container_id {
236            tracing::info!("Stopping container: {}", id);
237            let _ = Command::new(&self.options.engine_path)
238                .args(["stop", "-t", &timeout_secs.to_string(), &id])
239                .status()
240                .await;
241        }
242
243        Ok(())
244    }
245}
246
247#[async_trait]
248impl Backend for DockerBackend {
249    async fn start(&self) -> Result<()> {
250        {
251            let state = self.state.read().await;
252            if !state.can_start() {
253                return Err(Error::InvalidState {
254                    expected: "CREATED or STOPPED".to_string(),
255                    actual: state.to_string(),
256                });
257            }
258        }
259
260        *self.state.write().await = BusState::Starting;
261        self.pull_image().await?;
262
263        let container_id = self.start_container().await?;
264        *self.container_id.write().await = Some(container_id);
265
266        *self.state.write().await = BusState::Running;
267        Ok(())
268    }
269
270    async fn stop(&self, timeout_secs: u32) -> Result<()> {
271        *self.state.write().await = BusState::Stopping;
272        self.stop_container(timeout_secs).await?;
273        *self.state.write().await = BusState::Stopped;
274        Ok(())
275    }
276
277    async fn send(&self, message: &str) -> Result<()> {
278        let mut writer_guard = self.writer.write().await;
279        if let Some(ref mut writer) = *writer_guard {
280            let msg = format!("{}\n", message);
281            writer.write_all(msg.as_bytes()).await?;
282            self.stats.messages_in.fetch_add(1, Ordering::Relaxed);
283            self.stats.bytes_in.fetch_add(msg.len() as u64, Ordering::Relaxed);
284            Ok(())
285        } else {
286            Err(Error::InvalidState {
287                expected: "RUNNING with active connection".to_string(),
288                actual: "no connection".to_string(),
289            })
290        }
291    }
292
293    fn state(&self) -> BusState {
294        self.state.try_read().map(|s| *s).unwrap_or(BusState::Created)
295    }
296
297    fn stats(&self) -> BusStats {
298        BusStats {
299            messages_in: self.stats.messages_in.load(Ordering::Relaxed),
300            messages_out: self.stats.messages_out.load(Ordering::Relaxed),
301            bytes_in: self.stats.bytes_in.load(Ordering::Relaxed),
302            bytes_out: self.stats.bytes_out.load(Ordering::Relaxed),
303            ..Default::default()
304        }
305    }
306
307    fn worker_count(&self) -> i32 {
308        -1
309    }
310
311    fn client_count(&self) -> i32 {
312        -1
313    }
314
315    fn subscribe(&self) -> Option<mpsc::Receiver<BusMessage>> {
316        self.message_rx.try_lock().ok().and_then(|mut rx| rx.take())
317    }
318
319    fn backend_type(&self) -> &'static str {
320        "docker"
321    }
322}
323
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use tempfile::NamedTempFile;
329    use std::io::Write;
330
331    fn create_test_config() -> NamedTempFile {
332        let mut file = NamedTempFile::new().unwrap();
333        writeln!(file, r#"{{"pools": [{{"id": "test", "command": "echo", "args": ["hello"], "instances": 1}}]}}"#).unwrap();
334        file
335    }
336
337    #[test]
338    fn test_docker_backend_new_missing_config() {
339        let result = DockerBackend::new("/nonexistent/path.json", DockerOptions::default());
340        assert!(result.is_err());
341        
342        if let Err(Error::InvalidArgument { message }) = result {
343            assert!(message.contains("not found"));
344        } else {
345            panic!("Expected InvalidArgument error");
346        }
347    }
348
349    #[test]
350    fn test_docker_backend_new_valid_config() {
351        let config = create_test_config();
352        let result = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default());
353        assert!(result.is_ok());
354    }
355
356    #[test]
357    fn test_docker_backend_initial_state() {
358        let config = create_test_config();
359        let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
360        
361        assert_eq!(backend.state(), BusState::Created);
362        assert_eq!(backend.worker_count(), -1); // Unknown for docker
363        assert_eq!(backend.client_count(), -1); // Unknown for docker
364        assert_eq!(backend.backend_type(), "docker");
365    }
366
367    #[test]
368    fn test_docker_backend_stats_initial() {
369        let config = create_test_config();
370        let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
371        
372        let stats = backend.stats();
373        assert_eq!(stats.messages_in, 0);
374        assert_eq!(stats.messages_out, 0);
375        assert_eq!(stats.bytes_in, 0);
376        assert_eq!(stats.bytes_out, 0);
377    }
378
379    #[test]
380    fn test_docker_backend_subscribe() {
381        let config = create_test_config();
382        let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
383        
384        // First subscribe should succeed
385        let rx = backend.subscribe();
386        assert!(rx.is_some());
387        
388        // Second subscribe should fail (already taken)
389        let rx2 = backend.subscribe();
390        assert!(rx2.is_none());
391    }
392
393    #[tokio::test]
394    async fn test_docker_backend_start_invalid_state() {
395        let config = create_test_config();
396        let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
397        
398        // Manually set state to Running (simulating already started)
399        *backend.state.write().await = BusState::Running;
400        
401        let result = backend.start().await;
402        assert!(result.is_err());
403        
404        if let Err(Error::InvalidState { expected, actual }) = result {
405            assert!(expected.contains("CREATED"));
406            assert!(actual.contains("RUNNING"));
407        } else {
408            panic!("Expected InvalidState error");
409        }
410    }
411
412    #[tokio::test]
413    async fn test_docker_backend_send_not_connected() {
414        let config = create_test_config();
415        let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
416        
417        let result = backend.send(r#"{"test": true}"#).await;
418        assert!(result.is_err());
419        
420        if let Err(Error::InvalidState { .. }) = result {
421            // Expected
422        } else {
423            panic!("Expected InvalidState error");
424        }
425    }
426
427    #[tokio::test]
428    async fn test_docker_backend_stop_from_created() {
429        let config = create_test_config();
430        let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
431        
432        // Stop from Created state should work (no container to stop)
433        let result = backend.stop(1).await;
434        assert!(result.is_ok());
435        assert_eq!(backend.state(), BusState::Stopped);
436    }
437
438    #[test]
439    fn test_docker_options_in_backend() {
440        let config = create_test_config();
441        let opts = DockerOptions {
442            image: "custom:latest".to_string(),
443            pull_policy: "never".to_string(),
444            ..Default::default()
445        };
446        
447        let backend = DockerBackend::new(config.path().to_str().unwrap(), opts).unwrap();
448        assert_eq!(backend.backend_type(), "docker");
449    }
450}