sozu_command_lib/
channel.rs

1use std::{
2    cmp::min,
3    fmt::Debug,
4    io::{self, ErrorKind, Read, Write},
5    marker::PhantomData,
6    os::unix::{
7        io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
8        net::UnixStream as StdUnixStream,
9    },
10    time::Duration,
11};
12
13use mio::{event::Source, net::UnixStream as MioUnixStream};
14use prost::{DecodeError, Message as ProstMessage};
15
16use crate::{buffer::growable::Buffer, ready::Ready};
17
18#[derive(thiserror::Error, Debug)]
19pub enum ChannelError {
20    #[error("io read error")]
21    Read(std::io::Error),
22    #[error("no byte written on the channel")]
23    NoByteWritten,
24    #[error("no byte left to read on the channel")]
25    NoByteToRead,
26    #[error(
27        "message too large for the capacity of the back fuffer ({0}. Consider increasing the back buffer size"
28    )]
29    MessageTooLarge(usize),
30    #[error("channel could not write on the back buffer")]
31    Write(std::io::Error),
32    #[error("channel buffer is full ({0} bytes), cannot grow more")]
33    BufferFull(usize),
34    #[error("Timeout is reached: {0:?}")]
35    TimeoutReached(Duration),
36    #[error("Could not read anything on the channel")]
37    NothingRead,
38    #[error("invalid char set in command message, ignoring: {0}")]
39    InvalidCharSet(String),
40    #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
41    SetTimeout { fd: i32, error: String },
42    #[error(
43        "Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}"
44    )]
45    BlockingStatus { fd: i32, error: String },
46    #[error("Connection error: {0:?}")]
47    Connection(Option<std::io::Error>),
48    #[error("Invalid protobuf message: {0}")]
49    InvalidProtobufMessage(DecodeError),
50    #[error("This should never happen (index out of bound on a tested buffer)")]
51    MismatchBufferSize,
52}
53
54/// Channel meant for communication between Sōzu processes over a UNIX socket.
55/// It wraps a unix socket using the mio crate, and transmit prost messages
56/// by serializing them in a binary format, with a fix-sized delimiter.
57/// To function, channels must come in pairs, one for each agent.
58/// They can function in a blocking or non-blocking way.
59pub struct Channel<Tx, Rx> {
60    pub sock: MioUnixStream,
61    pub front_buf: Buffer,
62    pub back_buf: Buffer,
63    max_buffer_size: u64,
64    pub readiness: Ready,
65    pub interest: Ready,
66    blocking: bool,
67    phantom_tx: PhantomData<Tx>,
68    phantom_rx: PhantomData<Rx>,
69}
70
71impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct(&format!(
74            "Channel<{}, {}>",
75            std::any::type_name::<Tx>(),
76            std::any::type_name::<Rx>()
77        ))
78        .field("sock", &self.sock.as_raw_fd())
79        // .field("front_buf", &self.front_buf)
80        // .field("back_buf", &self.back_buf)
81        // .field("max_buffer_size", &self.max_buffer_size)
82        .field("readiness", &self.readiness)
83        .field("interest", &self.interest)
84        .field("blocking", &self.blocking)
85        .finish()
86    }
87}
88
89impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
90    /// Creates a nonblocking channel on a given socket path
91    pub fn from_path(
92        path: &str,
93        buffer_size: u64,
94        max_buffer_size: u64,
95    ) -> Result<Channel<Tx, Rx>, ChannelError> {
96        let unix_stream = MioUnixStream::connect(path)
97            .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
98        Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
99    }
100
101    /// Creates a nonblocking channel, using a unix stream
102    pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
103        Channel {
104            sock,
105            front_buf: Buffer::with_capacity(buffer_size as usize),
106            back_buf: Buffer::with_capacity(buffer_size as usize),
107            max_buffer_size,
108            readiness: Ready::EMPTY,
109            interest: Ready::READABLE,
110            blocking: false,
111            phantom_tx: PhantomData,
112            phantom_rx: PhantomData,
113        }
114    }
115
116    pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
117        self,
118    ) -> Channel<Tx2, Rx2> {
119        Channel {
120            sock: self.sock,
121            front_buf: self.front_buf,
122            back_buf: self.back_buf,
123            max_buffer_size: self.max_buffer_size,
124            readiness: self.readiness,
125            interest: self.interest,
126            blocking: self.blocking,
127            phantom_tx: PhantomData,
128            phantom_rx: PhantomData,
129        }
130    }
131
132    // Since MioUnixStream does not have a set_nonblocking method, we have to use the standard library.
133    // We get the file descriptor of the MioUnixStream socket, create a standard library UnixStream,
134    // set it to nonblocking, let go of the file descriptor
135    fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
136        unsafe {
137            let fd = self.sock.as_raw_fd();
138            let stream = StdUnixStream::from_raw_fd(fd);
139            stream
140                .set_nonblocking(nonblocking)
141                .map_err(|error| ChannelError::BlockingStatus {
142                    fd,
143                    error: error.to_string(),
144                })?;
145            let _fd = stream.into_raw_fd();
146        }
147        self.blocking = !nonblocking;
148        Ok(())
149    }
150
151    /// set the read_timeout of the unix stream. This works only temporary, be sure to set the timeout to None afterwards.
152    fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
153        unsafe {
154            let fd = self.sock.as_raw_fd();
155            let stream = StdUnixStream::from_raw_fd(fd);
156            stream
157                .set_read_timeout(timeout)
158                .map_err(|error| ChannelError::SetTimeout {
159                    fd,
160                    error: error.to_string(),
161                })?;
162            let _fd = stream.into_raw_fd();
163        }
164        Ok(())
165    }
166
167    /// set the channel to be blocking
168    pub fn blocking(&mut self) -> Result<(), ChannelError> {
169        self.set_nonblocking(false)
170    }
171
172    /// set the channel to be nonblocking
173    pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
174        self.set_nonblocking(true)
175    }
176
177    pub fn is_blocking(&self) -> bool {
178        self.blocking
179    }
180
181    /// Get the raw file descriptor of the UNIX socket
182    pub fn fd(&self) -> RawFd {
183        self.sock.as_raw_fd()
184    }
185
186    pub fn handle_events(&mut self, events: Ready) {
187        self.readiness |= events;
188    }
189
190    pub fn readiness(&self) -> Ready {
191        self.readiness & self.interest
192    }
193
194    /// Check wether we want and can read or write, and calls the appropriate handler.
195    pub fn run(&mut self) -> Result<(), ChannelError> {
196        let interest = self.interest & self.readiness;
197
198        if interest.is_readable() {
199            let _ = self.readable()?;
200        }
201
202        if interest.is_writable() {
203            let _ = self.writable()?;
204        }
205        Ok(())
206    }
207
208    /// Handle readability by filling the front buffer with the socket data.
209    pub fn readable(&mut self) -> Result<usize, ChannelError> {
210        if !(self.interest & self.readiness).is_readable() {
211            return Err(ChannelError::Connection(None));
212        }
213
214        let mut count = 0usize;
215        loop {
216            let size = self.front_buf.available_space();
217            trace!("channel available space: {}", size);
218            if size == 0 {
219                self.interest.remove(Ready::READABLE);
220                break;
221            }
222
223            match self.sock.read(self.front_buf.space()) {
224                Ok(0) => {
225                    self.interest = Ready::EMPTY;
226                    self.readiness.remove(Ready::READABLE);
227                    self.readiness.insert(Ready::HUP);
228                    return Err(ChannelError::NoByteToRead);
229                }
230                Err(read_error) => match read_error.kind() {
231                    ErrorKind::WouldBlock => {
232                        self.readiness.remove(Ready::READABLE);
233                        break;
234                    }
235                    _ => {
236                        self.interest = Ready::EMPTY;
237                        self.readiness = Ready::EMPTY;
238                        return Err(ChannelError::Read(read_error));
239                    }
240                },
241                Ok(bytes_read) => {
242                    count += bytes_read;
243                    self.front_buf.fill(bytes_read);
244                }
245            };
246        }
247
248        Ok(count)
249    }
250
251    /// Handle writability by writing the content of the back buffer onto the socket
252    pub fn writable(&mut self) -> Result<usize, ChannelError> {
253        if !(self.interest & self.readiness).is_writable() {
254            return Err(ChannelError::Connection(None));
255        }
256
257        let mut count = 0usize;
258        loop {
259            let size = self.back_buf.available_data();
260            if size == 0 {
261                self.interest.remove(Ready::WRITABLE);
262                break;
263            }
264
265            match self.sock.write(self.back_buf.data()) {
266                Ok(0) => {
267                    self.interest = Ready::EMPTY;
268                    self.readiness.insert(Ready::HUP);
269                    return Err(ChannelError::NoByteWritten);
270                }
271                Ok(bytes_written) => {
272                    count += bytes_written;
273                    self.back_buf.consume(bytes_written);
274                }
275                Err(write_error) => match write_error.kind() {
276                    ErrorKind::WouldBlock => {
277                        self.readiness.remove(Ready::WRITABLE);
278                        break;
279                    }
280                    _ => {
281                        self.interest = Ready::EMPTY;
282                        self.readiness = Ready::EMPTY;
283                        return Err(ChannelError::Read(write_error));
284                    }
285                },
286            }
287        }
288
289        Ok(count)
290    }
291
292    /// Depending on the blocking status:
293    ///
294    /// Blocking: wait for the front buffer to be filled, and parse a message from it
295    ///
296    /// Nonblocking: parse a message from the front buffer, without waiting.
297    /// Prefer using `channel.readable()` before
298    pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
299        if self.blocking {
300            self.read_message_blocking()
301        } else {
302            self.read_message_nonblocking()
303        }
304    }
305
306    fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
307        self.read_message_blocking_timeout(None)
308    }
309
310    /// Parse a message from the front buffer, without waiting
311    fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
312        if let Some(message) = self.try_read_delimited_message()? {
313            return Ok(message);
314        }
315
316        self.interest.insert(Ready::READABLE);
317        Err(ChannelError::NothingRead)
318    }
319
320    /// Wait for the front buffer to be filled, and parses a message from it.
321    pub fn read_message_blocking_timeout(
322        &mut self,
323        timeout: Option<Duration>,
324    ) -> Result<Rx, ChannelError> {
325        let now = std::time::Instant::now();
326
327        // set a very small timeout, to repeat the loop often
328        self.set_timeout(Some(Duration::from_millis(10)))?;
329
330        let status = loop {
331            if let Some(timeout) = timeout {
332                if now.elapsed() >= timeout {
333                    break Err(ChannelError::TimeoutReached(timeout));
334                }
335            }
336
337            if let Some(message) = self.try_read_delimited_message()? {
338                return Ok(message);
339            }
340
341            match self.sock.read(self.front_buf.space()) {
342                Ok(0) => return Err(ChannelError::NoByteToRead),
343                Ok(bytes_read) => self.front_buf.fill(bytes_read),
344                Err(io_error) => match io_error.kind() {
345                    ErrorKind::WouldBlock => continue, // ignore 10 millisecond timeouts
346                    _ => break Err(ChannelError::Read(io_error)),
347                },
348            };
349        };
350
351        self.set_timeout(None)?;
352
353        status
354    }
355
356    /// parse a prost message from the front buffer, grow it if necessary
357    fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
358        let buffer = self.front_buf.data();
359        if buffer.len() >= delimiter_size() {
360            let delimiter = buffer[..delimiter_size()]
361                .try_into()
362                .map_err(|_| ChannelError::MismatchBufferSize)?;
363            let message_len = usize::from_le_bytes(delimiter);
364
365            if buffer.len() >= message_len {
366                let message = Rx::decode(&buffer[delimiter_size()..message_len])
367                    .map_err(ChannelError::InvalidProtobufMessage)?;
368                self.front_buf.consume(message_len);
369                return Ok(Some(message));
370            }
371        }
372
373        if self.front_buf.available_space() == 0 {
374            if (self.front_buf.capacity() as u64) >= self.max_buffer_size {
375                return Err(ChannelError::BufferFull(self.front_buf.capacity()));
376            }
377            let new_size = min(
378                self.front_buf.capacity() + 5000,
379                self.max_buffer_size as usize,
380            );
381            self.front_buf.grow(new_size);
382        }
383        Ok(None)
384    }
385
386    /// Checks whether the channel is blocking or nonblocking, writes the message.
387    ///
388    /// If the channel is nonblocking, you have to flush using `channel.run()` afterwards
389    pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
390        if self.blocking {
391            self.write_message_blocking(message)
392        } else {
393            self.write_message_nonblocking(message)
394        }
395    }
396
397    /// Writes the message in the buffer, but NOT on the socket.
398    /// you have to call channel.run() afterwards
399    fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
400        self.write_delimited_message(message)?;
401
402        self.interest.insert(Ready::WRITABLE);
403
404        Ok(())
405    }
406
407    /// fills the back buffer with data AND writes on the socket
408    fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
409        self.write_delimited_message(message)?;
410
411        loop {
412            let size = self.back_buf.available_data();
413            if size == 0 {
414                break;
415            }
416
417            match self.sock.write(self.back_buf.data()) {
418                Ok(0) => return Err(ChannelError::NoByteWritten),
419                Ok(bytes_written) => {
420                    self.back_buf.consume(bytes_written);
421                }
422                Err(_) => return Ok(()), // are we sure?
423            }
424        }
425        Ok(())
426    }
427
428    /// write a message on the back buffer, using our own delimiter (the delimiter of prost
429    /// is not trustworthy since its size may change)
430    pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
431        let payload = message.encode_to_vec();
432
433        let payload_len = payload.len() + delimiter_size();
434
435        let delimiter = payload_len.to_le_bytes();
436
437        if payload_len > self.back_buf.available_space() {
438            self.back_buf.shift();
439        }
440
441        if payload_len > self.back_buf.available_space() {
442            if payload_len - self.back_buf.available_space() + self.back_buf.capacity()
443                > (self.max_buffer_size as usize)
444            {
445                return Err(ChannelError::MessageTooLarge(self.back_buf.capacity()));
446            }
447
448            let new_length =
449                payload_len - self.back_buf.available_space() + self.back_buf.capacity();
450            self.back_buf.grow(new_length);
451        }
452
453        self.back_buf
454            .write_all(&delimiter)
455            .map_err(ChannelError::Write)?;
456        self.back_buf
457            .write_all(&payload)
458            .map_err(ChannelError::Write)?;
459
460        Ok(())
461    }
462}
463
464/// the payload is prefixed with a delimiter of sizeof(usize) bytes
465pub const fn delimiter_size() -> usize {
466    std::mem::size_of::<usize>()
467}
468
469type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
470
471impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
472    /// creates a channel pair: `(blocking_channel, nonblocking_channel)`
473    pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
474        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
475        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
476        let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
477        command_channel.blocking()?;
478        Ok((command_channel, proxy_channel))
479    }
480
481    /// creates a pair of nonblocking channels
482    pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
483        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
484        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
485        let command_channel = Channel::new(command, buffer_size, max_buffer_size);
486        Ok((command_channel, proxy_channel))
487    }
488}
489
490impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
491    for Channel<Tx, Rx>
492{
493    type Item = Rx;
494    fn next(&mut self) -> Option<Self::Item> {
495        self.read_message().ok()
496    }
497}
498
499use mio::{Interest, Registry, Token};
500impl<Tx, Rx> Source for Channel<Tx, Rx> {
501    fn register(
502        &mut self,
503        registry: &Registry,
504        token: Token,
505        interests: Interest,
506    ) -> io::Result<()> {
507        self.sock.register(registry, token, interests)
508    }
509
510    fn reregister(
511        &mut self,
512        registry: &Registry,
513        token: Token,
514        interests: Interest,
515    ) -> io::Result<()> {
516        self.sock.reregister(registry, token, interests)
517    }
518
519    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
520        self.sock.deregister(registry)
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use std::{thread, time::Duration};
527
528    use super::*;
529
530    #[derive(Clone, PartialEq, prost::Message)]
531    pub struct ProtobufMessage {
532        #[prost(uint32, required, tag = "1")]
533        inner: u32,
534    }
535
536    fn test_channels() -> (
537        Channel<ProtobufMessage, ProtobufMessage>,
538        Channel<ProtobufMessage, ProtobufMessage>,
539    ) {
540        Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
541    }
542
543    #[test]
544    fn unblock_a_channel() {
545        let (mut blocking, _nonblocking) = test_channels();
546        assert!(blocking.nonblocking().is_ok())
547    }
548
549    #[test]
550    fn generate_blocking_and_nonblocking_channels() {
551        let (blocking_channel, nonblocking_channel) = test_channels();
552
553        assert!(blocking_channel.is_blocking());
554        assert!(!nonblocking_channel.is_blocking());
555
556        let (nonblocking_channel_1, nonblocking_channel_2): (
557            Channel<ProtobufMessage, ProtobufMessage>,
558            Channel<ProtobufMessage, ProtobufMessage>,
559        ) = Channel::generate_nonblocking(1000, 10000)
560            .expect("could not generatie nonblocking channels");
561
562        assert!(!nonblocking_channel_1.is_blocking());
563        assert!(!nonblocking_channel_2.is_blocking());
564    }
565
566    #[test]
567    fn write_and_read_message_blocking() {
568        let (mut blocking_channel, mut nonblocking_channel) = test_channels();
569
570        let message_to_send = ProtobufMessage { inner: 42 };
571
572        nonblocking_channel
573            .blocking()
574            .expect("Could not block channel");
575        nonblocking_channel
576            .write_message(&message_to_send)
577            .expect("Could not write message on channel");
578
579        trace!("we wrote a message!");
580
581        trace!("reading message..");
582        // blocking_channel.readable();
583        let message = blocking_channel
584            .read_message()
585            .expect("Could not read message on channel");
586        trace!("read message!");
587
588        assert_eq!(message, ProtobufMessage { inner: 42 });
589    }
590
591    #[test]
592    fn read_message_blocking_with_timeout_fails() {
593        let (mut reading_channel, mut writing_channel) = test_channels();
594        writing_channel.blocking().expect("Could not block channel");
595
596        trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
597        let awaiting_with_timeout = thread::spawn(move || {
598            let message =
599                reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
600            trace!("read message!");
601            message
602        });
603
604        trace!("Waiting 200 milliseconds…");
605        thread::sleep(std::time::Duration::from_millis(200));
606
607        writing_channel
608            .write_message(&ProtobufMessage { inner: 200 })
609            .expect("Could not write message on channel");
610        trace!("we wrote a message that should arrive too late!");
611
612        let arrived_too_late = awaiting_with_timeout
613            .join()
614            .expect("error with receiving message from awaiting thread");
615
616        assert!(arrived_too_late.is_err());
617    }
618
619    #[test]
620    fn read_message_blocking_with_timeout_succeeds() {
621        let (mut reading_channel, mut writing_channel) = test_channels();
622        writing_channel.blocking().expect("Could not block channel");
623
624        trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
625        let awaiting_with_timeout = thread::spawn(move || {
626            let message = reading_channel
627                .read_message_blocking_timeout(Some(Duration::from_millis(200)))
628                .expect("Could not read message with timeout on blocking channel");
629            trace!("read message!");
630            message
631        });
632
633        trace!("Waiting 100 milliseconds…");
634        thread::sleep(std::time::Duration::from_millis(100));
635
636        writing_channel
637            .write_message(&ProtobufMessage { inner: 100 })
638            .expect("Could not write message on channel");
639        trace!("we wrote a message that should arrive on time!");
640
641        let arrived_on_time = awaiting_with_timeout
642            .join()
643            .expect("error with receiving message from awaiting thread");
644
645        assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
646    }
647
648    #[test]
649    fn exhaustive_use_of_nonblocking_channels() {
650        // - two nonblocking channels A and B, identical
651        let (mut channel_a, mut channel_b) = test_channels();
652        channel_a.nonblocking().expect("Could not block channel");
653
654        // write on A
655        channel_a
656            .write_message(&ProtobufMessage { inner: 1 })
657            .expect("Could not write message on channel");
658
659        // set B as readable, normally mio tells when to, by giving events
660        channel_b.handle_events(Ready::READABLE);
661
662        // read on B
663        let should_err = channel_b.read_message();
664        assert!(should_err.is_err());
665
666        // write another message on A
667        channel_a
668            .write_message(&ProtobufMessage { inner: 2 })
669            .expect("Could not write message on channel");
670
671        // insert a handle_events Ready::writable on A
672        channel_a.handle_events(Ready::WRITABLE);
673
674        // flush A with run()
675        channel_a.run().expect("Failed to run the channel");
676
677        // maybe a thread sleep
678        thread::sleep(std::time::Duration::from_millis(100));
679
680        // receive with B using run()
681        channel_b.run().expect("Failed to run the channel");
682
683        // use read_message() twice on B, check them
684        let message_1 = channel_b
685            .read_message()
686            .expect("Could not read message on channel");
687        assert_eq!(message_1, ProtobufMessage { inner: 1 });
688
689        let message_2 = channel_b
690            .read_message()
691            .expect("Could not read message on channel");
692        assert_eq!(message_2, ProtobufMessage { inner: 2 });
693    }
694}