xrpc/transport/
shared_memory.rs

1use arc_swap::ArcSwap;
2use async_trait::async_trait;
3use bytes::Bytes;
4use parking_lot::Mutex;
5use raw_sync::Timeout;
6use raw_sync::events::{Event, EventImpl, EventInit, EventState};
7use shared_memory::{Shmem, ShmemConf};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use crate::error::{TransportError, TransportResult};
13use crate::transport::utils::spawn_weak_loop;
14use crate::transport::{FrameTransport, TransportStats};
15
16const SHM_MAGIC: u64 = 0x58525043; // XRPC IN ASCII
17const MIN_BUFFER_SIZE: usize = 4096;
18const MAX_BUFFER_SIZE: usize = 1024 * 1024 * 1024;
19const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024;
20
21#[repr(C)]
22struct SharedMemoryControlBlock {
23    magic: AtomicU64,
24    write_pos: AtomicU64,
25    read_pos: AtomicU64,
26    capacity: AtomicU64,
27    connected: AtomicBool,
28    error_count: AtomicU64,
29    last_heartbeat: AtomicU64,
30}
31
32impl SharedMemoryControlBlock {
33    fn new(capacity: usize) -> Self {
34        Self {
35            magic: AtomicU64::new(SHM_MAGIC),
36            write_pos: AtomicU64::new(0),
37            read_pos: AtomicU64::new(0),
38            capacity: AtomicU64::new(capacity as u64),
39            connected: AtomicBool::new(true),
40            error_count: AtomicU64::new(0),
41            last_heartbeat: AtomicU64::new(0),
42        }
43    }
44
45    fn is_valid(&self) -> bool {
46        self.magic.load(Ordering::Acquire) == SHM_MAGIC
47    }
48
49    fn available_write(&self) -> usize {
50        let w = self.write_pos.load(Ordering::Acquire) as usize;
51        let r = self.read_pos.load(Ordering::Acquire) as usize;
52        let cap = self.capacity.load(Ordering::Acquire) as usize;
53        let w_norm = w % cap;
54        let r_norm = r % cap;
55        if w_norm >= r_norm {
56            cap - (w_norm - r_norm) - 1
57        } else {
58            (r_norm - w_norm) - 1
59        }
60    }
61
62    fn available_read(&self) -> usize {
63        let w = self.write_pos.load(Ordering::Acquire) as usize;
64        let r = self.read_pos.load(Ordering::Acquire) as usize;
65        let cap = self.capacity.load(Ordering::Acquire) as usize;
66        let w_norm = w % cap;
67        let r_norm = r % cap;
68        if w_norm >= r_norm {
69            w_norm - r_norm
70        } else {
71            cap - (r_norm - w_norm)
72        }
73    }
74
75    fn record_error(&self) {
76        self.error_count.fetch_add(1, Ordering::Relaxed);
77    }
78
79    fn update_heartbeat(&self) {
80        let timestamp = std::time::SystemTime::now()
81            .duration_since(std::time::UNIX_EPOCH)
82            .unwrap()
83            .as_secs();
84        self.last_heartbeat.store(timestamp, Ordering::Release);
85    }
86
87    fn tick_heartbeat(&self) {
88        // Just update the timestamp without signaling
89        self.update_heartbeat();
90    }
91
92    fn is_healthy(&self, timeout_secs: u64) -> bool {
93        let now = std::time::SystemTime::now()
94            .duration_since(std::time::UNIX_EPOCH)
95            .unwrap()
96            .as_secs();
97        let last = self.last_heartbeat.load(Ordering::Acquire);
98
99        if last == 0 {
100            return true; // Not yet initialized
101        }
102
103        now - last < timeout_secs
104    }
105}
106
107struct SharedMemoryRingBuffer {
108    #[allow(dead_code)]
109    shmem: Shmem,
110    control: *mut SharedMemoryControlBlock,
111    data: *mut u8,
112    capacity: usize,
113    is_owner: bool,
114    #[allow(dead_code)]
115    name: String,
116    // raw-sync events for true IPC signaling
117    // write_signal: signaled when space is available (writer waits on this)
118    write_signal: Box<dyn EventImpl>,
119    // read_signal: signaled when data is available (reader waits on this)
120    read_signal: Box<dyn EventImpl>,
121}
122
123impl std::fmt::Debug for SharedMemoryRingBuffer {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        f.debug_struct("SharedMemoryRingBuffer")
126            .field("capacity", &self.capacity)
127            .field("is_owner", &self.is_owner)
128            .finish()
129    }
130}
131
132unsafe impl Send for SharedMemoryRingBuffer {}
133unsafe impl Sync for SharedMemoryRingBuffer {}
134
135impl SharedMemoryRingBuffer {
136    fn create(name: &str, capacity: usize) -> TransportResult<Self> {
137        if !(MIN_BUFFER_SIZE..=MAX_BUFFER_SIZE).contains(&capacity) {
138            return Err(TransportError::InvalidBufferState(format!(
139                "Invalid buffer size: {}",
140                capacity
141            )));
142        }
143
144        // Calculate total size: ControlBlock + 2 * Event + Buffer
145        let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
146        let event_size = Event::size_of(None);
147        let total_size = control_size + (event_size * 2) + capacity;
148
149        let shmem = ShmemConf::new()
150            .size(total_size)
151            .os_id(name)
152            .create()
153            .map_err(|e| TransportError::SharedMemoryCreation {
154                name: name.to_string(),
155                reason: e.to_string(),
156            })?;
157
158        let base_ptr = shmem.as_ptr();
159
160        // Setup pointers
161        let control = base_ptr as *mut SharedMemoryControlBlock;
162        let write_signal_ptr = unsafe { base_ptr.add(control_size) };
163        let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
164        let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
165
166        // Initialize structures
167        unsafe {
168            std::ptr::write(control, SharedMemoryControlBlock::new(capacity));
169
170            // Initialize events (auto-reset = true)
171            let (write_signal, _) = Event::new(write_signal_ptr, true).map_err(|e| {
172                TransportError::SharedMemoryCreation {
173                    name: name.to_string(),
174                    reason: format!("Failed to create write event: {}", e),
175                }
176            })?;
177
178            let (read_signal, _) = Event::new(read_signal_ptr, true).map_err(|e| {
179                TransportError::SharedMemoryCreation {
180                    name: name.to_string(),
181                    reason: format!("Failed to create read event: {}", e),
182                }
183            })?;
184
185            // Initially set write signal because buffer is empty (writable)
186            write_signal.set(EventState::Signaled).map_err(|e| {
187                TransportError::SharedMemoryCreation {
188                    name: name.to_string(),
189                    reason: format!("Failed to set write event: {}", e),
190                }
191            })?;
192
193            Ok(Self {
194                shmem,
195                control,
196                data,
197                capacity,
198                is_owner: true,
199                name: name.to_string(),
200                write_signal,
201                read_signal,
202            })
203        }
204    }
205
206    fn connect(name: &str) -> TransportResult<Self> {
207        let shmem =
208            ShmemConf::new()
209                .os_id(name)
210                .open()
211                .map_err(|e| TransportError::ConnectionFailed {
212                    name: name.to_string(),
213                    attempts: 1,
214                    reason: e.to_string(),
215                })?;
216
217        let base_ptr = shmem.as_ptr();
218        let control = base_ptr as *mut SharedMemoryControlBlock;
219
220        let ctrl = unsafe { &*control };
221        if !ctrl.is_valid() {
222            return Err(TransportError::InvalidBufferState(
223                "Invalid shared memory region".to_string(),
224            ));
225        }
226
227        let capacity = ctrl.capacity.load(Ordering::Acquire) as usize;
228        let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
229        let event_size = Event::size_of(None);
230
231        let write_signal_ptr = unsafe { base_ptr.add(control_size) };
232        let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
233        let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
234
235        unsafe {
236            let (write_signal, _) = Event::from_existing(write_signal_ptr).map_err(|e| {
237                TransportError::ConnectionFailed {
238                    name: name.to_string(),
239                    attempts: 1,
240                    reason: format!("Failed to open write event: {}", e),
241                }
242            })?;
243
244            let (read_signal, _) = Event::from_existing(read_signal_ptr).map_err(|e| {
245                TransportError::ConnectionFailed {
246                    name: name.to_string(),
247                    attempts: 1,
248                    reason: format!("Failed to open read event: {}", e),
249                }
250            })?;
251
252            Ok(Self {
253                shmem,
254                control,
255                data,
256                capacity,
257                is_owner: false,
258                name: name.to_string(),
259                write_signal,
260                read_signal,
261            })
262        }
263    }
264
265    fn write(&self, data: &[u8], timeout: Duration) -> TransportResult<()> {
266        let control = unsafe { &*self.control };
267
268        if !control.is_valid() {
269            control.record_error();
270            return Err(TransportError::InvalidBufferState(
271                "Bad control".to_string(),
272            ));
273        }
274
275        let msg_len = data.len();
276        let total_len = 4 + msg_len; // Length prefix (4 bytes) + Data
277
278        if total_len > self.capacity {
279            control.record_error();
280            return Err(TransportError::MessageTooLarge {
281                size: msg_len,
282                max: self.capacity - 4,
283            });
284        }
285
286        let start = Instant::now();
287
288        // Wait for space
289        loop {
290            if control.available_write() >= total_len {
291                break;
292            }
293
294            if start.elapsed() > timeout {
295                control.record_error();
296                return Err(TransportError::Timeout {
297                    duration_ms: timeout.as_millis() as u64,
298                    operation: "waiting for buffer space".into(),
299                });
300            }
301
302            // Wait for signal from reader (that they consumed data)
303            // Using wait_timeout to handle race conditions or lost signals
304            let remaining = timeout.saturating_sub(start.elapsed());
305            let _ = self.write_signal.wait(Timeout::Val(remaining));
306        }
307
308        let write_pos = control.write_pos.load(Ordering::Acquire) as usize;
309
310        // Optimized write using memcpy (std::ptr::copy_nonoverlapping)
311
312        // 1. Write Length (4 bytes)
313        let len_bytes = (msg_len as u32).to_le_bytes();
314        self.raw_write(write_pos, &len_bytes);
315
316        // 2. Write Data
317        self.raw_write(write_pos + 4, data);
318
319        // Update position and heartbeat
320        control
321            .write_pos
322            .store((write_pos + total_len) as u64, Ordering::Release);
323        control.update_heartbeat();
324
325        // Signal reader that data is available
326        let _ = self.read_signal.set(EventState::Signaled);
327
328        Ok(())
329    }
330
331    // Helper for circular buffer write
332    fn raw_write(&self, offset: usize, src: &[u8]) {
333        let offset = offset % self.capacity;
334        let len = src.len();
335        let first_chunk = std::cmp::min(len, self.capacity - offset);
336
337        unsafe {
338            // First part
339            std::ptr::copy_nonoverlapping(src.as_ptr(), self.data.add(offset), first_chunk);
340
341            // Second part (wrap around) if needed
342            if first_chunk < len {
343                std::ptr::copy_nonoverlapping(
344                    src.as_ptr().add(first_chunk),
345                    self.data,
346                    len - first_chunk,
347                );
348            }
349        }
350    }
351
352    fn read(&self, timeout: Duration) -> TransportResult<Bytes> {
353        let control = unsafe { &*self.control };
354
355        if !control.is_valid() {
356            control.record_error();
357            return Err(TransportError::InvalidBufferState(
358                "Bad control".to_string(),
359            ));
360        }
361
362        let start = Instant::now();
363
364        // Wait for data (at least 4 bytes for length)
365        loop {
366            if control.available_read() >= 4 {
367                break;
368            }
369
370            if start.elapsed() > timeout {
371                control.record_error();
372                return Err(TransportError::Timeout {
373                    duration_ms: timeout.as_millis() as u64,
374                    operation: "waiting for data".into(),
375                });
376            }
377
378            let remaining = timeout.saturating_sub(start.elapsed());
379            let _ = self.read_signal.wait(Timeout::Val(remaining));
380        }
381
382        let read_pos = control.read_pos.load(Ordering::Acquire) as usize;
383
384        // Read length prefix
385        let mut len_bytes = [0u8; 4];
386        self.raw_read(read_pos, &mut len_bytes);
387        let msg_len = u32::from_le_bytes(len_bytes) as usize;
388
389        if msg_len > self.capacity {
390            control.record_error();
391            return Err(TransportError::InvalidBufferState(
392                "Bad message length".to_string(),
393            ));
394        }
395
396        // Wait for full message
397        loop {
398            if control.available_read() >= 4 + msg_len {
399                break;
400            }
401
402            if start.elapsed() > timeout {
403                control.record_error();
404                return Err(TransportError::Timeout {
405                    duration_ms: timeout.as_millis() as u64,
406                    operation: "waiting for full message".into(),
407                });
408            }
409
410            let remaining = timeout.saturating_sub(start.elapsed());
411            let _ = self.read_signal.wait(Timeout::Val(remaining));
412        }
413
414        // Read payload
415        let mut buffer = vec![0u8; msg_len];
416        self.raw_read(read_pos + 4, &mut buffer);
417
418        // Update position and heartbeat
419        control
420            .read_pos
421            .store((read_pos + 4 + msg_len) as u64, Ordering::Release);
422        control.update_heartbeat();
423
424        // Signal writer that space is available
425        let _ = self.write_signal.set(EventState::Signaled);
426
427        Ok(Bytes::from(buffer))
428    }
429
430    // Helper for circular buffer read
431    fn raw_read(&self, offset: usize, dst: &mut [u8]) {
432        let offset = offset % self.capacity;
433        let len = dst.len();
434        let first_chunk = std::cmp::min(len, self.capacity - offset);
435
436        unsafe {
437            // First part
438            std::ptr::copy_nonoverlapping(self.data.add(offset), dst.as_mut_ptr(), first_chunk);
439
440            // Second part (wrap around) if needed
441            if first_chunk < len {
442                std::ptr::copy_nonoverlapping(
443                    self.data,
444                    dst.as_mut_ptr().add(first_chunk),
445                    len - first_chunk,
446                );
447            }
448        }
449    }
450
451    fn tick_heartbeat(&self) {
452        if self.is_owner {
453            let control = unsafe { &*self.control };
454            control.tick_heartbeat();
455        }
456    }
457}
458
459impl Drop for SharedMemoryRingBuffer {
460    fn drop(&mut self) {
461        if self.is_owner {
462            let control = unsafe { &*self.control };
463            control.connected.store(false, Ordering::Release);
464        }
465    }
466}
467
468/// Retry policy for reconnection attempts.
469#[derive(Clone, Copy, Debug, PartialEq)]
470pub enum RetryPolicy {
471    Fixed {
472        delay_ms: u64,
473    },
474    Linear {
475        base_delay_ms: u64,
476    },
477    Exponential {
478        base_delay_ms: u64,
479        max_delay_ms: u64,
480    },
481}
482
483impl RetryPolicy {
484    pub fn delay(&self, attempt: usize) -> Duration {
485        match self {
486            RetryPolicy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
487            RetryPolicy::Linear { base_delay_ms } => {
488                Duration::from_millis(base_delay_ms * (attempt as u64 + 1))
489            }
490            RetryPolicy::Exponential {
491                base_delay_ms,
492                max_delay_ms,
493            } => {
494                let delay = base_delay_ms * 2u64.pow(attempt as u32);
495                Duration::from_millis(delay.min(*max_delay_ms))
496            }
497        }
498    }
499}
500
501impl Default for RetryPolicy {
502    fn default() -> Self {
503        RetryPolicy::Linear { base_delay_ms: 100 }
504    }
505}
506
507/// Configuration for shared memory frame transport.
508#[derive(Clone, Debug)]
509pub struct SharedMemoryConfig {
510    pub buffer_size: usize,
511    pub read_timeout: Option<Duration>,
512    pub write_timeout: Option<Duration>,
513    pub max_retry_attempts: usize,
514    pub retry_policy: RetryPolicy,
515    pub auto_reconnect: bool,
516}
517
518impl Default for SharedMemoryConfig {
519    fn default() -> Self {
520        Self {
521            buffer_size: DEFAULT_BUFFER_SIZE,
522            read_timeout: Some(Duration::from_secs(5)),
523            write_timeout: Some(Duration::from_secs(5)),
524            max_retry_attempts: 3,
525            retry_policy: RetryPolicy::default(),
526            auto_reconnect: true,
527        }
528    }
529}
530
531impl SharedMemoryConfig {
532    pub fn new() -> Self {
533        Self::default()
534    }
535
536    pub fn with_buffer_size(mut self, size: usize) -> Self {
537        self.buffer_size = size;
538        self
539    }
540
541    pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
542        self.read_timeout = Some(timeout);
543        self
544    }
545
546    pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
547        self.write_timeout = Some(timeout);
548        self
549    }
550
551    pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
552        self.retry_policy = policy;
553        self
554    }
555
556    pub fn with_max_retries(mut self, max: usize) -> Self {
557        self.max_retry_attempts = max;
558        self
559    }
560
561    pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
562        self.auto_reconnect = enabled;
563        self
564    }
565}
566
567/// Shared memory frame transport with ring buffer (Layer 1).
568#[derive(Debug)]
569pub struct SharedMemoryFrameTransport {
570    send_buffer: ArcSwap<SharedMemoryRingBuffer>,
571    recv_buffer: ArcSwap<SharedMemoryRingBuffer>,
572    config: SharedMemoryConfig,
573    stats: Arc<Mutex<TransportStats>>,
574    name: String,
575    error_count: Arc<Mutex<usize>>,
576    connected: Arc<AtomicBool>,
577    reconnect_name: String,
578}
579
580impl SharedMemoryFrameTransport {
581    /// Create a server-side transport.
582    pub fn create_server(
583        name: impl Into<String>,
584        config: SharedMemoryConfig,
585    ) -> TransportResult<Self> {
586        let name = name.into();
587
588        let send_name = format!("{}_s2c", name);
589        let recv_name = format!("{}_c2s", name);
590
591        let send_buffer = Arc::new(SharedMemoryRingBuffer::create(
592            &send_name,
593            config.buffer_size,
594        )?);
595        let recv_buffer = Arc::new(SharedMemoryRingBuffer::create(
596            &recv_name,
597            config.buffer_size,
598        )?);
599
600        // Spawn heartbeat task
601        spawn_weak_loop(
602            Arc::downgrade(&send_buffer),
603            Duration::from_secs(5),
604            |buffer| {
605                buffer.tick_heartbeat();
606            },
607        );
608
609        Ok(Self {
610            send_buffer: ArcSwap::new(send_buffer),
611            recv_buffer: ArcSwap::new(recv_buffer),
612            config,
613            stats: Arc::new(Mutex::new(TransportStats::default())),
614            name: format!("{}-server", name),
615            error_count: Arc::new(Mutex::new(0)),
616            connected: Arc::new(AtomicBool::new(true)),
617            reconnect_name: name.clone(),
618        })
619    }
620
621    /// Connect as a client with custom config.
622    pub fn connect_client_with_config(
623        name: impl Into<String>,
624        config: SharedMemoryConfig,
625    ) -> TransportResult<Self> {
626        let name = name.into();
627        let send_name = format!("{}_c2s", name);
628        let recv_name = format!("{}_s2c", name);
629        let mut last_error = None;
630
631        for attempt in 0..config.max_retry_attempts {
632            if attempt > 0 {
633                std::thread::sleep(config.retry_policy.delay(attempt - 1));
634            }
635
636            match (
637                SharedMemoryRingBuffer::connect(&send_name),
638                SharedMemoryRingBuffer::connect(&recv_name),
639            ) {
640                (Ok(send_buffer), Ok(recv_buffer)) => {
641                    let send_arc = Arc::new(send_buffer);
642                    let recv_arc = Arc::new(recv_buffer);
643
644                    // Spawn heartbeat task
645                    spawn_weak_loop(
646                        Arc::downgrade(&send_arc),
647                        Duration::from_secs(5),
648                        |buffer| {
649                            buffer.tick_heartbeat();
650                        },
651                    );
652
653                    return Ok(Self {
654                        send_buffer: ArcSwap::new(send_arc),
655                        recv_buffer: ArcSwap::new(recv_arc),
656                        config: config.clone(),
657                        stats: Arc::new(Mutex::new(TransportStats::default())),
658                        name: format!("{}-client", name),
659                        error_count: Arc::new(Mutex::new(0)),
660                        connected: Arc::new(AtomicBool::new(true)),
661                        reconnect_name: name.clone(),
662                    });
663                }
664                (Err(e), _) | (_, Err(e)) => {
665                    last_error = Some(e);
666                }
667            }
668        }
669
670        Err(
671            last_error.unwrap_or_else(|| TransportError::ConnectionFailed {
672                name: name.clone(),
673                attempts: config.max_retry_attempts,
674                reason: "max retries exceeded".into(),
675            }),
676        )
677    }
678
679    /// Connect as a client with default config.
680    pub fn connect_client(name: impl Into<String>) -> TransportResult<Self> {
681        Self::connect_client_with_config(name, SharedMemoryConfig::default())
682    }
683
684    // Auto-reconnect if connection lost
685    fn try_reconnect(&self) -> TransportResult<()> {
686        if !self.config.auto_reconnect {
687            return Err(TransportError::NotConnected);
688        }
689
690        let send_name = format!("{}_c2s", self.reconnect_name);
691        let recv_name = format!("{}_s2c", self.reconnect_name);
692
693        for attempt in 0..self.config.max_retry_attempts {
694            if attempt > 0 {
695                std::thread::sleep(self.config.retry_policy.delay(attempt - 1));
696            }
697
698            match (
699                SharedMemoryRingBuffer::connect(&send_name),
700                SharedMemoryRingBuffer::connect(&recv_name),
701            ) {
702                (Ok(send), Ok(recv)) => {
703                    // Atomic swap with newly connected buffers
704                    self.send_buffer.store(Arc::new(send));
705                    self.recv_buffer.store(Arc::new(recv));
706                    *self.error_count.lock() = 0;
707                    self.connected.store(true, Ordering::Release);
708                    return Ok(());
709                }
710                _ => continue,
711            }
712        }
713
714        Err(TransportError::ConnectionFailed {
715            name: self.reconnect_name.clone(),
716            attempts: self.config.max_retry_attempts,
717            reason: "reconnect failed".into(),
718        })
719    }
720
721    pub fn shm_is_healthy(&self) -> bool {
722        let recv_buf = self.recv_buffer.load();
723        let control = unsafe { &*recv_buf.control };
724        control.is_healthy(30) && *self.error_count.lock() < 10
725    }
726
727    pub fn shm_stats(&self) -> TransportStats {
728        self.stats.lock().clone()
729    }
730}
731
732#[async_trait]
733impl FrameTransport for SharedMemoryFrameTransport {
734    fn is_healthy(&self) -> bool {
735        self.shm_is_healthy()
736    }
737
738    async fn send_frame(&self, data: &[u8]) -> TransportResult<()> {
739        if !self.connected.load(Ordering::Acquire) {
740            self.try_reconnect()?;
741        }
742
743        let mut last_error = None;
744
745        for attempt in 0..self.config.max_retry_attempts {
746            if attempt > 0 {
747                tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
748            }
749
750            let result = tokio::task::spawn_blocking({
751                let buffer = self.send_buffer.load_full();
752                let data = data.to_vec();
753                let timeout = self.config.write_timeout.unwrap_or(Duration::from_secs(30));
754                move || buffer.write(&data, timeout)
755            })
756            .await;
757
758            match result {
759                Ok(Ok(())) => {
760                    let mut stats = self.stats.lock();
761                    stats.messages_sent += 1;
762                    stats.bytes_sent += data.len() as u64;
763                    return Ok(());
764                }
765                Ok(Err(e)) => {
766                    last_error = Some(e);
767                    *self.error_count.lock() += 1;
768                }
769                Err(e) => {
770                    last_error = Some(TransportError::SendFailed {
771                        attempts: attempt + 1,
772                        reason: e.to_string(),
773                    });
774                    *self.error_count.lock() += 1;
775                }
776            }
777        }
778
779        self.connected.store(false, Ordering::Release);
780        Err(last_error.unwrap_or_else(|| TransportError::SendFailed {
781            attempts: self.config.max_retry_attempts,
782            reason: "max retries exceeded".into(),
783        }))
784    }
785
786    async fn recv_frame(&self) -> TransportResult<Bytes> {
787        if !self.connected.load(Ordering::Acquire) {
788            self.try_reconnect()?;
789        }
790
791        let mut last_error = None;
792
793        for attempt in 0..self.config.max_retry_attempts {
794            if attempt > 0 {
795                tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
796            }
797
798            let result = tokio::task::spawn_blocking({
799                let buffer = self.recv_buffer.load_full();
800                let timeout = self.config.read_timeout.unwrap_or(Duration::from_secs(30));
801                move || buffer.read(timeout)
802            })
803            .await;
804
805            match result {
806                Ok(Ok(bytes)) => {
807                    let mut stats = self.stats.lock();
808                    stats.messages_received += 1;
809                    stats.bytes_received += bytes.len() as u64;
810                    return Ok(bytes);
811                }
812                Ok(Err(e)) => {
813                    last_error = Some(e);
814                    *self.error_count.lock() += 1;
815                }
816                Err(e) => {
817                    last_error = Some(TransportError::ReceiveFailed {
818                        attempts: attempt + 1,
819                        reason: e.to_string(),
820                    });
821                    *self.error_count.lock() += 1;
822                }
823            }
824        }
825
826        self.connected.store(false, Ordering::Release);
827        Err(last_error.unwrap_or_else(|| TransportError::ReceiveFailed {
828            attempts: self.config.max_retry_attempts,
829            reason: "max retries exceeded".into(),
830        }))
831    }
832
833    fn is_connected(&self) -> bool {
834        self.connected.load(Ordering::Acquire)
835    }
836
837    async fn close(&self) -> TransportResult<()> {
838        let send_buf = self.send_buffer.load();
839        let control = unsafe { &*send_buf.control };
840        control.connected.store(false, Ordering::Release);
841        Ok(())
842    }
843
844    fn stats(&self) -> Option<TransportStats> {
845        Some(self.stats.lock().clone())
846    }
847
848    fn name(&self) -> &str {
849        &self.name
850    }
851}
852
853// Deprecated alias for backward compatibility
854#[deprecated(since = "0.2.0", note = "Use SharedMemoryFrameTransport instead")]
855pub type SharedMemoryTransport = SharedMemoryFrameTransport;
856
857#[cfg(test)]
858mod tests {
859    use super::*;
860
861    #[tokio::test]
862    async fn test_cross_process_basic() {
863        let config = SharedMemoryConfig::default();
864        let server = SharedMemoryFrameTransport::create_server("test-basic", config).unwrap();
865        let client = SharedMemoryFrameTransport::connect_client("test-basic").unwrap();
866
867        client.send_frame(b"Hello").await.unwrap();
868        let msg = server.recv_frame().await.unwrap();
869        assert_eq!(msg.as_ref(), b"Hello");
870    }
871
872    #[tokio::test]
873    async fn test_ring_buffer_wrapping() {
874        let config = SharedMemoryConfig::default().with_buffer_size(8192);
875        let server = SharedMemoryFrameTransport::create_server("test-wrap", config).unwrap();
876        let client = SharedMemoryFrameTransport::connect_client("test-wrap").unwrap();
877
878        let chunk_size = 1000;
879        let num_chunks = 20;
880
881        let send_handle = tokio::spawn(async move {
882            for i in 0..num_chunks {
883                let data = vec![i as u8; chunk_size];
884                client.send_frame(&data).await.unwrap();
885            }
886        });
887
888        for i in 0..num_chunks {
889            let msg = server.recv_frame().await.unwrap();
890            assert_eq!(msg.len(), chunk_size);
891            assert!(msg.iter().all(|&b| b == i as u8), "chunk {} corrupted", i);
892        }
893
894        send_handle.await.unwrap();
895    }
896
897    #[tokio::test]
898    async fn test_auto_reconnect() {
899        let config = SharedMemoryConfig::default().with_auto_reconnect(true);
900        let server =
901            SharedMemoryFrameTransport::create_server("test-reconnect", config.clone()).unwrap();
902        let client =
903            SharedMemoryFrameTransport::connect_client_with_config("test-reconnect", config)
904                .unwrap();
905
906        client.send_frame(b"before").await.unwrap();
907        let msg = server.recv_frame().await.unwrap();
908        assert_eq!(msg.as_ref(), b"before");
909
910        client.connected.store(false, Ordering::Release);
911        assert!(!client.is_connected());
912
913        client.send_frame(b"after").await.unwrap();
914        assert!(client.is_connected());
915
916        let msg = server.recv_frame().await.unwrap();
917        assert_eq!(msg.as_ref(), b"after");
918    }
919
920    #[tokio::test]
921    async fn test_configurable_timeout() {
922        let config = SharedMemoryConfig::default()
923            .with_read_timeout(Duration::from_millis(100))
924            .with_write_timeout(Duration::from_millis(100));
925
926        let _server =
927            SharedMemoryFrameTransport::create_server("test-timeout", config.clone()).unwrap();
928        let client =
929            SharedMemoryFrameTransport::connect_client_with_config("test-timeout", config).unwrap();
930
931        let start = Instant::now();
932        let result = client.recv_frame().await;
933        let elapsed = start.elapsed();
934
935        assert!(result.is_err());
936        assert!(
937            elapsed < Duration::from_secs(1),
938            "timeout took too long: {:?}",
939            elapsed
940        );
941    }
942}