1use std::{
10 cmp::min,
11 fmt::Debug,
12 io::{self, ErrorKind, Read, Write},
13 marker::PhantomData,
14 os::unix::{
15 io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
16 net::UnixStream as StdUnixStream,
17 },
18 time::Duration,
19};
20
21use mio::{event::Source, net::UnixStream as MioUnixStream};
22use prost::{DecodeError, Message as ProstMessage};
23
24use crate::{buffer::growable::Buffer, ready::Ready};
25
26const HIGH_WATERMARK_RATIO: f64 = 0.8;
28
29#[derive(thiserror::Error, Debug)]
30pub enum ChannelError {
31 #[error("io read error")]
32 Read(std::io::Error),
33 #[error("no byte written on the channel")]
34 NoByteWritten,
35 #[error("no byte left to read on the channel")]
36 NoByteToRead,
37 #[error(
38 "message ({message_len} bytes) too large for back buffer capacity ({capacity} bytes, max {max} bytes)"
39 )]
40 MessageTooLarge {
41 message_len: usize,
42 capacity: usize,
43 max: usize,
44 },
45 #[error(
46 "declared message length ({message_len} bytes) is shorter than the {delimiter_size}-byte length prefix"
47 )]
48 MessageLengthUnderDelimiter {
49 message_len: usize,
50 delimiter_size: usize,
51 },
52 #[error("channel could not write on the back buffer")]
53 Write(std::io::Error),
54 #[error("channel buffer is full ({capacity} bytes, max {max} bytes), cannot grow more")]
55 BufferFull { capacity: usize, max: usize },
56 #[error("Timeout is reached: {0:?}")]
57 TimeoutReached(Duration),
58 #[error("Could not read anything on the channel")]
59 NothingRead,
60 #[error("invalid char set in command message, ignoring: {0}")]
61 InvalidCharSet(String),
62 #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
63 SetTimeout { fd: i32, error: String },
64 #[error(
65 "Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}"
66 )]
67 BlockingStatus { fd: i32, error: String },
68 #[error("Connection error: {0:?}")]
69 Connection(Option<std::io::Error>),
70 #[error("Invalid protobuf message: {0}")]
71 InvalidProtobufMessage(DecodeError),
72 #[error("This should never happen (index out of bound on a tested buffer)")]
73 MismatchBufferSize,
74}
75
76pub struct Channel<Tx, Rx> {
82 pub sock: MioUnixStream,
83 pub front_buf: Buffer,
84 pub back_buf: Buffer,
85 initial_buffer_size: usize,
86 max_buffer_size: usize,
87 pub readiness: Ready,
88 pub interest: Ready,
89 blocking: bool,
90 front_high_watermark_logged: bool,
92 back_high_watermark_logged: bool,
94 phantom_tx: PhantomData<Tx>,
95 phantom_rx: PhantomData<Rx>,
96}
97
98impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct(&format!(
101 "Channel<{}, {}>",
102 std::any::type_name::<Tx>(),
103 std::any::type_name::<Rx>()
104 ))
105 .field("sock", &self.sock.as_raw_fd())
106 .field("readiness", &self.readiness)
110 .field("interest", &self.interest)
111 .field("blocking", &self.blocking)
112 .finish()
113 }
114}
115
116impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
117 pub fn from_path(
119 path: &str,
120 buffer_size: u64,
121 max_buffer_size: u64,
122 ) -> Result<Channel<Tx, Rx>, ChannelError> {
123 let unix_stream = MioUnixStream::connect(path)
124 .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
125 Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
126 }
127
128 pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
130 let buffer_size = buffer_size as usize;
131 let max_buffer_size = max_buffer_size as usize;
132 Channel {
133 sock,
134 front_buf: Buffer::with_capacity(buffer_size),
135 back_buf: Buffer::with_capacity(buffer_size),
136 initial_buffer_size: buffer_size,
137 max_buffer_size,
138 readiness: Ready::EMPTY,
139 interest: Ready::READABLE,
140 blocking: false,
141 front_high_watermark_logged: false,
142 back_high_watermark_logged: false,
143 phantom_tx: PhantomData,
144 phantom_rx: PhantomData,
145 }
146 }
147
148 pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
149 self,
150 ) -> Channel<Tx2, Rx2> {
151 Channel {
152 sock: self.sock,
153 front_buf: self.front_buf,
154 back_buf: self.back_buf,
155 initial_buffer_size: self.initial_buffer_size,
156 max_buffer_size: self.max_buffer_size,
157 readiness: self.readiness,
158 interest: self.interest,
159 blocking: self.blocking,
160 front_high_watermark_logged: self.front_high_watermark_logged,
161 back_high_watermark_logged: self.back_high_watermark_logged,
162 phantom_tx: PhantomData,
163 phantom_rx: PhantomData,
164 }
165 }
166
167 fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
171 unsafe {
177 let fd = self.sock.as_raw_fd();
178 let stream = StdUnixStream::from_raw_fd(fd);
179 stream
180 .set_nonblocking(nonblocking)
181 .map_err(|error| ChannelError::BlockingStatus {
182 fd,
183 error: error.to_string(),
184 })?;
185 let _fd = stream.into_raw_fd();
186 }
187 self.blocking = !nonblocking;
188 Ok(())
189 }
190
191 fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
193 unsafe {
199 let fd = self.sock.as_raw_fd();
200 let stream = StdUnixStream::from_raw_fd(fd);
201 stream
202 .set_read_timeout(timeout)
203 .map_err(|error| ChannelError::SetTimeout {
204 fd,
205 error: error.to_string(),
206 })?;
207 let _fd = stream.into_raw_fd();
208 }
209 Ok(())
210 }
211
212 pub fn blocking(&mut self) -> Result<(), ChannelError> {
214 self.set_nonblocking(false)
215 }
216
217 pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
219 self.set_nonblocking(true)
220 }
221
222 pub fn is_blocking(&self) -> bool {
223 self.blocking
224 }
225
226 pub fn fd(&self) -> RawFd {
228 self.sock.as_raw_fd()
229 }
230
231 pub fn handle_events(&mut self, events: Ready) {
232 self.readiness |= events;
233 }
234
235 pub fn readiness(&self) -> Ready {
236 self.readiness & self.interest
237 }
238
239 fn grow_size(&self, current_capacity: usize) -> Option<usize> {
242 if current_capacity >= self.max_buffer_size {
243 return None;
244 }
245 let new_size = min(current_capacity.saturating_mul(2), self.max_buffer_size);
247 let new_size = new_size.max(current_capacity + 1);
249 Some(min(new_size, self.max_buffer_size))
250 }
251
252 fn check_high_watermark(
254 buffer_name: &str,
255 capacity: usize,
256 max: usize,
257 already_logged: &mut bool,
258 ) {
259 if *already_logged {
260 return;
261 }
262 let threshold = (max as f64 * HIGH_WATERMARK_RATIO) as usize;
263 if capacity >= threshold {
264 warn!(
265 "channel {} buffer reached high watermark: {} bytes ({:.0}% of {} max)",
266 buffer_name,
267 capacity,
268 (capacity as f64 / max as f64) * 100.0,
269 max,
270 );
271 *already_logged = true;
272 }
273 }
274
275 pub fn run(&mut self) -> Result<(), ChannelError> {
277 let interest = self.interest & self.readiness;
278
279 if interest.is_readable() {
280 let _ = self.readable()?;
281 }
282
283 if interest.is_writable() {
284 let _ = self.writable()?;
285 }
286 Ok(())
287 }
288
289 pub fn readable(&mut self) -> Result<usize, ChannelError> {
292 if !(self.interest & self.readiness).is_readable() {
293 return Err(ChannelError::Connection(None));
294 }
295
296 let mut count = 0usize;
297 loop {
298 let size = self.front_buf.available_space();
299 trace!("channel available space: {}", size);
300 if size == 0 {
301 if let Some(new_size) = self.grow_size(self.front_buf.capacity()) {
303 Self::check_high_watermark(
304 "front",
305 new_size,
306 self.max_buffer_size,
307 &mut self.front_high_watermark_logged,
308 );
309 self.front_buf.grow(new_size);
310 } else {
311 self.interest.remove(Ready::READABLE);
312 break;
313 }
314 }
315
316 match self.sock.read(self.front_buf.space()) {
317 Ok(0) => {
318 self.interest = Ready::EMPTY;
319 self.readiness.remove(Ready::READABLE);
320 self.readiness.insert(Ready::HUP);
321 return Err(ChannelError::NoByteToRead);
322 }
323 Err(read_error) => match read_error.kind() {
324 ErrorKind::WouldBlock => {
325 self.readiness.remove(Ready::READABLE);
326 break;
327 }
328 _ => {
329 self.interest = Ready::EMPTY;
330 self.readiness = Ready::EMPTY;
331 return Err(ChannelError::Read(read_error));
332 }
333 },
334 Ok(bytes_read) => {
335 count += bytes_read;
336 self.front_buf.fill(bytes_read);
337 }
338 };
339 }
340
341 Ok(count)
342 }
343
344 pub fn writable(&mut self) -> Result<usize, ChannelError> {
347 if !(self.interest & self.readiness).is_writable() {
348 return Err(ChannelError::Connection(None));
349 }
350
351 let mut count = 0usize;
352 loop {
353 let size = self.back_buf.available_data();
354 if size == 0 {
355 self.interest.remove(Ready::WRITABLE);
356 self.try_shrink_back_buf();
357 break;
358 }
359
360 match self.sock.write(self.back_buf.data()) {
361 Ok(0) => {
362 self.interest = Ready::EMPTY;
363 self.readiness.insert(Ready::HUP);
364 return Err(ChannelError::NoByteWritten);
365 }
366 Ok(bytes_written) => {
367 count += bytes_written;
368 self.back_buf.consume(bytes_written);
369 }
370 Err(write_error) => match write_error.kind() {
371 ErrorKind::WouldBlock => {
372 self.readiness.remove(Ready::WRITABLE);
373 break;
374 }
375 _ => {
376 self.interest = Ready::EMPTY;
377 self.readiness = Ready::EMPTY;
378 return Err(ChannelError::Read(write_error));
379 }
380 },
381 }
382 }
383
384 Ok(count)
385 }
386
387 pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
394 if self.blocking {
395 self.read_message_blocking()
396 } else {
397 self.read_message_nonblocking()
398 }
399 }
400
401 fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
402 self.read_message_blocking_timeout(None)
403 }
404
405 fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
407 if let Some(message) = self.try_read_delimited_message()? {
408 self.try_shrink_front_buf();
409 return Ok(message);
410 }
411
412 self.interest.insert(Ready::READABLE);
413 Err(ChannelError::NothingRead)
414 }
415
416 pub fn read_message_blocking_timeout(
418 &mut self,
419 timeout: Option<Duration>,
420 ) -> Result<Rx, ChannelError> {
421 let now = std::time::Instant::now();
422
423 self.set_timeout(Some(Duration::from_millis(100)))?;
429
430 let status = loop {
431 if let Some(timeout) = timeout {
432 if now.elapsed() >= timeout {
433 break Err(ChannelError::TimeoutReached(timeout));
434 }
435 }
436
437 if let Some(message) = self.try_read_delimited_message()? {
438 self.try_shrink_front_buf();
439 return Ok(message);
440 }
441
442 match self.sock.read(self.front_buf.space()) {
443 Ok(0) => return Err(ChannelError::NoByteToRead),
444 Ok(bytes_read) => self.front_buf.fill(bytes_read),
445 Err(io_error) => match io_error.kind() {
446 ErrorKind::WouldBlock => continue, _ => break Err(ChannelError::Read(io_error)),
448 },
449 };
450 };
451
452 self.set_timeout(None)?;
453
454 status
455 }
456
457 fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
459 let buffer = self.front_buf.data();
460 if buffer.len() >= delimiter_size() {
461 let delimiter = buffer[..delimiter_size()]
462 .try_into()
463 .map_err(|_| ChannelError::MismatchBufferSize)?;
464 let message_len = usize::from_le_bytes(delimiter);
465
466 if message_len > self.max_buffer_size {
475 return Err(ChannelError::MessageTooLarge {
476 message_len,
477 capacity: self.front_buf.capacity(),
478 max: self.max_buffer_size,
479 });
480 }
481
482 if message_len < delimiter_size() {
494 self.front_buf.consume(delimiter_size());
495 return Err(ChannelError::MessageLengthUnderDelimiter {
496 message_len,
497 delimiter_size: delimiter_size(),
498 });
499 }
500
501 if buffer.len() >= message_len {
502 let message = Rx::decode(&buffer[delimiter_size()..message_len])
503 .map_err(ChannelError::InvalidProtobufMessage)?;
504 self.front_buf.consume(message_len);
505 return Ok(Some(message));
506 }
507 }
508
509 if self.front_buf.available_space() == 0 {
510 if self.front_buf.capacity() >= self.max_buffer_size {
511 return Err(ChannelError::BufferFull {
512 capacity: self.front_buf.capacity(),
513 max: self.max_buffer_size,
514 });
515 }
516 let new_size = self
517 .grow_size(self.front_buf.capacity())
518 .unwrap_or(self.max_buffer_size);
519 Self::check_high_watermark(
520 "front",
521 new_size,
522 self.max_buffer_size,
523 &mut self.front_high_watermark_logged,
524 );
525 self.front_buf.grow(new_size);
526 }
527 Ok(None)
528 }
529
530 pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
534 if self.blocking {
535 self.write_message_blocking(message)
536 } else {
537 self.write_message_nonblocking(message)
538 }
539 }
540
541 fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
544 self.write_delimited_message(message)?;
545
546 self.interest.insert(Ready::WRITABLE);
547
548 Ok(())
549 }
550
551 fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
553 self.write_delimited_message(message)?;
554
555 loop {
556 let size = self.back_buf.available_data();
557 if size == 0 {
558 break;
559 }
560
561 match self.sock.write(self.back_buf.data()) {
562 Ok(0) => return Err(ChannelError::NoByteWritten),
563 Ok(bytes_written) => {
564 self.back_buf.consume(bytes_written);
565 }
566 Err(_) => return Ok(()), }
568 }
569 Ok(())
570 }
571
572 pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
575 let payload = message.encode_to_vec();
576
577 let payload_len = payload.len() + delimiter_size();
578
579 let delimiter = payload_len.to_le_bytes();
580
581 if payload_len > self.back_buf.available_space() {
582 self.back_buf.shift();
583 }
584
585 if payload_len > self.back_buf.available_space() {
586 let needed = payload_len - self.back_buf.available_space() + self.back_buf.capacity();
587 if needed > self.max_buffer_size {
588 return Err(ChannelError::MessageTooLarge {
589 message_len: payload_len,
590 capacity: self.back_buf.capacity(),
591 max: self.max_buffer_size,
592 });
593 }
594
595 let mut new_length = self.back_buf.capacity();
597 while new_length < needed {
598 new_length = new_length.saturating_mul(2).max(new_length + 1);
599 }
600 new_length = min(new_length, self.max_buffer_size);
601 Self::check_high_watermark(
602 "back",
603 new_length,
604 self.max_buffer_size,
605 &mut self.back_high_watermark_logged,
606 );
607 self.back_buf.grow(new_length);
608 }
609
610 self.back_buf
611 .write_all(&delimiter)
612 .map_err(ChannelError::Write)?;
613 self.back_buf
614 .write_all(&payload)
615 .map_err(ChannelError::Write)?;
616
617 Ok(())
618 }
619
620 fn try_shrink_front_buf(&mut self) {
623 let capacity = self.front_buf.capacity();
624 if capacity <= self.initial_buffer_size {
625 return;
626 }
627 if self.front_buf.available_data() * 4 < self.initial_buffer_size {
629 self.front_buf.shrink(self.initial_buffer_size);
630 self.front_high_watermark_logged = false;
631 trace!(
632 "front buffer shrunk from {} to {} bytes",
633 capacity, self.initial_buffer_size
634 );
635 }
636 }
637
638 fn try_shrink_back_buf(&mut self) {
640 let capacity = self.back_buf.capacity();
641 if capacity <= self.initial_buffer_size {
642 return;
643 }
644 if self.back_buf.available_data() == 0 {
645 self.back_buf.shrink(self.initial_buffer_size);
646 self.back_high_watermark_logged = false;
647 trace!(
648 "back buffer shrunk from {} to {} bytes",
649 capacity, self.initial_buffer_size
650 );
651 }
652 }
653}
654
655pub const fn delimiter_size() -> usize {
657 std::mem::size_of::<usize>()
658}
659
660type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
661
662impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
663 pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
665 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
666 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
667 let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
668 command_channel.blocking()?;
669 Ok((command_channel, proxy_channel))
670 }
671
672 pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
674 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
675 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
676 let command_channel = Channel::new(command, buffer_size, max_buffer_size);
677 Ok((command_channel, proxy_channel))
678 }
679}
680
681impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
682 for Channel<Tx, Rx>
683{
684 type Item = Rx;
685 fn next(&mut self) -> Option<Self::Item> {
686 self.read_message().ok()
687 }
688}
689
690use mio::{Interest, Registry, Token};
691impl<Tx, Rx> Source for Channel<Tx, Rx> {
692 fn register(
693 &mut self,
694 registry: &Registry,
695 token: Token,
696 interests: Interest,
697 ) -> io::Result<()> {
698 self.sock.register(registry, token, interests)
699 }
700
701 fn reregister(
702 &mut self,
703 registry: &Registry,
704 token: Token,
705 interests: Interest,
706 ) -> io::Result<()> {
707 self.sock.reregister(registry, token, interests)
708 }
709
710 fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
711 self.sock.deregister(registry)
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use std::{thread, time::Duration};
718
719 use super::*;
720
721 #[derive(Clone, PartialEq, prost::Message)]
722 pub struct ProtobufMessage {
723 #[prost(uint32, required, tag = "1")]
724 inner: u32,
725 }
726
727 fn test_channels() -> (
728 Channel<ProtobufMessage, ProtobufMessage>,
729 Channel<ProtobufMessage, ProtobufMessage>,
730 ) {
731 Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
732 }
733
734 #[test]
735 fn unblock_a_channel() {
736 let (mut blocking, _nonblocking) = test_channels();
737 assert!(blocking.nonblocking().is_ok())
738 }
739
740 #[test]
741 fn generate_blocking_and_nonblocking_channels() {
742 let (blocking_channel, nonblocking_channel) = test_channels();
743
744 assert!(blocking_channel.is_blocking());
745 assert!(!nonblocking_channel.is_blocking());
746
747 let (nonblocking_channel_1, nonblocking_channel_2): (
748 Channel<ProtobufMessage, ProtobufMessage>,
749 Channel<ProtobufMessage, ProtobufMessage>,
750 ) = Channel::generate_nonblocking(1000, 10000)
751 .expect("could not generatie nonblocking channels");
752
753 assert!(!nonblocking_channel_1.is_blocking());
754 assert!(!nonblocking_channel_2.is_blocking());
755 }
756
757 #[test]
758 fn write_and_read_message_blocking() {
759 let (mut blocking_channel, mut nonblocking_channel) = test_channels();
760
761 let message_to_send = ProtobufMessage { inner: 42 };
762
763 nonblocking_channel
764 .blocking()
765 .expect("Could not block channel");
766 nonblocking_channel
767 .write_message(&message_to_send)
768 .expect("Could not write message on channel");
769
770 trace!("we wrote a message!");
771
772 trace!("reading message..");
773 let message = blocking_channel
775 .read_message()
776 .expect("Could not read message on channel");
777 trace!("read message!");
778
779 assert_eq!(message, ProtobufMessage { inner: 42 });
780 }
781
782 #[test]
783 fn read_message_blocking_with_timeout_fails() {
784 let (mut reading_channel, mut writing_channel) = test_channels();
785 writing_channel.blocking().expect("Could not block channel");
786
787 trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
788 let awaiting_with_timeout = thread::spawn(move || {
789 let message =
790 reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
791 trace!("read message!");
792 message
793 });
794
795 trace!("Waiting 200 milliseconds…");
796 thread::sleep(std::time::Duration::from_millis(200));
797
798 writing_channel
799 .write_message(&ProtobufMessage { inner: 200 })
800 .expect("Could not write message on channel");
801 trace!("we wrote a message that should arrive too late!");
802
803 let arrived_too_late = awaiting_with_timeout
804 .join()
805 .expect("error with receiving message from awaiting thread");
806
807 assert!(arrived_too_late.is_err());
808 }
809
810 #[test]
811 fn read_message_blocking_with_timeout_succeeds() {
812 let (mut reading_channel, mut writing_channel) = test_channels();
813 writing_channel.blocking().expect("Could not block channel");
814
815 trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
816 let awaiting_with_timeout = thread::spawn(move || {
817 let message = reading_channel
818 .read_message_blocking_timeout(Some(Duration::from_millis(200)))
819 .expect("Could not read message with timeout on blocking channel");
820 trace!("read message!");
821 message
822 });
823
824 trace!("Waiting 100 milliseconds…");
825 thread::sleep(std::time::Duration::from_millis(100));
826
827 writing_channel
828 .write_message(&ProtobufMessage { inner: 100 })
829 .expect("Could not write message on channel");
830 trace!("we wrote a message that should arrive on time!");
831
832 let arrived_on_time = awaiting_with_timeout
833 .join()
834 .expect("error with receiving message from awaiting thread");
835
836 assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
837 }
838
839 #[test]
840 fn exhaustive_use_of_nonblocking_channels() {
841 let (mut channel_a, mut channel_b) = test_channels();
843 channel_a.nonblocking().expect("Could not block channel");
844
845 channel_a
847 .write_message(&ProtobufMessage { inner: 1 })
848 .expect("Could not write message on channel");
849
850 channel_b.handle_events(Ready::READABLE);
852
853 let should_err = channel_b.read_message();
855 assert!(should_err.is_err());
856
857 channel_a
859 .write_message(&ProtobufMessage { inner: 2 })
860 .expect("Could not write message on channel");
861
862 channel_a.handle_events(Ready::WRITABLE);
864
865 channel_a.run().expect("Failed to run the channel");
867
868 thread::sleep(std::time::Duration::from_millis(100));
870
871 channel_b.run().expect("Failed to run the channel");
873
874 let message_1 = channel_b
876 .read_message()
877 .expect("Could not read message on channel");
878 assert_eq!(message_1, ProtobufMessage { inner: 1 });
879
880 let message_2 = channel_b
881 .read_message()
882 .expect("Could not read message on channel");
883 assert_eq!(message_2, ProtobufMessage { inner: 2 });
884 }
885
886 #[test]
887 fn buffer_grows_with_doubling_strategy() {
888 let (writing_channel, _reading_channel): (
889 Channel<ProtobufMessage, ProtobufMessage>,
890 Channel<ProtobufMessage, ProtobufMessage>,
891 ) = Channel::generate(100, 10000).expect("could not generate channels");
892
893 assert_eq!(writing_channel.back_buf.capacity(), 100);
894
895 assert_eq!(writing_channel.grow_size(100), Some(200));
896 assert_eq!(writing_channel.grow_size(200), Some(400));
897 assert_eq!(writing_channel.grow_size(5000), Some(10000));
898 assert_eq!(writing_channel.grow_size(10000), None);
899 }
900
901 #[test]
902 fn buffer_cap_returns_error() {
903 let (mut writing_channel, _reading_channel): (
904 Channel<ProtobufMessage, ProtobufMessage>,
905 Channel<ProtobufMessage, ProtobufMessage>,
906 ) = Channel::generate(50, 50).expect("could not generate channels");
907
908 writing_channel.blocking().expect("Could not block channel");
909
910 let mut i = 0u32;
911 let result = loop {
912 let msg = ProtobufMessage { inner: i };
913 match writing_channel.write_delimited_message(&msg) {
914 Ok(()) => i += 1,
915 Err(e) => break Err(e),
916 }
917 if i > 10000 {
918 break Ok(());
919 }
920 };
921
922 assert!(result.is_err());
923 let err = result.unwrap_err();
924 let err_msg = format!("{err}");
925 assert!(
926 err_msg.contains("too large") || err_msg.contains("cannot grow"),
927 "unexpected error: {err_msg}"
928 );
929 }
930
931 #[test]
932 fn back_buffer_shrinks_after_drain() {
933 let (mut channel, _other): (
934 Channel<ProtobufMessage, ProtobufMessage>,
935 Channel<ProtobufMessage, ProtobufMessage>,
936 ) = Channel::generate(100, 10000).expect("could not generate channels");
937
938 for i in 0..20 {
941 channel
942 .write_delimited_message(&ProtobufMessage { inner: i })
943 .expect("Could not write message");
944 }
945
946 let grown_capacity = channel.back_buf.capacity();
947 assert!(
948 grown_capacity > 100,
949 "expected buffer growth, got capacity {grown_capacity}"
950 );
951
952 let data_len = channel.back_buf.available_data();
954 channel.back_buf.consume(data_len);
955 assert_eq!(channel.back_buf.available_data(), 0);
956
957 channel.try_shrink_back_buf();
958 assert_eq!(
959 channel.back_buf.capacity(),
960 100,
961 "back buffer should shrink to initial size after drain"
962 );
963 }
964
965 #[test]
966 fn back_buffer_grows_with_doubling_on_write() {
967 let (mut channel, _other): (
968 Channel<ProtobufMessage, ProtobufMessage>,
969 Channel<ProtobufMessage, ProtobufMessage>,
970 ) = Channel::generate(32, 10000).expect("could not generate channels");
971
972 assert_eq!(channel.back_buf.capacity(), 32);
973
974 for i in 0..10 {
977 channel
978 .write_delimited_message(&ProtobufMessage { inner: i })
979 .expect("Could not write message");
980 }
981
982 let grown = channel.back_buf.capacity();
983 assert!(grown > 32, "expected buffer growth beyond 32, got {grown}");
984 assert!(
987 grown.is_power_of_two() || grown == 10000,
988 "expected doubling growth pattern, got {grown}"
989 );
990 }
991
992 #[test]
1002 fn rejects_declared_length_below_delimiter() {
1003 let (mut reader, mut writer): (
1004 Channel<ProtobufMessage, ProtobufMessage>,
1005 Channel<ProtobufMessage, ProtobufMessage>,
1006 ) = Channel::generate(1000, 10000).expect("could not generate channels");
1007 writer.blocking().expect("writer to block");
1008 reader.blocking().expect("reader to block");
1009
1010 let bogus: usize = 5;
1013 let bytes = bogus.to_le_bytes();
1014 std::io::Write::write_all(&mut writer.sock, &bytes).expect("raw write of bogus delimiter");
1015
1016 match reader.read_message() {
1017 Err(ChannelError::MessageLengthUnderDelimiter {
1018 message_len,
1019 delimiter_size,
1020 }) => {
1021 assert_eq!(message_len, 5);
1022 assert_eq!(delimiter_size, std::mem::size_of::<usize>());
1023 }
1024 other => panic!(
1025 "expected MessageLengthUnderDelimiter, got {other:?}\n\
1026 NOTE: a panic here means the slice-OOB hardening was reverted",
1027 ),
1028 }
1029 }
1030}