1use std::{
2 cmp::min,
3 fmt::Debug,
4 io::{self, ErrorKind, Read, Write},
5 marker::PhantomData,
6 os::unix::{
7 io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
8 net::UnixStream as StdUnixStream,
9 },
10 time::Duration,
11};
12
13use mio::{event::Source, net::UnixStream as MioUnixStream};
14use prost::{DecodeError, Message as ProstMessage};
15
16use crate::{buffer::growable::Buffer, ready::Ready};
17
18#[derive(thiserror::Error, Debug)]
19pub enum ChannelError {
20 #[error("io read error")]
21 Read(std::io::Error),
22 #[error("no byte written on the channel")]
23 NoByteWritten,
24 #[error("no byte left to read on the channel")]
25 NoByteToRead,
26 #[error(
27 "message too large for the capacity of the back fuffer ({0}. Consider increasing the back buffer size"
28 )]
29 MessageTooLarge(usize),
30 #[error("channel could not write on the back buffer")]
31 Write(std::io::Error),
32 #[error("channel buffer is full ({0} bytes), cannot grow more")]
33 BufferFull(usize),
34 #[error("Timeout is reached: {0:?}")]
35 TimeoutReached(Duration),
36 #[error("Could not read anything on the channel")]
37 NothingRead,
38 #[error("invalid char set in command message, ignoring: {0}")]
39 InvalidCharSet(String),
40 #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
41 SetTimeout { fd: i32, error: String },
42 #[error(
43 "Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}"
44 )]
45 BlockingStatus { fd: i32, error: String },
46 #[error("Connection error: {0:?}")]
47 Connection(Option<std::io::Error>),
48 #[error("Invalid protobuf message: {0}")]
49 InvalidProtobufMessage(DecodeError),
50 #[error("This should never happen (index out of bound on a tested buffer)")]
51 MismatchBufferSize,
52}
53
54pub struct Channel<Tx, Rx> {
60 pub sock: MioUnixStream,
61 pub front_buf: Buffer,
62 pub back_buf: Buffer,
63 max_buffer_size: u64,
64 pub readiness: Ready,
65 pub interest: Ready,
66 blocking: bool,
67 phantom_tx: PhantomData<Tx>,
68 phantom_rx: PhantomData<Rx>,
69}
70
71impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.debug_struct(&format!(
74 "Channel<{}, {}>",
75 std::any::type_name::<Tx>(),
76 std::any::type_name::<Rx>()
77 ))
78 .field("sock", &self.sock.as_raw_fd())
79 .field("readiness", &self.readiness)
83 .field("interest", &self.interest)
84 .field("blocking", &self.blocking)
85 .finish()
86 }
87}
88
89impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
90 pub fn from_path(
92 path: &str,
93 buffer_size: u64,
94 max_buffer_size: u64,
95 ) -> Result<Channel<Tx, Rx>, ChannelError> {
96 let unix_stream = MioUnixStream::connect(path)
97 .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
98 Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
99 }
100
101 pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
103 Channel {
104 sock,
105 front_buf: Buffer::with_capacity(buffer_size as usize),
106 back_buf: Buffer::with_capacity(buffer_size as usize),
107 max_buffer_size,
108 readiness: Ready::EMPTY,
109 interest: Ready::READABLE,
110 blocking: false,
111 phantom_tx: PhantomData,
112 phantom_rx: PhantomData,
113 }
114 }
115
116 pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
117 self,
118 ) -> Channel<Tx2, Rx2> {
119 Channel {
120 sock: self.sock,
121 front_buf: self.front_buf,
122 back_buf: self.back_buf,
123 max_buffer_size: self.max_buffer_size,
124 readiness: self.readiness,
125 interest: self.interest,
126 blocking: self.blocking,
127 phantom_tx: PhantomData,
128 phantom_rx: PhantomData,
129 }
130 }
131
132 fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
136 unsafe {
137 let fd = self.sock.as_raw_fd();
138 let stream = StdUnixStream::from_raw_fd(fd);
139 stream
140 .set_nonblocking(nonblocking)
141 .map_err(|error| ChannelError::BlockingStatus {
142 fd,
143 error: error.to_string(),
144 })?;
145 let _fd = stream.into_raw_fd();
146 }
147 self.blocking = !nonblocking;
148 Ok(())
149 }
150
151 fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
153 unsafe {
154 let fd = self.sock.as_raw_fd();
155 let stream = StdUnixStream::from_raw_fd(fd);
156 stream
157 .set_read_timeout(timeout)
158 .map_err(|error| ChannelError::SetTimeout {
159 fd,
160 error: error.to_string(),
161 })?;
162 let _fd = stream.into_raw_fd();
163 }
164 Ok(())
165 }
166
167 pub fn blocking(&mut self) -> Result<(), ChannelError> {
169 self.set_nonblocking(false)
170 }
171
172 pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
174 self.set_nonblocking(true)
175 }
176
177 pub fn is_blocking(&self) -> bool {
178 self.blocking
179 }
180
181 pub fn fd(&self) -> RawFd {
183 self.sock.as_raw_fd()
184 }
185
186 pub fn handle_events(&mut self, events: Ready) {
187 self.readiness |= events;
188 }
189
190 pub fn readiness(&self) -> Ready {
191 self.readiness & self.interest
192 }
193
194 pub fn run(&mut self) -> Result<(), ChannelError> {
196 let interest = self.interest & self.readiness;
197
198 if interest.is_readable() {
199 let _ = self.readable()?;
200 }
201
202 if interest.is_writable() {
203 let _ = self.writable()?;
204 }
205 Ok(())
206 }
207
208 pub fn readable(&mut self) -> Result<usize, ChannelError> {
210 if !(self.interest & self.readiness).is_readable() {
211 return Err(ChannelError::Connection(None));
212 }
213
214 let mut count = 0usize;
215 loop {
216 let size = self.front_buf.available_space();
217 trace!("channel available space: {}", size);
218 if size == 0 {
219 self.interest.remove(Ready::READABLE);
220 break;
221 }
222
223 match self.sock.read(self.front_buf.space()) {
224 Ok(0) => {
225 self.interest = Ready::EMPTY;
226 self.readiness.remove(Ready::READABLE);
227 self.readiness.insert(Ready::HUP);
228 return Err(ChannelError::NoByteToRead);
229 }
230 Err(read_error) => match read_error.kind() {
231 ErrorKind::WouldBlock => {
232 self.readiness.remove(Ready::READABLE);
233 break;
234 }
235 _ => {
236 self.interest = Ready::EMPTY;
237 self.readiness = Ready::EMPTY;
238 return Err(ChannelError::Read(read_error));
239 }
240 },
241 Ok(bytes_read) => {
242 count += bytes_read;
243 self.front_buf.fill(bytes_read);
244 }
245 };
246 }
247
248 Ok(count)
249 }
250
251 pub fn writable(&mut self) -> Result<usize, ChannelError> {
253 if !(self.interest & self.readiness).is_writable() {
254 return Err(ChannelError::Connection(None));
255 }
256
257 let mut count = 0usize;
258 loop {
259 let size = self.back_buf.available_data();
260 if size == 0 {
261 self.interest.remove(Ready::WRITABLE);
262 break;
263 }
264
265 match self.sock.write(self.back_buf.data()) {
266 Ok(0) => {
267 self.interest = Ready::EMPTY;
268 self.readiness.insert(Ready::HUP);
269 return Err(ChannelError::NoByteWritten);
270 }
271 Ok(bytes_written) => {
272 count += bytes_written;
273 self.back_buf.consume(bytes_written);
274 }
275 Err(write_error) => match write_error.kind() {
276 ErrorKind::WouldBlock => {
277 self.readiness.remove(Ready::WRITABLE);
278 break;
279 }
280 _ => {
281 self.interest = Ready::EMPTY;
282 self.readiness = Ready::EMPTY;
283 return Err(ChannelError::Read(write_error));
284 }
285 },
286 }
287 }
288
289 Ok(count)
290 }
291
292 pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
299 if self.blocking {
300 self.read_message_blocking()
301 } else {
302 self.read_message_nonblocking()
303 }
304 }
305
306 fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
307 self.read_message_blocking_timeout(None)
308 }
309
310 fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
312 if let Some(message) = self.try_read_delimited_message()? {
313 return Ok(message);
314 }
315
316 self.interest.insert(Ready::READABLE);
317 Err(ChannelError::NothingRead)
318 }
319
320 pub fn read_message_blocking_timeout(
322 &mut self,
323 timeout: Option<Duration>,
324 ) -> Result<Rx, ChannelError> {
325 let now = std::time::Instant::now();
326
327 self.set_timeout(Some(Duration::from_millis(10)))?;
329
330 let status = loop {
331 if let Some(timeout) = timeout {
332 if now.elapsed() >= timeout {
333 break Err(ChannelError::TimeoutReached(timeout));
334 }
335 }
336
337 if let Some(message) = self.try_read_delimited_message()? {
338 return Ok(message);
339 }
340
341 match self.sock.read(self.front_buf.space()) {
342 Ok(0) => return Err(ChannelError::NoByteToRead),
343 Ok(bytes_read) => self.front_buf.fill(bytes_read),
344 Err(io_error) => match io_error.kind() {
345 ErrorKind::WouldBlock => continue, _ => break Err(ChannelError::Read(io_error)),
347 },
348 };
349 };
350
351 self.set_timeout(None)?;
352
353 status
354 }
355
356 fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
358 let buffer = self.front_buf.data();
359 if buffer.len() >= delimiter_size() {
360 let delimiter = buffer[..delimiter_size()]
361 .try_into()
362 .map_err(|_| ChannelError::MismatchBufferSize)?;
363 let message_len = usize::from_le_bytes(delimiter);
364
365 if buffer.len() >= message_len {
366 let message = Rx::decode(&buffer[delimiter_size()..message_len])
367 .map_err(ChannelError::InvalidProtobufMessage)?;
368 self.front_buf.consume(message_len);
369 return Ok(Some(message));
370 }
371 }
372
373 if self.front_buf.available_space() == 0 {
374 if (self.front_buf.capacity() as u64) >= self.max_buffer_size {
375 return Err(ChannelError::BufferFull(self.front_buf.capacity()));
376 }
377 let new_size = min(
378 self.front_buf.capacity() + 5000,
379 self.max_buffer_size as usize,
380 );
381 self.front_buf.grow(new_size);
382 }
383 Ok(None)
384 }
385
386 pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
390 if self.blocking {
391 self.write_message_blocking(message)
392 } else {
393 self.write_message_nonblocking(message)
394 }
395 }
396
397 fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
400 self.write_delimited_message(message)?;
401
402 self.interest.insert(Ready::WRITABLE);
403
404 Ok(())
405 }
406
407 fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
409 self.write_delimited_message(message)?;
410
411 loop {
412 let size = self.back_buf.available_data();
413 if size == 0 {
414 break;
415 }
416
417 match self.sock.write(self.back_buf.data()) {
418 Ok(0) => return Err(ChannelError::NoByteWritten),
419 Ok(bytes_written) => {
420 self.back_buf.consume(bytes_written);
421 }
422 Err(_) => return Ok(()), }
424 }
425 Ok(())
426 }
427
428 pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
431 let payload = message.encode_to_vec();
432
433 let payload_len = payload.len() + delimiter_size();
434
435 let delimiter = payload_len.to_le_bytes();
436
437 if payload_len > self.back_buf.available_space() {
438 self.back_buf.shift();
439 }
440
441 if payload_len > self.back_buf.available_space() {
442 if payload_len - self.back_buf.available_space() + self.back_buf.capacity()
443 > (self.max_buffer_size as usize)
444 {
445 return Err(ChannelError::MessageTooLarge(self.back_buf.capacity()));
446 }
447
448 let new_length =
449 payload_len - self.back_buf.available_space() + self.back_buf.capacity();
450 self.back_buf.grow(new_length);
451 }
452
453 self.back_buf
454 .write_all(&delimiter)
455 .map_err(ChannelError::Write)?;
456 self.back_buf
457 .write_all(&payload)
458 .map_err(ChannelError::Write)?;
459
460 Ok(())
461 }
462}
463
464pub const fn delimiter_size() -> usize {
466 std::mem::size_of::<usize>()
467}
468
469type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
470
471impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
472 pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
474 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
475 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
476 let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
477 command_channel.blocking()?;
478 Ok((command_channel, proxy_channel))
479 }
480
481 pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
483 let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
484 let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
485 let command_channel = Channel::new(command, buffer_size, max_buffer_size);
486 Ok((command_channel, proxy_channel))
487 }
488}
489
490impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
491 for Channel<Tx, Rx>
492{
493 type Item = Rx;
494 fn next(&mut self) -> Option<Self::Item> {
495 self.read_message().ok()
496 }
497}
498
499use mio::{Interest, Registry, Token};
500impl<Tx, Rx> Source for Channel<Tx, Rx> {
501 fn register(
502 &mut self,
503 registry: &Registry,
504 token: Token,
505 interests: Interest,
506 ) -> io::Result<()> {
507 self.sock.register(registry, token, interests)
508 }
509
510 fn reregister(
511 &mut self,
512 registry: &Registry,
513 token: Token,
514 interests: Interest,
515 ) -> io::Result<()> {
516 self.sock.reregister(registry, token, interests)
517 }
518
519 fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
520 self.sock.deregister(registry)
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use std::{thread, time::Duration};
527
528 use super::*;
529
530 #[derive(Clone, PartialEq, prost::Message)]
531 pub struct ProtobufMessage {
532 #[prost(uint32, required, tag = "1")]
533 inner: u32,
534 }
535
536 fn test_channels() -> (
537 Channel<ProtobufMessage, ProtobufMessage>,
538 Channel<ProtobufMessage, ProtobufMessage>,
539 ) {
540 Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
541 }
542
543 #[test]
544 fn unblock_a_channel() {
545 let (mut blocking, _nonblocking) = test_channels();
546 assert!(blocking.nonblocking().is_ok())
547 }
548
549 #[test]
550 fn generate_blocking_and_nonblocking_channels() {
551 let (blocking_channel, nonblocking_channel) = test_channels();
552
553 assert!(blocking_channel.is_blocking());
554 assert!(!nonblocking_channel.is_blocking());
555
556 let (nonblocking_channel_1, nonblocking_channel_2): (
557 Channel<ProtobufMessage, ProtobufMessage>,
558 Channel<ProtobufMessage, ProtobufMessage>,
559 ) = Channel::generate_nonblocking(1000, 10000)
560 .expect("could not generatie nonblocking channels");
561
562 assert!(!nonblocking_channel_1.is_blocking());
563 assert!(!nonblocking_channel_2.is_blocking());
564 }
565
566 #[test]
567 fn write_and_read_message_blocking() {
568 let (mut blocking_channel, mut nonblocking_channel) = test_channels();
569
570 let message_to_send = ProtobufMessage { inner: 42 };
571
572 nonblocking_channel
573 .blocking()
574 .expect("Could not block channel");
575 nonblocking_channel
576 .write_message(&message_to_send)
577 .expect("Could not write message on channel");
578
579 trace!("we wrote a message!");
580
581 trace!("reading message..");
582 let message = blocking_channel
584 .read_message()
585 .expect("Could not read message on channel");
586 trace!("read message!");
587
588 assert_eq!(message, ProtobufMessage { inner: 42 });
589 }
590
591 #[test]
592 fn read_message_blocking_with_timeout_fails() {
593 let (mut reading_channel, mut writing_channel) = test_channels();
594 writing_channel.blocking().expect("Could not block channel");
595
596 trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
597 let awaiting_with_timeout = thread::spawn(move || {
598 let message =
599 reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
600 trace!("read message!");
601 message
602 });
603
604 trace!("Waiting 200 milliseconds…");
605 thread::sleep(std::time::Duration::from_millis(200));
606
607 writing_channel
608 .write_message(&ProtobufMessage { inner: 200 })
609 .expect("Could not write message on channel");
610 trace!("we wrote a message that should arrive too late!");
611
612 let arrived_too_late = awaiting_with_timeout
613 .join()
614 .expect("error with receiving message from awaiting thread");
615
616 assert!(arrived_too_late.is_err());
617 }
618
619 #[test]
620 fn read_message_blocking_with_timeout_succeeds() {
621 let (mut reading_channel, mut writing_channel) = test_channels();
622 writing_channel.blocking().expect("Could not block channel");
623
624 trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
625 let awaiting_with_timeout = thread::spawn(move || {
626 let message = reading_channel
627 .read_message_blocking_timeout(Some(Duration::from_millis(200)))
628 .expect("Could not read message with timeout on blocking channel");
629 trace!("read message!");
630 message
631 });
632
633 trace!("Waiting 100 milliseconds…");
634 thread::sleep(std::time::Duration::from_millis(100));
635
636 writing_channel
637 .write_message(&ProtobufMessage { inner: 100 })
638 .expect("Could not write message on channel");
639 trace!("we wrote a message that should arrive on time!");
640
641 let arrived_on_time = awaiting_with_timeout
642 .join()
643 .expect("error with receiving message from awaiting thread");
644
645 assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
646 }
647
648 #[test]
649 fn exhaustive_use_of_nonblocking_channels() {
650 let (mut channel_a, mut channel_b) = test_channels();
652 channel_a.nonblocking().expect("Could not block channel");
653
654 channel_a
656 .write_message(&ProtobufMessage { inner: 1 })
657 .expect("Could not write message on channel");
658
659 channel_b.handle_events(Ready::READABLE);
661
662 let should_err = channel_b.read_message();
664 assert!(should_err.is_err());
665
666 channel_a
668 .write_message(&ProtobufMessage { inner: 2 })
669 .expect("Could not write message on channel");
670
671 channel_a.handle_events(Ready::WRITABLE);
673
674 channel_a.run().expect("Failed to run the channel");
676
677 thread::sleep(std::time::Duration::from_millis(100));
679
680 channel_b.run().expect("Failed to run the channel");
682
683 let message_1 = channel_b
685 .read_message()
686 .expect("Could not read message on channel");
687 assert_eq!(message_1, ProtobufMessage { inner: 1 });
688
689 let message_2 = channel_b
690 .read_message()
691 .expect("Could not read message on channel");
692 assert_eq!(message_2, ProtobufMessage { inner: 2 });
693 }
694}