1use std::{
2 cmp::min,
3 fmt::Debug,
4 io::{self, ErrorKind, Read, Write},
5 marker::PhantomData,
6 os::unix::{
7 io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
8 net::UnixStream as StdUnixStream,
9 },
10 time::Duration,
11};
12
13use mio::{event::Source, net::UnixStream as MioUnixStream};
14use prost::{DecodeError, Message as ProstMessage};
15
16use crate::{buffer::growable::Buffer, ready::Ready};
17
18#[derive(thiserror::Error, Debug)]
19pub enum ChannelError {
20 #[error("io read error")]
21 Read(std::io::Error),
22 #[error("no byte written on the channel")]
23 NoByteWritten,
24 #[error("no byte left to read on the channel")]
25 NoByteToRead,
26 #[error("message too large for the capacity of the back fuffer ({0}. Consider increasing the back buffer size")]
27 MessageTooLarge(usize),
28 #[error("channel could not write on the back buffer")]
29 Write(std::io::Error),
30 #[error("channel buffer is full ({0} bytes), cannot grow more")]
31 BufferFull(usize),
32 #[error("Timeout is reached: {0:?}")]
33 TimeoutReached(Duration),
34 #[error("Could not read anything on the channel")]
35 NothingRead,
36 #[error("invalid char set in command message, ignoring: {0}")]
37 InvalidCharSet(String),
38 #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
39 SetTimeout { fd: i32, error: String },
40 #[error("Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}")]
41 BlockingStatus { fd: i32, error: String },
42 #[error("Connection error: {0:?}")]
43 Connection(Option<std::io::Error>),
44 #[error("Invalid protobuf message: {0}")]
45 InvalidProtobufMessage(DecodeError),
46 #[error("This should never happen (index out of bound on a tested buffer)")]
47 MismatchBufferSize,
48}
49
50pub struct Channel<Tx, Rx> {
56 pub sock: MioUnixStream,
57 pub front_buf: Buffer,
58 pub back_buf: Buffer,
59 max_buffer_size: u64,
60 pub readiness: Ready,
61 pub interest: Ready,
62 blocking: bool,
63 phantom_tx: PhantomData<Tx>,
64 phantom_rx: PhantomData<Rx>,
65}
66
67impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 f.debug_struct(&format!(
70 "Channel<{}, {}>",
71 std::any::type_name::<Tx>(),
72 std::any::type_name::<Rx>()
73 ))
74 .field("sock", &self.sock.as_raw_fd())
75 .field("readiness", &self.readiness)
79 .field("interest", &self.interest)
80 .field("blocking", &self.blocking)
81 .finish()
82 }
83}
84
85impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
86 pub fn from_path(
88 path: &str,
89 buffer_size: u64,
90 max_buffer_size: u64,
91 ) -> Result<Channel<Tx, Rx>, ChannelError> {
92 let unix_stream = MioUnixStream::connect(path)
93 .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
94 Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
95 }
96
97 pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
99 Channel {
100 sock,
101 front_buf: Buffer::with_capacity(buffer_size as usize),
102 back_buf: Buffer::with_capacity(buffer_size as usize),
103 max_buffer_size,
104 readiness: Ready::EMPTY,
105 interest: Ready::READABLE,
106 blocking: false,
107 phantom_tx: PhantomData,
108 phantom_rx: PhantomData,
109 }
110 }
111
112 pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
113 self,
114 ) -> Channel<Tx2, Rx2> {
115 Channel {
116 sock: self.sock,
117 front_buf: self.front_buf,
118 back_buf: self.back_buf,
119 max_buffer_size: self.max_buffer_size,
120 readiness: self.readiness,
121 interest: self.interest,
122 blocking: self.blocking,
123 phantom_tx: PhantomData,
124 phantom_rx: PhantomData,
125 }
126 }
127
128 fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
132 unsafe {
133 let fd = self.sock.as_raw_fd();
134 let stream = StdUnixStream::from_raw_fd(fd);
135 stream
136 .set_nonblocking(nonblocking)
137 .map_err(|error| ChannelError::BlockingStatus {
138 fd,
139 error: error.to_string(),
140 })?;
141 let _fd = stream.into_raw_fd();
142 }
143 self.blocking = !nonblocking;
144 Ok(())
145 }
146
147 fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
149 unsafe {
150 let fd = self.sock.as_raw_fd();
151 let stream = StdUnixStream::from_raw_fd(fd);
152 stream
153 .set_read_timeout(timeout)
154 .map_err(|error| ChannelError::SetTimeout {
155 fd,
156 error: error.to_string(),
157 })?;
158 let _fd = stream.into_raw_fd();
159 }
160 Ok(())
161 }
162
163 pub fn blocking(&mut self) -> Result<(), ChannelError> {
165 self.set_nonblocking(false)
166 }
167
168 pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
170 self.set_nonblocking(true)
171 }
172
173 pub fn is_blocking(&self) -> bool {
174 self.blocking
175 }
176
177 pub fn fd(&self) -> RawFd {
179 self.sock.as_raw_fd()
180 }
181
182 pub fn handle_events(&mut self, events: Ready) {
183 self.readiness |= events;
184 }
185
186 pub fn readiness(&self) -> Ready {
187 self.readiness & self.interest
188 }
189
190 pub fn run(&mut self) -> Result<(), ChannelError> {
192 let interest = self.interest & self.readiness;
193
194 if interest.is_readable() {
195 let _ = self.readable()?;
196 }
197
198 if interest.is_writable() {
199 let _ = self.writable()?;
200 }
201 Ok(())
202 }
203
204 pub fn readable(&mut self) -> Result<usize, ChannelError> {
206 if !(self.interest & self.readiness).is_readable() {
207 return Err(ChannelError::Connection(None));
208 }
209
210 let mut count = 0usize;
211 loop {
212 let size = self.front_buf.available_space();
213 trace!("channel available space: {}", size);
214 if size == 0 {
215 self.interest.remove(Ready::READABLE);
216 break;
217 }
218
219 match self.sock.read(self.front_buf.space()) {
220 Ok(0) => {
221 self.interest = Ready::EMPTY;
222 self.readiness.remove(Ready::READABLE);
223 self.readiness.insert(Ready::HUP);
224 return Err(ChannelError::NoByteToRead);
225 }
226 Err(read_error) => match read_error.kind() {
227 ErrorKind::WouldBlock => {
228 self.readiness.remove(Ready::READABLE);
229 break;
230 }
231 _ => {
232 self.interest = Ready::EMPTY;
233 self.readiness = Ready::EMPTY;
234 return Err(ChannelError::Read(read_error));
235 }
236 },
237 Ok(bytes_read) => {
238 count += bytes_read;
239 self.front_buf.fill(bytes_read);
240 }
241 };
242 }
243
244 Ok(count)
245 }
246
247 pub fn writable(&mut self) -> Result<usize, ChannelError> {
249 if !(self.interest & self.readiness).is_writable() {
250 return Err(ChannelError::Connection(None));
251 }
252
253 let mut count = 0usize;
254 loop {
255 let size = self.back_buf.available_data();
256 if size == 0 {
257 self.interest.remove(Ready::WRITABLE);
258 break;
259 }
260
261 match self.sock.write(self.back_buf.data()) {
262 Ok(0) => {
263 self.interest = Ready::EMPTY;
264 self.readiness.insert(Ready::HUP);
265 return Err(ChannelError::NoByteWritten);
266 }
267 Ok(bytes_written) => {
268 count += bytes_written;
269 self.back_buf.consume(bytes_written);
270 }
271 Err(write_error) => match write_error.kind() {
272 ErrorKind::WouldBlock => {
273 self.readiness.remove(Ready::WRITABLE);
274 break;
275 }
276 _ => {
277 self.interest = Ready::EMPTY;
278 self.readiness = Ready::EMPTY;
279 return Err(ChannelError::Read(write_error));
280 }
281 },
282 }
283 }
284
285 Ok(count)
286 }
287
288 pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
295 if self.blocking {
296 self.read_message_blocking()
297 } else {
298 self.read_message_nonblocking()
299 }
300 }
301
302 fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
303 self.read_message_blocking_timeout(None)
304 }
305
306 fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
308 if let Some(message) = self.try_read_delimited_message()? {
309 return Ok(message);
310 }
311
312 self.interest.insert(Ready::READABLE);
313 Err(ChannelError::NothingRead)
314 }
315
316 pub fn read_message_blocking_timeout(
318 &mut self,
319 timeout: Option<Duration>,
320 ) -> Result<Rx, ChannelError> {
321 let now = std::time::Instant::now();
322
323 self.set_timeout(Some(Duration::from_millis(10)))?;
325
326 let status = loop {
327 if let Some(timeout) = timeout {
328 if now.elapsed() >= timeout {
329 break Err(ChannelError::TimeoutReached(timeout));
330 }
331 }
332
333 if let Some(message) = self.try_read_delimited_message()? {
334 return Ok(message);
335 }
336
337 match self.sock.read(self.front_buf.space()) {
338 Ok(0) => return Err(ChannelError::NoByteToRead),
339 Ok(bytes_read) => self.front_buf.fill(bytes_read),
340 Err(io_error) => match io_error.kind() {
341 ErrorKind::WouldBlock => continue, _ => break Err(ChannelError::Read(io_error)),
343 },
344 };
345 };
346
347 self.set_timeout(None)?;
348
349 status
350 }
351
352 fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
354 let buffer = self.front_buf.data();
355 if buffer.len() >= delimiter_size() {
356 let delimiter = buffer[..delimiter_size()]
357 .try_into()
358 .map_err(|_| ChannelError::MismatchBufferSize)?;
359 let message_len = usize::from_le_bytes(delimiter);
360
361 if buffer.len() >= message_len {
362 let message = Rx::decode(&buffer[delimiter_size()..message_len])
363 .map_err(ChannelError::InvalidProtobufMessage)?;
364 self.front_buf.consume(message_len);
365 return Ok(Some(message));
366 }
367 }
368
369 if self.front_buf.available_space() == 0 {
370 if (self.front_buf.capacity() as u64) >= self.max_buffer_size {
371 return Err(ChannelError::BufferFull(self.front_buf.capacity()));
372 }
373 let new_size = min(
374 self.front_buf.capacity() + 5000,
375 self.max_buffer_size as usize,
376 );
377 self.front_buf.grow(new_size);
378 }
379 Ok(None)
380 }
381
382 pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
386 if self.blocking {
387 self.write_message_blocking(message)
388 } else {
389 self.write_message_nonblocking(message)
390 }
391 }
392
393 fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
396 self.write_delimited_message(message)?;
397
398 self.interest.insert(Ready::WRITABLE);
399
400 Ok(())
401 }
402
403 fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
405 self.write_delimited_message(message)?;
406
407 loop {
408 let size = self.back_buf.available_data();
409 if size == 0 {
410 break;
411 }
412
413 match self.sock.write(self.back_buf.data()) {
414 Ok(0) => return Err(ChannelError::NoByteWritten),
415 Ok(bytes_written) => {
416 self.back_buf.consume(bytes_written);
417 }
418 Err(_) => return Ok(()), }
420 }
421 Ok(())
422 }
423
424 pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
427 let payload = message.encode_to_vec();
428
429 let payload_len = payload.len() + delimiter_size();
430
431 let delimiter = payload_len.to_le_bytes();
432
433 if payload_len > self.back_buf.available_space() {
434 self.back_buf.shift();
435 }
436
437 if payload_len > self.back_buf.available_space() {
438 if payload_len - self.back_buf.available_space() + self.back_buf.capacity()
439 > (self.max_buffer_size as usize)
440 {
441 return Err(ChannelError::MessageTooLarge(self.back_buf.capacity()));
442 }
443
444 let new_length =
445 payload_len - self.back_buf.available_space() + self.back_buf.capacity();
446 self.back_buf.grow(new_length);
447 }
448
449 self.back_buf
450 .write_all(&delimiter)
451 .map_err(ChannelError::Write)?;
452 self.back_buf
453 .write_all(&payload)
454 .map_err(ChannelError::Write)?;
455
456 Ok(())
457 }
458}
459
460pub const fn delimiter_size() -> usize {
462 std::mem::size_of::<usize>()
463}
464
465type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
466
467impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
468 pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
470 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
471 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
472 let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
473 command_channel.blocking()?;
474 Ok((command_channel, proxy_channel))
475 }
476
477 pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
479 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
480 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
481 let command_channel = Channel::new(command, buffer_size, max_buffer_size);
482 Ok((command_channel, proxy_channel))
483 }
484}
485
486impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
487 for Channel<Tx, Rx>
488{
489 type Item = Rx;
490 fn next(&mut self) -> Option<Self::Item> {
491 self.read_message().ok()
492 }
493}
494
495use mio::{Interest, Registry, Token};
496impl<Tx, Rx> Source for Channel<Tx, Rx> {
497 fn register(
498 &mut self,
499 registry: &Registry,
500 token: Token,
501 interests: Interest,
502 ) -> io::Result<()> {
503 self.sock.register(registry, token, interests)
504 }
505
506 fn reregister(
507 &mut self,
508 registry: &Registry,
509 token: Token,
510 interests: Interest,
511 ) -> io::Result<()> {
512 self.sock.reregister(registry, token, interests)
513 }
514
515 fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
516 self.sock.deregister(registry)
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use std::{thread, time::Duration};
523
524 use super::*;
525
526 #[derive(Clone, PartialEq, prost::Message)]
527 pub struct ProtobufMessage {
528 #[prost(uint32, required, tag = "1")]
529 inner: u32,
530 }
531
532 fn test_channels() -> (
533 Channel<ProtobufMessage, ProtobufMessage>,
534 Channel<ProtobufMessage, ProtobufMessage>,
535 ) {
536 Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
537 }
538
539 #[test]
540 fn unblock_a_channel() {
541 let (mut blocking, _nonblocking) = test_channels();
542 assert!(blocking.nonblocking().is_ok())
543 }
544
545 #[test]
546 fn generate_blocking_and_nonblocking_channels() {
547 let (blocking_channel, nonblocking_channel) = test_channels();
548
549 assert!(blocking_channel.is_blocking());
550 assert!(!nonblocking_channel.is_blocking());
551
552 let (nonblocking_channel_1, nonblocking_channel_2): (
553 Channel<ProtobufMessage, ProtobufMessage>,
554 Channel<ProtobufMessage, ProtobufMessage>,
555 ) = Channel::generate_nonblocking(1000, 10000)
556 .expect("could not generatie nonblocking channels");
557
558 assert!(!nonblocking_channel_1.is_blocking());
559 assert!(!nonblocking_channel_2.is_blocking());
560 }
561
562 #[test]
563 fn write_and_read_message_blocking() {
564 let (mut blocking_channel, mut nonblocking_channel) = test_channels();
565
566 let message_to_send = ProtobufMessage { inner: 42 };
567
568 nonblocking_channel
569 .blocking()
570 .expect("Could not block channel");
571 nonblocking_channel
572 .write_message(&message_to_send)
573 .expect("Could not write message on channel");
574
575 trace!("we wrote a message!");
576
577 trace!("reading message..");
578 let message = blocking_channel
580 .read_message()
581 .expect("Could not read message on channel");
582 trace!("read message!");
583
584 assert_eq!(message, ProtobufMessage { inner: 42 });
585 }
586
587 #[test]
588 fn read_message_blocking_with_timeout_fails() {
589 let (mut reading_channel, mut writing_channel) = test_channels();
590 writing_channel.blocking().expect("Could not block channel");
591
592 trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
593 let awaiting_with_timeout = thread::spawn(move || {
594 let message =
595 reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
596 trace!("read message!");
597 message
598 });
599
600 trace!("Waiting 200 milliseconds…");
601 thread::sleep(std::time::Duration::from_millis(200));
602
603 writing_channel
604 .write_message(&ProtobufMessage { inner: 200 })
605 .expect("Could not write message on channel");
606 trace!("we wrote a message that should arrive too late!");
607
608 let arrived_too_late = awaiting_with_timeout
609 .join()
610 .expect("error with receiving message from awaiting thread");
611
612 assert!(arrived_too_late.is_err());
613 }
614
615 #[test]
616 fn read_message_blocking_with_timeout_succeeds() {
617 let (mut reading_channel, mut writing_channel) = test_channels();
618 writing_channel.blocking().expect("Could not block channel");
619
620 trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
621 let awaiting_with_timeout = thread::spawn(move || {
622 let message = reading_channel
623 .read_message_blocking_timeout(Some(Duration::from_millis(200)))
624 .expect("Could not read message with timeout on blocking channel");
625 trace!("read message!");
626 message
627 });
628
629 trace!("Waiting 100 milliseconds…");
630 thread::sleep(std::time::Duration::from_millis(100));
631
632 writing_channel
633 .write_message(&ProtobufMessage { inner: 100 })
634 .expect("Could not write message on channel");
635 trace!("we wrote a message that should arrive on time!");
636
637 let arrived_on_time = awaiting_with_timeout
638 .join()
639 .expect("error with receiving message from awaiting thread");
640
641 assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
642 }
643
644 #[test]
645 fn exhaustive_use_of_nonblocking_channels() {
646 let (mut channel_a, mut channel_b) = test_channels();
648 channel_a.nonblocking().expect("Could not block channel");
649
650 channel_a
652 .write_message(&ProtobufMessage { inner: 1 })
653 .expect("Could not write message on channel");
654
655 channel_b.handle_events(Ready::READABLE);
657
658 let should_err = channel_b.read_message();
660 assert!(should_err.is_err());
661
662 channel_a
664 .write_message(&ProtobufMessage { inner: 2 })
665 .expect("Could not write message on channel");
666
667 channel_a.handle_events(Ready::WRITABLE);
669
670 channel_a.run().expect("Failed to run the channel");
672
673 thread::sleep(std::time::Duration::from_millis(100));
675
676 channel_b.run().expect("Failed to run the channel");
678
679 let message_1 = channel_b
681 .read_message()
682 .expect("Could not read message on channel");
683 assert_eq!(message_1, ProtobufMessage { inner: 1 });
684
685 let message_2 = channel_b
686 .read_message()
687 .expect("Could not read message on channel");
688 assert_eq!(message_2, ProtobufMessage { inner: 2 });
689 }
690}