Skip to main content

sozu_command_lib/
channel.rs

1//! Bidirectional length-delimited unix-socket channel.
2//!
3//! Implements the master ↔ worker / master ↔ CLI message channel: each
4//! payload is preceded by a native `usize` length prefix (NOT a NUL
5//! separator — that scheme belongs to the state-file save format in
6//! `command/src/state.rs:1613`-`1630`). Bounded by the per-channel
7//! `max_buffer_size` (`channel.rs:71`) checked before payload allocation.
8
9use std::{
10    cmp::min,
11    fmt::Debug,
12    io::{self, ErrorKind, Read, Write},
13    marker::PhantomData,
14    os::unix::{
15        io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
16        net::UnixStream as StdUnixStream,
17    },
18    time::Duration,
19};
20
21use mio::{event::Source, net::UnixStream as MioUnixStream};
22use prost::{DecodeError, Message as ProstMessage};
23
24use crate::{buffer::growable::Buffer, ready::Ready};
25
26/// High watermark threshold: log a warning when buffer usage exceeds 80% of max
27const HIGH_WATERMARK_RATIO: f64 = 0.8;
28
29#[derive(thiserror::Error, Debug)]
30pub enum ChannelError {
31    #[error("io read error")]
32    Read(std::io::Error),
33    #[error("no byte written on the channel")]
34    NoByteWritten,
35    #[error("no byte left to read on the channel")]
36    NoByteToRead,
37    #[error(
38        "message ({message_len} bytes) too large for back buffer capacity ({capacity} bytes, max {max} bytes)"
39    )]
40    MessageTooLarge {
41        message_len: usize,
42        capacity: usize,
43        max: usize,
44    },
45    #[error(
46        "declared message length ({message_len} bytes) is shorter than the {delimiter_size}-byte length prefix"
47    )]
48    MessageLengthUnderDelimiter {
49        message_len: usize,
50        delimiter_size: usize,
51    },
52    #[error("channel could not write on the back buffer")]
53    Write(std::io::Error),
54    #[error("channel buffer is full ({capacity} bytes, max {max} bytes), cannot grow more")]
55    BufferFull { capacity: usize, max: usize },
56    #[error("Timeout is reached: {0:?}")]
57    TimeoutReached(Duration),
58    #[error("Could not read anything on the channel")]
59    NothingRead,
60    #[error("invalid char set in command message, ignoring: {0}")]
61    InvalidCharSet(String),
62    #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
63    SetTimeout { fd: i32, error: String },
64    #[error(
65        "Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}"
66    )]
67    BlockingStatus { fd: i32, error: String },
68    #[error("Connection error: {0:?}")]
69    Connection(Option<std::io::Error>),
70    #[error("Invalid protobuf message: {0}")]
71    InvalidProtobufMessage(DecodeError),
72    #[error("This should never happen (index out of bound on a tested buffer)")]
73    MismatchBufferSize,
74}
75
76/// Channel meant for communication between Sōzu processes over a UNIX socket.
77/// It wraps a unix socket using the mio crate, and transmit prost messages
78/// by serializing them in a binary format, with a fix-sized delimiter.
79/// To function, channels must come in pairs, one for each agent.
80/// They can function in a blocking or non-blocking way.
81pub struct Channel<Tx, Rx> {
82    pub sock: MioUnixStream,
83    pub front_buf: Buffer,
84    pub back_buf: Buffer,
85    initial_buffer_size: usize,
86    max_buffer_size: usize,
87    pub readiness: Ready,
88    pub interest: Ready,
89    blocking: bool,
90    /// true if a high watermark warning has been logged for the front buffer
91    front_high_watermark_logged: bool,
92    /// true if a high watermark warning has been logged for the back buffer
93    back_high_watermark_logged: bool,
94    phantom_tx: PhantomData<Tx>,
95    phantom_rx: PhantomData<Rx>,
96}
97
98impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct(&format!(
101            "Channel<{}, {}>",
102            std::any::type_name::<Tx>(),
103            std::any::type_name::<Rx>()
104        ))
105        .field("sock", &self.sock.as_raw_fd())
106        // .field("front_buf", &self.front_buf)
107        // .field("back_buf", &self.back_buf)
108        // .field("max_buffer_size", &self.max_buffer_size)
109        .field("readiness", &self.readiness)
110        .field("interest", &self.interest)
111        .field("blocking", &self.blocking)
112        .finish()
113    }
114}
115
116impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
117    /// Creates a nonblocking channel on a given socket path
118    pub fn from_path(
119        path: &str,
120        buffer_size: u64,
121        max_buffer_size: u64,
122    ) -> Result<Channel<Tx, Rx>, ChannelError> {
123        let unix_stream = MioUnixStream::connect(path)
124            .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
125        Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
126    }
127
128    /// Creates a nonblocking channel, using a unix stream
129    pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
130        let buffer_size = buffer_size as usize;
131        let max_buffer_size = max_buffer_size as usize;
132        Channel {
133            sock,
134            front_buf: Buffer::with_capacity(buffer_size),
135            back_buf: Buffer::with_capacity(buffer_size),
136            initial_buffer_size: buffer_size,
137            max_buffer_size,
138            readiness: Ready::EMPTY,
139            interest: Ready::READABLE,
140            blocking: false,
141            front_high_watermark_logged: false,
142            back_high_watermark_logged: false,
143            phantom_tx: PhantomData,
144            phantom_rx: PhantomData,
145        }
146    }
147
148    pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
149        self,
150    ) -> Channel<Tx2, Rx2> {
151        Channel {
152            sock: self.sock,
153            front_buf: self.front_buf,
154            back_buf: self.back_buf,
155            initial_buffer_size: self.initial_buffer_size,
156            max_buffer_size: self.max_buffer_size,
157            readiness: self.readiness,
158            interest: self.interest,
159            blocking: self.blocking,
160            front_high_watermark_logged: self.front_high_watermark_logged,
161            back_high_watermark_logged: self.back_high_watermark_logged,
162            phantom_tx: PhantomData,
163            phantom_rx: PhantomData,
164        }
165    }
166
167    // Since MioUnixStream does not have a set_nonblocking method, we have to use the standard library.
168    // We get the file descriptor of the MioUnixStream socket, create a standard library UnixStream,
169    // set it to nonblocking, let go of the file descriptor
170    fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
171        // SAFETY: `fd` is borrowed from `self.sock` for the duration of this
172        // block. We wrap it in a `StdUnixStream` to call `set_nonblocking`,
173        // then immediately release ownership again with `into_raw_fd` so the
174        // descriptor is not closed by `Drop`. `self.sock` retains the
175        // original ownership.
176        unsafe {
177            let fd = self.sock.as_raw_fd();
178            let stream = StdUnixStream::from_raw_fd(fd);
179            stream
180                .set_nonblocking(nonblocking)
181                .map_err(|error| ChannelError::BlockingStatus {
182                    fd,
183                    error: error.to_string(),
184                })?;
185            let _fd = stream.into_raw_fd();
186        }
187        self.blocking = !nonblocking;
188        Ok(())
189    }
190
191    /// set the read_timeout of the unix stream. This works only temporary, be sure to set the timeout to None afterwards.
192    fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
193        // SAFETY: `fd` is borrowed from `self.sock` for the duration of this
194        // block. We wrap it in a `StdUnixStream` to call `set_read_timeout`,
195        // then immediately release ownership again with `into_raw_fd` so the
196        // descriptor is not closed by `Drop`. `self.sock` retains the
197        // original ownership.
198        unsafe {
199            let fd = self.sock.as_raw_fd();
200            let stream = StdUnixStream::from_raw_fd(fd);
201            stream
202                .set_read_timeout(timeout)
203                .map_err(|error| ChannelError::SetTimeout {
204                    fd,
205                    error: error.to_string(),
206                })?;
207            let _fd = stream.into_raw_fd();
208        }
209        Ok(())
210    }
211
212    /// set the channel to be blocking
213    pub fn blocking(&mut self) -> Result<(), ChannelError> {
214        self.set_nonblocking(false)
215    }
216
217    /// set the channel to be nonblocking
218    pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
219        self.set_nonblocking(true)
220    }
221
222    pub fn is_blocking(&self) -> bool {
223        self.blocking
224    }
225
226    /// Get the raw file descriptor of the UNIX socket
227    pub fn fd(&self) -> RawFd {
228        self.sock.as_raw_fd()
229    }
230
231    pub fn handle_events(&mut self, events: Ready) {
232        self.readiness |= events;
233    }
234
235    pub fn readiness(&self) -> Ready {
236        self.readiness & self.interest
237    }
238
239    /// Compute the next buffer size using a doubling strategy, capped at max_buffer_size.
240    /// Returns None if the buffer is already at max capacity.
241    fn grow_size(&self, current_capacity: usize) -> Option<usize> {
242        if current_capacity >= self.max_buffer_size {
243            return None;
244        }
245        // double the capacity, but don't exceed max
246        let new_size = min(current_capacity.saturating_mul(2), self.max_buffer_size);
247        // ensure we grow by at least something (in case current_capacity is 0)
248        let new_size = new_size.max(current_capacity + 1);
249        Some(min(new_size, self.max_buffer_size))
250    }
251
252    /// Check if a buffer has exceeded the high watermark and log a warning once
253    fn check_high_watermark(
254        buffer_name: &str,
255        capacity: usize,
256        max: usize,
257        already_logged: &mut bool,
258    ) {
259        if *already_logged {
260            return;
261        }
262        let threshold = (max as f64 * HIGH_WATERMARK_RATIO) as usize;
263        if capacity >= threshold {
264            warn!(
265                "channel {} buffer reached high watermark: {} bytes ({:.0}% of {} max)",
266                buffer_name,
267                capacity,
268                (capacity as f64 / max as f64) * 100.0,
269                max,
270            );
271            *already_logged = true;
272        }
273    }
274
275    /// Check wether we want and can read or write, and calls the appropriate handler.
276    pub fn run(&mut self) -> Result<(), ChannelError> {
277        let interest = self.interest & self.readiness;
278
279        if interest.is_readable() {
280            let _ = self.readable()?;
281        }
282
283        if interest.is_writable() {
284            let _ = self.writable()?;
285        }
286        Ok(())
287    }
288
289    /// Handle readability by filling the front buffer with the socket data.
290    /// Grows the front buffer when full using a doubling strategy, up to max_buffer_size.
291    pub fn readable(&mut self) -> Result<usize, ChannelError> {
292        if !(self.interest & self.readiness).is_readable() {
293            return Err(ChannelError::Connection(None));
294        }
295
296        let mut count = 0usize;
297        loop {
298            let size = self.front_buf.available_space();
299            trace!("channel available space: {}", size);
300            if size == 0 {
301                // try to grow the buffer before giving up
302                if let Some(new_size) = self.grow_size(self.front_buf.capacity()) {
303                    Self::check_high_watermark(
304                        "front",
305                        new_size,
306                        self.max_buffer_size,
307                        &mut self.front_high_watermark_logged,
308                    );
309                    self.front_buf.grow(new_size);
310                } else {
311                    self.interest.remove(Ready::READABLE);
312                    break;
313                }
314            }
315
316            match self.sock.read(self.front_buf.space()) {
317                Ok(0) => {
318                    self.interest = Ready::EMPTY;
319                    self.readiness.remove(Ready::READABLE);
320                    self.readiness.insert(Ready::HUP);
321                    return Err(ChannelError::NoByteToRead);
322                }
323                Err(read_error) => match read_error.kind() {
324                    ErrorKind::WouldBlock => {
325                        self.readiness.remove(Ready::READABLE);
326                        break;
327                    }
328                    _ => {
329                        self.interest = Ready::EMPTY;
330                        self.readiness = Ready::EMPTY;
331                        return Err(ChannelError::Read(read_error));
332                    }
333                },
334                Ok(bytes_read) => {
335                    count += bytes_read;
336                    self.front_buf.fill(bytes_read);
337                }
338            };
339        }
340
341        Ok(count)
342    }
343
344    /// Handle writability by writing the content of the back buffer onto the socket.
345    /// Shrinks the back buffer back toward initial size once fully drained.
346    pub fn writable(&mut self) -> Result<usize, ChannelError> {
347        if !(self.interest & self.readiness).is_writable() {
348            return Err(ChannelError::Connection(None));
349        }
350
351        let mut count = 0usize;
352        loop {
353            let size = self.back_buf.available_data();
354            if size == 0 {
355                self.interest.remove(Ready::WRITABLE);
356                self.try_shrink_back_buf();
357                break;
358            }
359
360            match self.sock.write(self.back_buf.data()) {
361                Ok(0) => {
362                    self.interest = Ready::EMPTY;
363                    self.readiness.insert(Ready::HUP);
364                    return Err(ChannelError::NoByteWritten);
365                }
366                Ok(bytes_written) => {
367                    count += bytes_written;
368                    self.back_buf.consume(bytes_written);
369                }
370                Err(write_error) => match write_error.kind() {
371                    ErrorKind::WouldBlock => {
372                        self.readiness.remove(Ready::WRITABLE);
373                        break;
374                    }
375                    _ => {
376                        self.interest = Ready::EMPTY;
377                        self.readiness = Ready::EMPTY;
378                        return Err(ChannelError::Read(write_error));
379                    }
380                },
381            }
382        }
383
384        Ok(count)
385    }
386
387    /// Depending on the blocking status:
388    ///
389    /// Blocking: wait for the front buffer to be filled, and parse a message from it
390    ///
391    /// Nonblocking: parse a message from the front buffer, without waiting.
392    /// Prefer using `channel.readable()` before
393    pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
394        if self.blocking {
395            self.read_message_blocking()
396        } else {
397            self.read_message_nonblocking()
398        }
399    }
400
401    fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
402        self.read_message_blocking_timeout(None)
403    }
404
405    /// Parse a message from the front buffer, without waiting
406    fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
407        if let Some(message) = self.try_read_delimited_message()? {
408            self.try_shrink_front_buf();
409            return Ok(message);
410        }
411
412        self.interest.insert(Ready::READABLE);
413        Err(ChannelError::NothingRead)
414    }
415
416    /// Wait for the front buffer to be filled, and parses a message from it.
417    pub fn read_message_blocking_timeout(
418        &mut self,
419        timeout: Option<Duration>,
420    ) -> Result<Rx, ChannelError> {
421        let now = std::time::Instant::now();
422
423        // 10 ms = 100 syscalls/sec on idle WouldBlock, pinning a CPU on
424        // long blocking waits with no payload. 100 ms is
425        // a usability-acceptable resolution for the outer `timeout`
426        // deadline check (the wait is bounded by `timeout`, not by this
427        // value) and drops the steady-state read syscall rate to 10/sec.
428        self.set_timeout(Some(Duration::from_millis(100)))?;
429
430        let status = loop {
431            if let Some(timeout) = timeout {
432                if now.elapsed() >= timeout {
433                    break Err(ChannelError::TimeoutReached(timeout));
434                }
435            }
436
437            if let Some(message) = self.try_read_delimited_message()? {
438                self.try_shrink_front_buf();
439                return Ok(message);
440            }
441
442            match self.sock.read(self.front_buf.space()) {
443                Ok(0) => return Err(ChannelError::NoByteToRead),
444                Ok(bytes_read) => self.front_buf.fill(bytes_read),
445                Err(io_error) => match io_error.kind() {
446                    ErrorKind::WouldBlock => continue, // ignore 10 millisecond timeouts
447                    _ => break Err(ChannelError::Read(io_error)),
448                },
449            };
450        };
451
452        self.set_timeout(None)?;
453
454        status
455    }
456
457    /// parse a prost message from the front buffer, grow it if necessary
458    fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
459        let buffer = self.front_buf.data();
460        if buffer.len() >= delimiter_size() {
461            let delimiter = buffer[..delimiter_size()]
462                .try_into()
463                .map_err(|_| ChannelError::MismatchBufferSize)?;
464            let message_len = usize::from_le_bytes(delimiter);
465
466            // Defense in depth: bound the parser-side length up-front.
467            // Without this an attacker who controls the
468            // first 8 bytes of a frame can declare an arbitrarily large
469            // message and drive `Buffer::grow` toward the
470            // `max_buffer_size` ceiling before any byte of payload has
471            // been read. Reject as `MessageTooLarge` so the read loop
472            // disconnects cleanly instead of running the doubling growth
473            // strategy on attacker-supplied numbers.
474            if message_len > self.max_buffer_size {
475                return Err(ChannelError::MessageTooLarge {
476                    message_len,
477                    capacity: self.front_buf.capacity(),
478                    max: self.max_buffer_size,
479                });
480            }
481
482            // A length-delimited frame is `[delimiter][payload]`. The declared
483            // `message_len` is the total frame size and MUST therefore be at
484            // least `delimiter_size()`. A peer-controlled value below that
485            // ceiling makes `&buffer[delimiter_size()..message_len]` slice
486            // backwards and panic; reject it the same way as oversized frames.
487            //
488            // Drop the bogus delimiter bytes before returning so the channel
489            // can re-sync on the peer's next frame. Without this, every
490            // subsequent `read_message()` re-reads the same bad header from
491            // the front buffer and the worker burns CPU on the same error
492            // until the peer disconnects.
493            if message_len < delimiter_size() {
494                self.front_buf.consume(delimiter_size());
495                return Err(ChannelError::MessageLengthUnderDelimiter {
496                    message_len,
497                    delimiter_size: delimiter_size(),
498                });
499            }
500
501            if buffer.len() >= message_len {
502                let message = Rx::decode(&buffer[delimiter_size()..message_len])
503                    .map_err(ChannelError::InvalidProtobufMessage)?;
504                self.front_buf.consume(message_len);
505                return Ok(Some(message));
506            }
507        }
508
509        if self.front_buf.available_space() == 0 {
510            if self.front_buf.capacity() >= self.max_buffer_size {
511                return Err(ChannelError::BufferFull {
512                    capacity: self.front_buf.capacity(),
513                    max: self.max_buffer_size,
514                });
515            }
516            let new_size = self
517                .grow_size(self.front_buf.capacity())
518                .unwrap_or(self.max_buffer_size);
519            Self::check_high_watermark(
520                "front",
521                new_size,
522                self.max_buffer_size,
523                &mut self.front_high_watermark_logged,
524            );
525            self.front_buf.grow(new_size);
526        }
527        Ok(None)
528    }
529
530    /// Checks whether the channel is blocking or nonblocking, writes the message.
531    ///
532    /// If the channel is nonblocking, you have to flush using `channel.run()` afterwards
533    pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
534        if self.blocking {
535            self.write_message_blocking(message)
536        } else {
537            self.write_message_nonblocking(message)
538        }
539    }
540
541    /// Writes the message in the buffer, but NOT on the socket.
542    /// you have to call channel.run() afterwards
543    fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
544        self.write_delimited_message(message)?;
545
546        self.interest.insert(Ready::WRITABLE);
547
548        Ok(())
549    }
550
551    /// fills the back buffer with data AND writes on the socket
552    fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
553        self.write_delimited_message(message)?;
554
555        loop {
556            let size = self.back_buf.available_data();
557            if size == 0 {
558                break;
559            }
560
561            match self.sock.write(self.back_buf.data()) {
562                Ok(0) => return Err(ChannelError::NoByteWritten),
563                Ok(bytes_written) => {
564                    self.back_buf.consume(bytes_written);
565                }
566                Err(_) => return Ok(()), // are we sure?
567            }
568        }
569        Ok(())
570    }
571
572    /// write a message on the back buffer, using our own delimiter (the delimiter of prost
573    /// is not trustworthy since its size may change)
574    pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
575        let payload = message.encode_to_vec();
576
577        let payload_len = payload.len() + delimiter_size();
578
579        let delimiter = payload_len.to_le_bytes();
580
581        if payload_len > self.back_buf.available_space() {
582            self.back_buf.shift();
583        }
584
585        if payload_len > self.back_buf.available_space() {
586            let needed = payload_len - self.back_buf.available_space() + self.back_buf.capacity();
587            if needed > self.max_buffer_size {
588                return Err(ChannelError::MessageTooLarge {
589                    message_len: payload_len,
590                    capacity: self.back_buf.capacity(),
591                    max: self.max_buffer_size,
592                });
593            }
594
595            // use doubling strategy to reach at least `needed`, amortizing future writes
596            let mut new_length = self.back_buf.capacity();
597            while new_length < needed {
598                new_length = new_length.saturating_mul(2).max(new_length + 1);
599            }
600            new_length = min(new_length, self.max_buffer_size);
601            Self::check_high_watermark(
602                "back",
603                new_length,
604                self.max_buffer_size,
605                &mut self.back_high_watermark_logged,
606            );
607            self.back_buf.grow(new_length);
608        }
609
610        self.back_buf
611            .write_all(&delimiter)
612            .map_err(ChannelError::Write)?;
613        self.back_buf
614            .write_all(&payload)
615            .map_err(ChannelError::Write)?;
616
617        Ok(())
618    }
619
620    /// Shrink the front buffer back toward initial_buffer_size when it is
621    /// mostly empty (data consumed) and was previously grown.
622    fn try_shrink_front_buf(&mut self) {
623        let capacity = self.front_buf.capacity();
624        if capacity <= self.initial_buffer_size {
625            return;
626        }
627        // only shrink when the buffer has little pending data
628        if self.front_buf.available_data() * 4 < self.initial_buffer_size {
629            self.front_buf.shrink(self.initial_buffer_size);
630            self.front_high_watermark_logged = false;
631            trace!(
632                "front buffer shrunk from {} to {} bytes",
633                capacity, self.initial_buffer_size
634            );
635        }
636    }
637
638    /// Shrink the back buffer back toward initial_buffer_size when fully drained.
639    fn try_shrink_back_buf(&mut self) {
640        let capacity = self.back_buf.capacity();
641        if capacity <= self.initial_buffer_size {
642            return;
643        }
644        if self.back_buf.available_data() == 0 {
645            self.back_buf.shrink(self.initial_buffer_size);
646            self.back_high_watermark_logged = false;
647            trace!(
648                "back buffer shrunk from {} to {} bytes",
649                capacity, self.initial_buffer_size
650            );
651        }
652    }
653}
654
655/// the payload is prefixed with a delimiter of sizeof(usize) bytes
656pub const fn delimiter_size() -> usize {
657    std::mem::size_of::<usize>()
658}
659
660type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
661
662impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
663    /// creates a channel pair: `(blocking_channel, nonblocking_channel)`
664    pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
665        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
666        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
667        let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
668        command_channel.blocking()?;
669        Ok((command_channel, proxy_channel))
670    }
671
672    /// creates a pair of nonblocking channels
673    pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
674        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
675        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
676        let command_channel = Channel::new(command, buffer_size, max_buffer_size);
677        Ok((command_channel, proxy_channel))
678    }
679}
680
681impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
682    for Channel<Tx, Rx>
683{
684    type Item = Rx;
685    fn next(&mut self) -> Option<Self::Item> {
686        self.read_message().ok()
687    }
688}
689
690use mio::{Interest, Registry, Token};
691impl<Tx, Rx> Source for Channel<Tx, Rx> {
692    fn register(
693        &mut self,
694        registry: &Registry,
695        token: Token,
696        interests: Interest,
697    ) -> io::Result<()> {
698        self.sock.register(registry, token, interests)
699    }
700
701    fn reregister(
702        &mut self,
703        registry: &Registry,
704        token: Token,
705        interests: Interest,
706    ) -> io::Result<()> {
707        self.sock.reregister(registry, token, interests)
708    }
709
710    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
711        self.sock.deregister(registry)
712    }
713}
714
715#[cfg(test)]
716mod tests {
717    use std::{thread, time::Duration};
718
719    use super::*;
720
721    #[derive(Clone, PartialEq, prost::Message)]
722    pub struct ProtobufMessage {
723        #[prost(uint32, required, tag = "1")]
724        inner: u32,
725    }
726
727    fn test_channels() -> (
728        Channel<ProtobufMessage, ProtobufMessage>,
729        Channel<ProtobufMessage, ProtobufMessage>,
730    ) {
731        Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
732    }
733
734    #[test]
735    fn unblock_a_channel() {
736        let (mut blocking, _nonblocking) = test_channels();
737        assert!(blocking.nonblocking().is_ok())
738    }
739
740    #[test]
741    fn generate_blocking_and_nonblocking_channels() {
742        let (blocking_channel, nonblocking_channel) = test_channels();
743
744        assert!(blocking_channel.is_blocking());
745        assert!(!nonblocking_channel.is_blocking());
746
747        let (nonblocking_channel_1, nonblocking_channel_2): (
748            Channel<ProtobufMessage, ProtobufMessage>,
749            Channel<ProtobufMessage, ProtobufMessage>,
750        ) = Channel::generate_nonblocking(1000, 10000)
751            .expect("could not generatie nonblocking channels");
752
753        assert!(!nonblocking_channel_1.is_blocking());
754        assert!(!nonblocking_channel_2.is_blocking());
755    }
756
757    #[test]
758    fn write_and_read_message_blocking() {
759        let (mut blocking_channel, mut nonblocking_channel) = test_channels();
760
761        let message_to_send = ProtobufMessage { inner: 42 };
762
763        nonblocking_channel
764            .blocking()
765            .expect("Could not block channel");
766        nonblocking_channel
767            .write_message(&message_to_send)
768            .expect("Could not write message on channel");
769
770        trace!("we wrote a message!");
771
772        trace!("reading message..");
773        // blocking_channel.readable();
774        let message = blocking_channel
775            .read_message()
776            .expect("Could not read message on channel");
777        trace!("read message!");
778
779        assert_eq!(message, ProtobufMessage { inner: 42 });
780    }
781
782    #[test]
783    fn read_message_blocking_with_timeout_fails() {
784        let (mut reading_channel, mut writing_channel) = test_channels();
785        writing_channel.blocking().expect("Could not block channel");
786
787        trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
788        let awaiting_with_timeout = thread::spawn(move || {
789            let message =
790                reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
791            trace!("read message!");
792            message
793        });
794
795        trace!("Waiting 200 milliseconds…");
796        thread::sleep(std::time::Duration::from_millis(200));
797
798        writing_channel
799            .write_message(&ProtobufMessage { inner: 200 })
800            .expect("Could not write message on channel");
801        trace!("we wrote a message that should arrive too late!");
802
803        let arrived_too_late = awaiting_with_timeout
804            .join()
805            .expect("error with receiving message from awaiting thread");
806
807        assert!(arrived_too_late.is_err());
808    }
809
810    #[test]
811    fn read_message_blocking_with_timeout_succeeds() {
812        let (mut reading_channel, mut writing_channel) = test_channels();
813        writing_channel.blocking().expect("Could not block channel");
814
815        trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
816        let awaiting_with_timeout = thread::spawn(move || {
817            let message = reading_channel
818                .read_message_blocking_timeout(Some(Duration::from_millis(200)))
819                .expect("Could not read message with timeout on blocking channel");
820            trace!("read message!");
821            message
822        });
823
824        trace!("Waiting 100 milliseconds…");
825        thread::sleep(std::time::Duration::from_millis(100));
826
827        writing_channel
828            .write_message(&ProtobufMessage { inner: 100 })
829            .expect("Could not write message on channel");
830        trace!("we wrote a message that should arrive on time!");
831
832        let arrived_on_time = awaiting_with_timeout
833            .join()
834            .expect("error with receiving message from awaiting thread");
835
836        assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
837    }
838
839    #[test]
840    fn exhaustive_use_of_nonblocking_channels() {
841        // - two nonblocking channels A and B, identical
842        let (mut channel_a, mut channel_b) = test_channels();
843        channel_a.nonblocking().expect("Could not block channel");
844
845        // write on A
846        channel_a
847            .write_message(&ProtobufMessage { inner: 1 })
848            .expect("Could not write message on channel");
849
850        // set B as readable, normally mio tells when to, by giving events
851        channel_b.handle_events(Ready::READABLE);
852
853        // read on B
854        let should_err = channel_b.read_message();
855        assert!(should_err.is_err());
856
857        // write another message on A
858        channel_a
859            .write_message(&ProtobufMessage { inner: 2 })
860            .expect("Could not write message on channel");
861
862        // insert a handle_events Ready::writable on A
863        channel_a.handle_events(Ready::WRITABLE);
864
865        // flush A with run()
866        channel_a.run().expect("Failed to run the channel");
867
868        // maybe a thread sleep
869        thread::sleep(std::time::Duration::from_millis(100));
870
871        // receive with B using run()
872        channel_b.run().expect("Failed to run the channel");
873
874        // use read_message() twice on B, check them
875        let message_1 = channel_b
876            .read_message()
877            .expect("Could not read message on channel");
878        assert_eq!(message_1, ProtobufMessage { inner: 1 });
879
880        let message_2 = channel_b
881            .read_message()
882            .expect("Could not read message on channel");
883        assert_eq!(message_2, ProtobufMessage { inner: 2 });
884    }
885
886    #[test]
887    fn buffer_grows_with_doubling_strategy() {
888        let (writing_channel, _reading_channel): (
889            Channel<ProtobufMessage, ProtobufMessage>,
890            Channel<ProtobufMessage, ProtobufMessage>,
891        ) = Channel::generate(100, 10000).expect("could not generate channels");
892
893        assert_eq!(writing_channel.back_buf.capacity(), 100);
894
895        assert_eq!(writing_channel.grow_size(100), Some(200));
896        assert_eq!(writing_channel.grow_size(200), Some(400));
897        assert_eq!(writing_channel.grow_size(5000), Some(10000));
898        assert_eq!(writing_channel.grow_size(10000), None);
899    }
900
901    #[test]
902    fn buffer_cap_returns_error() {
903        let (mut writing_channel, _reading_channel): (
904            Channel<ProtobufMessage, ProtobufMessage>,
905            Channel<ProtobufMessage, ProtobufMessage>,
906        ) = Channel::generate(50, 50).expect("could not generate channels");
907
908        writing_channel.blocking().expect("Could not block channel");
909
910        let mut i = 0u32;
911        let result = loop {
912            let msg = ProtobufMessage { inner: i };
913            match writing_channel.write_delimited_message(&msg) {
914                Ok(()) => i += 1,
915                Err(e) => break Err(e),
916            }
917            if i > 10000 {
918                break Ok(());
919            }
920        };
921
922        assert!(result.is_err());
923        let err = result.unwrap_err();
924        let err_msg = format!("{err}");
925        assert!(
926            err_msg.contains("too large") || err_msg.contains("cannot grow"),
927            "unexpected error: {err_msg}"
928        );
929    }
930
931    #[test]
932    fn back_buffer_shrinks_after_drain() {
933        let (mut channel, _other): (
934            Channel<ProtobufMessage, ProtobufMessage>,
935            Channel<ProtobufMessage, ProtobufMessage>,
936        ) = Channel::generate(100, 10000).expect("could not generate channels");
937
938        // Write directly to the back buffer (without draining to socket)
939        // to force growth. Each message is ~10 bytes (delimiter + varint).
940        for i in 0..20 {
941            channel
942                .write_delimited_message(&ProtobufMessage { inner: i })
943                .expect("Could not write message");
944        }
945
946        let grown_capacity = channel.back_buf.capacity();
947        assert!(
948            grown_capacity > 100,
949            "expected buffer growth, got capacity {grown_capacity}"
950        );
951
952        // Simulate full drain by consuming all data
953        let data_len = channel.back_buf.available_data();
954        channel.back_buf.consume(data_len);
955        assert_eq!(channel.back_buf.available_data(), 0);
956
957        channel.try_shrink_back_buf();
958        assert_eq!(
959            channel.back_buf.capacity(),
960            100,
961            "back buffer should shrink to initial size after drain"
962        );
963    }
964
965    #[test]
966    fn back_buffer_grows_with_doubling_on_write() {
967        let (mut channel, _other): (
968            Channel<ProtobufMessage, ProtobufMessage>,
969            Channel<ProtobufMessage, ProtobufMessage>,
970        ) = Channel::generate(32, 10000).expect("could not generate channels");
971
972        assert_eq!(channel.back_buf.capacity(), 32);
973
974        // Write enough messages to force growth beyond initial capacity.
975        // Each ProtobufMessage encodes to ~4 bytes + 8-byte delimiter = ~12 bytes.
976        for i in 0..10 {
977            channel
978                .write_delimited_message(&ProtobufMessage { inner: i })
979                .expect("Could not write message");
980        }
981
982        let grown = channel.back_buf.capacity();
983        assert!(grown > 32, "expected buffer growth beyond 32, got {grown}");
984        // doubling from 32 should yield a power-of-two-like size (64, 128, 256, ...)
985        // rather than the exact needed amount
986        assert!(
987            grown.is_power_of_two() || grown == 10000,
988            "expected doubling growth pattern, got {grown}"
989        );
990    }
991
992    /// Regression: a peer that writes a length-delimited frame whose
993    /// declared length is *less than* the delimiter itself must be
994    /// rejected with `MessageLengthUnderDelimiter`, never panic the
995    /// reader with `slice index starts at N but ends at M`.
996    ///
997    /// Without the bounds check, `&buffer[delimiter_size()..message_len]`
998    /// at `try_read_delimited_message` panics for any peer-controlled
999    /// `message_len < delimiter_size()` (= 8 on 64-bit) — a one-packet
1000    /// denial-of-service against the master command socket.
1001    #[test]
1002    fn rejects_declared_length_below_delimiter() {
1003        let (mut reader, mut writer): (
1004            Channel<ProtobufMessage, ProtobufMessage>,
1005            Channel<ProtobufMessage, ProtobufMessage>,
1006        ) = Channel::generate(1000, 10000).expect("could not generate channels");
1007        writer.blocking().expect("writer to block");
1008        reader.blocking().expect("reader to block");
1009
1010        // Craft a delimiter that lies: message_len = 5 (< delimiter_size() = 8).
1011        // Send it as raw bytes, bypassing write_delimited_message.
1012        let bogus: usize = 5;
1013        let bytes = bogus.to_le_bytes();
1014        std::io::Write::write_all(&mut writer.sock, &bytes).expect("raw write of bogus delimiter");
1015
1016        match reader.read_message() {
1017            Err(ChannelError::MessageLengthUnderDelimiter {
1018                message_len,
1019                delimiter_size,
1020            }) => {
1021                assert_eq!(message_len, 5);
1022                assert_eq!(delimiter_size, std::mem::size_of::<usize>());
1023            }
1024            other => panic!(
1025                "expected MessageLengthUnderDelimiter, got {other:?}\n\
1026                 NOTE: a panic here means the slice-OOB hardening was reverted",
1027            ),
1028        }
1029    }
1030}