1use arc_swap::ArcSwap;
2use async_trait::async_trait;
3use bytes::Bytes;
4use parking_lot::Mutex;
5use raw_sync::Timeout;
6use raw_sync::events::{Event, EventImpl, EventInit, EventState};
7use shared_memory::{Shmem, ShmemConf};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use crate::error::{TransportError, TransportResult};
13use crate::transport::utils::spawn_weak_loop;
14use crate::transport::{Transport, TransportStats};
15
16const SHM_MAGIC: u64 = 0x58525043; const MIN_BUFFER_SIZE: usize = 4096;
18const MAX_BUFFER_SIZE: usize = 1024 * 1024 * 1024;
19const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024;
20
21#[repr(C)]
22struct SharedMemoryControlBlock {
23 magic: AtomicU64,
24 write_pos: AtomicU64,
25 read_pos: AtomicU64,
26 capacity: AtomicU64,
27 connected: AtomicBool,
28 error_count: AtomicU64,
29 last_heartbeat: AtomicU64,
30}
31
32impl SharedMemoryControlBlock {
33 fn new(capacity: usize) -> Self {
34 Self {
35 magic: AtomicU64::new(SHM_MAGIC),
36 write_pos: AtomicU64::new(0),
37 read_pos: AtomicU64::new(0),
38 capacity: AtomicU64::new(capacity as u64),
39 connected: AtomicBool::new(true),
40 error_count: AtomicU64::new(0),
41 last_heartbeat: AtomicU64::new(0),
42 }
43 }
44
45 fn is_valid(&self) -> bool {
46 self.magic.load(Ordering::Acquire) == SHM_MAGIC
47 }
48
49 fn available_write(&self) -> usize {
50 let w = self.write_pos.load(Ordering::Acquire) as usize;
51 let r = self.read_pos.load(Ordering::Acquire) as usize;
52 let cap = self.capacity.load(Ordering::Acquire) as usize;
53 let w_norm = w % cap;
54 let r_norm = r % cap;
55 if w_norm >= r_norm {
56 cap - (w_norm - r_norm) - 1
57 } else {
58 (r_norm - w_norm) - 1
59 }
60 }
61
62 fn available_read(&self) -> usize {
63 let w = self.write_pos.load(Ordering::Acquire) as usize;
64 let r = self.read_pos.load(Ordering::Acquire) as usize;
65 let cap = self.capacity.load(Ordering::Acquire) as usize;
66 let w_norm = w % cap;
67 let r_norm = r % cap;
68 if w_norm >= r_norm {
69 w_norm - r_norm
70 } else {
71 cap - (r_norm - w_norm)
72 }
73 }
74
75 fn record_error(&self) {
76 self.error_count.fetch_add(1, Ordering::Relaxed);
77 }
78
79 fn update_heartbeat(&self) {
80 let timestamp = std::time::SystemTime::now()
81 .duration_since(std::time::UNIX_EPOCH)
82 .unwrap()
83 .as_secs();
84 self.last_heartbeat.store(timestamp, Ordering::Release);
85 }
86
87 fn tick_heartbeat(&self) {
88 self.update_heartbeat();
90 }
91
92 fn is_healthy(&self, timeout_secs: u64) -> bool {
93 let now = std::time::SystemTime::now()
94 .duration_since(std::time::UNIX_EPOCH)
95 .unwrap()
96 .as_secs();
97 let last = self.last_heartbeat.load(Ordering::Acquire);
98
99 if last == 0 {
100 return true; }
102
103 now - last < timeout_secs
104 }
105}
106
107struct SharedMemoryRingBuffer {
108 #[allow(dead_code)]
109 shmem: Shmem,
110 control: *mut SharedMemoryControlBlock,
111 data: *mut u8,
112 capacity: usize,
113 is_owner: bool,
114 #[allow(dead_code)]
115 name: String,
116 write_signal: Box<dyn EventImpl>,
119 read_signal: Box<dyn EventImpl>,
121}
122
123impl std::fmt::Debug for SharedMemoryRingBuffer {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("SharedMemoryRingBuffer")
126 .field("capacity", &self.capacity)
127 .field("is_owner", &self.is_owner)
128 .finish()
129 }
130}
131
132unsafe impl Send for SharedMemoryRingBuffer {}
133unsafe impl Sync for SharedMemoryRingBuffer {}
134
135impl SharedMemoryRingBuffer {
136 fn create(name: &str, capacity: usize) -> TransportResult<Self> {
137 if !(MIN_BUFFER_SIZE..=MAX_BUFFER_SIZE).contains(&capacity) {
138 return Err(TransportError::InvalidBufferState(format!(
139 "Invalid buffer size: {}",
140 capacity
141 )));
142 }
143
144 let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
146 let event_size = Event::size_of(None);
147 let total_size = control_size + (event_size * 2) + capacity;
148
149 let shmem = ShmemConf::new()
150 .size(total_size)
151 .os_id(name)
152 .create()
153 .map_err(|e| TransportError::SharedMemoryCreation {
154 name: name.to_string(),
155 reason: e.to_string(),
156 })?;
157
158 let base_ptr = shmem.as_ptr();
159
160 let control = base_ptr as *mut SharedMemoryControlBlock;
162 let write_signal_ptr = unsafe { base_ptr.add(control_size) };
163 let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
164 let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
165
166 unsafe {
168 std::ptr::write(control, SharedMemoryControlBlock::new(capacity));
169
170 let (write_signal, _) = Event::new(write_signal_ptr, true).map_err(|e| {
172 TransportError::SharedMemoryCreation {
173 name: name.to_string(),
174 reason: format!("Failed to create write event: {}", e),
175 }
176 })?;
177
178 let (read_signal, _) = Event::new(read_signal_ptr, true).map_err(|e| {
179 TransportError::SharedMemoryCreation {
180 name: name.to_string(),
181 reason: format!("Failed to create read event: {}", e),
182 }
183 })?;
184
185 write_signal.set(EventState::Signaled).map_err(|e| {
187 TransportError::SharedMemoryCreation {
188 name: name.to_string(),
189 reason: format!("Failed to set write event: {}", e),
190 }
191 })?;
192
193 Ok(Self {
194 shmem,
195 control,
196 data,
197 capacity,
198 is_owner: true,
199 name: name.to_string(),
200 write_signal,
201 read_signal,
202 })
203 }
204 }
205
206 fn connect(name: &str) -> TransportResult<Self> {
207 let shmem =
208 ShmemConf::new()
209 .os_id(name)
210 .open()
211 .map_err(|e| TransportError::ConnectionFailed {
212 name: name.to_string(),
213 attempts: 1,
214 reason: e.to_string(),
215 })?;
216
217 let base_ptr = shmem.as_ptr();
218 let control = base_ptr as *mut SharedMemoryControlBlock;
219
220 let ctrl = unsafe { &*control };
221 if !ctrl.is_valid() {
222 return Err(TransportError::InvalidBufferState(
223 "Invalid shared memory region".to_string(),
224 ));
225 }
226
227 let capacity = ctrl.capacity.load(Ordering::Acquire) as usize;
228 let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
229 let event_size = Event::size_of(None);
230
231 let write_signal_ptr = unsafe { base_ptr.add(control_size) };
232 let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
233 let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
234
235 unsafe {
236 let (write_signal, _) = Event::from_existing(write_signal_ptr).map_err(|e| {
237 TransportError::ConnectionFailed {
238 name: name.to_string(),
239 attempts: 1,
240 reason: format!("Failed to open write event: {}", e),
241 }
242 })?;
243
244 let (read_signal, _) = Event::from_existing(read_signal_ptr).map_err(|e| {
245 TransportError::ConnectionFailed {
246 name: name.to_string(),
247 attempts: 1,
248 reason: format!("Failed to open read event: {}", e),
249 }
250 })?;
251
252 Ok(Self {
253 shmem,
254 control,
255 data,
256 capacity,
257 is_owner: false,
258 name: name.to_string(),
259 write_signal,
260 read_signal,
261 })
262 }
263 }
264
265 fn write(&self, data: &[u8], timeout: Duration) -> TransportResult<()> {
266 let control = unsafe { &*self.control };
267
268 if !control.is_valid() {
269 control.record_error();
270 return Err(TransportError::InvalidBufferState(
271 "Bad control".to_string(),
272 ));
273 }
274
275 let msg_len = data.len();
276 let total_len = 4 + msg_len; if total_len > self.capacity {
279 control.record_error();
280 return Err(TransportError::MessageTooLarge {
281 size: msg_len,
282 max: self.capacity - 4,
283 });
284 }
285
286 let start = Instant::now();
287
288 loop {
290 if control.available_write() >= total_len {
291 break;
292 }
293
294 if start.elapsed() > timeout {
295 control.record_error();
296 return Err(TransportError::Timeout {
297 duration_ms: timeout.as_millis() as u64,
298 operation: "waiting for buffer space".into(),
299 });
300 }
301
302 let remaining = timeout.saturating_sub(start.elapsed());
305 let _ = self.write_signal.wait(Timeout::Val(remaining));
306 }
307
308 let write_pos = control.write_pos.load(Ordering::Acquire) as usize;
309
310 let len_bytes = (msg_len as u32).to_le_bytes();
314 self.raw_write(write_pos, &len_bytes);
315
316 self.raw_write(write_pos + 4, data);
318
319 control
321 .write_pos
322 .store((write_pos + total_len) as u64, Ordering::Release);
323 control.update_heartbeat();
324
325 let _ = self.read_signal.set(EventState::Signaled);
327
328 Ok(())
329 }
330
331 fn raw_write(&self, offset: usize, src: &[u8]) {
333 let offset = offset % self.capacity;
334 let len = src.len();
335 let first_chunk = std::cmp::min(len, self.capacity - offset);
336
337 unsafe {
338 std::ptr::copy_nonoverlapping(src.as_ptr(), self.data.add(offset), first_chunk);
340
341 if first_chunk < len {
343 std::ptr::copy_nonoverlapping(
344 src.as_ptr().add(first_chunk),
345 self.data,
346 len - first_chunk,
347 );
348 }
349 }
350 }
351
352 fn read(&self, timeout: Duration) -> TransportResult<Bytes> {
353 let control = unsafe { &*self.control };
354
355 if !control.is_valid() {
356 control.record_error();
357 return Err(TransportError::InvalidBufferState(
358 "Bad control".to_string(),
359 ));
360 }
361
362 let start = Instant::now();
363
364 loop {
366 if control.available_read() >= 4 {
367 break;
368 }
369
370 if start.elapsed() > timeout {
371 control.record_error();
372 return Err(TransportError::Timeout {
373 duration_ms: timeout.as_millis() as u64,
374 operation: "waiting for data".into(),
375 });
376 }
377
378 let remaining = timeout.saturating_sub(start.elapsed());
379 let _ = self.read_signal.wait(Timeout::Val(remaining));
380 }
381
382 let read_pos = control.read_pos.load(Ordering::Acquire) as usize;
383
384 let mut len_bytes = [0u8; 4];
386 self.raw_read(read_pos, &mut len_bytes);
387 let msg_len = u32::from_le_bytes(len_bytes) as usize;
388
389 if msg_len > self.capacity {
390 control.record_error();
391 return Err(TransportError::InvalidBufferState(
392 "Bad message length".to_string(),
393 ));
394 }
395
396 loop {
398 if control.available_read() >= 4 + msg_len {
399 break;
400 }
401
402 if start.elapsed() > timeout {
403 control.record_error();
404 return Err(TransportError::Timeout {
405 duration_ms: timeout.as_millis() as u64,
406 operation: "waiting for full message".into(),
407 });
408 }
409
410 let remaining = timeout.saturating_sub(start.elapsed());
411 let _ = self.read_signal.wait(Timeout::Val(remaining));
412 }
413
414 let mut buffer = vec![0u8; msg_len];
416 self.raw_read(read_pos + 4, &mut buffer);
417
418 control
420 .read_pos
421 .store((read_pos + 4 + msg_len) as u64, Ordering::Release);
422 control.update_heartbeat();
423
424 let _ = self.write_signal.set(EventState::Signaled);
426
427 Ok(Bytes::from(buffer))
428 }
429
430 fn raw_read(&self, offset: usize, dst: &mut [u8]) {
432 let offset = offset % self.capacity;
433 let len = dst.len();
434 let first_chunk = std::cmp::min(len, self.capacity - offset);
435
436 unsafe {
437 std::ptr::copy_nonoverlapping(self.data.add(offset), dst.as_mut_ptr(), first_chunk);
439
440 if first_chunk < len {
442 std::ptr::copy_nonoverlapping(
443 self.data,
444 dst.as_mut_ptr().add(first_chunk),
445 len - first_chunk,
446 );
447 }
448 }
449 }
450
451 fn tick_heartbeat(&self) {
452 if self.is_owner {
453 let control = unsafe { &*self.control };
454 control.tick_heartbeat();
455 }
456 }
457}
458
459impl Drop for SharedMemoryRingBuffer {
460 fn drop(&mut self) {
461 if self.is_owner {
462 let control = unsafe { &*self.control };
463 control.connected.store(false, Ordering::Release);
464 }
465 }
466}
467
468#[derive(Clone, Copy, Debug, PartialEq)]
469pub enum RetryPolicy {
470 Fixed {
471 delay_ms: u64,
472 },
473 Linear {
474 base_delay_ms: u64,
475 },
476 Exponential {
477 base_delay_ms: u64,
478 max_delay_ms: u64,
479 },
480}
481
482impl RetryPolicy {
483 pub fn delay(&self, attempt: usize) -> Duration {
484 match self {
485 RetryPolicy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
486 RetryPolicy::Linear { base_delay_ms } => {
487 Duration::from_millis(base_delay_ms * (attempt as u64 + 1))
488 }
489 RetryPolicy::Exponential {
490 base_delay_ms,
491 max_delay_ms,
492 } => {
493 let delay = base_delay_ms * 2u64.pow(attempt as u32);
494 Duration::from_millis(delay.min(*max_delay_ms))
495 }
496 }
497 }
498}
499
500impl Default for RetryPolicy {
501 fn default() -> Self {
502 RetryPolicy::Linear { base_delay_ms: 100 }
503 }
504}
505
506#[derive(Clone, Debug)]
507pub struct SharedMemoryConfig {
508 pub buffer_size: usize,
509 pub read_timeout: Option<Duration>,
510 pub write_timeout: Option<Duration>,
511 pub max_retry_attempts: usize,
512 pub retry_policy: RetryPolicy,
513 pub auto_reconnect: bool,
514}
515
516impl Default for SharedMemoryConfig {
517 fn default() -> Self {
518 Self {
519 buffer_size: DEFAULT_BUFFER_SIZE,
520 read_timeout: Some(Duration::from_secs(5)),
521 write_timeout: Some(Duration::from_secs(5)),
522 max_retry_attempts: 3,
523 retry_policy: RetryPolicy::default(),
524 auto_reconnect: true,
525 }
526 }
527}
528
529impl SharedMemoryConfig {
530 pub fn new() -> Self {
531 Self::default()
532 }
533
534 pub fn with_buffer_size(mut self, size: usize) -> Self {
535 self.buffer_size = size;
536 self
537 }
538
539 pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
540 self.read_timeout = Some(timeout);
541 self
542 }
543
544 pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
545 self.write_timeout = Some(timeout);
546 self
547 }
548
549 pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
550 self.retry_policy = policy;
551 self
552 }
553
554 pub fn with_max_retries(mut self, max: usize) -> Self {
555 self.max_retry_attempts = max;
556 self
557 }
558
559 pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
560 self.auto_reconnect = enabled;
561 self
562 }
563}
564
565#[derive(Debug)]
566pub struct SharedMemoryTransport {
567 send_buffer: ArcSwap<SharedMemoryRingBuffer>,
569 recv_buffer: ArcSwap<SharedMemoryRingBuffer>,
570 config: SharedMemoryConfig,
571 stats: Arc<Mutex<TransportStats>>,
572 name: String,
573 error_count: Arc<Mutex<usize>>,
574 connected: Arc<AtomicBool>,
575 reconnect_name: String,
576}
577
578impl SharedMemoryTransport {
579 pub fn create_server(
580 name: impl Into<String>,
581 config: SharedMemoryConfig,
582 ) -> TransportResult<Self> {
583 let name = name.into();
584
585 let send_name = format!("{}_s2c", name);
586 let recv_name = format!("{}_c2s", name);
587
588 let send_buffer = Arc::new(SharedMemoryRingBuffer::create(
589 &send_name,
590 config.buffer_size,
591 )?);
592 let recv_buffer = Arc::new(SharedMemoryRingBuffer::create(
593 &recv_name,
594 config.buffer_size,
595 )?);
596
597 spawn_weak_loop(
599 Arc::downgrade(&send_buffer),
600 Duration::from_secs(5),
601 |buffer| {
602 buffer.tick_heartbeat();
603 },
604 );
605
606 Ok(Self {
607 send_buffer: ArcSwap::new(send_buffer),
608 recv_buffer: ArcSwap::new(recv_buffer),
609 config,
610 stats: Arc::new(Mutex::new(TransportStats::default())),
611 name: format!("{}-server", name),
612 error_count: Arc::new(Mutex::new(0)),
613 connected: Arc::new(AtomicBool::new(true)),
614 reconnect_name: name.clone(),
615 })
616 }
617
618 pub fn connect_client_with_config(
619 name: impl Into<String>,
620 config: SharedMemoryConfig,
621 ) -> TransportResult<Self> {
622 let name = name.into();
623 let send_name = format!("{}_c2s", name);
624 let recv_name = format!("{}_s2c", name);
625 let mut last_error = None;
626
627 for attempt in 0..config.max_retry_attempts {
628 if attempt > 0 {
629 std::thread::sleep(config.retry_policy.delay(attempt - 1));
630 }
631
632 match (
633 SharedMemoryRingBuffer::connect(&send_name),
634 SharedMemoryRingBuffer::connect(&recv_name),
635 ) {
636 (Ok(send_buffer), Ok(recv_buffer)) => {
637 let send_arc = Arc::new(send_buffer);
638 let recv_arc = Arc::new(recv_buffer);
639
640 spawn_weak_loop(
642 Arc::downgrade(&send_arc),
643 Duration::from_secs(5),
644 |buffer| {
645 buffer.tick_heartbeat();
646 },
647 );
648
649 return Ok(Self {
650 send_buffer: ArcSwap::new(send_arc),
651 recv_buffer: ArcSwap::new(recv_arc),
652 config: config.clone(),
653 stats: Arc::new(Mutex::new(TransportStats::default())),
654 name: format!("{}-client", name),
655 error_count: Arc::new(Mutex::new(0)),
656 connected: Arc::new(AtomicBool::new(true)),
657 reconnect_name: name.clone(),
658 });
659 }
660 (Err(e), _) | (_, Err(e)) => {
661 last_error = Some(e);
662 }
663 }
664 }
665
666 Err(
667 last_error.unwrap_or_else(|| TransportError::ConnectionFailed {
668 name: name.clone(),
669 attempts: config.max_retry_attempts,
670 reason: "max retries exceeded".into(),
671 }),
672 )
673 }
674
675 pub fn connect_client(name: impl Into<String>) -> TransportResult<Self> {
676 Self::connect_client_with_config(name, SharedMemoryConfig::default())
677 }
678
679 fn try_reconnect(&self) -> TransportResult<()> {
681 if !self.config.auto_reconnect {
682 return Err(TransportError::NotConnected);
683 }
684
685 let send_name = format!("{}_c2s", self.reconnect_name);
686 let recv_name = format!("{}_s2c", self.reconnect_name);
687
688 for attempt in 0..self.config.max_retry_attempts {
689 if attempt > 0 {
690 std::thread::sleep(self.config.retry_policy.delay(attempt - 1));
691 }
692
693 match (
694 SharedMemoryRingBuffer::connect(&send_name),
695 SharedMemoryRingBuffer::connect(&recv_name),
696 ) {
697 (Ok(send), Ok(recv)) => {
698 self.send_buffer.store(Arc::new(send));
700 self.recv_buffer.store(Arc::new(recv));
701 *self.error_count.lock() = 0;
702 self.connected.store(true, Ordering::Release);
703 return Ok(());
704 }
705 _ => continue,
706 }
707 }
708
709 Err(TransportError::ConnectionFailed {
710 name: self.reconnect_name.clone(),
711 attempts: self.config.max_retry_attempts,
712 reason: "reconnect failed".into(),
713 })
714 }
715
716 pub fn is_healthy(&self) -> bool {
717 let recv_buf = self.recv_buffer.load();
718 let control = unsafe { &*recv_buf.control };
719 control.is_healthy(30) && *self.error_count.lock() < 10
720 }
721
722 pub fn stats(&self) -> TransportStats {
723 self.stats.lock().clone()
724 }
725}
726
727#[async_trait]
728impl Transport for SharedMemoryTransport {
729 fn is_healthy(&self) -> bool {
730 self.is_healthy()
731 }
732
733 async fn send(&self, data: &[u8]) -> TransportResult<()> {
734 if !self.connected.load(Ordering::Acquire) {
736 self.try_reconnect()?;
737 }
738
739 let mut last_error = None;
740
741 for attempt in 0..self.config.max_retry_attempts {
742 if attempt > 0 {
743 tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
744 }
745
746 let result = tokio::task::spawn_blocking({
747 let buffer = self.send_buffer.load_full();
748 let data = data.to_vec();
749 let timeout = self.config.write_timeout.unwrap_or(Duration::from_secs(30));
750 move || buffer.write(&data, timeout)
751 })
752 .await;
753
754 match result {
755 Ok(Ok(())) => {
756 let mut stats = self.stats.lock();
757 stats.messages_sent += 1;
758 stats.bytes_sent += data.len() as u64;
759 return Ok(());
760 }
761 Ok(Err(e)) => {
762 last_error = Some(e);
763 *self.error_count.lock() += 1;
764 }
765 Err(e) => {
766 last_error = Some(TransportError::SendFailed {
767 attempts: attempt + 1,
768 reason: e.to_string(),
769 });
770 *self.error_count.lock() += 1;
771 }
772 }
773 }
774
775 self.connected.store(false, Ordering::Release);
776 Err(last_error.unwrap_or_else(|| TransportError::SendFailed {
777 attempts: self.config.max_retry_attempts,
778 reason: "max retries exceeded".into(),
779 }))
780 }
781
782 async fn recv(&self) -> TransportResult<Bytes> {
783 if !self.connected.load(Ordering::Acquire) {
785 self.try_reconnect()?;
786 }
787
788 let mut last_error = None;
789
790 for attempt in 0..self.config.max_retry_attempts {
791 if attempt > 0 {
792 tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
793 }
794
795 let result = tokio::task::spawn_blocking({
796 let buffer = self.recv_buffer.load_full();
797 let timeout = self.config.read_timeout.unwrap_or(Duration::from_secs(30));
798 move || buffer.read(timeout)
799 })
800 .await;
801
802 match result {
803 Ok(Ok(bytes)) => {
804 let mut stats = self.stats.lock();
805 stats.messages_received += 1;
806 stats.bytes_received += bytes.len() as u64;
807 return Ok(bytes);
808 }
809 Ok(Err(e)) => {
810 last_error = Some(e);
811 *self.error_count.lock() += 1;
812 }
813 Err(e) => {
814 last_error = Some(TransportError::ReceiveFailed {
815 attempts: attempt + 1,
816 reason: e.to_string(),
817 });
818 *self.error_count.lock() += 1;
819 }
820 }
821 }
822
823 self.connected.store(false, Ordering::Release);
824 Err(last_error.unwrap_or_else(|| TransportError::ReceiveFailed {
825 attempts: self.config.max_retry_attempts,
826 reason: "max retries exceeded".into(),
827 }))
828 }
829
830 fn is_connected(&self) -> bool {
831 self.connected.load(Ordering::Acquire)
832 }
833
834 async fn close(&self) -> TransportResult<()> {
835 let send_buf = self.send_buffer.load();
836 let control = unsafe { &*send_buf.control };
837 control.connected.store(false, Ordering::Release);
838 Ok(())
839 }
840
841 fn stats(&self) -> Option<TransportStats> {
842 Some(self.stats.lock().clone())
843 }
844
845 fn name(&self) -> &str {
846 &self.name
847 }
848}
849
850#[cfg(test)]
851mod tests {
852 use super::*;
853
854 #[tokio::test]
855 async fn test_cross_process_basic() {
856 let config = SharedMemoryConfig::default();
857 let server = SharedMemoryTransport::create_server("test-basic", config).unwrap();
858 let client = SharedMemoryTransport::connect_client("test-basic").unwrap();
859
860 client.send(b"Hello").await.unwrap();
861 let msg = server.recv().await.unwrap();
862 assert_eq!(msg.as_ref(), b"Hello");
863 }
864
865 #[tokio::test]
866 async fn test_ring_buffer_wrapping() {
867 let config = SharedMemoryConfig::default().with_buffer_size(8192);
868 let server = SharedMemoryTransport::create_server("test-wrap", config).unwrap();
869 let client = SharedMemoryTransport::connect_client("test-wrap").unwrap();
870
871 let chunk_size = 1000;
872 let num_chunks = 20;
873
874 let send_handle = tokio::spawn(async move {
876 for i in 0..num_chunks {
877 let data = vec![i as u8; chunk_size];
878 client.send(&data).await.unwrap();
879 }
880 });
881
882 for i in 0..num_chunks {
883 let msg = server.recv().await.unwrap();
884 assert_eq!(msg.len(), chunk_size);
885 assert!(msg.iter().all(|&b| b == i as u8), "chunk {} corrupted", i);
886 }
887
888 send_handle.await.unwrap();
889 }
890
891 #[tokio::test]
892 async fn test_auto_reconnect() {
893 let config = SharedMemoryConfig::default().with_auto_reconnect(true);
894 let server =
895 SharedMemoryTransport::create_server("test-reconnect", config.clone()).unwrap();
896 let client =
897 SharedMemoryTransport::connect_client_with_config("test-reconnect", config).unwrap();
898
899 client.send(b"before").await.unwrap();
901 let msg = server.recv().await.unwrap();
902 assert_eq!(msg.as_ref(), b"before");
903
904 client.connected.store(false, Ordering::Release);
906 assert!(!client.is_connected());
907
908 client.send(b"after").await.unwrap();
910 assert!(client.is_connected());
911
912 let msg = server.recv().await.unwrap();
913 assert_eq!(msg.as_ref(), b"after");
914 }
915
916 #[tokio::test]
917 async fn test_configurable_timeout() {
918 let config = SharedMemoryConfig::default()
920 .with_read_timeout(Duration::from_millis(100))
921 .with_write_timeout(Duration::from_millis(100));
922
923 let _server = SharedMemoryTransport::create_server("test-timeout", config.clone()).unwrap();
924 let client =
925 SharedMemoryTransport::connect_client_with_config("test-timeout", config).unwrap();
926
927 let start = Instant::now();
929 let result = client.recv().await;
930 let elapsed = start.elapsed();
931
932 assert!(result.is_err());
933 assert!(
934 elapsed < Duration::from_secs(1),
935 "timeout took too long: {:?}",
936 elapsed
937 );
938 }
939}