1use arc_swap::ArcSwap;
2use async_trait::async_trait;
3use bytes::Bytes;
4use parking_lot::Mutex;
5use raw_sync::Timeout;
6use raw_sync::events::{Event, EventImpl, EventInit, EventState};
7use shared_memory::{Shmem, ShmemConf};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use crate::error::{TransportError, TransportResult};
13use crate::transport::utils::spawn_weak_loop;
14use crate::transport::{FrameTransport, TransportStats};
15
16const SHM_MAGIC: u64 = 0x58525043; const MIN_BUFFER_SIZE: usize = 4096;
18const MAX_BUFFER_SIZE: usize = 1024 * 1024 * 1024;
19const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024;
20
21#[repr(C)]
22struct SharedMemoryControlBlock {
23 magic: AtomicU64,
24 write_pos: AtomicU64,
25 read_pos: AtomicU64,
26 capacity: AtomicU64,
27 connected: AtomicBool,
28 error_count: AtomicU64,
29 last_heartbeat: AtomicU64,
30}
31
32impl SharedMemoryControlBlock {
33 fn new(capacity: usize) -> Self {
34 Self {
35 magic: AtomicU64::new(SHM_MAGIC),
36 write_pos: AtomicU64::new(0),
37 read_pos: AtomicU64::new(0),
38 capacity: AtomicU64::new(capacity as u64),
39 connected: AtomicBool::new(true),
40 error_count: AtomicU64::new(0),
41 last_heartbeat: AtomicU64::new(0),
42 }
43 }
44
45 fn is_valid(&self) -> bool {
46 self.magic.load(Ordering::Acquire) == SHM_MAGIC
47 }
48
49 fn available_write(&self) -> usize {
50 let w = self.write_pos.load(Ordering::Acquire) as usize;
51 let r = self.read_pos.load(Ordering::Acquire) as usize;
52 let cap = self.capacity.load(Ordering::Acquire) as usize;
53 let w_norm = w % cap;
54 let r_norm = r % cap;
55 if w_norm >= r_norm {
56 cap - (w_norm - r_norm) - 1
57 } else {
58 (r_norm - w_norm) - 1
59 }
60 }
61
62 fn available_read(&self) -> usize {
63 let w = self.write_pos.load(Ordering::Acquire) as usize;
64 let r = self.read_pos.load(Ordering::Acquire) as usize;
65 let cap = self.capacity.load(Ordering::Acquire) as usize;
66 let w_norm = w % cap;
67 let r_norm = r % cap;
68 if w_norm >= r_norm {
69 w_norm - r_norm
70 } else {
71 cap - (r_norm - w_norm)
72 }
73 }
74
75 fn record_error(&self) {
76 self.error_count.fetch_add(1, Ordering::Relaxed);
77 }
78
79 fn update_heartbeat(&self) {
80 let timestamp = std::time::SystemTime::now()
81 .duration_since(std::time::UNIX_EPOCH)
82 .unwrap()
83 .as_secs();
84 self.last_heartbeat.store(timestamp, Ordering::Release);
85 }
86
87 fn tick_heartbeat(&self) {
88 self.update_heartbeat();
90 }
91
92 fn is_healthy(&self, timeout_secs: u64) -> bool {
93 let now = std::time::SystemTime::now()
94 .duration_since(std::time::UNIX_EPOCH)
95 .unwrap()
96 .as_secs();
97 let last = self.last_heartbeat.load(Ordering::Acquire);
98
99 if last == 0 {
100 return true; }
102
103 now - last < timeout_secs
104 }
105}
106
107struct SharedMemoryRingBuffer {
108 #[allow(dead_code)]
109 shmem: Shmem,
110 control: *mut SharedMemoryControlBlock,
111 data: *mut u8,
112 capacity: usize,
113 is_owner: bool,
114 #[allow(dead_code)]
115 name: String,
116 write_signal: Box<dyn EventImpl>,
119 read_signal: Box<dyn EventImpl>,
121}
122
123impl std::fmt::Debug for SharedMemoryRingBuffer {
124 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("SharedMemoryRingBuffer")
126 .field("capacity", &self.capacity)
127 .field("is_owner", &self.is_owner)
128 .finish()
129 }
130}
131
132unsafe impl Send for SharedMemoryRingBuffer {}
133unsafe impl Sync for SharedMemoryRingBuffer {}
134
135impl SharedMemoryRingBuffer {
136 fn create(name: &str, capacity: usize) -> TransportResult<Self> {
137 if !(MIN_BUFFER_SIZE..=MAX_BUFFER_SIZE).contains(&capacity) {
138 return Err(TransportError::InvalidBufferState(format!(
139 "Invalid buffer size: {}",
140 capacity
141 )));
142 }
143
144 let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
146 let event_size = Event::size_of(None);
147 let total_size = control_size + (event_size * 2) + capacity;
148
149 let shmem = ShmemConf::new()
150 .size(total_size)
151 .os_id(name)
152 .create()
153 .map_err(|e| TransportError::SharedMemoryCreation {
154 name: name.to_string(),
155 reason: e.to_string(),
156 })?;
157
158 let base_ptr = shmem.as_ptr();
159
160 let control = base_ptr as *mut SharedMemoryControlBlock;
162 let write_signal_ptr = unsafe { base_ptr.add(control_size) };
163 let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
164 let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
165
166 unsafe {
168 std::ptr::write(control, SharedMemoryControlBlock::new(capacity));
169
170 let (write_signal, _) = Event::new(write_signal_ptr, true).map_err(|e| {
172 TransportError::SharedMemoryCreation {
173 name: name.to_string(),
174 reason: format!("Failed to create write event: {}", e),
175 }
176 })?;
177
178 let (read_signal, _) = Event::new(read_signal_ptr, true).map_err(|e| {
179 TransportError::SharedMemoryCreation {
180 name: name.to_string(),
181 reason: format!("Failed to create read event: {}", e),
182 }
183 })?;
184
185 write_signal.set(EventState::Signaled).map_err(|e| {
187 TransportError::SharedMemoryCreation {
188 name: name.to_string(),
189 reason: format!("Failed to set write event: {}", e),
190 }
191 })?;
192
193 Ok(Self {
194 shmem,
195 control,
196 data,
197 capacity,
198 is_owner: true,
199 name: name.to_string(),
200 write_signal,
201 read_signal,
202 })
203 }
204 }
205
206 fn connect(name: &str) -> TransportResult<Self> {
207 let shmem =
208 ShmemConf::new()
209 .os_id(name)
210 .open()
211 .map_err(|e| TransportError::ConnectionFailed {
212 name: name.to_string(),
213 attempts: 1,
214 reason: e.to_string(),
215 })?;
216
217 let base_ptr = shmem.as_ptr();
218 let control = base_ptr as *mut SharedMemoryControlBlock;
219
220 let ctrl = unsafe { &*control };
221 if !ctrl.is_valid() {
222 return Err(TransportError::InvalidBufferState(
223 "Invalid shared memory region".to_string(),
224 ));
225 }
226
227 let capacity = ctrl.capacity.load(Ordering::Acquire) as usize;
228 let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
229 let event_size = Event::size_of(None);
230
231 let write_signal_ptr = unsafe { base_ptr.add(control_size) };
232 let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
233 let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
234
235 unsafe {
236 let (write_signal, _) = Event::from_existing(write_signal_ptr).map_err(|e| {
237 TransportError::ConnectionFailed {
238 name: name.to_string(),
239 attempts: 1,
240 reason: format!("Failed to open write event: {}", e),
241 }
242 })?;
243
244 let (read_signal, _) = Event::from_existing(read_signal_ptr).map_err(|e| {
245 TransportError::ConnectionFailed {
246 name: name.to_string(),
247 attempts: 1,
248 reason: format!("Failed to open read event: {}", e),
249 }
250 })?;
251
252 Ok(Self {
253 shmem,
254 control,
255 data,
256 capacity,
257 is_owner: false,
258 name: name.to_string(),
259 write_signal,
260 read_signal,
261 })
262 }
263 }
264
265 fn write(&self, data: &[u8], timeout: Duration) -> TransportResult<()> {
266 let control = unsafe { &*self.control };
267
268 if !control.is_valid() {
269 control.record_error();
270 return Err(TransportError::InvalidBufferState(
271 "Bad control".to_string(),
272 ));
273 }
274
275 let msg_len = data.len();
276 let total_len = 4 + msg_len; if total_len > self.capacity {
279 control.record_error();
280 return Err(TransportError::MessageTooLarge {
281 size: msg_len,
282 max: self.capacity - 4,
283 });
284 }
285
286 let start = Instant::now();
287
288 loop {
290 if control.available_write() >= total_len {
291 break;
292 }
293
294 if start.elapsed() > timeout {
295 control.record_error();
296 return Err(TransportError::Timeout {
297 duration_ms: timeout.as_millis() as u64,
298 operation: "waiting for buffer space".into(),
299 });
300 }
301
302 let remaining = timeout.saturating_sub(start.elapsed());
305 let _ = self.write_signal.wait(Timeout::Val(remaining));
306 }
307
308 let write_pos = control.write_pos.load(Ordering::Acquire) as usize;
309
310 let len_bytes = (msg_len as u32).to_le_bytes();
314 self.raw_write(write_pos, &len_bytes);
315
316 self.raw_write(write_pos + 4, data);
318
319 control
321 .write_pos
322 .store((write_pos + total_len) as u64, Ordering::Release);
323 control.update_heartbeat();
324
325 let _ = self.read_signal.set(EventState::Signaled);
327
328 Ok(())
329 }
330
331 fn raw_write(&self, offset: usize, src: &[u8]) {
333 let offset = offset % self.capacity;
334 let len = src.len();
335 let first_chunk = std::cmp::min(len, self.capacity - offset);
336
337 unsafe {
338 std::ptr::copy_nonoverlapping(src.as_ptr(), self.data.add(offset), first_chunk);
340
341 if first_chunk < len {
343 std::ptr::copy_nonoverlapping(
344 src.as_ptr().add(first_chunk),
345 self.data,
346 len - first_chunk,
347 );
348 }
349 }
350 }
351
352 fn read(&self, timeout: Duration) -> TransportResult<Bytes> {
353 let control = unsafe { &*self.control };
354
355 if !control.is_valid() {
356 control.record_error();
357 return Err(TransportError::InvalidBufferState(
358 "Bad control".to_string(),
359 ));
360 }
361
362 let start = Instant::now();
363
364 loop {
366 if control.available_read() >= 4 {
367 break;
368 }
369
370 if start.elapsed() > timeout {
371 control.record_error();
372 return Err(TransportError::Timeout {
373 duration_ms: timeout.as_millis() as u64,
374 operation: "waiting for data".into(),
375 });
376 }
377
378 let remaining = timeout.saturating_sub(start.elapsed());
379 let _ = self.read_signal.wait(Timeout::Val(remaining));
380 }
381
382 let read_pos = control.read_pos.load(Ordering::Acquire) as usize;
383
384 let mut len_bytes = [0u8; 4];
386 self.raw_read(read_pos, &mut len_bytes);
387 let msg_len = u32::from_le_bytes(len_bytes) as usize;
388
389 if msg_len > self.capacity {
390 control.record_error();
391 return Err(TransportError::InvalidBufferState(
392 "Bad message length".to_string(),
393 ));
394 }
395
396 loop {
398 if control.available_read() >= 4 + msg_len {
399 break;
400 }
401
402 if start.elapsed() > timeout {
403 control.record_error();
404 return Err(TransportError::Timeout {
405 duration_ms: timeout.as_millis() as u64,
406 operation: "waiting for full message".into(),
407 });
408 }
409
410 let remaining = timeout.saturating_sub(start.elapsed());
411 let _ = self.read_signal.wait(Timeout::Val(remaining));
412 }
413
414 let mut buffer = vec![0u8; msg_len];
416 self.raw_read(read_pos + 4, &mut buffer);
417
418 control
420 .read_pos
421 .store((read_pos + 4 + msg_len) as u64, Ordering::Release);
422 control.update_heartbeat();
423
424 let _ = self.write_signal.set(EventState::Signaled);
426
427 Ok(Bytes::from(buffer))
428 }
429
430 fn raw_read(&self, offset: usize, dst: &mut [u8]) {
432 let offset = offset % self.capacity;
433 let len = dst.len();
434 let first_chunk = std::cmp::min(len, self.capacity - offset);
435
436 unsafe {
437 std::ptr::copy_nonoverlapping(self.data.add(offset), dst.as_mut_ptr(), first_chunk);
439
440 if first_chunk < len {
442 std::ptr::copy_nonoverlapping(
443 self.data,
444 dst.as_mut_ptr().add(first_chunk),
445 len - first_chunk,
446 );
447 }
448 }
449 }
450
451 fn tick_heartbeat(&self) {
452 if self.is_owner {
453 let control = unsafe { &*self.control };
454 control.tick_heartbeat();
455 }
456 }
457}
458
459impl Drop for SharedMemoryRingBuffer {
460 fn drop(&mut self) {
461 if self.is_owner {
462 let control = unsafe { &*self.control };
463 control.connected.store(false, Ordering::Release);
464 }
465 }
466}
467
468#[derive(Clone, Copy, Debug, PartialEq)]
470pub enum RetryPolicy {
471 Fixed {
472 delay_ms: u64,
473 },
474 Linear {
475 base_delay_ms: u64,
476 },
477 Exponential {
478 base_delay_ms: u64,
479 max_delay_ms: u64,
480 },
481}
482
483impl RetryPolicy {
484 pub fn delay(&self, attempt: usize) -> Duration {
485 match self {
486 RetryPolicy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
487 RetryPolicy::Linear { base_delay_ms } => {
488 Duration::from_millis(base_delay_ms * (attempt as u64 + 1))
489 }
490 RetryPolicy::Exponential {
491 base_delay_ms,
492 max_delay_ms,
493 } => {
494 let delay = base_delay_ms * 2u64.pow(attempt as u32);
495 Duration::from_millis(delay.min(*max_delay_ms))
496 }
497 }
498 }
499}
500
501impl Default for RetryPolicy {
502 fn default() -> Self {
503 RetryPolicy::Linear { base_delay_ms: 100 }
504 }
505}
506
507#[derive(Clone, Debug)]
509pub struct SharedMemoryConfig {
510 pub buffer_size: usize,
511 pub read_timeout: Option<Duration>,
512 pub write_timeout: Option<Duration>,
513 pub max_retry_attempts: usize,
514 pub retry_policy: RetryPolicy,
515 pub auto_reconnect: bool,
516}
517
518impl Default for SharedMemoryConfig {
519 fn default() -> Self {
520 Self {
521 buffer_size: DEFAULT_BUFFER_SIZE,
522 read_timeout: Some(Duration::from_secs(5)),
523 write_timeout: Some(Duration::from_secs(5)),
524 max_retry_attempts: 3,
525 retry_policy: RetryPolicy::default(),
526 auto_reconnect: true,
527 }
528 }
529}
530
531impl SharedMemoryConfig {
532 pub fn new() -> Self {
533 Self::default()
534 }
535
536 pub fn with_buffer_size(mut self, size: usize) -> Self {
537 self.buffer_size = size;
538 self
539 }
540
541 pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
542 self.read_timeout = Some(timeout);
543 self
544 }
545
546 pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
547 self.write_timeout = Some(timeout);
548 self
549 }
550
551 pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
552 self.retry_policy = policy;
553 self
554 }
555
556 pub fn with_max_retries(mut self, max: usize) -> Self {
557 self.max_retry_attempts = max;
558 self
559 }
560
561 pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
562 self.auto_reconnect = enabled;
563 self
564 }
565}
566
567#[derive(Debug)]
569pub struct SharedMemoryFrameTransport {
570 send_buffer: ArcSwap<SharedMemoryRingBuffer>,
571 recv_buffer: ArcSwap<SharedMemoryRingBuffer>,
572 config: SharedMemoryConfig,
573 stats: Arc<Mutex<TransportStats>>,
574 name: String,
575 error_count: Arc<Mutex<usize>>,
576 connected: Arc<AtomicBool>,
577 reconnect_name: String,
578}
579
580impl SharedMemoryFrameTransport {
581 pub fn create_server(
583 name: impl Into<String>,
584 config: SharedMemoryConfig,
585 ) -> TransportResult<Self> {
586 let name = name.into();
587
588 let send_name = format!("{}_s2c", name);
589 let recv_name = format!("{}_c2s", name);
590
591 let send_buffer = Arc::new(SharedMemoryRingBuffer::create(
592 &send_name,
593 config.buffer_size,
594 )?);
595 let recv_buffer = Arc::new(SharedMemoryRingBuffer::create(
596 &recv_name,
597 config.buffer_size,
598 )?);
599
600 spawn_weak_loop(
602 Arc::downgrade(&send_buffer),
603 Duration::from_secs(5),
604 |buffer| {
605 buffer.tick_heartbeat();
606 },
607 );
608
609 Ok(Self {
610 send_buffer: ArcSwap::new(send_buffer),
611 recv_buffer: ArcSwap::new(recv_buffer),
612 config,
613 stats: Arc::new(Mutex::new(TransportStats::default())),
614 name: format!("{}-server", name),
615 error_count: Arc::new(Mutex::new(0)),
616 connected: Arc::new(AtomicBool::new(true)),
617 reconnect_name: name.clone(),
618 })
619 }
620
621 pub fn connect_client_with_config(
623 name: impl Into<String>,
624 config: SharedMemoryConfig,
625 ) -> TransportResult<Self> {
626 let name = name.into();
627 let send_name = format!("{}_c2s", name);
628 let recv_name = format!("{}_s2c", name);
629 let mut last_error = None;
630
631 for attempt in 0..config.max_retry_attempts {
632 if attempt > 0 {
633 std::thread::sleep(config.retry_policy.delay(attempt - 1));
634 }
635
636 match (
637 SharedMemoryRingBuffer::connect(&send_name),
638 SharedMemoryRingBuffer::connect(&recv_name),
639 ) {
640 (Ok(send_buffer), Ok(recv_buffer)) => {
641 let send_arc = Arc::new(send_buffer);
642 let recv_arc = Arc::new(recv_buffer);
643
644 spawn_weak_loop(
646 Arc::downgrade(&send_arc),
647 Duration::from_secs(5),
648 |buffer| {
649 buffer.tick_heartbeat();
650 },
651 );
652
653 return Ok(Self {
654 send_buffer: ArcSwap::new(send_arc),
655 recv_buffer: ArcSwap::new(recv_arc),
656 config: config.clone(),
657 stats: Arc::new(Mutex::new(TransportStats::default())),
658 name: format!("{}-client", name),
659 error_count: Arc::new(Mutex::new(0)),
660 connected: Arc::new(AtomicBool::new(true)),
661 reconnect_name: name.clone(),
662 });
663 }
664 (Err(e), _) | (_, Err(e)) => {
665 last_error = Some(e);
666 }
667 }
668 }
669
670 Err(
671 last_error.unwrap_or_else(|| TransportError::ConnectionFailed {
672 name: name.clone(),
673 attempts: config.max_retry_attempts,
674 reason: "max retries exceeded".into(),
675 }),
676 )
677 }
678
679 pub fn connect_client(name: impl Into<String>) -> TransportResult<Self> {
681 Self::connect_client_with_config(name, SharedMemoryConfig::default())
682 }
683
684 fn try_reconnect(&self) -> TransportResult<()> {
686 if !self.config.auto_reconnect {
687 return Err(TransportError::NotConnected);
688 }
689
690 let send_name = format!("{}_c2s", self.reconnect_name);
691 let recv_name = format!("{}_s2c", self.reconnect_name);
692
693 for attempt in 0..self.config.max_retry_attempts {
694 if attempt > 0 {
695 std::thread::sleep(self.config.retry_policy.delay(attempt - 1));
696 }
697
698 match (
699 SharedMemoryRingBuffer::connect(&send_name),
700 SharedMemoryRingBuffer::connect(&recv_name),
701 ) {
702 (Ok(send), Ok(recv)) => {
703 self.send_buffer.store(Arc::new(send));
705 self.recv_buffer.store(Arc::new(recv));
706 *self.error_count.lock() = 0;
707 self.connected.store(true, Ordering::Release);
708 return Ok(());
709 }
710 _ => continue,
711 }
712 }
713
714 Err(TransportError::ConnectionFailed {
715 name: self.reconnect_name.clone(),
716 attempts: self.config.max_retry_attempts,
717 reason: "reconnect failed".into(),
718 })
719 }
720
721 pub fn shm_is_healthy(&self) -> bool {
722 let recv_buf = self.recv_buffer.load();
723 let control = unsafe { &*recv_buf.control };
724 control.is_healthy(30) && *self.error_count.lock() < 10
725 }
726
727 pub fn shm_stats(&self) -> TransportStats {
728 self.stats.lock().clone()
729 }
730}
731
732#[async_trait]
733impl FrameTransport for SharedMemoryFrameTransport {
734 fn is_healthy(&self) -> bool {
735 self.shm_is_healthy()
736 }
737
738 async fn send_frame(&self, data: &[u8]) -> TransportResult<()> {
739 if !self.connected.load(Ordering::Acquire) {
740 self.try_reconnect()?;
741 }
742
743 let mut last_error = None;
744
745 for attempt in 0..self.config.max_retry_attempts {
746 if attempt > 0 {
747 tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
748 }
749
750 let result = tokio::task::spawn_blocking({
751 let buffer = self.send_buffer.load_full();
752 let data = data.to_vec();
753 let timeout = self.config.write_timeout.unwrap_or(Duration::from_secs(30));
754 move || buffer.write(&data, timeout)
755 })
756 .await;
757
758 match result {
759 Ok(Ok(())) => {
760 let mut stats = self.stats.lock();
761 stats.messages_sent += 1;
762 stats.bytes_sent += data.len() as u64;
763 return Ok(());
764 }
765 Ok(Err(e)) => {
766 last_error = Some(e);
767 *self.error_count.lock() += 1;
768 }
769 Err(e) => {
770 last_error = Some(TransportError::SendFailed {
771 attempts: attempt + 1,
772 reason: e.to_string(),
773 });
774 *self.error_count.lock() += 1;
775 }
776 }
777 }
778
779 self.connected.store(false, Ordering::Release);
780 Err(last_error.unwrap_or_else(|| TransportError::SendFailed {
781 attempts: self.config.max_retry_attempts,
782 reason: "max retries exceeded".into(),
783 }))
784 }
785
786 async fn recv_frame(&self) -> TransportResult<Bytes> {
787 if !self.connected.load(Ordering::Acquire) {
788 self.try_reconnect()?;
789 }
790
791 let mut last_error = None;
792
793 for attempt in 0..self.config.max_retry_attempts {
794 if attempt > 0 {
795 tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
796 }
797
798 let result = tokio::task::spawn_blocking({
799 let buffer = self.recv_buffer.load_full();
800 let timeout = self.config.read_timeout.unwrap_or(Duration::from_secs(30));
801 move || buffer.read(timeout)
802 })
803 .await;
804
805 match result {
806 Ok(Ok(bytes)) => {
807 let mut stats = self.stats.lock();
808 stats.messages_received += 1;
809 stats.bytes_received += bytes.len() as u64;
810 return Ok(bytes);
811 }
812 Ok(Err(e)) => {
813 last_error = Some(e);
814 *self.error_count.lock() += 1;
815 }
816 Err(e) => {
817 last_error = Some(TransportError::ReceiveFailed {
818 attempts: attempt + 1,
819 reason: e.to_string(),
820 });
821 *self.error_count.lock() += 1;
822 }
823 }
824 }
825
826 self.connected.store(false, Ordering::Release);
827 Err(last_error.unwrap_or_else(|| TransportError::ReceiveFailed {
828 attempts: self.config.max_retry_attempts,
829 reason: "max retries exceeded".into(),
830 }))
831 }
832
833 fn is_connected(&self) -> bool {
834 self.connected.load(Ordering::Acquire)
835 }
836
837 async fn close(&self) -> TransportResult<()> {
838 let send_buf = self.send_buffer.load();
839 let control = unsafe { &*send_buf.control };
840 control.connected.store(false, Ordering::Release);
841 Ok(())
842 }
843
844 fn stats(&self) -> Option<TransportStats> {
845 Some(self.stats.lock().clone())
846 }
847
848 fn name(&self) -> &str {
849 &self.name
850 }
851}
852
853#[deprecated(since = "0.2.0", note = "Use SharedMemoryFrameTransport instead")]
855pub type SharedMemoryTransport = SharedMemoryFrameTransport;
856
857#[cfg(test)]
858mod tests {
859 use super::*;
860
861 #[tokio::test]
862 async fn test_cross_process_basic() {
863 let config = SharedMemoryConfig::default();
864 let server = SharedMemoryFrameTransport::create_server("test-basic", config).unwrap();
865 let client = SharedMemoryFrameTransport::connect_client("test-basic").unwrap();
866
867 client.send_frame(b"Hello").await.unwrap();
868 let msg = server.recv_frame().await.unwrap();
869 assert_eq!(msg.as_ref(), b"Hello");
870 }
871
872 #[tokio::test]
873 async fn test_ring_buffer_wrapping() {
874 let config = SharedMemoryConfig::default().with_buffer_size(8192);
875 let server = SharedMemoryFrameTransport::create_server("test-wrap", config).unwrap();
876 let client = SharedMemoryFrameTransport::connect_client("test-wrap").unwrap();
877
878 let chunk_size = 1000;
879 let num_chunks = 20;
880
881 let send_handle = tokio::spawn(async move {
882 for i in 0..num_chunks {
883 let data = vec![i as u8; chunk_size];
884 client.send_frame(&data).await.unwrap();
885 }
886 });
887
888 for i in 0..num_chunks {
889 let msg = server.recv_frame().await.unwrap();
890 assert_eq!(msg.len(), chunk_size);
891 assert!(msg.iter().all(|&b| b == i as u8), "chunk {} corrupted", i);
892 }
893
894 send_handle.await.unwrap();
895 }
896
897 #[tokio::test]
898 async fn test_auto_reconnect() {
899 let config = SharedMemoryConfig::default().with_auto_reconnect(true);
900 let server =
901 SharedMemoryFrameTransport::create_server("test-reconnect", config.clone()).unwrap();
902 let client =
903 SharedMemoryFrameTransport::connect_client_with_config("test-reconnect", config)
904 .unwrap();
905
906 client.send_frame(b"before").await.unwrap();
907 let msg = server.recv_frame().await.unwrap();
908 assert_eq!(msg.as_ref(), b"before");
909
910 client.connected.store(false, Ordering::Release);
911 assert!(!client.is_connected());
912
913 client.send_frame(b"after").await.unwrap();
914 assert!(client.is_connected());
915
916 let msg = server.recv_frame().await.unwrap();
917 assert_eq!(msg.as_ref(), b"after");
918 }
919
920 #[tokio::test]
921 async fn test_configurable_timeout() {
922 let config = SharedMemoryConfig::default()
923 .with_read_timeout(Duration::from_millis(100))
924 .with_write_timeout(Duration::from_millis(100));
925
926 let _server =
927 SharedMemoryFrameTransport::create_server("test-timeout", config.clone()).unwrap();
928 let client =
929 SharedMemoryFrameTransport::connect_client_with_config("test-timeout", config).unwrap();
930
931 let start = Instant::now();
932 let result = client.recv_frame().await;
933 let elapsed = start.elapsed();
934
935 assert!(result.is_err());
936 assert!(
937 elapsed < Duration::from_secs(1),
938 "timeout took too long: {:?}",
939 elapsed
940 );
941 }
942}