1#![cfg_attr(docsrs, feature(doc_cfg))]
6
7use async_trait::async_trait;
12use std::path::Path;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use stdiobus_core::{Backend, BusMessage, BusState, BusStats, DockerOptions, Error, Result};
16use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
17use tokio::net::tcp::OwnedWriteHalf;
18use tokio::net::TcpStream;
19use tokio::process::Command;
20use tokio::sync::{mpsc, Mutex, RwLock};
21
22pub struct DockerBackend {
24 config_path: String,
25 options: DockerOptions,
26 state: RwLock<BusState>,
27 container_id: RwLock<Option<String>>,
28 writer: RwLock<Option<OwnedWriteHalf>>,
29 message_tx: mpsc::Sender<BusMessage>,
30 message_rx: Mutex<Option<mpsc::Receiver<BusMessage>>>,
31 stats: Arc<Stats>,
32}
33
34struct Stats {
35 messages_in: AtomicU64,
36 messages_out: AtomicU64,
37 bytes_in: AtomicU64,
38 bytes_out: AtomicU64,
39}
40
41impl DockerBackend {
42 pub fn new(config_path: &str, options: DockerOptions) -> Result<Self> {
44 if !Path::new(config_path).exists() {
46 return Err(Error::InvalidArgument {
47 message: format!("Config file not found: {}", config_path),
48 });
49 }
50
51 let (tx, rx) = mpsc::channel(1000);
52
53 Ok(Self {
54 config_path: config_path.to_string(),
55 options,
56 state: RwLock::new(BusState::Created),
57 container_id: RwLock::new(None),
58 writer: RwLock::new(None),
59 message_tx: tx,
60 message_rx: Mutex::new(Some(rx)),
61 stats: Arc::new(Stats {
62 messages_in: AtomicU64::new(0),
63 messages_out: AtomicU64::new(0),
64 bytes_in: AtomicU64::new(0),
65 bytes_out: AtomicU64::new(0),
66 }),
67 })
68 }
69
70 async fn pull_image(&self) -> Result<()> {
72 if self.options.pull_policy == "never" {
73 return Ok(());
74 }
75
76 if self.options.pull_policy == "if-missing" {
77 let output = Command::new(&self.options.engine_path)
78 .args(["image", "inspect", &self.options.image])
79 .output()
80 .await?;
81
82 if output.status.success() {
83 return Ok(());
84 }
85 }
86
87 tracing::info!("Pulling Docker image: {}", self.options.image);
88
89 let status = Command::new(&self.options.engine_path)
90 .args(["pull", &self.options.image])
91 .status()
92 .await?;
93
94 if !status.success() {
95 return Err(Error::TransportError {
96 message: format!("Failed to pull image: {}", self.options.image),
97 });
98 }
99
100 Ok(())
101 }
102
103 async fn start_container(&self) -> Result<String> {
105 let container_name = format!(
106 "{}-{}",
107 self.options.container_name_prefix,
108 uuid::Uuid::new_v4()
109 );
110
111 let config_path = std::fs::canonicalize(&self.config_path)?;
112 let config_mount = format!("{}:/config.json:ro", config_path.display());
113
114 let listener = std::net::TcpListener::bind("127.0.0.1:0").map_err(|e| {
116 Error::TransportError {
117 message: format!("Failed to find available port: {}", e),
118 }
119 })?;
120 let port = listener.local_addr().map_err(|e| {
121 Error::TransportError {
122 message: format!("Failed to get local address: {}", e),
123 }
124 })?.port();
125 drop(listener);
127
128 let mut args = vec![
129 "run".to_string(),
130 "-d".to_string(),
131 "--rm".to_string(),
132 "--name".to_string(),
133 container_name.clone(),
134 "-v".to_string(),
135 config_mount,
136 "-p".to_string(),
137 format!("127.0.0.1:{}:8080", port),
138 ];
139
140 for (key, value) in &self.options.env {
141 args.push("-e".to_string());
142 args.push(format!("{}={}", key, value));
143 }
144
145 args.extend(self.options.extra_args.clone());
146 args.push(self.options.image.clone());
147 args.push("--config".to_string());
148 args.push("/config.json".to_string());
149 args.push("--tcp".to_string());
150 args.push("0.0.0.0:8080".to_string());
151
152 tracing::info!("Starting container: {}", container_name);
153
154 let output = Command::new(&self.options.engine_path)
155 .args(&args)
156 .output()
157 .await?;
158
159 if !output.status.success() {
160 let stderr = String::from_utf8_lossy(&output.stderr);
161 return Err(Error::TransportError {
162 message: format!("Failed to start container: {}", stderr),
163 });
164 }
165
166 let container_id = String::from_utf8_lossy(&output.stdout).trim().to_string();
167 tracing::info!("Container started: {}", container_id);
168
169 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
170
171 let addr = format!("127.0.0.1:{}", port);
172 let stream = tokio::time::timeout(
173 self.options.startup_timeout,
174 Self::connect_with_retry(&addr),
175 )
176 .await
177 .map_err(|_| Error::Timeout {
178 timeout_ms: self.options.startup_timeout.as_millis() as u64,
179 })??;
180
181 let (reader, writer) = stream.into_split();
182 *self.writer.write().await = Some(writer);
183
184 let tx = self.message_tx.clone();
185 let stats = self.stats.clone();
186 tokio::spawn(async move {
187 let mut reader = BufReader::new(reader);
188 let mut line = String::new();
189
190 loop {
191 line.clear();
192 match reader.read_line(&mut line).await {
193 Ok(0) => break,
194 Ok(n) => {
195 stats.messages_out.fetch_add(1, Ordering::Relaxed);
196 stats.bytes_out.fetch_add(n as u64, Ordering::Relaxed);
197
198 let json = line.trim().to_string();
199 if !json.is_empty() {
200 let _ = tx.send(BusMessage { json }).await;
201 }
202 }
203 Err(e) => {
204 tracing::error!("Read error: {}", e);
205 break;
206 }
207 }
208 }
209 });
210
211 Ok(container_id)
212 }
213
214 async fn connect_with_retry(addr: &str) -> Result<TcpStream> {
215 let mut attempts = 0;
216 loop {
217 match TcpStream::connect(addr).await {
218 Ok(stream) => return Ok(stream),
219 Err(_e) if attempts < 30 => {
220 attempts += 1;
221 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
222 }
223 Err(e) => {
224 return Err(Error::TransportError {
225 message: format!("Failed to connect to {}: {}", addr, e),
226 });
227 }
228 }
229 }
230 }
231
232 async fn stop_container(&self, timeout_secs: u32) -> Result<()> {
233 let container_id = self.container_id.read().await.clone();
234
235 if let Some(id) = container_id {
236 tracing::info!("Stopping container: {}", id);
237 let _ = Command::new(&self.options.engine_path)
238 .args(["stop", "-t", &timeout_secs.to_string(), &id])
239 .status()
240 .await;
241 }
242
243 Ok(())
244 }
245}
246
247#[async_trait]
248impl Backend for DockerBackend {
249 async fn start(&self) -> Result<()> {
250 {
251 let state = self.state.read().await;
252 if !state.can_start() {
253 return Err(Error::InvalidState {
254 expected: "CREATED or STOPPED".to_string(),
255 actual: state.to_string(),
256 });
257 }
258 }
259
260 *self.state.write().await = BusState::Starting;
261 self.pull_image().await?;
262
263 let container_id = self.start_container().await?;
264 *self.container_id.write().await = Some(container_id);
265
266 *self.state.write().await = BusState::Running;
267 Ok(())
268 }
269
270 async fn stop(&self, timeout_secs: u32) -> Result<()> {
271 *self.state.write().await = BusState::Stopping;
272 self.stop_container(timeout_secs).await?;
273 *self.state.write().await = BusState::Stopped;
274 Ok(())
275 }
276
277 async fn send(&self, message: &str) -> Result<()> {
278 let mut writer_guard = self.writer.write().await;
279 if let Some(ref mut writer) = *writer_guard {
280 let msg = format!("{}\n", message);
281 writer.write_all(msg.as_bytes()).await?;
282 self.stats.messages_in.fetch_add(1, Ordering::Relaxed);
283 self.stats.bytes_in.fetch_add(msg.len() as u64, Ordering::Relaxed);
284 Ok(())
285 } else {
286 Err(Error::InvalidState {
287 expected: "RUNNING with active connection".to_string(),
288 actual: "no connection".to_string(),
289 })
290 }
291 }
292
293 fn state(&self) -> BusState {
294 self.state.try_read().map(|s| *s).unwrap_or(BusState::Created)
295 }
296
297 fn stats(&self) -> BusStats {
298 BusStats {
299 messages_in: self.stats.messages_in.load(Ordering::Relaxed),
300 messages_out: self.stats.messages_out.load(Ordering::Relaxed),
301 bytes_in: self.stats.bytes_in.load(Ordering::Relaxed),
302 bytes_out: self.stats.bytes_out.load(Ordering::Relaxed),
303 ..Default::default()
304 }
305 }
306
307 fn worker_count(&self) -> i32 {
308 -1
309 }
310
311 fn client_count(&self) -> i32 {
312 -1
313 }
314
315 fn subscribe(&self) -> Option<mpsc::Receiver<BusMessage>> {
316 self.message_rx.try_lock().ok().and_then(|mut rx| rx.take())
317 }
318
319 fn backend_type(&self) -> &'static str {
320 "docker"
321 }
322}
323
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use tempfile::NamedTempFile;
329 use std::io::Write;
330
331 fn create_test_config() -> NamedTempFile {
332 let mut file = NamedTempFile::new().unwrap();
333 writeln!(file, r#"{{"pools": [{{"id": "test", "command": "echo", "args": ["hello"], "instances": 1}}]}}"#).unwrap();
334 file
335 }
336
337 #[test]
338 fn test_docker_backend_new_missing_config() {
339 let result = DockerBackend::new("/nonexistent/path.json", DockerOptions::default());
340 assert!(result.is_err());
341
342 if let Err(Error::InvalidArgument { message }) = result {
343 assert!(message.contains("not found"));
344 } else {
345 panic!("Expected InvalidArgument error");
346 }
347 }
348
349 #[test]
350 fn test_docker_backend_new_valid_config() {
351 let config = create_test_config();
352 let result = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default());
353 assert!(result.is_ok());
354 }
355
356 #[test]
357 fn test_docker_backend_initial_state() {
358 let config = create_test_config();
359 let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
360
361 assert_eq!(backend.state(), BusState::Created);
362 assert_eq!(backend.worker_count(), -1); assert_eq!(backend.client_count(), -1); assert_eq!(backend.backend_type(), "docker");
365 }
366
367 #[test]
368 fn test_docker_backend_stats_initial() {
369 let config = create_test_config();
370 let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
371
372 let stats = backend.stats();
373 assert_eq!(stats.messages_in, 0);
374 assert_eq!(stats.messages_out, 0);
375 assert_eq!(stats.bytes_in, 0);
376 assert_eq!(stats.bytes_out, 0);
377 }
378
379 #[test]
380 fn test_docker_backend_subscribe() {
381 let config = create_test_config();
382 let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
383
384 let rx = backend.subscribe();
386 assert!(rx.is_some());
387
388 let rx2 = backend.subscribe();
390 assert!(rx2.is_none());
391 }
392
393 #[tokio::test]
394 async fn test_docker_backend_start_invalid_state() {
395 let config = create_test_config();
396 let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
397
398 *backend.state.write().await = BusState::Running;
400
401 let result = backend.start().await;
402 assert!(result.is_err());
403
404 if let Err(Error::InvalidState { expected, actual }) = result {
405 assert!(expected.contains("CREATED"));
406 assert!(actual.contains("RUNNING"));
407 } else {
408 panic!("Expected InvalidState error");
409 }
410 }
411
412 #[tokio::test]
413 async fn test_docker_backend_send_not_connected() {
414 let config = create_test_config();
415 let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
416
417 let result = backend.send(r#"{"test": true}"#).await;
418 assert!(result.is_err());
419
420 if let Err(Error::InvalidState { .. }) = result {
421 } else {
423 panic!("Expected InvalidState error");
424 }
425 }
426
427 #[tokio::test]
428 async fn test_docker_backend_stop_from_created() {
429 let config = create_test_config();
430 let backend = DockerBackend::new(config.path().to_str().unwrap(), DockerOptions::default()).unwrap();
431
432 let result = backend.stop(1).await;
434 assert!(result.is_ok());
435 assert_eq!(backend.state(), BusState::Stopped);
436 }
437
438 #[test]
439 fn test_docker_options_in_backend() {
440 let config = create_test_config();
441 let opts = DockerOptions {
442 image: "custom:latest".to_string(),
443 pull_policy: "never".to_string(),
444 ..Default::default()
445 };
446
447 let backend = DockerBackend::new(config.path().to_str().unwrap(), opts).unwrap();
448 assert_eq!(backend.backend_type(), "docker");
449 }
450}