xrpc/transport/
shared_memory.rs

1use arc_swap::ArcSwap;
2use async_trait::async_trait;
3use bytes::Bytes;
4use parking_lot::Mutex;
5use raw_sync::Timeout;
6use raw_sync::events::{Event, EventImpl, EventInit, EventState};
7use shared_memory::{Shmem, ShmemConf};
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11
12use crate::error::{TransportError, TransportResult};
13use crate::transport::utils::spawn_weak_loop;
14use crate::transport::{Transport, TransportStats};
15
16const SHM_MAGIC: u64 = 0x58525043; // XRPC IN ASCII
17const MIN_BUFFER_SIZE: usize = 4096;
18const MAX_BUFFER_SIZE: usize = 1024 * 1024 * 1024;
19const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024;
20
21#[repr(C)]
22struct SharedMemoryControlBlock {
23    magic: AtomicU64,
24    write_pos: AtomicU64,
25    read_pos: AtomicU64,
26    capacity: AtomicU64,
27    connected: AtomicBool,
28    error_count: AtomicU64,
29    last_heartbeat: AtomicU64,
30}
31
32impl SharedMemoryControlBlock {
33    fn new(capacity: usize) -> Self {
34        Self {
35            magic: AtomicU64::new(SHM_MAGIC),
36            write_pos: AtomicU64::new(0),
37            read_pos: AtomicU64::new(0),
38            capacity: AtomicU64::new(capacity as u64),
39            connected: AtomicBool::new(true),
40            error_count: AtomicU64::new(0),
41            last_heartbeat: AtomicU64::new(0),
42        }
43    }
44
45    fn is_valid(&self) -> bool {
46        self.magic.load(Ordering::Acquire) == SHM_MAGIC
47    }
48
49    fn available_write(&self) -> usize {
50        let w = self.write_pos.load(Ordering::Acquire) as usize;
51        let r = self.read_pos.load(Ordering::Acquire) as usize;
52        let cap = self.capacity.load(Ordering::Acquire) as usize;
53        let w_norm = w % cap;
54        let r_norm = r % cap;
55        if w_norm >= r_norm {
56            cap - (w_norm - r_norm) - 1
57        } else {
58            (r_norm - w_norm) - 1
59        }
60    }
61
62    fn available_read(&self) -> usize {
63        let w = self.write_pos.load(Ordering::Acquire) as usize;
64        let r = self.read_pos.load(Ordering::Acquire) as usize;
65        let cap = self.capacity.load(Ordering::Acquire) as usize;
66        let w_norm = w % cap;
67        let r_norm = r % cap;
68        if w_norm >= r_norm {
69            w_norm - r_norm
70        } else {
71            cap - (r_norm - w_norm)
72        }
73    }
74
75    fn record_error(&self) {
76        self.error_count.fetch_add(1, Ordering::Relaxed);
77    }
78
79    fn update_heartbeat(&self) {
80        let timestamp = std::time::SystemTime::now()
81            .duration_since(std::time::UNIX_EPOCH)
82            .unwrap()
83            .as_secs();
84        self.last_heartbeat.store(timestamp, Ordering::Release);
85    }
86
87    fn tick_heartbeat(&self) {
88        // Just update the timestamp without signaling
89        self.update_heartbeat();
90    }
91
92    fn is_healthy(&self, timeout_secs: u64) -> bool {
93        let now = std::time::SystemTime::now()
94            .duration_since(std::time::UNIX_EPOCH)
95            .unwrap()
96            .as_secs();
97        let last = self.last_heartbeat.load(Ordering::Acquire);
98
99        if last == 0 {
100            return true; // Not yet initialized
101        }
102
103        now - last < timeout_secs
104    }
105}
106
107struct SharedMemoryRingBuffer {
108    #[allow(dead_code)]
109    shmem: Shmem,
110    control: *mut SharedMemoryControlBlock,
111    data: *mut u8,
112    capacity: usize,
113    is_owner: bool,
114    #[allow(dead_code)]
115    name: String,
116    // raw-sync events for true IPC signaling
117    // write_signal: signaled when space is available (writer waits on this)
118    write_signal: Box<dyn EventImpl>,
119    // read_signal: signaled when data is available (reader waits on this)
120    read_signal: Box<dyn EventImpl>,
121}
122
123impl std::fmt::Debug for SharedMemoryRingBuffer {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        f.debug_struct("SharedMemoryRingBuffer")
126            .field("capacity", &self.capacity)
127            .field("is_owner", &self.is_owner)
128            .finish()
129    }
130}
131
132unsafe impl Send for SharedMemoryRingBuffer {}
133unsafe impl Sync for SharedMemoryRingBuffer {}
134
135impl SharedMemoryRingBuffer {
136    fn create(name: &str, capacity: usize) -> TransportResult<Self> {
137        if !(MIN_BUFFER_SIZE..=MAX_BUFFER_SIZE).contains(&capacity) {
138            return Err(TransportError::InvalidBufferState(format!(
139                "Invalid buffer size: {}",
140                capacity
141            )));
142        }
143
144        // Calculate total size: ControlBlock + 2 * Event + Buffer
145        let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
146        let event_size = Event::size_of(None);
147        let total_size = control_size + (event_size * 2) + capacity;
148
149        let shmem = ShmemConf::new()
150            .size(total_size)
151            .os_id(name)
152            .create()
153            .map_err(|e| TransportError::SharedMemoryCreation {
154                name: name.to_string(),
155                reason: e.to_string(),
156            })?;
157
158        let base_ptr = shmem.as_ptr();
159
160        // Setup pointers
161        let control = base_ptr as *mut SharedMemoryControlBlock;
162        let write_signal_ptr = unsafe { base_ptr.add(control_size) };
163        let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
164        let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
165
166        // Initialize structures
167        unsafe {
168            std::ptr::write(control, SharedMemoryControlBlock::new(capacity));
169
170            // Initialize events (auto-reset = true)
171            let (write_signal, _) = Event::new(write_signal_ptr, true).map_err(|e| {
172                TransportError::SharedMemoryCreation {
173                    name: name.to_string(),
174                    reason: format!("Failed to create write event: {}", e),
175                }
176            })?;
177
178            let (read_signal, _) = Event::new(read_signal_ptr, true).map_err(|e| {
179                TransportError::SharedMemoryCreation {
180                    name: name.to_string(),
181                    reason: format!("Failed to create read event: {}", e),
182                }
183            })?;
184
185            // Initially set write signal because buffer is empty (writable)
186            write_signal.set(EventState::Signaled).map_err(|e| {
187                TransportError::SharedMemoryCreation {
188                    name: name.to_string(),
189                    reason: format!("Failed to set write event: {}", e),
190                }
191            })?;
192
193            Ok(Self {
194                shmem,
195                control,
196                data,
197                capacity,
198                is_owner: true,
199                name: name.to_string(),
200                write_signal,
201                read_signal,
202            })
203        }
204    }
205
206    fn connect(name: &str) -> TransportResult<Self> {
207        let shmem =
208            ShmemConf::new()
209                .os_id(name)
210                .open()
211                .map_err(|e| TransportError::ConnectionFailed {
212                    name: name.to_string(),
213                    attempts: 1,
214                    reason: e.to_string(),
215                })?;
216
217        let base_ptr = shmem.as_ptr();
218        let control = base_ptr as *mut SharedMemoryControlBlock;
219
220        let ctrl = unsafe { &*control };
221        if !ctrl.is_valid() {
222            return Err(TransportError::InvalidBufferState(
223                "Invalid shared memory region".to_string(),
224            ));
225        }
226
227        let capacity = ctrl.capacity.load(Ordering::Acquire) as usize;
228        let control_size = std::mem::size_of::<SharedMemoryControlBlock>();
229        let event_size = Event::size_of(None);
230
231        let write_signal_ptr = unsafe { base_ptr.add(control_size) };
232        let read_signal_ptr = unsafe { base_ptr.add(control_size + event_size) };
233        let data = unsafe { base_ptr.add(control_size + (event_size * 2)) };
234
235        unsafe {
236            let (write_signal, _) = Event::from_existing(write_signal_ptr).map_err(|e| {
237                TransportError::ConnectionFailed {
238                    name: name.to_string(),
239                    attempts: 1,
240                    reason: format!("Failed to open write event: {}", e),
241                }
242            })?;
243
244            let (read_signal, _) = Event::from_existing(read_signal_ptr).map_err(|e| {
245                TransportError::ConnectionFailed {
246                    name: name.to_string(),
247                    attempts: 1,
248                    reason: format!("Failed to open read event: {}", e),
249                }
250            })?;
251
252            Ok(Self {
253                shmem,
254                control,
255                data,
256                capacity,
257                is_owner: false,
258                name: name.to_string(),
259                write_signal,
260                read_signal,
261            })
262        }
263    }
264
265    fn write(&self, data: &[u8], timeout: Duration) -> TransportResult<()> {
266        let control = unsafe { &*self.control };
267
268        if !control.is_valid() {
269            control.record_error();
270            return Err(TransportError::InvalidBufferState(
271                "Bad control".to_string(),
272            ));
273        }
274
275        let msg_len = data.len();
276        let total_len = 4 + msg_len; // Length prefix (4 bytes) + Data
277
278        if total_len > self.capacity {
279            control.record_error();
280            return Err(TransportError::MessageTooLarge {
281                size: msg_len,
282                max: self.capacity - 4,
283            });
284        }
285
286        let start = Instant::now();
287
288        // Wait for space
289        loop {
290            if control.available_write() >= total_len {
291                break;
292            }
293
294            if start.elapsed() > timeout {
295                control.record_error();
296                return Err(TransportError::Timeout {
297                    duration_ms: timeout.as_millis() as u64,
298                    operation: "waiting for buffer space".into(),
299                });
300            }
301
302            // Wait for signal from reader (that they consumed data)
303            // Using wait_timeout to handle race conditions or lost signals
304            let remaining = timeout.saturating_sub(start.elapsed());
305            let _ = self.write_signal.wait(Timeout::Val(remaining));
306        }
307
308        let write_pos = control.write_pos.load(Ordering::Acquire) as usize;
309
310        // Optimized write using memcpy (std::ptr::copy_nonoverlapping)
311
312        // 1. Write Length (4 bytes)
313        let len_bytes = (msg_len as u32).to_le_bytes();
314        self.raw_write(write_pos, &len_bytes);
315
316        // 2. Write Data
317        self.raw_write(write_pos + 4, data);
318
319        // Update position and heartbeat
320        control
321            .write_pos
322            .store((write_pos + total_len) as u64, Ordering::Release);
323        control.update_heartbeat();
324
325        // Signal reader that data is available
326        let _ = self.read_signal.set(EventState::Signaled);
327
328        Ok(())
329    }
330
331    // Helper for circular buffer write
332    fn raw_write(&self, offset: usize, src: &[u8]) {
333        let offset = offset % self.capacity;
334        let len = src.len();
335        let first_chunk = std::cmp::min(len, self.capacity - offset);
336
337        unsafe {
338            // First part
339            std::ptr::copy_nonoverlapping(src.as_ptr(), self.data.add(offset), first_chunk);
340
341            // Second part (wrap around) if needed
342            if first_chunk < len {
343                std::ptr::copy_nonoverlapping(
344                    src.as_ptr().add(first_chunk),
345                    self.data,
346                    len - first_chunk,
347                );
348            }
349        }
350    }
351
352    fn read(&self, timeout: Duration) -> TransportResult<Bytes> {
353        let control = unsafe { &*self.control };
354
355        if !control.is_valid() {
356            control.record_error();
357            return Err(TransportError::InvalidBufferState(
358                "Bad control".to_string(),
359            ));
360        }
361
362        let start = Instant::now();
363
364        // Wait for data (at least 4 bytes for length)
365        loop {
366            if control.available_read() >= 4 {
367                break;
368            }
369
370            if start.elapsed() > timeout {
371                control.record_error();
372                return Err(TransportError::Timeout {
373                    duration_ms: timeout.as_millis() as u64,
374                    operation: "waiting for data".into(),
375                });
376            }
377
378            let remaining = timeout.saturating_sub(start.elapsed());
379            let _ = self.read_signal.wait(Timeout::Val(remaining));
380        }
381
382        let read_pos = control.read_pos.load(Ordering::Acquire) as usize;
383
384        // Read length prefix
385        let mut len_bytes = [0u8; 4];
386        self.raw_read(read_pos, &mut len_bytes);
387        let msg_len = u32::from_le_bytes(len_bytes) as usize;
388
389        if msg_len > self.capacity {
390            control.record_error();
391            return Err(TransportError::InvalidBufferState(
392                "Bad message length".to_string(),
393            ));
394        }
395
396        // Wait for full message
397        loop {
398            if control.available_read() >= 4 + msg_len {
399                break;
400            }
401
402            if start.elapsed() > timeout {
403                control.record_error();
404                return Err(TransportError::Timeout {
405                    duration_ms: timeout.as_millis() as u64,
406                    operation: "waiting for full message".into(),
407                });
408            }
409
410            let remaining = timeout.saturating_sub(start.elapsed());
411            let _ = self.read_signal.wait(Timeout::Val(remaining));
412        }
413
414        // Read payload
415        let mut buffer = vec![0u8; msg_len];
416        self.raw_read(read_pos + 4, &mut buffer);
417
418        // Update position and heartbeat
419        control
420            .read_pos
421            .store((read_pos + 4 + msg_len) as u64, Ordering::Release);
422        control.update_heartbeat();
423
424        // Signal writer that space is available
425        let _ = self.write_signal.set(EventState::Signaled);
426
427        Ok(Bytes::from(buffer))
428    }
429
430    // Helper for circular buffer read
431    fn raw_read(&self, offset: usize, dst: &mut [u8]) {
432        let offset = offset % self.capacity;
433        let len = dst.len();
434        let first_chunk = std::cmp::min(len, self.capacity - offset);
435
436        unsafe {
437            // First part
438            std::ptr::copy_nonoverlapping(self.data.add(offset), dst.as_mut_ptr(), first_chunk);
439
440            // Second part (wrap around) if needed
441            if first_chunk < len {
442                std::ptr::copy_nonoverlapping(
443                    self.data,
444                    dst.as_mut_ptr().add(first_chunk),
445                    len - first_chunk,
446                );
447            }
448        }
449    }
450
451    fn tick_heartbeat(&self) {
452        if self.is_owner {
453            let control = unsafe { &*self.control };
454            control.tick_heartbeat();
455        }
456    }
457}
458
459impl Drop for SharedMemoryRingBuffer {
460    fn drop(&mut self) {
461        if self.is_owner {
462            let control = unsafe { &*self.control };
463            control.connected.store(false, Ordering::Release);
464        }
465    }
466}
467
468#[derive(Clone, Copy, Debug, PartialEq)]
469pub enum RetryPolicy {
470    Fixed {
471        delay_ms: u64,
472    },
473    Linear {
474        base_delay_ms: u64,
475    },
476    Exponential {
477        base_delay_ms: u64,
478        max_delay_ms: u64,
479    },
480}
481
482impl RetryPolicy {
483    pub fn delay(&self, attempt: usize) -> Duration {
484        match self {
485            RetryPolicy::Fixed { delay_ms } => Duration::from_millis(*delay_ms),
486            RetryPolicy::Linear { base_delay_ms } => {
487                Duration::from_millis(base_delay_ms * (attempt as u64 + 1))
488            }
489            RetryPolicy::Exponential {
490                base_delay_ms,
491                max_delay_ms,
492            } => {
493                let delay = base_delay_ms * 2u64.pow(attempt as u32);
494                Duration::from_millis(delay.min(*max_delay_ms))
495            }
496        }
497    }
498}
499
500impl Default for RetryPolicy {
501    fn default() -> Self {
502        RetryPolicy::Linear { base_delay_ms: 100 }
503    }
504}
505
506#[derive(Clone, Debug)]
507pub struct SharedMemoryConfig {
508    pub buffer_size: usize,
509    pub read_timeout: Option<Duration>,
510    pub write_timeout: Option<Duration>,
511    pub max_retry_attempts: usize,
512    pub retry_policy: RetryPolicy,
513    pub auto_reconnect: bool,
514}
515
516impl Default for SharedMemoryConfig {
517    fn default() -> Self {
518        Self {
519            buffer_size: DEFAULT_BUFFER_SIZE,
520            read_timeout: Some(Duration::from_secs(5)),
521            write_timeout: Some(Duration::from_secs(5)),
522            max_retry_attempts: 3,
523            retry_policy: RetryPolicy::default(),
524            auto_reconnect: true,
525        }
526    }
527}
528
529impl SharedMemoryConfig {
530    pub fn new() -> Self {
531        Self::default()
532    }
533
534    pub fn with_buffer_size(mut self, size: usize) -> Self {
535        self.buffer_size = size;
536        self
537    }
538
539    pub fn with_read_timeout(mut self, timeout: Duration) -> Self {
540        self.read_timeout = Some(timeout);
541        self
542    }
543
544    pub fn with_write_timeout(mut self, timeout: Duration) -> Self {
545        self.write_timeout = Some(timeout);
546        self
547    }
548
549    pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
550        self.retry_policy = policy;
551        self
552    }
553
554    pub fn with_max_retries(mut self, max: usize) -> Self {
555        self.max_retry_attempts = max;
556        self
557    }
558
559    pub fn with_auto_reconnect(mut self, enabled: bool) -> Self {
560        self.auto_reconnect = enabled;
561        self
562    }
563}
564
565#[derive(Debug)]
566pub struct SharedMemoryTransport {
567    // ArcSwap allows lock-free reads with occasional swap on reconnect
568    send_buffer: ArcSwap<SharedMemoryRingBuffer>,
569    recv_buffer: ArcSwap<SharedMemoryRingBuffer>,
570    config: SharedMemoryConfig,
571    stats: Arc<Mutex<TransportStats>>,
572    name: String,
573    error_count: Arc<Mutex<usize>>,
574    connected: Arc<AtomicBool>,
575    reconnect_name: String,
576}
577
578impl SharedMemoryTransport {
579    pub fn create_server(
580        name: impl Into<String>,
581        config: SharedMemoryConfig,
582    ) -> TransportResult<Self> {
583        let name = name.into();
584
585        let send_name = format!("{}_s2c", name);
586        let recv_name = format!("{}_c2s", name);
587
588        let send_buffer = Arc::new(SharedMemoryRingBuffer::create(
589            &send_name,
590            config.buffer_size,
591        )?);
592        let recv_buffer = Arc::new(SharedMemoryRingBuffer::create(
593            &recv_name,
594            config.buffer_size,
595        )?);
596
597        // Spawn heartbeat task
598        spawn_weak_loop(
599            Arc::downgrade(&send_buffer),
600            Duration::from_secs(5),
601            |buffer| {
602                buffer.tick_heartbeat();
603            },
604        );
605
606        Ok(Self {
607            send_buffer: ArcSwap::new(send_buffer),
608            recv_buffer: ArcSwap::new(recv_buffer),
609            config,
610            stats: Arc::new(Mutex::new(TransportStats::default())),
611            name: format!("{}-server", name),
612            error_count: Arc::new(Mutex::new(0)),
613            connected: Arc::new(AtomicBool::new(true)),
614            reconnect_name: name.clone(),
615        })
616    }
617
618    pub fn connect_client_with_config(
619        name: impl Into<String>,
620        config: SharedMemoryConfig,
621    ) -> TransportResult<Self> {
622        let name = name.into();
623        let send_name = format!("{}_c2s", name);
624        let recv_name = format!("{}_s2c", name);
625        let mut last_error = None;
626
627        for attempt in 0..config.max_retry_attempts {
628            if attempt > 0 {
629                std::thread::sleep(config.retry_policy.delay(attempt - 1));
630            }
631
632            match (
633                SharedMemoryRingBuffer::connect(&send_name),
634                SharedMemoryRingBuffer::connect(&recv_name),
635            ) {
636                (Ok(send_buffer), Ok(recv_buffer)) => {
637                    let send_arc = Arc::new(send_buffer);
638                    let recv_arc = Arc::new(recv_buffer);
639
640                    // Spawn heartbeat task
641                    spawn_weak_loop(
642                        Arc::downgrade(&send_arc),
643                        Duration::from_secs(5),
644                        |buffer| {
645                            buffer.tick_heartbeat();
646                        },
647                    );
648
649                    return Ok(Self {
650                        send_buffer: ArcSwap::new(send_arc),
651                        recv_buffer: ArcSwap::new(recv_arc),
652                        config: config.clone(),
653                        stats: Arc::new(Mutex::new(TransportStats::default())),
654                        name: format!("{}-client", name),
655                        error_count: Arc::new(Mutex::new(0)),
656                        connected: Arc::new(AtomicBool::new(true)),
657                        reconnect_name: name.clone(),
658                    });
659                }
660                (Err(e), _) | (_, Err(e)) => {
661                    last_error = Some(e);
662                }
663            }
664        }
665
666        Err(
667            last_error.unwrap_or_else(|| TransportError::ConnectionFailed {
668                name: name.clone(),
669                attempts: config.max_retry_attempts,
670                reason: "max retries exceeded".into(),
671            }),
672        )
673    }
674
675    pub fn connect_client(name: impl Into<String>) -> TransportResult<Self> {
676        Self::connect_client_with_config(name, SharedMemoryConfig::default())
677    }
678
679    // Auto-reconnect if connection lost
680    fn try_reconnect(&self) -> TransportResult<()> {
681        if !self.config.auto_reconnect {
682            return Err(TransportError::NotConnected);
683        }
684
685        let send_name = format!("{}_c2s", self.reconnect_name);
686        let recv_name = format!("{}_s2c", self.reconnect_name);
687
688        for attempt in 0..self.config.max_retry_attempts {
689            if attempt > 0 {
690                std::thread::sleep(self.config.retry_policy.delay(attempt - 1));
691            }
692
693            match (
694                SharedMemoryRingBuffer::connect(&send_name),
695                SharedMemoryRingBuffer::connect(&recv_name),
696            ) {
697                (Ok(send), Ok(recv)) => {
698                    // Atomic swap with newly connected buffers
699                    self.send_buffer.store(Arc::new(send));
700                    self.recv_buffer.store(Arc::new(recv));
701                    *self.error_count.lock() = 0;
702                    self.connected.store(true, Ordering::Release);
703                    return Ok(());
704                }
705                _ => continue,
706            }
707        }
708
709        Err(TransportError::ConnectionFailed {
710            name: self.reconnect_name.clone(),
711            attempts: self.config.max_retry_attempts,
712            reason: "reconnect failed".into(),
713        })
714    }
715
716    pub fn is_healthy(&self) -> bool {
717        let recv_buf = self.recv_buffer.load();
718        let control = unsafe { &*recv_buf.control };
719        control.is_healthy(30) && *self.error_count.lock() < 10
720    }
721
722    pub fn stats(&self) -> TransportStats {
723        self.stats.lock().clone()
724    }
725}
726
727#[async_trait]
728impl Transport for SharedMemoryTransport {
729    fn is_healthy(&self) -> bool {
730        self.is_healthy()
731    }
732
733    async fn send(&self, data: &[u8]) -> TransportResult<()> {
734        // Try reconnect if disconnected
735        if !self.connected.load(Ordering::Acquire) {
736            self.try_reconnect()?;
737        }
738
739        let mut last_error = None;
740
741        for attempt in 0..self.config.max_retry_attempts {
742            if attempt > 0 {
743                tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
744            }
745
746            let result = tokio::task::spawn_blocking({
747                let buffer = self.send_buffer.load_full();
748                let data = data.to_vec();
749                let timeout = self.config.write_timeout.unwrap_or(Duration::from_secs(30));
750                move || buffer.write(&data, timeout)
751            })
752            .await;
753
754            match result {
755                Ok(Ok(())) => {
756                    let mut stats = self.stats.lock();
757                    stats.messages_sent += 1;
758                    stats.bytes_sent += data.len() as u64;
759                    return Ok(());
760                }
761                Ok(Err(e)) => {
762                    last_error = Some(e);
763                    *self.error_count.lock() += 1;
764                }
765                Err(e) => {
766                    last_error = Some(TransportError::SendFailed {
767                        attempts: attempt + 1,
768                        reason: e.to_string(),
769                    });
770                    *self.error_count.lock() += 1;
771                }
772            }
773        }
774
775        self.connected.store(false, Ordering::Release);
776        Err(last_error.unwrap_or_else(|| TransportError::SendFailed {
777            attempts: self.config.max_retry_attempts,
778            reason: "max retries exceeded".into(),
779        }))
780    }
781
782    async fn recv(&self) -> TransportResult<Bytes> {
783        // Try reconnect if disconnected
784        if !self.connected.load(Ordering::Acquire) {
785            self.try_reconnect()?;
786        }
787
788        let mut last_error = None;
789
790        for attempt in 0..self.config.max_retry_attempts {
791            if attempt > 0 {
792                tokio::time::sleep(self.config.retry_policy.delay(attempt - 1)).await;
793            }
794
795            let result = tokio::task::spawn_blocking({
796                let buffer = self.recv_buffer.load_full();
797                let timeout = self.config.read_timeout.unwrap_or(Duration::from_secs(30));
798                move || buffer.read(timeout)
799            })
800            .await;
801
802            match result {
803                Ok(Ok(bytes)) => {
804                    let mut stats = self.stats.lock();
805                    stats.messages_received += 1;
806                    stats.bytes_received += bytes.len() as u64;
807                    return Ok(bytes);
808                }
809                Ok(Err(e)) => {
810                    last_error = Some(e);
811                    *self.error_count.lock() += 1;
812                }
813                Err(e) => {
814                    last_error = Some(TransportError::ReceiveFailed {
815                        attempts: attempt + 1,
816                        reason: e.to_string(),
817                    });
818                    *self.error_count.lock() += 1;
819                }
820            }
821        }
822
823        self.connected.store(false, Ordering::Release);
824        Err(last_error.unwrap_or_else(|| TransportError::ReceiveFailed {
825            attempts: self.config.max_retry_attempts,
826            reason: "max retries exceeded".into(),
827        }))
828    }
829
830    fn is_connected(&self) -> bool {
831        self.connected.load(Ordering::Acquire)
832    }
833
834    async fn close(&self) -> TransportResult<()> {
835        let send_buf = self.send_buffer.load();
836        let control = unsafe { &*send_buf.control };
837        control.connected.store(false, Ordering::Release);
838        Ok(())
839    }
840
841    fn stats(&self) -> Option<TransportStats> {
842        Some(self.stats.lock().clone())
843    }
844
845    fn name(&self) -> &str {
846        &self.name
847    }
848}
849
850#[cfg(test)]
851mod tests {
852    use super::*;
853
854    #[tokio::test]
855    async fn test_cross_process_basic() {
856        let config = SharedMemoryConfig::default();
857        let server = SharedMemoryTransport::create_server("test-basic", config).unwrap();
858        let client = SharedMemoryTransport::connect_client("test-basic").unwrap();
859
860        client.send(b"Hello").await.unwrap();
861        let msg = server.recv().await.unwrap();
862        assert_eq!(msg.as_ref(), b"Hello");
863    }
864
865    #[tokio::test]
866    async fn test_ring_buffer_wrapping() {
867        let config = SharedMemoryConfig::default().with_buffer_size(8192);
868        let server = SharedMemoryTransport::create_server("test-wrap", config).unwrap();
869        let client = SharedMemoryTransport::connect_client("test-wrap").unwrap();
870
871        let chunk_size = 1000;
872        let num_chunks = 20;
873
874        // Send 20KB total buffer is 8KB
875        let send_handle = tokio::spawn(async move {
876            for i in 0..num_chunks {
877                let data = vec![i as u8; chunk_size];
878                client.send(&data).await.unwrap();
879            }
880        });
881
882        for i in 0..num_chunks {
883            let msg = server.recv().await.unwrap();
884            assert_eq!(msg.len(), chunk_size);
885            assert!(msg.iter().all(|&b| b == i as u8), "chunk {} corrupted", i);
886        }
887
888        send_handle.await.unwrap();
889    }
890
891    #[tokio::test]
892    async fn test_auto_reconnect() {
893        let config = SharedMemoryConfig::default().with_auto_reconnect(true);
894        let server =
895            SharedMemoryTransport::create_server("test-reconnect", config.clone()).unwrap();
896        let client =
897            SharedMemoryTransport::connect_client_with_config("test-reconnect", config).unwrap();
898
899        // Verify initial connection works
900        client.send(b"before").await.unwrap();
901        let msg = server.recv().await.unwrap();
902        assert_eq!(msg.as_ref(), b"before");
903
904        // Simulate disconnect by marking client as disconnected
905        client.connected.store(false, Ordering::Release);
906        assert!(!client.is_connected());
907
908        // Next send should trigger reconnect and succeed
909        client.send(b"after").await.unwrap();
910        assert!(client.is_connected());
911
912        let msg = server.recv().await.unwrap();
913        assert_eq!(msg.as_ref(), b"after");
914    }
915
916    #[tokio::test]
917    async fn test_configurable_timeout() {
918        // Use short timeout to verify config is respected
919        let config = SharedMemoryConfig::default()
920            .with_read_timeout(Duration::from_millis(100))
921            .with_write_timeout(Duration::from_millis(100));
922
923        let _server = SharedMemoryTransport::create_server("test-timeout", config.clone()).unwrap();
924        let client =
925            SharedMemoryTransport::connect_client_with_config("test-timeout", config).unwrap();
926
927        // recv should timeout quickly since no data is sent
928        let start = Instant::now();
929        let result = client.recv().await;
930        let elapsed = start.elapsed();
931
932        assert!(result.is_err());
933        assert!(
934            elapsed < Duration::from_secs(1),
935            "timeout took too long: {:?}",
936            elapsed
937        );
938    }
939}