1use std::{
10 cmp::min,
11 fmt::Debug,
12 io::{self, ErrorKind, Read, Write},
13 marker::PhantomData,
14 os::unix::{
15 io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
16 net::UnixStream as StdUnixStream,
17 },
18 time::Duration,
19};
20
21use mio::{event::Source, net::UnixStream as MioUnixStream};
22use prost::{DecodeError, Message as ProstMessage};
23
24use crate::{buffer::growable::Buffer, ready::Ready};
25
26const HIGH_WATERMARK_RATIO: f64 = 0.8;
28
29#[derive(thiserror::Error, Debug)]
30pub enum ChannelError {
31 #[error("io read error")]
32 Read(std::io::Error),
33 #[error("no byte written on the channel")]
34 NoByteWritten,
35 #[error("no byte left to read on the channel")]
36 NoByteToRead,
37 #[error(
38 "message ({message_len} bytes) too large for back buffer capacity ({capacity} bytes, max {max} bytes)"
39 )]
40 MessageTooLarge {
41 message_len: usize,
42 capacity: usize,
43 max: usize,
44 },
45 #[error(
46 "declared message length ({message_len} bytes) is shorter than the {delimiter_size}-byte length prefix"
47 )]
48 MessageLengthUnderDelimiter {
49 message_len: usize,
50 delimiter_size: usize,
51 },
52 #[error("channel could not write on the back buffer")]
53 Write(std::io::Error),
54 #[error("channel buffer is full ({capacity} bytes, max {max} bytes), cannot grow more")]
55 BufferFull { capacity: usize, max: usize },
56 #[error("Timeout is reached: {0:?}")]
57 TimeoutReached(Duration),
58 #[error("Could not read anything on the channel")]
59 NothingRead,
60 #[error("invalid char set in command message, ignoring: {0}")]
61 InvalidCharSet(String),
62 #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
63 SetTimeout { fd: i32, error: String },
64 #[error(
65 "Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}"
66 )]
67 BlockingStatus { fd: i32, error: String },
68 #[error("Connection error: {0:?}")]
69 Connection(Option<std::io::Error>),
70 #[error("Invalid protobuf message: {0}")]
71 InvalidProtobufMessage(DecodeError),
72 #[error("This should never happen (index out of bound on a tested buffer)")]
73 MismatchBufferSize,
74}
75
76pub struct Channel<Tx, Rx> {
82 pub sock: MioUnixStream,
83 pub front_buf: Buffer,
84 pub back_buf: Buffer,
85 initial_buffer_size: usize,
86 max_buffer_size: usize,
87 pub readiness: Ready,
88 pub interest: Ready,
89 blocking: bool,
90 front_high_watermark_logged: bool,
92 back_high_watermark_logged: bool,
94 phantom_tx: PhantomData<Tx>,
95 phantom_rx: PhantomData<Rx>,
96}
97
98impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct(&format!(
101 "Channel<{}, {}>",
102 std::any::type_name::<Tx>(),
103 std::any::type_name::<Rx>()
104 ))
105 .field("sock", &self.sock.as_raw_fd())
106 .field("readiness", &self.readiness)
110 .field("interest", &self.interest)
111 .field("blocking", &self.blocking)
112 .finish()
113 }
114}
115
116impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
117 pub fn from_path(
119 path: &str,
120 buffer_size: u64,
121 max_buffer_size: u64,
122 ) -> Result<Channel<Tx, Rx>, ChannelError> {
123 let unix_stream = MioUnixStream::connect(path)
124 .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
125 Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
126 }
127
128 pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
130 let buffer_size = buffer_size as usize;
131 let max_buffer_size = max_buffer_size as usize;
132 Channel {
133 sock,
134 front_buf: Buffer::with_capacity(buffer_size),
135 back_buf: Buffer::with_capacity(buffer_size),
136 initial_buffer_size: buffer_size,
137 max_buffer_size,
138 readiness: Ready::EMPTY,
139 interest: Ready::READABLE,
140 blocking: false,
141 front_high_watermark_logged: false,
142 back_high_watermark_logged: false,
143 phantom_tx: PhantomData,
144 phantom_rx: PhantomData,
145 }
146 }
147
148 pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
149 self,
150 ) -> Channel<Tx2, Rx2> {
151 Channel {
152 sock: self.sock,
153 front_buf: self.front_buf,
154 back_buf: self.back_buf,
155 initial_buffer_size: self.initial_buffer_size,
156 max_buffer_size: self.max_buffer_size,
157 readiness: self.readiness,
158 interest: self.interest,
159 blocking: self.blocking,
160 front_high_watermark_logged: self.front_high_watermark_logged,
161 back_high_watermark_logged: self.back_high_watermark_logged,
162 phantom_tx: PhantomData,
163 phantom_rx: PhantomData,
164 }
165 }
166
167 fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
171 unsafe {
177 let fd = self.sock.as_raw_fd();
178 let stream = StdUnixStream::from_raw_fd(fd);
179 stream
180 .set_nonblocking(nonblocking)
181 .map_err(|error| ChannelError::BlockingStatus {
182 fd,
183 error: error.to_string(),
184 })?;
185 let _fd = stream.into_raw_fd();
186 }
187 self.blocking = !nonblocking;
188 Ok(())
189 }
190
191 fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
193 unsafe {
199 let fd = self.sock.as_raw_fd();
200 let stream = StdUnixStream::from_raw_fd(fd);
201 stream
202 .set_read_timeout(timeout)
203 .map_err(|error| ChannelError::SetTimeout {
204 fd,
205 error: error.to_string(),
206 })?;
207 let _fd = stream.into_raw_fd();
208 }
209 Ok(())
210 }
211
212 pub fn blocking(&mut self) -> Result<(), ChannelError> {
214 self.set_nonblocking(false)
215 }
216
217 pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
219 self.set_nonblocking(true)
220 }
221
222 pub fn is_blocking(&self) -> bool {
223 self.blocking
224 }
225
226 pub fn fd(&self) -> RawFd {
228 self.sock.as_raw_fd()
229 }
230
231 pub fn handle_events(&mut self, events: Ready) {
232 self.readiness |= events;
233 }
234
235 pub fn readiness(&self) -> Ready {
236 self.readiness & self.interest
237 }
238
239 fn grow_size(&self, current_capacity: usize) -> Option<usize> {
242 if current_capacity >= self.max_buffer_size {
243 return None;
244 }
245 debug_assert!(
248 current_capacity < self.max_buffer_size,
249 "grow_size must only grow buffers that are strictly below the max ceiling"
250 );
251 let new_size = min(current_capacity.saturating_mul(2), self.max_buffer_size);
253 let new_size = new_size.max(current_capacity + 1);
255 let new_size = min(new_size, self.max_buffer_size);
256 debug_assert!(
259 new_size > current_capacity,
260 "grow_size must make forward progress (new capacity strictly larger)"
261 );
262 debug_assert!(
263 new_size <= self.max_buffer_size,
264 "grow_size must never exceed the configured max_buffer_size ceiling"
265 );
266 Some(new_size)
267 }
268
269 fn check_high_watermark(
271 buffer_name: &str,
272 capacity: usize,
273 max: usize,
274 already_logged: &mut bool,
275 ) {
276 if *already_logged {
277 return;
278 }
279 let threshold = (max as f64 * HIGH_WATERMARK_RATIO) as usize;
280 if capacity >= threshold {
281 warn!(
282 "channel {} buffer reached high watermark: {} bytes ({:.0}% of {} max)",
283 buffer_name,
284 capacity,
285 (capacity as f64 / max as f64) * 100.0,
286 max,
287 );
288 *already_logged = true;
289 }
290 }
291
292 pub fn run(&mut self) -> Result<(), ChannelError> {
294 let interest = self.interest & self.readiness;
295
296 if interest.is_readable() {
297 let _ = self.readable()?;
298 }
299
300 if interest.is_writable() {
301 let _ = self.writable()?;
302 }
303 Ok(())
304 }
305
306 pub fn readable(&mut self) -> Result<usize, ChannelError> {
309 if !(self.interest & self.readiness).is_readable() {
310 return Err(ChannelError::Connection(None));
311 }
312
313 let mut count = 0usize;
314 loop {
315 let size = self.front_buf.available_space();
316 trace!("channel available space: {}", size);
317 if size == 0 {
318 if let Some(new_size) = self.grow_size(self.front_buf.capacity()) {
320 Self::check_high_watermark(
321 "front",
322 new_size,
323 self.max_buffer_size,
324 &mut self.front_high_watermark_logged,
325 );
326 self.front_buf.grow(new_size);
327 debug_assert!(
329 self.front_buf.capacity() <= self.max_buffer_size,
330 "front buffer capacity must stay within the max_buffer_size ceiling"
331 );
332 } else {
333 self.interest.remove(Ready::READABLE);
334 break;
335 }
336 }
337
338 let space_before = self.front_buf.available_space();
342 debug_assert!(
343 space_before > 0,
344 "readable must only call read() with non-empty space (zero-space path grows or breaks)"
345 );
346 let data_before = self.front_buf.available_data();
347 match self.sock.read(self.front_buf.space()) {
348 Ok(0) => {
349 self.interest = Ready::EMPTY;
350 self.readiness.remove(Ready::READABLE);
351 self.readiness.insert(Ready::HUP);
352 return Err(ChannelError::NoByteToRead);
353 }
354 Err(read_error) => match read_error.kind() {
355 ErrorKind::WouldBlock => {
356 self.readiness.remove(Ready::READABLE);
357 break;
358 }
359 _ => {
360 self.interest = Ready::EMPTY;
361 self.readiness = Ready::EMPTY;
362 return Err(ChannelError::Read(read_error));
363 }
364 },
365 Ok(bytes_read) => {
366 debug_assert!(
369 bytes_read <= space_before,
370 "read delivered more bytes than the buffer space it was given"
371 );
372 count += bytes_read;
373 self.front_buf.fill(bytes_read);
374 debug_assert_eq!(
378 self.front_buf.available_data(),
379 data_before + bytes_read,
380 "front buffer available_data must increase by exactly bytes_read"
381 );
382 }
383 };
384 }
385
386 Ok(count)
387 }
388
389 pub fn writable(&mut self) -> Result<usize, ChannelError> {
392 if !(self.interest & self.readiness).is_writable() {
393 return Err(ChannelError::Connection(None));
394 }
395
396 let mut count = 0usize;
397 loop {
398 let size = self.back_buf.available_data();
399 if size == 0 {
400 self.interest.remove(Ready::WRITABLE);
401 self.try_shrink_back_buf();
402 break;
403 }
404
405 let data_before = self.back_buf.available_data();
406 match self.sock.write(self.back_buf.data()) {
407 Ok(0) => {
408 self.interest = Ready::EMPTY;
409 self.readiness.insert(Ready::HUP);
410 return Err(ChannelError::NoByteWritten);
411 }
412 Ok(bytes_written) => {
413 debug_assert!(
416 bytes_written <= data_before,
417 "write reported more bytes than the buffer data it was given"
418 );
419 count += bytes_written;
420 let consumed = self.back_buf.consume(bytes_written);
421 debug_assert_eq!(
424 consumed, bytes_written,
425 "back buffer must consume exactly the bytes written to the socket"
426 );
427 debug_assert_eq!(
428 self.back_buf.available_data(),
429 data_before - bytes_written,
430 "back buffer available_data must shrink by exactly bytes_written"
431 );
432 }
433 Err(write_error) => match write_error.kind() {
434 ErrorKind::WouldBlock => {
435 self.readiness.remove(Ready::WRITABLE);
436 break;
437 }
438 _ => {
439 self.interest = Ready::EMPTY;
440 self.readiness = Ready::EMPTY;
441 return Err(ChannelError::Read(write_error));
442 }
443 },
444 }
445 }
446
447 Ok(count)
448 }
449
450 pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
457 if self.blocking {
458 self.read_message_blocking()
459 } else {
460 self.read_message_nonblocking()
461 }
462 }
463
464 fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
465 self.read_message_blocking_timeout(None)
466 }
467
468 fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
470 if let Some(message) = self.try_read_delimited_message()? {
471 self.try_shrink_front_buf();
472 return Ok(message);
473 }
474
475 self.interest.insert(Ready::READABLE);
476 Err(ChannelError::NothingRead)
477 }
478
479 pub fn read_message_blocking_timeout(
481 &mut self,
482 timeout: Option<Duration>,
483 ) -> Result<Rx, ChannelError> {
484 let now = std::time::Instant::now();
485
486 self.set_timeout(Some(Duration::from_millis(100)))?;
492
493 let status = loop {
494 if let Some(timeout) = timeout {
495 if now.elapsed() >= timeout {
496 break Err(ChannelError::TimeoutReached(timeout));
497 }
498 }
499
500 if let Some(message) = self.try_read_delimited_message()? {
501 self.try_shrink_front_buf();
502 return Ok(message);
503 }
504
505 match self.sock.read(self.front_buf.space()) {
506 Ok(0) => return Err(ChannelError::NoByteToRead),
507 Ok(bytes_read) => self.front_buf.fill(bytes_read),
508 Err(io_error) => match io_error.kind() {
509 ErrorKind::WouldBlock => continue, _ => break Err(ChannelError::Read(io_error)),
511 },
512 };
513 };
514
515 self.set_timeout(None)?;
516
517 status
518 }
519
520 fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
522 debug_assert!(
527 self.front_buf.capacity() <= self.max_buffer_size,
528 "front buffer capacity must never exceed max_buffer_size"
529 );
530 let buffer = self.front_buf.data();
531 debug_assert!(
534 buffer.len() <= self.front_buf.capacity(),
535 "available data slice cannot exceed buffer capacity"
536 );
537 if buffer.len() >= delimiter_size() {
538 let delimiter = buffer[..delimiter_size()]
539 .try_into()
540 .map_err(|_| ChannelError::MismatchBufferSize)?;
541 let message_len = usize::from_le_bytes(delimiter);
542
543 if message_len > self.max_buffer_size {
552 return Err(ChannelError::MessageTooLarge {
553 message_len,
554 capacity: self.front_buf.capacity(),
555 max: self.max_buffer_size,
556 });
557 }
558
559 if message_len < delimiter_size() {
571 self.front_buf.consume(delimiter_size());
572 return Err(ChannelError::MessageLengthUnderDelimiter {
573 message_len,
574 delimiter_size: delimiter_size(),
575 });
576 }
577
578 if buffer.len() >= message_len {
579 debug_assert!(
588 message_len >= delimiter_size(),
589 "decode path requires a frame at least as large as its delimiter"
590 );
591 debug_assert!(
592 message_len <= self.max_buffer_size,
593 "decode path requires the declared length within the max ceiling"
594 );
595 debug_assert!(
596 message_len <= buffer.len(),
597 "decode path requires the full frame to be buffered before slicing"
598 );
599 let available_before = self.front_buf.available_data();
600 debug_assert_eq!(
601 available_before,
602 buffer.len(),
603 "available_data must equal the data slice length we validated against"
604 );
605 let message = Rx::decode(&buffer[delimiter_size()..message_len])
606 .map_err(ChannelError::InvalidProtobufMessage)?;
607 let consumed = self.front_buf.consume(message_len);
608 debug_assert_eq!(
612 consumed, message_len,
613 "must consume exactly the validated frame length"
614 );
615 debug_assert_eq!(
616 self.front_buf.available_data(),
617 available_before - message_len,
618 "available_data must drop by exactly the consumed frame length"
619 );
620 return Ok(Some(message));
621 }
622 }
623
624 if self.front_buf.available_space() == 0 {
625 if self.front_buf.capacity() >= self.max_buffer_size {
626 return Err(ChannelError::BufferFull {
627 capacity: self.front_buf.capacity(),
628 max: self.max_buffer_size,
629 });
630 }
631 let new_size = self
632 .grow_size(self.front_buf.capacity())
633 .unwrap_or(self.max_buffer_size);
634 Self::check_high_watermark(
635 "front",
636 new_size,
637 self.max_buffer_size,
638 &mut self.front_high_watermark_logged,
639 );
640 self.front_buf.grow(new_size);
641 }
642 Ok(None)
643 }
644
645 pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
649 if self.blocking {
650 self.write_message_blocking(message)
651 } else {
652 self.write_message_nonblocking(message)
653 }
654 }
655
656 fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
659 self.write_delimited_message(message)?;
660
661 self.interest.insert(Ready::WRITABLE);
662
663 Ok(())
664 }
665
666 fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
668 self.write_delimited_message(message)?;
669
670 loop {
671 let size = self.back_buf.available_data();
672 if size == 0 {
673 break;
674 }
675
676 match self.sock.write(self.back_buf.data()) {
677 Ok(0) => return Err(ChannelError::NoByteWritten),
678 Ok(bytes_written) => {
679 self.back_buf.consume(bytes_written);
680 }
681 Err(_) => return Ok(()), }
683 }
684 Ok(())
685 }
686
687 pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
690 let payload = message.encode_to_vec();
691
692 let payload_len = payload.len() + delimiter_size();
693
694 debug_assert!(
698 payload_len >= delimiter_size(),
699 "framed length must include the fixed-size delimiter prefix"
700 );
701
702 let delimiter = payload_len.to_le_bytes();
703
704 if payload_len > self.back_buf.available_space() {
705 self.back_buf.shift();
706 }
707
708 let data_before = self.back_buf.available_data();
709 if payload_len > self.back_buf.available_space() {
710 let needed = payload_len - self.back_buf.available_space() + self.back_buf.capacity();
711 if needed > self.max_buffer_size {
712 return Err(ChannelError::MessageTooLarge {
713 message_len: payload_len,
714 capacity: self.back_buf.capacity(),
715 max: self.max_buffer_size,
716 });
717 }
718 debug_assert!(
721 needed <= self.max_buffer_size,
722 "grow target must be within the max ceiling once the cap check passed"
723 );
724
725 let capacity_before = self.back_buf.capacity();
726 let mut new_length = self.back_buf.capacity();
728 while new_length < needed {
729 new_length = new_length.saturating_mul(2).max(new_length + 1);
730 }
731 new_length = min(new_length, self.max_buffer_size);
732 debug_assert!(
735 new_length >= needed,
736 "doubling growth must reach at least the needed capacity"
737 );
738 debug_assert!(
739 new_length <= self.max_buffer_size,
740 "grown back buffer must stay within the max_buffer_size ceiling"
741 );
742 debug_assert!(
743 new_length >= capacity_before,
744 "growth must never shrink the back buffer"
745 );
746 Self::check_high_watermark(
747 "back",
748 new_length,
749 self.max_buffer_size,
750 &mut self.back_high_watermark_logged,
751 );
752 self.back_buf.grow(new_length);
753 debug_assert!(
755 payload_len <= self.back_buf.available_space(),
756 "back buffer must have room for the full frame after growth"
757 );
758 }
759
760 self.back_buf
761 .write_all(&delimiter)
762 .map_err(ChannelError::Write)?;
763 self.back_buf
764 .write_all(&payload)
765 .map_err(ChannelError::Write)?;
766
767 debug_assert_eq!(
770 self.back_buf.available_data(),
771 data_before + payload_len,
772 "back buffer pending data must grow by exactly the framed length"
773 );
774 debug_assert!(
775 self.back_buf.capacity() <= self.max_buffer_size,
776 "back buffer capacity must never exceed the max_buffer_size ceiling"
777 );
778
779 Ok(())
780 }
781
782 fn try_shrink_front_buf(&mut self) {
785 let capacity = self.front_buf.capacity();
786 if capacity <= self.initial_buffer_size {
787 return;
788 }
789 debug_assert!(
792 capacity > self.initial_buffer_size,
793 "shrink path only runs when capacity is above the initial floor"
794 );
795 if self.front_buf.available_data() * 4 < self.initial_buffer_size {
797 let data_before = self.front_buf.available_data();
798 self.front_buf.shrink(self.initial_buffer_size);
799 self.front_high_watermark_logged = false;
800 debug_assert!(
802 self.front_buf.capacity() >= self.initial_buffer_size,
803 "front buffer must never shrink below the initial buffer size floor"
804 );
805 debug_assert_eq!(
806 self.front_buf.available_data(),
807 data_before,
808 "shrink must preserve all pending front-buffer data"
809 );
810 trace!(
811 "front buffer shrunk from {} to {} bytes",
812 capacity, self.initial_buffer_size
813 );
814 }
815 }
816
817 fn try_shrink_back_buf(&mut self) {
819 let capacity = self.back_buf.capacity();
820 if capacity <= self.initial_buffer_size {
821 return;
822 }
823 debug_assert!(
824 capacity > self.initial_buffer_size,
825 "shrink path only runs when capacity is above the initial floor"
826 );
827 if self.back_buf.available_data() == 0 {
828 self.back_buf.shrink(self.initial_buffer_size);
829 self.back_high_watermark_logged = false;
830 debug_assert!(
833 self.back_buf.capacity() >= self.initial_buffer_size,
834 "back buffer must never shrink below the initial buffer size floor"
835 );
836 debug_assert_eq!(
837 self.back_buf.available_data(),
838 0,
839 "back buffer must stay empty across a drained shrink"
840 );
841 trace!(
842 "back buffer shrunk from {} to {} bytes",
843 capacity, self.initial_buffer_size
844 );
845 }
846 }
847}
848
849pub const fn delimiter_size() -> usize {
851 std::mem::size_of::<usize>()
852}
853
854type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
855
856impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
857 pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
859 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
860 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
861 let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
862 command_channel.blocking()?;
863 Ok((command_channel, proxy_channel))
864 }
865
866 pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
868 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
869 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
870 let command_channel = Channel::new(command, buffer_size, max_buffer_size);
871 Ok((command_channel, proxy_channel))
872 }
873}
874
875impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
876 for Channel<Tx, Rx>
877{
878 type Item = Rx;
879 fn next(&mut self) -> Option<Self::Item> {
880 self.read_message().ok()
881 }
882}
883
884use mio::{Interest, Registry, Token};
885impl<Tx, Rx> Source for Channel<Tx, Rx> {
886 fn register(
887 &mut self,
888 registry: &Registry,
889 token: Token,
890 interests: Interest,
891 ) -> io::Result<()> {
892 self.sock.register(registry, token, interests)
893 }
894
895 fn reregister(
896 &mut self,
897 registry: &Registry,
898 token: Token,
899 interests: Interest,
900 ) -> io::Result<()> {
901 self.sock.reregister(registry, token, interests)
902 }
903
904 fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
905 self.sock.deregister(registry)
906 }
907}
908
909#[cfg(test)]
910mod tests {
911 use std::{thread, time::Duration};
912
913 use super::*;
914
915 #[derive(Clone, PartialEq, prost::Message)]
916 pub struct ProtobufMessage {
917 #[prost(uint32, required, tag = "1")]
918 inner: u32,
919 }
920
921 fn test_channels() -> (
922 Channel<ProtobufMessage, ProtobufMessage>,
923 Channel<ProtobufMessage, ProtobufMessage>,
924 ) {
925 Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
926 }
927
928 #[test]
929 fn unblock_a_channel() {
930 let (mut blocking, _nonblocking) = test_channels();
931 assert!(blocking.nonblocking().is_ok())
932 }
933
934 #[test]
935 fn generate_blocking_and_nonblocking_channels() {
936 let (blocking_channel, nonblocking_channel) = test_channels();
937
938 assert!(blocking_channel.is_blocking());
939 assert!(!nonblocking_channel.is_blocking());
940
941 let (nonblocking_channel_1, nonblocking_channel_2): (
942 Channel<ProtobufMessage, ProtobufMessage>,
943 Channel<ProtobufMessage, ProtobufMessage>,
944 ) = Channel::generate_nonblocking(1000, 10000)
945 .expect("could not generatie nonblocking channels");
946
947 assert!(!nonblocking_channel_1.is_blocking());
948 assert!(!nonblocking_channel_2.is_blocking());
949 }
950
951 #[test]
952 fn write_and_read_message_blocking() {
953 let (mut blocking_channel, mut nonblocking_channel) = test_channels();
954
955 let message_to_send = ProtobufMessage { inner: 42 };
956
957 nonblocking_channel
958 .blocking()
959 .expect("Could not block channel");
960 nonblocking_channel
961 .write_message(&message_to_send)
962 .expect("Could not write message on channel");
963
964 trace!("we wrote a message!");
965
966 trace!("reading message..");
967 let message = blocking_channel
969 .read_message()
970 .expect("Could not read message on channel");
971 trace!("read message!");
972
973 assert_eq!(message, ProtobufMessage { inner: 42 });
974 }
975
976 #[test]
977 fn read_message_blocking_with_timeout_fails() {
978 let (mut reading_channel, mut writing_channel) = test_channels();
979 writing_channel.blocking().expect("Could not block channel");
980
981 trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
982 let awaiting_with_timeout = thread::spawn(move || {
983 let message =
984 reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
985 trace!("read message!");
986 message
987 });
988
989 trace!("Waiting 200 milliseconds…");
990 thread::sleep(std::time::Duration::from_millis(200));
991
992 writing_channel
993 .write_message(&ProtobufMessage { inner: 200 })
994 .expect("Could not write message on channel");
995 trace!("we wrote a message that should arrive too late!");
996
997 let arrived_too_late = awaiting_with_timeout
998 .join()
999 .expect("error with receiving message from awaiting thread");
1000
1001 assert!(arrived_too_late.is_err());
1002 }
1003
1004 #[test]
1005 fn read_message_blocking_with_timeout_succeeds() {
1006 let (mut reading_channel, mut writing_channel) = test_channels();
1007 writing_channel.blocking().expect("Could not block channel");
1008
1009 trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
1010 let awaiting_with_timeout = thread::spawn(move || {
1011 let message = reading_channel
1012 .read_message_blocking_timeout(Some(Duration::from_millis(200)))
1013 .expect("Could not read message with timeout on blocking channel");
1014 trace!("read message!");
1015 message
1016 });
1017
1018 trace!("Waiting 100 milliseconds…");
1019 thread::sleep(std::time::Duration::from_millis(100));
1020
1021 writing_channel
1022 .write_message(&ProtobufMessage { inner: 100 })
1023 .expect("Could not write message on channel");
1024 trace!("we wrote a message that should arrive on time!");
1025
1026 let arrived_on_time = awaiting_with_timeout
1027 .join()
1028 .expect("error with receiving message from awaiting thread");
1029
1030 assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
1031 }
1032
1033 #[test]
1034 fn exhaustive_use_of_nonblocking_channels() {
1035 let (mut channel_a, mut channel_b) = test_channels();
1037 channel_a.nonblocking().expect("Could not block channel");
1038
1039 channel_a
1041 .write_message(&ProtobufMessage { inner: 1 })
1042 .expect("Could not write message on channel");
1043
1044 channel_b.handle_events(Ready::READABLE);
1046
1047 let should_err = channel_b.read_message();
1049 assert!(should_err.is_err());
1050
1051 channel_a
1053 .write_message(&ProtobufMessage { inner: 2 })
1054 .expect("Could not write message on channel");
1055
1056 channel_a.handle_events(Ready::WRITABLE);
1058
1059 channel_a.run().expect("Failed to run the channel");
1061
1062 thread::sleep(std::time::Duration::from_millis(100));
1064
1065 channel_b.run().expect("Failed to run the channel");
1067
1068 let message_1 = channel_b
1070 .read_message()
1071 .expect("Could not read message on channel");
1072 assert_eq!(message_1, ProtobufMessage { inner: 1 });
1073
1074 let message_2 = channel_b
1075 .read_message()
1076 .expect("Could not read message on channel");
1077 assert_eq!(message_2, ProtobufMessage { inner: 2 });
1078 }
1079
1080 #[test]
1081 fn buffer_grows_with_doubling_strategy() {
1082 let (writing_channel, _reading_channel): (
1083 Channel<ProtobufMessage, ProtobufMessage>,
1084 Channel<ProtobufMessage, ProtobufMessage>,
1085 ) = Channel::generate(100, 10000).expect("could not generate channels");
1086
1087 assert_eq!(writing_channel.back_buf.capacity(), 100);
1088
1089 assert_eq!(writing_channel.grow_size(100), Some(200));
1090 assert_eq!(writing_channel.grow_size(200), Some(400));
1091 assert_eq!(writing_channel.grow_size(5000), Some(10000));
1092 assert_eq!(writing_channel.grow_size(10000), None);
1093 }
1094
1095 #[test]
1096 fn buffer_cap_returns_error() {
1097 let (mut writing_channel, _reading_channel): (
1098 Channel<ProtobufMessage, ProtobufMessage>,
1099 Channel<ProtobufMessage, ProtobufMessage>,
1100 ) = Channel::generate(50, 50).expect("could not generate channels");
1101
1102 writing_channel.blocking().expect("Could not block channel");
1103
1104 let mut i = 0u32;
1105 let result = loop {
1106 let msg = ProtobufMessage { inner: i };
1107 match writing_channel.write_delimited_message(&msg) {
1108 Ok(()) => i += 1,
1109 Err(e) => break Err(e),
1110 }
1111 if i > 10000 {
1112 break Ok(());
1113 }
1114 };
1115
1116 assert!(result.is_err());
1117 let err = result.unwrap_err();
1118 let err_msg = format!("{err}");
1119 assert!(
1120 err_msg.contains("too large") || err_msg.contains("cannot grow"),
1121 "unexpected error: {err_msg}"
1122 );
1123 }
1124
1125 #[test]
1126 fn back_buffer_shrinks_after_drain() {
1127 let (mut channel, _other): (
1128 Channel<ProtobufMessage, ProtobufMessage>,
1129 Channel<ProtobufMessage, ProtobufMessage>,
1130 ) = Channel::generate(100, 10000).expect("could not generate channels");
1131
1132 for i in 0..20 {
1135 channel
1136 .write_delimited_message(&ProtobufMessage { inner: i })
1137 .expect("Could not write message");
1138 }
1139
1140 let grown_capacity = channel.back_buf.capacity();
1141 assert!(
1142 grown_capacity > 100,
1143 "expected buffer growth, got capacity {grown_capacity}"
1144 );
1145
1146 let data_len = channel.back_buf.available_data();
1148 channel.back_buf.consume(data_len);
1149 assert_eq!(channel.back_buf.available_data(), 0);
1150
1151 channel.try_shrink_back_buf();
1152 assert_eq!(
1153 channel.back_buf.capacity(),
1154 100,
1155 "back buffer should shrink to initial size after drain"
1156 );
1157 }
1158
1159 #[test]
1160 fn back_buffer_grows_with_doubling_on_write() {
1161 let (mut channel, _other): (
1162 Channel<ProtobufMessage, ProtobufMessage>,
1163 Channel<ProtobufMessage, ProtobufMessage>,
1164 ) = Channel::generate(32, 10000).expect("could not generate channels");
1165
1166 assert_eq!(channel.back_buf.capacity(), 32);
1167
1168 for i in 0..10 {
1171 channel
1172 .write_delimited_message(&ProtobufMessage { inner: i })
1173 .expect("Could not write message");
1174 }
1175
1176 let grown = channel.back_buf.capacity();
1177 assert!(grown > 32, "expected buffer growth beyond 32, got {grown}");
1178 assert!(
1181 grown.is_power_of_two() || grown == 10000,
1182 "expected doubling growth pattern, got {grown}"
1183 );
1184 }
1185
1186 #[test]
1196 fn rejects_declared_length_below_delimiter() {
1197 let (mut reader, mut writer): (
1198 Channel<ProtobufMessage, ProtobufMessage>,
1199 Channel<ProtobufMessage, ProtobufMessage>,
1200 ) = Channel::generate(1000, 10000).expect("could not generate channels");
1201 writer.blocking().expect("writer to block");
1202 reader.blocking().expect("reader to block");
1203
1204 let bogus: usize = 5;
1207 let bytes = bogus.to_le_bytes();
1208 std::io::Write::write_all(&mut writer.sock, &bytes).expect("raw write of bogus delimiter");
1209
1210 match reader.read_message() {
1211 Err(ChannelError::MessageLengthUnderDelimiter {
1212 message_len,
1213 delimiter_size,
1214 }) => {
1215 assert_eq!(message_len, 5);
1216 assert_eq!(delimiter_size, std::mem::size_of::<usize>());
1217 }
1218 other => panic!(
1219 "expected MessageLengthUnderDelimiter, got {other:?}\n\
1220 NOTE: a panic here means the slice-OOB hardening was reverted",
1221 ),
1222 }
1223 }
1224}