1#![cfg_attr(docsrs, feature(doc_cfg))]
6
7use async_trait::async_trait;
13use std::ffi::{CStr, CString};
14use std::os::raw::{c_char, c_int, c_void};
15use std::ptr;
16use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, AtomicUsize, Ordering};
17use std::sync::Arc;
18use stdiobus_core::{Backend, BusMessage, BusState, BusStats, ConfigSource, Error, Result};
19use stdiobus_ffi::*;
20use tokio::sync::{mpsc, Mutex};
21
22struct BusPtr(AtomicUsize);
24
25impl BusPtr {
26 fn new() -> Self {
27 Self(AtomicUsize::new(0))
28 }
29
30 fn set(&self, ptr: *mut stdio_bus_t) {
31 self.0.store(ptr as usize, Ordering::SeqCst);
32 }
33
34 fn get(&self) -> Option<*mut stdio_bus_t> {
35 let ptr = self.0.load(Ordering::SeqCst);
36 if ptr == 0 {
37 None
38 } else {
39 Some(ptr as *mut stdio_bus_t)
40 }
41 }
42
43 fn take(&self) -> Option<*mut stdio_bus_t> {
44 let ptr = self.0.swap(0, Ordering::SeqCst);
45 if ptr == 0 {
46 None
47 } else {
48 Some(ptr as *mut stdio_bus_t)
49 }
50 }
51}
52
53unsafe impl Send for BusPtr {}
54unsafe impl Sync for BusPtr {}
55
56struct CtxPtr(*mut CallbackContext);
60unsafe impl Send for CtxPtr {}
61unsafe impl Sync for CtxPtr {}
62
63fn state_to_u8(s: BusState) -> u8 {
64 match s {
65 BusState::Created => 0,
66 BusState::Starting => 1,
67 BusState::Running => 2,
68 BusState::Stopping => 3,
69 BusState::Stopped => 4,
70 }
71}
72
73fn u8_to_state(v: u8) -> BusState {
74 match v {
75 0 => BusState::Created,
76 1 => BusState::Starting,
77 2 => BusState::Running,
78 3 => BusState::Stopping,
79 4 => BusState::Stopped,
80 _ => BusState::Created,
81 }
82}
83
84struct CallbackContext {
90 alive: AtomicBool,
92 message_tx: mpsc::Sender<BusMessage>,
93 stats: Arc<Stats>,
94}
95
96pub struct NativeBackend {
98 bus: Arc<BusPtr>,
99 config: InternalConfig,
100 state: Arc<AtomicU8>,
101 message_tx: mpsc::Sender<BusMessage>,
102 message_rx: Mutex<Option<mpsc::Receiver<BusMessage>>>,
103 stats: Arc<Stats>,
104 running: Arc<AtomicBool>,
105 callback_ctx: Mutex<Option<CtxPtr>>,
107}
108
109enum InternalConfig {
111 Path(String),
112 Json(String),
113}
114
115struct Stats {
116 messages_in: AtomicU64,
117 messages_out: AtomicU64,
118 bytes_in: AtomicU64,
119 bytes_out: AtomicU64,
120 worker_restarts: AtomicU64,
121 routing_errors: AtomicU64,
122}
123
124impl NativeBackend {
125 pub fn new(config_path: &str) -> Result<Self> {
127 Self::create(InternalConfig::Path(config_path.to_string()))
128 }
129
130 pub fn from_config_source(source: &ConfigSource) -> Result<Self> {
132 let internal = match source {
133 ConfigSource::Path(p) => InternalConfig::Path(p.clone()),
134 ConfigSource::Config(cfg) => {
135 let json = cfg.to_json().map_err(|e| Error::InvalidArgument {
136 message: format!("Failed to serialize config: {}", e),
137 })?;
138 InternalConfig::Json(json)
139 }
140 };
141 Self::create(internal)
142 }
143
144 fn create(config: InternalConfig) -> Result<Self> {
145 let (tx, rx) = mpsc::channel(1000);
146
147 Ok(Self {
148 bus: Arc::new(BusPtr::new()),
149 config,
150 state: Arc::new(AtomicU8::new(0)),
151 message_tx: tx,
152 message_rx: Mutex::new(Some(rx)),
153 stats: Arc::new(Stats {
154 messages_in: AtomicU64::new(0),
155 messages_out: AtomicU64::new(0),
156 bytes_in: AtomicU64::new(0),
157 bytes_out: AtomicU64::new(0),
158 worker_restarts: AtomicU64::new(0),
159 routing_errors: AtomicU64::new(0),
160 }),
161 running: Arc::new(AtomicBool::new(false)),
162 callback_ctx: Mutex::new(None),
163 })
164 }
165
166 fn get_state(&self) -> BusState {
167 u8_to_state(self.state.load(Ordering::SeqCst))
168 }
169
170 fn set_state(&self, state: BusState) {
171 self.state.store(state_to_u8(state), Ordering::SeqCst);
172 }
173}
174
175impl Drop for NativeBackend {
176 fn drop(&mut self) {
177 self.running.store(false, Ordering::SeqCst);
178
179 if let Ok(guard) = self.callback_ctx.try_lock() {
181 if let Some(ref wrapper) = *guard {
182 unsafe { (*wrapper.0).alive.store(false, Ordering::SeqCst) };
183 }
184 }
185
186 if let Some(bus) = self.bus.take() {
188 unsafe {
189 stdio_bus_stop(bus, 1);
190 stdio_bus_destroy(bus);
191 }
192 }
193
194 if let Ok(mut guard) = self.callback_ctx.try_lock() {
196 if let Some(wrapper) = guard.take() {
197 unsafe { drop(Box::from_raw(wrapper.0)) };
198 }
199 }
200 }
201}
202
203
204#[async_trait]
205impl Backend for NativeBackend {
206 async fn start(&self) -> Result<()> {
207 let current_state = self.get_state();
208 if !current_state.can_start() {
209 return Err(Error::InvalidState {
210 expected: "CREATED or STOPPED".to_string(),
211 actual: current_state.to_string(),
212 });
213 }
214
215 self.set_state(BusState::Starting);
216
217 let ctx = Box::new(CallbackContext {
219 alive: AtomicBool::new(true),
220 message_tx: self.message_tx.clone(),
221 stats: self.stats.clone(),
222 });
223 let ctx_ptr = Box::into_raw(ctx);
224 let ctx_usize = ctx_ptr as usize;
225
226 *self.callback_ctx.lock().await = Some(CtxPtr(ctx_ptr));
228
229 let config = match &self.config {
231 InternalConfig::Path(p) => InternalConfig::Path(p.clone()),
232 InternalConfig::Json(j) => InternalConfig::Json(j.clone()),
233 };
234
235 let bus = tokio::task::spawn_blocking(move || {
236 let (path_ptr, json_ptr, _path_cstr, _json_cstr) = match &config {
238 InternalConfig::Path(p) => {
239 let cstr = CString::new(p.as_str()).map_err(|_| Error::InvalidArgument {
240 message: "Invalid config path".to_string(),
241 })?;
242 let ptr = cstr.as_ptr();
243 (ptr, ptr::null(), Some(cstr), None)
244 }
245 InternalConfig::Json(j) => {
246 let cstr = CString::new(j.as_str()).map_err(|_| Error::InvalidArgument {
247 message: "Invalid config JSON (contains null byte)".to_string(),
248 })?;
249 let ptr = cstr.as_ptr();
250 (ptr::null(), ptr, None, Some(cstr))
251 }
252 };
253
254 let listener = stdio_bus_listener_config_t {
255 mode: stdio_bus_listen_mode_t::STDIO_BUS_LISTEN_NONE,
256 tcp_host: ptr::null(),
257 tcp_port: 0,
258 unix_path: ptr::null(),
259 };
260
261 let options = stdio_bus_options_t {
262 config_path: path_ptr,
263 config_json: json_ptr,
264 listener,
265 on_message: Some(on_message_callback),
266 on_error: Some(on_error_callback),
267 on_log: Some(on_log_callback),
268 on_worker: None,
269 on_client_connect: None,
270 on_client_disconnect: None,
271 user_data: ctx_usize as *mut c_void,
272 log_level: 1,
273 };
274
275 let bus = unsafe { stdio_bus_create(&options) };
276 if bus.is_null() {
277 return Err(Error::InternalError {
278 message: "Failed to create bus".to_string(),
279 });
280 }
281
282 let result = unsafe { stdio_bus_start(bus) };
283 if result != STDIO_BUS_OK {
284 unsafe { stdio_bus_destroy(bus) };
285 return Err(Error::InternalError {
286 message: format!("Failed to start bus: error code {}", result),
287 });
288 }
289
290 Ok(bus as usize)
291 })
292 .await
293 .map_err(|e| Error::InternalError {
294 message: format!("Task join error: {}", e),
295 })??;
296
297 self.bus.set(bus as *mut stdio_bus_t);
298 self.set_state(BusState::Running);
299 self.running.store(true, Ordering::SeqCst);
300
301 let bus_ptr = self.bus.clone();
303 let running = self.running.clone();
304
305 tokio::spawn(async move {
306 while running.load(Ordering::SeqCst) {
307 if let Some(bus) = bus_ptr.get() {
308 let bus_usize = bus as usize;
309 let _ = tokio::task::spawn_blocking(move || {
310 unsafe { stdio_bus_step(bus_usize as *mut stdio_bus_t, 10) };
311 })
312 .await;
313 }
314 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
315 }
316 });
317
318 Ok(())
319 }
320
321 async fn stop(&self, timeout_secs: u32) -> Result<()> {
322 self.running.store(false, Ordering::SeqCst);
323 self.set_state(BusState::Stopping);
324
325 {
327 let guard = self.callback_ctx.lock().await;
328 if let Some(ref wrapper) = *guard {
329 unsafe { (*wrapper.0).alive.store(false, Ordering::SeqCst) };
330 }
331 }
332
333 if let Some(bus) = self.bus.take() {
335 let bus_usize = bus as usize;
336 let timeout = timeout_secs as c_int;
337
338 tokio::task::spawn_blocking(move || {
339 unsafe {
340 stdio_bus_stop(bus_usize as *mut stdio_bus_t, timeout);
341 stdio_bus_destroy(bus_usize as *mut stdio_bus_t);
342 }
343 })
344 .await
345 .map_err(|e| Error::InternalError {
346 message: format!("Task join error: {}", e),
347 })?;
348 }
349
350 {
352 let mut guard = self.callback_ctx.lock().await;
353 if let Some(wrapper) = guard.take() {
354 unsafe { drop(Box::from_raw(wrapper.0)) };
355 }
356 }
357
358 self.set_state(BusState::Stopped);
359 Ok(())
360 }
361
362 async fn send(&self, message: &str) -> Result<()> {
363 let bus = self.bus.get().ok_or_else(|| Error::InvalidState {
364 expected: "RUNNING".to_string(),
365 actual: "not initialized".to_string(),
366 })?;
367
368 let bus_usize = bus as usize;
369 let msg = message.to_string();
370 let msg_len = msg.len();
371
372 let result = tokio::task::spawn_blocking(move || {
373 unsafe {
374 stdio_bus_ingest(
375 bus_usize as *mut stdio_bus_t,
376 msg.as_ptr() as *const c_char,
377 msg_len,
378 )
379 }
380 })
381 .await
382 .map_err(|e| Error::InternalError {
383 message: format!("Task join error: {}", e),
384 })?;
385
386 if result != STDIO_BUS_OK {
387 return Err(Error::TransportError {
388 message: format!("Failed to send message: error code {}", result),
389 });
390 }
391
392 self.stats.messages_in.fetch_add(1, Ordering::Relaxed);
393 self.stats.bytes_in.fetch_add(msg_len as u64, Ordering::Relaxed);
394
395 Ok(())
396 }
397
398 fn state(&self) -> BusState {
399 self.get_state()
400 }
401
402 fn stats(&self) -> BusStats {
403 BusStats {
404 messages_in: self.stats.messages_in.load(Ordering::Relaxed),
405 messages_out: self.stats.messages_out.load(Ordering::Relaxed),
406 bytes_in: self.stats.bytes_in.load(Ordering::Relaxed),
407 bytes_out: self.stats.bytes_out.load(Ordering::Relaxed),
408 worker_restarts: self.stats.worker_restarts.load(Ordering::Relaxed),
409 routing_errors: self.stats.routing_errors.load(Ordering::Relaxed),
410 ..Default::default()
411 }
412 }
413
414 fn worker_count(&self) -> i32 {
415 self.bus
416 .get()
417 .map(|bus| unsafe { stdio_bus_worker_count(bus) })
418 .unwrap_or(-1)
419 }
420
421 fn client_count(&self) -> i32 {
422 self.bus
423 .get()
424 .map(|bus| unsafe { stdio_bus_client_count(bus) })
425 .unwrap_or(0)
426 }
427
428 fn subscribe(&self) -> Option<mpsc::Receiver<BusMessage>> {
429 self.message_rx.try_lock().ok().and_then(|mut rx| rx.take())
430 }
431
432 fn backend_type(&self) -> &'static str {
433 "native"
434 }
435}
436
437
438extern "C" fn on_message_callback(
439 _bus: *mut stdio_bus_t,
440 msg: *const c_char,
441 len: usize,
442 user_data: *mut c_void,
443) {
444 let _ = std::panic::catch_unwind(|| {
446 if user_data.is_null() {
447 return;
448 }
449
450 let ctx = unsafe { &*(user_data as *const CallbackContext) };
451
452 if !ctx.alive.load(Ordering::SeqCst) {
454 return;
455 }
456
457 let slice = unsafe { std::slice::from_raw_parts(msg as *const u8, len) };
458 if let Ok(json) = std::str::from_utf8(slice) {
459 ctx.stats.messages_out.fetch_add(1, Ordering::Relaxed);
460 ctx.stats.bytes_out.fetch_add(len as u64, Ordering::Relaxed);
461
462 let message = BusMessage { json: json.to_string() };
463 if let Err(e) = ctx.message_tx.try_send(message) {
464 tracing::warn!("Message channel full: {}", e);
465 }
466 }
467 });
468}
469
470extern "C" fn on_error_callback(
471 _bus: *mut stdio_bus_t,
472 code: c_int,
473 msg: *const c_char,
474 user_data: *mut c_void,
475) {
476 let _ = std::panic::catch_unwind(|| {
477 if !user_data.is_null() {
478 let ctx = unsafe { &*(user_data as *const CallbackContext) };
479 if !ctx.alive.load(Ordering::SeqCst) {
480 return;
481 }
482 }
483 let msg = unsafe { CStr::from_ptr(msg) };
484 tracing::error!("Bus error {}: {:?}", code, msg);
485 });
486}
487
488extern "C" fn on_log_callback(
489 _bus: *mut stdio_bus_t,
490 level: c_int,
491 msg: *const c_char,
492 user_data: *mut c_void,
493) {
494 let _ = std::panic::catch_unwind(|| {
495 if !user_data.is_null() {
496 let ctx = unsafe { &*(user_data as *const CallbackContext) };
497 if !ctx.alive.load(Ordering::SeqCst) {
498 return;
499 }
500 }
501 let msg = unsafe { CStr::from_ptr(msg) };
502 match level {
503 0 => tracing::debug!("{:?}", msg),
504 1 => tracing::info!("{:?}", msg),
505 2 => tracing::warn!("{:?}", msg),
506 _ => tracing::error!("{:?}", msg),
507 }
508 });
509}
510
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_native_backend_new() {
518 let result = NativeBackend::new("./test-config.json");
519 assert!(result.is_ok());
520 }
521
522 #[test]
523 fn test_native_backend_initial_state() {
524 let backend = NativeBackend::new("./test-config.json").unwrap();
525 assert_eq!(backend.state(), BusState::Created);
526 }
527
528 #[test]
529 fn test_native_backend_stats_initial() {
530 let backend = NativeBackend::new("./test-config.json").unwrap();
531 let stats = backend.stats();
532
533 assert_eq!(stats.messages_in, 0);
534 assert_eq!(stats.messages_out, 0);
535 assert_eq!(stats.bytes_in, 0);
536 assert_eq!(stats.bytes_out, 0);
537 }
538
539 #[test]
540 fn test_native_backend_type() {
541 let backend = NativeBackend::new("./test-config.json").unwrap();
542 assert_eq!(backend.backend_type(), "native");
543 }
544
545 #[test]
546 fn test_native_backend_worker_count_not_started() {
547 let backend = NativeBackend::new("./test-config.json").unwrap();
548 assert_eq!(backend.worker_count(), -1);
549 }
550
551 #[test]
552 fn test_native_backend_client_count_not_started() {
553 let backend = NativeBackend::new("./test-config.json").unwrap();
554 assert_eq!(backend.client_count(), 0);
555 }
556
557 #[test]
558 fn test_native_backend_subscribe() {
559 let backend = NativeBackend::new("./test-config.json").unwrap();
560
561 let rx = backend.subscribe();
563 assert!(rx.is_some());
564
565 let rx2 = backend.subscribe();
567 assert!(rx2.is_none());
568 }
569
570 #[test]
571 fn test_state_conversion() {
572 assert_eq!(u8_to_state(0), BusState::Created);
573 assert_eq!(u8_to_state(1), BusState::Starting);
574 assert_eq!(u8_to_state(2), BusState::Running);
575 assert_eq!(u8_to_state(3), BusState::Stopping);
576 assert_eq!(u8_to_state(4), BusState::Stopped);
577 assert_eq!(u8_to_state(255), BusState::Created);
578
579 assert_eq!(state_to_u8(BusState::Created), 0);
580 assert_eq!(state_to_u8(BusState::Starting), 1);
581 assert_eq!(state_to_u8(BusState::Running), 2);
582 assert_eq!(state_to_u8(BusState::Stopping), 3);
583 assert_eq!(state_to_u8(BusState::Stopped), 4);
584 }
585
586 #[test]
587 fn test_bus_ptr_operations() {
588 let ptr = BusPtr::new();
589 assert!(ptr.get().is_none());
590
591 let fake_ptr = 0x12345678 as *mut stdio_bus_t;
592 ptr.set(fake_ptr);
593
594 assert!(ptr.get().is_some());
595 assert_eq!(ptr.get().unwrap() as usize, 0x12345678);
596
597 let taken = ptr.take();
598 assert!(taken.is_some());
599 assert!(ptr.get().is_none());
600 }
601
602 #[tokio::test]
603 async fn test_native_backend_start_invalid_state() {
604 let backend = NativeBackend::new("./test-config.json").unwrap();
605
606 backend.state.store(state_to_u8(BusState::Running), Ordering::SeqCst);
607
608 let result = backend.start().await;
609 assert!(result.is_err());
610
611 if let Err(Error::InvalidState { expected, actual }) = result {
612 assert!(expected.contains("CREATED"));
613 assert!(actual.contains("RUNNING"));
614 }
615 }
616
617 #[tokio::test]
618 async fn test_native_backend_send_not_started() {
619 let backend = NativeBackend::new("./test-config.json").unwrap();
620
621 let result = backend.send(r#"{"test": true}"#).await;
622 assert!(result.is_err());
623
624 if let Err(Error::InvalidState { .. }) = result {
625 } else {
627 panic!("Expected InvalidState error");
628 }
629 }
630
631 #[tokio::test]
632 async fn test_native_backend_stop_not_started() {
633 let backend = NativeBackend::new("./test-config.json").unwrap();
634
635 let result = backend.stop(1).await;
636 assert!(result.is_ok());
637 assert_eq!(backend.state(), BusState::Stopped);
638 }
639}