Skip to main content

sozu_command_lib/
channel.rs

1//! Bidirectional length-delimited unix-socket channel.
2//!
3//! Implements the master ↔ worker / master ↔ CLI message channel: each
4//! payload is preceded by a native `usize` length prefix (NOT a NUL
5//! separator — that scheme belongs to the state-file save format in
6//! `command/src/state.rs:1613`-`1630`). Bounded by the per-channel
7//! `max_buffer_size` (`channel.rs:71`) checked before payload allocation.
8
9use std::{
10    cmp::min,
11    fmt::Debug,
12    io::{self, ErrorKind, Read, Write},
13    marker::PhantomData,
14    os::unix::{
15        io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
16        net::UnixStream as StdUnixStream,
17    },
18    time::Duration,
19};
20
21use mio::{event::Source, net::UnixStream as MioUnixStream};
22use prost::{DecodeError, Message as ProstMessage};
23
24use crate::{buffer::growable::Buffer, ready::Ready};
25
26/// High watermark threshold: log a warning when buffer usage exceeds 80% of max
27const HIGH_WATERMARK_RATIO: f64 = 0.8;
28
29#[derive(thiserror::Error, Debug)]
30pub enum ChannelError {
31    #[error("io read error")]
32    Read(std::io::Error),
33    #[error("no byte written on the channel")]
34    NoByteWritten,
35    #[error("no byte left to read on the channel")]
36    NoByteToRead,
37    #[error(
38        "message ({message_len} bytes) too large for back buffer capacity ({capacity} bytes, max {max} bytes)"
39    )]
40    MessageTooLarge {
41        message_len: usize,
42        capacity: usize,
43        max: usize,
44    },
45    #[error(
46        "declared message length ({message_len} bytes) is shorter than the {delimiter_size}-byte length prefix"
47    )]
48    MessageLengthUnderDelimiter {
49        message_len: usize,
50        delimiter_size: usize,
51    },
52    #[error("channel could not write on the back buffer")]
53    Write(std::io::Error),
54    #[error("channel buffer is full ({capacity} bytes, max {max} bytes), cannot grow more")]
55    BufferFull { capacity: usize, max: usize },
56    #[error("Timeout is reached: {0:?}")]
57    TimeoutReached(Duration),
58    #[error("Could not read anything on the channel")]
59    NothingRead,
60    #[error("invalid char set in command message, ignoring: {0}")]
61    InvalidCharSet(String),
62    #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
63    SetTimeout { fd: i32, error: String },
64    #[error(
65        "Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}"
66    )]
67    BlockingStatus { fd: i32, error: String },
68    #[error("Connection error: {0:?}")]
69    Connection(Option<std::io::Error>),
70    #[error("Invalid protobuf message: {0}")]
71    InvalidProtobufMessage(DecodeError),
72    #[error("This should never happen (index out of bound on a tested buffer)")]
73    MismatchBufferSize,
74}
75
76/// Channel meant for communication between Sōzu processes over a UNIX socket.
77/// It wraps a unix socket using the mio crate, and transmit prost messages
78/// by serializing them in a binary format, with a fix-sized delimiter.
79/// To function, channels must come in pairs, one for each agent.
80/// They can function in a blocking or non-blocking way.
81pub struct Channel<Tx, Rx> {
82    pub sock: MioUnixStream,
83    pub front_buf: Buffer,
84    pub back_buf: Buffer,
85    initial_buffer_size: usize,
86    max_buffer_size: usize,
87    pub readiness: Ready,
88    pub interest: Ready,
89    blocking: bool,
90    /// true if a high watermark warning has been logged for the front buffer
91    front_high_watermark_logged: bool,
92    /// true if a high watermark warning has been logged for the back buffer
93    back_high_watermark_logged: bool,
94    phantom_tx: PhantomData<Tx>,
95    phantom_rx: PhantomData<Rx>,
96}
97
98impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct(&format!(
101            "Channel<{}, {}>",
102            std::any::type_name::<Tx>(),
103            std::any::type_name::<Rx>()
104        ))
105        .field("sock", &self.sock.as_raw_fd())
106        // .field("front_buf", &self.front_buf)
107        // .field("back_buf", &self.back_buf)
108        // .field("max_buffer_size", &self.max_buffer_size)
109        .field("readiness", &self.readiness)
110        .field("interest", &self.interest)
111        .field("blocking", &self.blocking)
112        .finish()
113    }
114}
115
116impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
117    /// Creates a nonblocking channel on a given socket path
118    pub fn from_path(
119        path: &str,
120        buffer_size: u64,
121        max_buffer_size: u64,
122    ) -> Result<Channel<Tx, Rx>, ChannelError> {
123        let unix_stream = MioUnixStream::connect(path)
124            .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
125        Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
126    }
127
128    /// Creates a nonblocking channel, using a unix stream
129    pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
130        let buffer_size = buffer_size as usize;
131        let max_buffer_size = max_buffer_size as usize;
132        Channel {
133            sock,
134            front_buf: Buffer::with_capacity(buffer_size),
135            back_buf: Buffer::with_capacity(buffer_size),
136            initial_buffer_size: buffer_size,
137            max_buffer_size,
138            readiness: Ready::EMPTY,
139            interest: Ready::READABLE,
140            blocking: false,
141            front_high_watermark_logged: false,
142            back_high_watermark_logged: false,
143            phantom_tx: PhantomData,
144            phantom_rx: PhantomData,
145        }
146    }
147
148    pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
149        self,
150    ) -> Channel<Tx2, Rx2> {
151        Channel {
152            sock: self.sock,
153            front_buf: self.front_buf,
154            back_buf: self.back_buf,
155            initial_buffer_size: self.initial_buffer_size,
156            max_buffer_size: self.max_buffer_size,
157            readiness: self.readiness,
158            interest: self.interest,
159            blocking: self.blocking,
160            front_high_watermark_logged: self.front_high_watermark_logged,
161            back_high_watermark_logged: self.back_high_watermark_logged,
162            phantom_tx: PhantomData,
163            phantom_rx: PhantomData,
164        }
165    }
166
167    // Since MioUnixStream does not have a set_nonblocking method, we have to use the standard library.
168    // We get the file descriptor of the MioUnixStream socket, create a standard library UnixStream,
169    // set it to nonblocking, let go of the file descriptor
170    fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
171        // SAFETY: `fd` is borrowed from `self.sock` for the duration of this
172        // block. We wrap it in a `StdUnixStream` to call `set_nonblocking`,
173        // then immediately release ownership again with `into_raw_fd` so the
174        // descriptor is not closed by `Drop`. `self.sock` retains the
175        // original ownership.
176        unsafe {
177            let fd = self.sock.as_raw_fd();
178            let stream = StdUnixStream::from_raw_fd(fd);
179            stream
180                .set_nonblocking(nonblocking)
181                .map_err(|error| ChannelError::BlockingStatus {
182                    fd,
183                    error: error.to_string(),
184                })?;
185            let _fd = stream.into_raw_fd();
186        }
187        self.blocking = !nonblocking;
188        Ok(())
189    }
190
191    /// set the read_timeout of the unix stream. This works only temporary, be sure to set the timeout to None afterwards.
192    fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
193        // SAFETY: `fd` is borrowed from `self.sock` for the duration of this
194        // block. We wrap it in a `StdUnixStream` to call `set_read_timeout`,
195        // then immediately release ownership again with `into_raw_fd` so the
196        // descriptor is not closed by `Drop`. `self.sock` retains the
197        // original ownership.
198        unsafe {
199            let fd = self.sock.as_raw_fd();
200            let stream = StdUnixStream::from_raw_fd(fd);
201            stream
202                .set_read_timeout(timeout)
203                .map_err(|error| ChannelError::SetTimeout {
204                    fd,
205                    error: error.to_string(),
206                })?;
207            let _fd = stream.into_raw_fd();
208        }
209        Ok(())
210    }
211
212    /// set the channel to be blocking
213    pub fn blocking(&mut self) -> Result<(), ChannelError> {
214        self.set_nonblocking(false)
215    }
216
217    /// set the channel to be nonblocking
218    pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
219        self.set_nonblocking(true)
220    }
221
222    pub fn is_blocking(&self) -> bool {
223        self.blocking
224    }
225
226    /// Get the raw file descriptor of the UNIX socket
227    pub fn fd(&self) -> RawFd {
228        self.sock.as_raw_fd()
229    }
230
231    pub fn handle_events(&mut self, events: Ready) {
232        self.readiness |= events;
233    }
234
235    pub fn readiness(&self) -> Ready {
236        self.readiness & self.interest
237    }
238
239    /// Compute the next buffer size using a doubling strategy, capped at max_buffer_size.
240    /// Returns None if the buffer is already at max capacity.
241    fn grow_size(&self, current_capacity: usize) -> Option<usize> {
242        if current_capacity >= self.max_buffer_size {
243            return None;
244        }
245        // Precondition: we only reach here below the ceiling, so a strictly
246        // larger size always exists within [current+1, max_buffer_size].
247        debug_assert!(
248            current_capacity < self.max_buffer_size,
249            "grow_size must only grow buffers that are strictly below the max ceiling"
250        );
251        // double the capacity, but don't exceed max
252        let new_size = min(current_capacity.saturating_mul(2), self.max_buffer_size);
253        // ensure we grow by at least something (in case current_capacity is 0)
254        let new_size = new_size.max(current_capacity + 1);
255        let new_size = min(new_size, self.max_buffer_size);
256        // Postconditions: the new capacity strictly grows (forward progress,
257        // no spin) and never overshoots the configured ceiling.
258        debug_assert!(
259            new_size > current_capacity,
260            "grow_size must make forward progress (new capacity strictly larger)"
261        );
262        debug_assert!(
263            new_size <= self.max_buffer_size,
264            "grow_size must never exceed the configured max_buffer_size ceiling"
265        );
266        Some(new_size)
267    }
268
269    /// Check if a buffer has exceeded the high watermark and log a warning once
270    fn check_high_watermark(
271        buffer_name: &str,
272        capacity: usize,
273        max: usize,
274        already_logged: &mut bool,
275    ) {
276        if *already_logged {
277            return;
278        }
279        let threshold = (max as f64 * HIGH_WATERMARK_RATIO) as usize;
280        if capacity >= threshold {
281            warn!(
282                "channel {} buffer reached high watermark: {} bytes ({:.0}% of {} max)",
283                buffer_name,
284                capacity,
285                (capacity as f64 / max as f64) * 100.0,
286                max,
287            );
288            *already_logged = true;
289        }
290    }
291
292    /// Check wether we want and can read or write, and calls the appropriate handler.
293    pub fn run(&mut self) -> Result<(), ChannelError> {
294        let interest = self.interest & self.readiness;
295
296        if interest.is_readable() {
297            let _ = self.readable()?;
298        }
299
300        if interest.is_writable() {
301            let _ = self.writable()?;
302        }
303        Ok(())
304    }
305
306    /// Handle readability by filling the front buffer with the socket data.
307    /// Grows the front buffer when full using a doubling strategy, up to max_buffer_size.
308    pub fn readable(&mut self) -> Result<usize, ChannelError> {
309        if !(self.interest & self.readiness).is_readable() {
310            return Err(ChannelError::Connection(None));
311        }
312
313        let mut count = 0usize;
314        loop {
315            let size = self.front_buf.available_space();
316            trace!("channel available space: {}", size);
317            if size == 0 {
318                // try to grow the buffer before giving up
319                if let Some(new_size) = self.grow_size(self.front_buf.capacity()) {
320                    Self::check_high_watermark(
321                        "front",
322                        new_size,
323                        self.max_buffer_size,
324                        &mut self.front_high_watermark_logged,
325                    );
326                    self.front_buf.grow(new_size);
327                    // The read buffer must never grow past the configured ceiling.
328                    debug_assert!(
329                        self.front_buf.capacity() <= self.max_buffer_size,
330                        "front buffer capacity must stay within the max_buffer_size ceiling"
331                    );
332                } else {
333                    self.interest.remove(Ready::READABLE);
334                    break;
335                }
336            }
337
338            // The slice handed to `read` is exactly the free tail; the kernel
339            // can never write more than that, so a successful read of N bytes
340            // is bounded by the space we just measured (post-grow).
341            let space_before = self.front_buf.available_space();
342            debug_assert!(
343                space_before > 0,
344                "readable must only call read() with non-empty space (zero-space path grows or breaks)"
345            );
346            let data_before = self.front_buf.available_data();
347            match self.sock.read(self.front_buf.space()) {
348                Ok(0) => {
349                    self.interest = Ready::EMPTY;
350                    self.readiness.remove(Ready::READABLE);
351                    self.readiness.insert(Ready::HUP);
352                    return Err(ChannelError::NoByteToRead);
353                }
354                Err(read_error) => match read_error.kind() {
355                    ErrorKind::WouldBlock => {
356                        self.readiness.remove(Ready::READABLE);
357                        break;
358                    }
359                    _ => {
360                        self.interest = Ready::EMPTY;
361                        self.readiness = Ready::EMPTY;
362                        return Err(ChannelError::Read(read_error));
363                    }
364                },
365                Ok(bytes_read) => {
366                    // A read can never deliver more bytes than the free space
367                    // it was handed; otherwise `fill` would corrupt offsets.
368                    debug_assert!(
369                        bytes_read <= space_before,
370                        "read delivered more bytes than the buffer space it was given"
371                    );
372                    count += bytes_read;
373                    self.front_buf.fill(bytes_read);
374                    // `fill` advances `end` by exactly the bytes read, so the
375                    // available data grows by that delta — pair-assert the
376                    // offset mutation.
377                    debug_assert_eq!(
378                        self.front_buf.available_data(),
379                        data_before + bytes_read,
380                        "front buffer available_data must increase by exactly bytes_read"
381                    );
382                }
383            };
384        }
385
386        Ok(count)
387    }
388
389    /// Handle writability by writing the content of the back buffer onto the socket.
390    /// Shrinks the back buffer back toward initial size once fully drained.
391    pub fn writable(&mut self) -> Result<usize, ChannelError> {
392        if !(self.interest & self.readiness).is_writable() {
393            return Err(ChannelError::Connection(None));
394        }
395
396        let mut count = 0usize;
397        loop {
398            let size = self.back_buf.available_data();
399            if size == 0 {
400                self.interest.remove(Ready::WRITABLE);
401                self.try_shrink_back_buf();
402                break;
403            }
404
405            let data_before = self.back_buf.available_data();
406            match self.sock.write(self.back_buf.data()) {
407                Ok(0) => {
408                    self.interest = Ready::EMPTY;
409                    self.readiness.insert(Ready::HUP);
410                    return Err(ChannelError::NoByteWritten);
411                }
412                Ok(bytes_written) => {
413                    // The kernel cannot accept more than the slice (`data()`)
414                    // it was handed; otherwise `consume` would underflow.
415                    debug_assert!(
416                        bytes_written <= data_before,
417                        "write reported more bytes than the buffer data it was given"
418                    );
419                    count += bytes_written;
420                    let consumed = self.back_buf.consume(bytes_written);
421                    // `consume` is saturating but the bound above guarantees an
422                    // exact consume here; pair-assert the offset mutation.
423                    debug_assert_eq!(
424                        consumed, bytes_written,
425                        "back buffer must consume exactly the bytes written to the socket"
426                    );
427                    debug_assert_eq!(
428                        self.back_buf.available_data(),
429                        data_before - bytes_written,
430                        "back buffer available_data must shrink by exactly bytes_written"
431                    );
432                }
433                Err(write_error) => match write_error.kind() {
434                    ErrorKind::WouldBlock => {
435                        self.readiness.remove(Ready::WRITABLE);
436                        break;
437                    }
438                    _ => {
439                        self.interest = Ready::EMPTY;
440                        self.readiness = Ready::EMPTY;
441                        return Err(ChannelError::Read(write_error));
442                    }
443                },
444            }
445        }
446
447        Ok(count)
448    }
449
450    /// Depending on the blocking status:
451    ///
452    /// Blocking: wait for the front buffer to be filled, and parse a message from it
453    ///
454    /// Nonblocking: parse a message from the front buffer, without waiting.
455    /// Prefer using `channel.readable()` before
456    pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
457        if self.blocking {
458            self.read_message_blocking()
459        } else {
460            self.read_message_nonblocking()
461        }
462    }
463
464    fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
465        self.read_message_blocking_timeout(None)
466    }
467
468    /// Parse a message from the front buffer, without waiting
469    fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
470        if let Some(message) = self.try_read_delimited_message()? {
471            self.try_shrink_front_buf();
472            return Ok(message);
473        }
474
475        self.interest.insert(Ready::READABLE);
476        Err(ChannelError::NothingRead)
477    }
478
479    /// Wait for the front buffer to be filled, and parses a message from it.
480    pub fn read_message_blocking_timeout(
481        &mut self,
482        timeout: Option<Duration>,
483    ) -> Result<Rx, ChannelError> {
484        let now = std::time::Instant::now();
485
486        // 10 ms = 100 syscalls/sec on idle WouldBlock, pinning a CPU on
487        // long blocking waits with no payload. 100 ms is
488        // a usability-acceptable resolution for the outer `timeout`
489        // deadline check (the wait is bounded by `timeout`, not by this
490        // value) and drops the steady-state read syscall rate to 10/sec.
491        self.set_timeout(Some(Duration::from_millis(100)))?;
492
493        let status = loop {
494            if let Some(timeout) = timeout {
495                if now.elapsed() >= timeout {
496                    break Err(ChannelError::TimeoutReached(timeout));
497                }
498            }
499
500            if let Some(message) = self.try_read_delimited_message()? {
501                self.try_shrink_front_buf();
502                return Ok(message);
503            }
504
505            match self.sock.read(self.front_buf.space()) {
506                Ok(0) => return Err(ChannelError::NoByteToRead),
507                Ok(bytes_read) => self.front_buf.fill(bytes_read),
508                Err(io_error) => match io_error.kind() {
509                    ErrorKind::WouldBlock => continue, // ignore 10 millisecond timeouts
510                    _ => break Err(ChannelError::Read(io_error)),
511                },
512            };
513        };
514
515        self.set_timeout(None)?;
516
517        status
518    }
519
520    /// parse a prost message from the front buffer, grow it if necessary
521    fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
522        // Invariant guarding all the slice indexing below: the front buffer can
523        // never have grown past the configured ceiling. Every grow site routes
524        // through `grow_size`/`max_buffer_size`; if this ever fired the length
525        // checks would be reasoning against a stale bound.
526        debug_assert!(
527            self.front_buf.capacity() <= self.max_buffer_size,
528            "front buffer capacity must never exceed max_buffer_size"
529        );
530        let buffer = self.front_buf.data();
531        // `data()` returns `memory[position..end]`, so its length is exactly the
532        // available data and can never exceed the buffer capacity.
533        debug_assert!(
534            buffer.len() <= self.front_buf.capacity(),
535            "available data slice cannot exceed buffer capacity"
536        );
537        if buffer.len() >= delimiter_size() {
538            let delimiter = buffer[..delimiter_size()]
539                .try_into()
540                .map_err(|_| ChannelError::MismatchBufferSize)?;
541            let message_len = usize::from_le_bytes(delimiter);
542
543            // Defense in depth: bound the parser-side length up-front.
544            // Without this an attacker who controls the
545            // first 8 bytes of a frame can declare an arbitrarily large
546            // message and drive `Buffer::grow` toward the
547            // `max_buffer_size` ceiling before any byte of payload has
548            // been read. Reject as `MessageTooLarge` so the read loop
549            // disconnects cleanly instead of running the doubling growth
550            // strategy on attacker-supplied numbers.
551            if message_len > self.max_buffer_size {
552                return Err(ChannelError::MessageTooLarge {
553                    message_len,
554                    capacity: self.front_buf.capacity(),
555                    max: self.max_buffer_size,
556                });
557            }
558
559            // A length-delimited frame is `[delimiter][payload]`. The declared
560            // `message_len` is the total frame size and MUST therefore be at
561            // least `delimiter_size()`. A peer-controlled value below that
562            // ceiling makes `&buffer[delimiter_size()..message_len]` slice
563            // backwards and panic; reject it the same way as oversized frames.
564            //
565            // Drop the bogus delimiter bytes before returning so the channel
566            // can re-sync on the peer's next frame. Without this, every
567            // subsequent `read_message()` re-reads the same bad header from
568            // the front buffer and the worker burns CPU on the same error
569            // until the peer disconnects.
570            if message_len < delimiter_size() {
571                self.front_buf.consume(delimiter_size());
572                return Err(ChannelError::MessageLengthUnderDelimiter {
573                    message_len,
574                    delimiter_size: delimiter_size(),
575                });
576            }
577
578            if buffer.len() >= message_len {
579                // By the time we slice, the two guards above have proven the
580                // length prefix is well-formed: it is at least the delimiter
581                // (so `delimiter_size()..message_len` runs forward) and at most
582                // the configured ceiling (so it cannot drive growth). The
583                // `buffer.len() >= message_len` branch then guarantees the whole
584                // frame is in the buffer. These are the exact invariants that
585                // keep the slice and the `consume` in bounds — never reachable
586                // from a malformed length, which already returned an error.
587                debug_assert!(
588                    message_len >= delimiter_size(),
589                    "decode path requires a frame at least as large as its delimiter"
590                );
591                debug_assert!(
592                    message_len <= self.max_buffer_size,
593                    "decode path requires the declared length within the max ceiling"
594                );
595                debug_assert!(
596                    message_len <= buffer.len(),
597                    "decode path requires the full frame to be buffered before slicing"
598                );
599                let available_before = self.front_buf.available_data();
600                debug_assert_eq!(
601                    available_before,
602                    buffer.len(),
603                    "available_data must equal the data slice length we validated against"
604                );
605                let message = Rx::decode(&buffer[delimiter_size()..message_len])
606                    .map_err(ChannelError::InvalidProtobufMessage)?;
607                let consumed = self.front_buf.consume(message_len);
608                // The whole frame (delimiter + payload) is consumed exactly:
609                // pair-assert that consume advanced by message_len and the data
610                // pointer moved forward by the same amount.
611                debug_assert_eq!(
612                    consumed, message_len,
613                    "must consume exactly the validated frame length"
614                );
615                debug_assert_eq!(
616                    self.front_buf.available_data(),
617                    available_before - message_len,
618                    "available_data must drop by exactly the consumed frame length"
619                );
620                return Ok(Some(message));
621            }
622        }
623
624        if self.front_buf.available_space() == 0 {
625            if self.front_buf.capacity() >= self.max_buffer_size {
626                return Err(ChannelError::BufferFull {
627                    capacity: self.front_buf.capacity(),
628                    max: self.max_buffer_size,
629                });
630            }
631            let new_size = self
632                .grow_size(self.front_buf.capacity())
633                .unwrap_or(self.max_buffer_size);
634            Self::check_high_watermark(
635                "front",
636                new_size,
637                self.max_buffer_size,
638                &mut self.front_high_watermark_logged,
639            );
640            self.front_buf.grow(new_size);
641        }
642        Ok(None)
643    }
644
645    /// Checks whether the channel is blocking or nonblocking, writes the message.
646    ///
647    /// If the channel is nonblocking, you have to flush using `channel.run()` afterwards
648    pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
649        if self.blocking {
650            self.write_message_blocking(message)
651        } else {
652            self.write_message_nonblocking(message)
653        }
654    }
655
656    /// Writes the message in the buffer, but NOT on the socket.
657    /// you have to call channel.run() afterwards
658    fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
659        self.write_delimited_message(message)?;
660
661        self.interest.insert(Ready::WRITABLE);
662
663        Ok(())
664    }
665
666    /// fills the back buffer with data AND writes on the socket
667    fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
668        self.write_delimited_message(message)?;
669
670        loop {
671            let size = self.back_buf.available_data();
672            if size == 0 {
673                break;
674            }
675
676            match self.sock.write(self.back_buf.data()) {
677                Ok(0) => return Err(ChannelError::NoByteWritten),
678                Ok(bytes_written) => {
679                    self.back_buf.consume(bytes_written);
680                }
681                Err(_) => return Ok(()), // are we sure?
682            }
683        }
684        Ok(())
685    }
686
687    /// write a message on the back buffer, using our own delimiter (the delimiter of prost
688    /// is not trustworthy since its size may change)
689    pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
690        let payload = message.encode_to_vec();
691
692        let payload_len = payload.len() + delimiter_size();
693
694        // The framed length is the payload plus a fixed delimiter, so it is by
695        // construction at least one delimiter wide — the mirror of the
696        // `MessageLengthUnderDelimiter` invariant the reader enforces.
697        debug_assert!(
698            payload_len >= delimiter_size(),
699            "framed length must include the fixed-size delimiter prefix"
700        );
701
702        let delimiter = payload_len.to_le_bytes();
703
704        if payload_len > self.back_buf.available_space() {
705            self.back_buf.shift();
706        }
707
708        let data_before = self.back_buf.available_data();
709        if payload_len > self.back_buf.available_space() {
710            let needed = payload_len - self.back_buf.available_space() + self.back_buf.capacity();
711            if needed > self.max_buffer_size {
712                return Err(ChannelError::MessageTooLarge {
713                    message_len: payload_len,
714                    capacity: self.back_buf.capacity(),
715                    max: self.max_buffer_size,
716                });
717            }
718            // Past the ceiling check, the required size is within the ceiling
719            // and at least the current capacity (we only enter on shortfall).
720            debug_assert!(
721                needed <= self.max_buffer_size,
722                "grow target must be within the max ceiling once the cap check passed"
723            );
724
725            let capacity_before = self.back_buf.capacity();
726            // use doubling strategy to reach at least `needed`, amortizing future writes
727            let mut new_length = self.back_buf.capacity();
728            while new_length < needed {
729                new_length = new_length.saturating_mul(2).max(new_length + 1);
730            }
731            new_length = min(new_length, self.max_buffer_size);
732            // Post-grow target: large enough to fit the frame yet capped at the
733            // configured ceiling and never below where we started.
734            debug_assert!(
735                new_length >= needed,
736                "doubling growth must reach at least the needed capacity"
737            );
738            debug_assert!(
739                new_length <= self.max_buffer_size,
740                "grown back buffer must stay within the max_buffer_size ceiling"
741            );
742            debug_assert!(
743                new_length >= capacity_before,
744                "growth must never shrink the back buffer"
745            );
746            Self::check_high_watermark(
747                "back",
748                new_length,
749                self.max_buffer_size,
750                &mut self.back_high_watermark_logged,
751            );
752            self.back_buf.grow(new_length);
753            // After the grow the frame must fit in the now-available space.
754            debug_assert!(
755                payload_len <= self.back_buf.available_space(),
756                "back buffer must have room for the full frame after growth"
757            );
758        }
759
760        self.back_buf
761            .write_all(&delimiter)
762            .map_err(ChannelError::Write)?;
763        self.back_buf
764            .write_all(&payload)
765            .map_err(ChannelError::Write)?;
766
767        // The two writes appended exactly `payload_len` bytes (delimiter +
768        // payload) to the back buffer's pending data.
769        debug_assert_eq!(
770            self.back_buf.available_data(),
771            data_before + payload_len,
772            "back buffer pending data must grow by exactly the framed length"
773        );
774        debug_assert!(
775            self.back_buf.capacity() <= self.max_buffer_size,
776            "back buffer capacity must never exceed the max_buffer_size ceiling"
777        );
778
779        Ok(())
780    }
781
782    /// Shrink the front buffer back toward initial_buffer_size when it is
783    /// mostly empty (data consumed) and was previously grown.
784    fn try_shrink_front_buf(&mut self) {
785        let capacity = self.front_buf.capacity();
786        if capacity <= self.initial_buffer_size {
787            return;
788        }
789        // Past the early return, we are strictly above the floor, so a shrink
790        // back to `initial_buffer_size` is a genuine reduction.
791        debug_assert!(
792            capacity > self.initial_buffer_size,
793            "shrink path only runs when capacity is above the initial floor"
794        );
795        // only shrink when the buffer has little pending data
796        if self.front_buf.available_data() * 4 < self.initial_buffer_size {
797            let data_before = self.front_buf.available_data();
798            self.front_buf.shrink(self.initial_buffer_size);
799            self.front_high_watermark_logged = false;
800            // Shrink preserves pending data and never drops below the floor.
801            debug_assert!(
802                self.front_buf.capacity() >= self.initial_buffer_size,
803                "front buffer must never shrink below the initial buffer size floor"
804            );
805            debug_assert_eq!(
806                self.front_buf.available_data(),
807                data_before,
808                "shrink must preserve all pending front-buffer data"
809            );
810            trace!(
811                "front buffer shrunk from {} to {} bytes",
812                capacity, self.initial_buffer_size
813            );
814        }
815    }
816
817    /// Shrink the back buffer back toward initial_buffer_size when fully drained.
818    fn try_shrink_back_buf(&mut self) {
819        let capacity = self.back_buf.capacity();
820        if capacity <= self.initial_buffer_size {
821            return;
822        }
823        debug_assert!(
824            capacity > self.initial_buffer_size,
825            "shrink path only runs when capacity is above the initial floor"
826        );
827        if self.back_buf.available_data() == 0 {
828            self.back_buf.shrink(self.initial_buffer_size);
829            self.back_high_watermark_logged = false;
830            // The back buffer is only shrunk once fully drained; it must end at
831            // the floor with no pending data resurrected.
832            debug_assert!(
833                self.back_buf.capacity() >= self.initial_buffer_size,
834                "back buffer must never shrink below the initial buffer size floor"
835            );
836            debug_assert_eq!(
837                self.back_buf.available_data(),
838                0,
839                "back buffer must stay empty across a drained shrink"
840            );
841            trace!(
842                "back buffer shrunk from {} to {} bytes",
843                capacity, self.initial_buffer_size
844            );
845        }
846    }
847}
848
849/// the payload is prefixed with a delimiter of sizeof(usize) bytes
850pub const fn delimiter_size() -> usize {
851    std::mem::size_of::<usize>()
852}
853
854type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
855
856impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
857    /// creates a channel pair: `(blocking_channel, nonblocking_channel)`
858    pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
859        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
860        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
861        let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
862        command_channel.blocking()?;
863        Ok((command_channel, proxy_channel))
864    }
865
866    /// creates a pair of nonblocking channels
867    pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
868        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
869        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
870        let command_channel = Channel::new(command, buffer_size, max_buffer_size);
871        Ok((command_channel, proxy_channel))
872    }
873}
874
875impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
876    for Channel<Tx, Rx>
877{
878    type Item = Rx;
879    fn next(&mut self) -> Option<Self::Item> {
880        self.read_message().ok()
881    }
882}
883
884use mio::{Interest, Registry, Token};
885impl<Tx, Rx> Source for Channel<Tx, Rx> {
886    fn register(
887        &mut self,
888        registry: &Registry,
889        token: Token,
890        interests: Interest,
891    ) -> io::Result<()> {
892        self.sock.register(registry, token, interests)
893    }
894
895    fn reregister(
896        &mut self,
897        registry: &Registry,
898        token: Token,
899        interests: Interest,
900    ) -> io::Result<()> {
901        self.sock.reregister(registry, token, interests)
902    }
903
904    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
905        self.sock.deregister(registry)
906    }
907}
908
909#[cfg(test)]
910mod tests {
911    use std::{thread, time::Duration};
912
913    use super::*;
914
915    #[derive(Clone, PartialEq, prost::Message)]
916    pub struct ProtobufMessage {
917        #[prost(uint32, required, tag = "1")]
918        inner: u32,
919    }
920
921    fn test_channels() -> (
922        Channel<ProtobufMessage, ProtobufMessage>,
923        Channel<ProtobufMessage, ProtobufMessage>,
924    ) {
925        Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
926    }
927
928    #[test]
929    fn unblock_a_channel() {
930        let (mut blocking, _nonblocking) = test_channels();
931        assert!(blocking.nonblocking().is_ok())
932    }
933
934    #[test]
935    fn generate_blocking_and_nonblocking_channels() {
936        let (blocking_channel, nonblocking_channel) = test_channels();
937
938        assert!(blocking_channel.is_blocking());
939        assert!(!nonblocking_channel.is_blocking());
940
941        let (nonblocking_channel_1, nonblocking_channel_2): (
942            Channel<ProtobufMessage, ProtobufMessage>,
943            Channel<ProtobufMessage, ProtobufMessage>,
944        ) = Channel::generate_nonblocking(1000, 10000)
945            .expect("could not generatie nonblocking channels");
946
947        assert!(!nonblocking_channel_1.is_blocking());
948        assert!(!nonblocking_channel_2.is_blocking());
949    }
950
951    #[test]
952    fn write_and_read_message_blocking() {
953        let (mut blocking_channel, mut nonblocking_channel) = test_channels();
954
955        let message_to_send = ProtobufMessage { inner: 42 };
956
957        nonblocking_channel
958            .blocking()
959            .expect("Could not block channel");
960        nonblocking_channel
961            .write_message(&message_to_send)
962            .expect("Could not write message on channel");
963
964        trace!("we wrote a message!");
965
966        trace!("reading message..");
967        // blocking_channel.readable();
968        let message = blocking_channel
969            .read_message()
970            .expect("Could not read message on channel");
971        trace!("read message!");
972
973        assert_eq!(message, ProtobufMessage { inner: 42 });
974    }
975
976    #[test]
977    fn read_message_blocking_with_timeout_fails() {
978        let (mut reading_channel, mut writing_channel) = test_channels();
979        writing_channel.blocking().expect("Could not block channel");
980
981        trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
982        let awaiting_with_timeout = thread::spawn(move || {
983            let message =
984                reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
985            trace!("read message!");
986            message
987        });
988
989        trace!("Waiting 200 milliseconds…");
990        thread::sleep(std::time::Duration::from_millis(200));
991
992        writing_channel
993            .write_message(&ProtobufMessage { inner: 200 })
994            .expect("Could not write message on channel");
995        trace!("we wrote a message that should arrive too late!");
996
997        let arrived_too_late = awaiting_with_timeout
998            .join()
999            .expect("error with receiving message from awaiting thread");
1000
1001        assert!(arrived_too_late.is_err());
1002    }
1003
1004    #[test]
1005    fn read_message_blocking_with_timeout_succeeds() {
1006        let (mut reading_channel, mut writing_channel) = test_channels();
1007        writing_channel.blocking().expect("Could not block channel");
1008
1009        trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
1010        let awaiting_with_timeout = thread::spawn(move || {
1011            let message = reading_channel
1012                .read_message_blocking_timeout(Some(Duration::from_millis(200)))
1013                .expect("Could not read message with timeout on blocking channel");
1014            trace!("read message!");
1015            message
1016        });
1017
1018        trace!("Waiting 100 milliseconds…");
1019        thread::sleep(std::time::Duration::from_millis(100));
1020
1021        writing_channel
1022            .write_message(&ProtobufMessage { inner: 100 })
1023            .expect("Could not write message on channel");
1024        trace!("we wrote a message that should arrive on time!");
1025
1026        let arrived_on_time = awaiting_with_timeout
1027            .join()
1028            .expect("error with receiving message from awaiting thread");
1029
1030        assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
1031    }
1032
1033    #[test]
1034    fn exhaustive_use_of_nonblocking_channels() {
1035        // - two nonblocking channels A and B, identical
1036        let (mut channel_a, mut channel_b) = test_channels();
1037        channel_a.nonblocking().expect("Could not block channel");
1038
1039        // write on A
1040        channel_a
1041            .write_message(&ProtobufMessage { inner: 1 })
1042            .expect("Could not write message on channel");
1043
1044        // set B as readable, normally mio tells when to, by giving events
1045        channel_b.handle_events(Ready::READABLE);
1046
1047        // read on B
1048        let should_err = channel_b.read_message();
1049        assert!(should_err.is_err());
1050
1051        // write another message on A
1052        channel_a
1053            .write_message(&ProtobufMessage { inner: 2 })
1054            .expect("Could not write message on channel");
1055
1056        // insert a handle_events Ready::writable on A
1057        channel_a.handle_events(Ready::WRITABLE);
1058
1059        // flush A with run()
1060        channel_a.run().expect("Failed to run the channel");
1061
1062        // maybe a thread sleep
1063        thread::sleep(std::time::Duration::from_millis(100));
1064
1065        // receive with B using run()
1066        channel_b.run().expect("Failed to run the channel");
1067
1068        // use read_message() twice on B, check them
1069        let message_1 = channel_b
1070            .read_message()
1071            .expect("Could not read message on channel");
1072        assert_eq!(message_1, ProtobufMessage { inner: 1 });
1073
1074        let message_2 = channel_b
1075            .read_message()
1076            .expect("Could not read message on channel");
1077        assert_eq!(message_2, ProtobufMessage { inner: 2 });
1078    }
1079
1080    #[test]
1081    fn buffer_grows_with_doubling_strategy() {
1082        let (writing_channel, _reading_channel): (
1083            Channel<ProtobufMessage, ProtobufMessage>,
1084            Channel<ProtobufMessage, ProtobufMessage>,
1085        ) = Channel::generate(100, 10000).expect("could not generate channels");
1086
1087        assert_eq!(writing_channel.back_buf.capacity(), 100);
1088
1089        assert_eq!(writing_channel.grow_size(100), Some(200));
1090        assert_eq!(writing_channel.grow_size(200), Some(400));
1091        assert_eq!(writing_channel.grow_size(5000), Some(10000));
1092        assert_eq!(writing_channel.grow_size(10000), None);
1093    }
1094
1095    #[test]
1096    fn buffer_cap_returns_error() {
1097        let (mut writing_channel, _reading_channel): (
1098            Channel<ProtobufMessage, ProtobufMessage>,
1099            Channel<ProtobufMessage, ProtobufMessage>,
1100        ) = Channel::generate(50, 50).expect("could not generate channels");
1101
1102        writing_channel.blocking().expect("Could not block channel");
1103
1104        let mut i = 0u32;
1105        let result = loop {
1106            let msg = ProtobufMessage { inner: i };
1107            match writing_channel.write_delimited_message(&msg) {
1108                Ok(()) => i += 1,
1109                Err(e) => break Err(e),
1110            }
1111            if i > 10000 {
1112                break Ok(());
1113            }
1114        };
1115
1116        assert!(result.is_err());
1117        let err = result.unwrap_err();
1118        let err_msg = format!("{err}");
1119        assert!(
1120            err_msg.contains("too large") || err_msg.contains("cannot grow"),
1121            "unexpected error: {err_msg}"
1122        );
1123    }
1124
1125    #[test]
1126    fn back_buffer_shrinks_after_drain() {
1127        let (mut channel, _other): (
1128            Channel<ProtobufMessage, ProtobufMessage>,
1129            Channel<ProtobufMessage, ProtobufMessage>,
1130        ) = Channel::generate(100, 10000).expect("could not generate channels");
1131
1132        // Write directly to the back buffer (without draining to socket)
1133        // to force growth. Each message is ~10 bytes (delimiter + varint).
1134        for i in 0..20 {
1135            channel
1136                .write_delimited_message(&ProtobufMessage { inner: i })
1137                .expect("Could not write message");
1138        }
1139
1140        let grown_capacity = channel.back_buf.capacity();
1141        assert!(
1142            grown_capacity > 100,
1143            "expected buffer growth, got capacity {grown_capacity}"
1144        );
1145
1146        // Simulate full drain by consuming all data
1147        let data_len = channel.back_buf.available_data();
1148        channel.back_buf.consume(data_len);
1149        assert_eq!(channel.back_buf.available_data(), 0);
1150
1151        channel.try_shrink_back_buf();
1152        assert_eq!(
1153            channel.back_buf.capacity(),
1154            100,
1155            "back buffer should shrink to initial size after drain"
1156        );
1157    }
1158
1159    #[test]
1160    fn back_buffer_grows_with_doubling_on_write() {
1161        let (mut channel, _other): (
1162            Channel<ProtobufMessage, ProtobufMessage>,
1163            Channel<ProtobufMessage, ProtobufMessage>,
1164        ) = Channel::generate(32, 10000).expect("could not generate channels");
1165
1166        assert_eq!(channel.back_buf.capacity(), 32);
1167
1168        // Write enough messages to force growth beyond initial capacity.
1169        // Each ProtobufMessage encodes to ~4 bytes + 8-byte delimiter = ~12 bytes.
1170        for i in 0..10 {
1171            channel
1172                .write_delimited_message(&ProtobufMessage { inner: i })
1173                .expect("Could not write message");
1174        }
1175
1176        let grown = channel.back_buf.capacity();
1177        assert!(grown > 32, "expected buffer growth beyond 32, got {grown}");
1178        // doubling from 32 should yield a power-of-two-like size (64, 128, 256, ...)
1179        // rather than the exact needed amount
1180        assert!(
1181            grown.is_power_of_two() || grown == 10000,
1182            "expected doubling growth pattern, got {grown}"
1183        );
1184    }
1185
1186    /// Regression: a peer that writes a length-delimited frame whose
1187    /// declared length is *less than* the delimiter itself must be
1188    /// rejected with `MessageLengthUnderDelimiter`, never panic the
1189    /// reader with `slice index starts at N but ends at M`.
1190    ///
1191    /// Without the bounds check, `&buffer[delimiter_size()..message_len]`
1192    /// at `try_read_delimited_message` panics for any peer-controlled
1193    /// `message_len < delimiter_size()` (= 8 on 64-bit) — a one-packet
1194    /// denial-of-service against the master command socket.
1195    #[test]
1196    fn rejects_declared_length_below_delimiter() {
1197        let (mut reader, mut writer): (
1198            Channel<ProtobufMessage, ProtobufMessage>,
1199            Channel<ProtobufMessage, ProtobufMessage>,
1200        ) = Channel::generate(1000, 10000).expect("could not generate channels");
1201        writer.blocking().expect("writer to block");
1202        reader.blocking().expect("reader to block");
1203
1204        // Craft a delimiter that lies: message_len = 5 (< delimiter_size() = 8).
1205        // Send it as raw bytes, bypassing write_delimited_message.
1206        let bogus: usize = 5;
1207        let bytes = bogus.to_le_bytes();
1208        std::io::Write::write_all(&mut writer.sock, &bytes).expect("raw write of bogus delimiter");
1209
1210        match reader.read_message() {
1211            Err(ChannelError::MessageLengthUnderDelimiter {
1212                message_len,
1213                delimiter_size,
1214            }) => {
1215                assert_eq!(message_len, 5);
1216                assert_eq!(delimiter_size, std::mem::size_of::<usize>());
1217            }
1218            other => panic!(
1219                "expected MessageLengthUnderDelimiter, got {other:?}\n\
1220                 NOTE: a panic here means the slice-OOB hardening was reverted",
1221            ),
1222        }
1223    }
1224}