sozu_command_lib/
channel.rs

1use std::{
2    cmp::min,
3    fmt::Debug,
4    io::{self, ErrorKind, Read, Write},
5    marker::PhantomData,
6    os::unix::{
7        io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
8        net::UnixStream as StdUnixStream,
9    },
10    time::Duration,
11};
12
13use mio::{event::Source, net::UnixStream as MioUnixStream};
14use prost::{DecodeError, Message as ProstMessage};
15
16use crate::{buffer::growable::Buffer, ready::Ready};
17
18#[derive(thiserror::Error, Debug)]
19pub enum ChannelError {
20    #[error("io read error")]
21    Read(std::io::Error),
22    #[error("no byte written on the channel")]
23    NoByteWritten,
24    #[error("no byte left to read on the channel")]
25    NoByteToRead,
26    #[error("message too large for the capacity of the back fuffer ({0}. Consider increasing the back buffer size")]
27    MessageTooLarge(usize),
28    #[error("channel could not write on the back buffer")]
29    Write(std::io::Error),
30    #[error("channel buffer is full ({0} bytes), cannot grow more")]
31    BufferFull(usize),
32    #[error("Timeout is reached: {0:?}")]
33    TimeoutReached(Duration),
34    #[error("Could not read anything on the channel")]
35    NothingRead,
36    #[error("invalid char set in command message, ignoring: {0}")]
37    InvalidCharSet(String),
38    #[error("could not set the timeout of the unix stream with file descriptor {fd}: {error}")]
39    SetTimeout { fd: i32, error: String },
40    #[error("Could not change the blocking status ef the unix stream with file descriptor {fd}: {error}")]
41    BlockingStatus { fd: i32, error: String },
42    #[error("Connection error: {0:?}")]
43    Connection(Option<std::io::Error>),
44    #[error("Invalid protobuf message: {0}")]
45    InvalidProtobufMessage(DecodeError),
46    #[error("This should never happen (index out of bound on a tested buffer)")]
47    MismatchBufferSize,
48}
49
50/// Channel meant for communication between Sōzu processes over a UNIX socket.
51/// It wraps a unix socket using the mio crate, and transmit prost messages
52/// by serializing them in a binary format, with a fix-sized delimiter.
53/// To function, channels must come in pairs, one for each agent.
54/// They can function in a blocking or non-blocking way.
55pub struct Channel<Tx, Rx> {
56    pub sock: MioUnixStream,
57    pub front_buf: Buffer,
58    pub back_buf: Buffer,
59    max_buffer_size: u64,
60    pub readiness: Ready,
61    pub interest: Ready,
62    blocking: bool,
63    phantom_tx: PhantomData<Tx>,
64    phantom_rx: PhantomData<Rx>,
65}
66
67impl<Tx, Rx> std::fmt::Debug for Channel<Tx, Rx> {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct(&format!(
70            "Channel<{}, {}>",
71            std::any::type_name::<Tx>(),
72            std::any::type_name::<Rx>()
73        ))
74        .field("sock", &self.sock.as_raw_fd())
75        // .field("front_buf", &self.front_buf)
76        // .field("back_buf", &self.back_buf)
77        // .field("max_buffer_size", &self.max_buffer_size)
78        .field("readiness", &self.readiness)
79        .field("interest", &self.interest)
80        .field("blocking", &self.blocking)
81        .finish()
82    }
83}
84
85impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
86    /// Creates a nonblocking channel on a given socket path
87    pub fn from_path(
88        path: &str,
89        buffer_size: u64,
90        max_buffer_size: u64,
91    ) -> Result<Channel<Tx, Rx>, ChannelError> {
92        let unix_stream = MioUnixStream::connect(path)
93            .map_err(|io_error| ChannelError::Connection(Some(io_error)))?;
94        Ok(Channel::new(unix_stream, buffer_size, max_buffer_size))
95    }
96
97    /// Creates a nonblocking channel, using a unix stream
98    pub fn new(sock: MioUnixStream, buffer_size: u64, max_buffer_size: u64) -> Channel<Tx, Rx> {
99        Channel {
100            sock,
101            front_buf: Buffer::with_capacity(buffer_size as usize),
102            back_buf: Buffer::with_capacity(buffer_size as usize),
103            max_buffer_size,
104            readiness: Ready::EMPTY,
105            interest: Ready::READABLE,
106            blocking: false,
107            phantom_tx: PhantomData,
108            phantom_rx: PhantomData,
109        }
110    }
111
112    pub fn into<Tx2: Debug + ProstMessage + Default, Rx2: Debug + ProstMessage + Default>(
113        self,
114    ) -> Channel<Tx2, Rx2> {
115        Channel {
116            sock: self.sock,
117            front_buf: self.front_buf,
118            back_buf: self.back_buf,
119            max_buffer_size: self.max_buffer_size,
120            readiness: self.readiness,
121            interest: self.interest,
122            blocking: self.blocking,
123            phantom_tx: PhantomData,
124            phantom_rx: PhantomData,
125        }
126    }
127
128    // Since MioUnixStream does not have a set_nonblocking method, we have to use the standard library.
129    // We get the file descriptor of the MioUnixStream socket, create a standard library UnixStream,
130    // set it to nonblocking, let go of the file descriptor
131    fn set_nonblocking(&mut self, nonblocking: bool) -> Result<(), ChannelError> {
132        unsafe {
133            let fd = self.sock.as_raw_fd();
134            let stream = StdUnixStream::from_raw_fd(fd);
135            stream
136                .set_nonblocking(nonblocking)
137                .map_err(|error| ChannelError::BlockingStatus {
138                    fd,
139                    error: error.to_string(),
140                })?;
141            let _fd = stream.into_raw_fd();
142        }
143        self.blocking = !nonblocking;
144        Ok(())
145    }
146
147    /// set the read_timeout of the unix stream. This works only temporary, be sure to set the timeout to None afterwards.
148    fn set_timeout(&mut self, timeout: Option<Duration>) -> Result<(), ChannelError> {
149        unsafe {
150            let fd = self.sock.as_raw_fd();
151            let stream = StdUnixStream::from_raw_fd(fd);
152            stream
153                .set_read_timeout(timeout)
154                .map_err(|error| ChannelError::SetTimeout {
155                    fd,
156                    error: error.to_string(),
157                })?;
158            let _fd = stream.into_raw_fd();
159        }
160        Ok(())
161    }
162
163    /// set the channel to be blocking
164    pub fn blocking(&mut self) -> Result<(), ChannelError> {
165        self.set_nonblocking(false)
166    }
167
168    /// set the channel to be nonblocking
169    pub fn nonblocking(&mut self) -> Result<(), ChannelError> {
170        self.set_nonblocking(true)
171    }
172
173    pub fn is_blocking(&self) -> bool {
174        self.blocking
175    }
176
177    /// Get the raw file descriptor of the UNIX socket
178    pub fn fd(&self) -> RawFd {
179        self.sock.as_raw_fd()
180    }
181
182    pub fn handle_events(&mut self, events: Ready) {
183        self.readiness |= events;
184    }
185
186    pub fn readiness(&self) -> Ready {
187        self.readiness & self.interest
188    }
189
190    /// Check wether we want and can read or write, and calls the appropriate handler.
191    pub fn run(&mut self) -> Result<(), ChannelError> {
192        let interest = self.interest & self.readiness;
193
194        if interest.is_readable() {
195            let _ = self.readable()?;
196        }
197
198        if interest.is_writable() {
199            let _ = self.writable()?;
200        }
201        Ok(())
202    }
203
204    /// Handle readability by filling the front buffer with the socket data.
205    pub fn readable(&mut self) -> Result<usize, ChannelError> {
206        if !(self.interest & self.readiness).is_readable() {
207            return Err(ChannelError::Connection(None));
208        }
209
210        let mut count = 0usize;
211        loop {
212            let size = self.front_buf.available_space();
213            trace!("channel available space: {}", size);
214            if size == 0 {
215                self.interest.remove(Ready::READABLE);
216                break;
217            }
218
219            match self.sock.read(self.front_buf.space()) {
220                Ok(0) => {
221                    self.interest = Ready::EMPTY;
222                    self.readiness.remove(Ready::READABLE);
223                    self.readiness.insert(Ready::HUP);
224                    return Err(ChannelError::NoByteToRead);
225                }
226                Err(read_error) => match read_error.kind() {
227                    ErrorKind::WouldBlock => {
228                        self.readiness.remove(Ready::READABLE);
229                        break;
230                    }
231                    _ => {
232                        self.interest = Ready::EMPTY;
233                        self.readiness = Ready::EMPTY;
234                        return Err(ChannelError::Read(read_error));
235                    }
236                },
237                Ok(bytes_read) => {
238                    count += bytes_read;
239                    self.front_buf.fill(bytes_read);
240                }
241            };
242        }
243
244        Ok(count)
245    }
246
247    /// Handle writability by writing the content of the back buffer onto the socket
248    pub fn writable(&mut self) -> Result<usize, ChannelError> {
249        if !(self.interest & self.readiness).is_writable() {
250            return Err(ChannelError::Connection(None));
251        }
252
253        let mut count = 0usize;
254        loop {
255            let size = self.back_buf.available_data();
256            if size == 0 {
257                self.interest.remove(Ready::WRITABLE);
258                break;
259            }
260
261            match self.sock.write(self.back_buf.data()) {
262                Ok(0) => {
263                    self.interest = Ready::EMPTY;
264                    self.readiness.insert(Ready::HUP);
265                    return Err(ChannelError::NoByteWritten);
266                }
267                Ok(bytes_written) => {
268                    count += bytes_written;
269                    self.back_buf.consume(bytes_written);
270                }
271                Err(write_error) => match write_error.kind() {
272                    ErrorKind::WouldBlock => {
273                        self.readiness.remove(Ready::WRITABLE);
274                        break;
275                    }
276                    _ => {
277                        self.interest = Ready::EMPTY;
278                        self.readiness = Ready::EMPTY;
279                        return Err(ChannelError::Read(write_error));
280                    }
281                },
282            }
283        }
284
285        Ok(count)
286    }
287
288    /// Depending on the blocking status:
289    ///
290    /// Blocking: wait for the front buffer to be filled, and parse a message from it
291    ///
292    /// Nonblocking: parse a message from the front buffer, without waiting.
293    /// Prefer using `channel.readable()` before
294    pub fn read_message(&mut self) -> Result<Rx, ChannelError> {
295        if self.blocking {
296            self.read_message_blocking()
297        } else {
298            self.read_message_nonblocking()
299        }
300    }
301
302    fn read_message_blocking(&mut self) -> Result<Rx, ChannelError> {
303        self.read_message_blocking_timeout(None)
304    }
305
306    /// Parse a message from the front buffer, without waiting
307    fn read_message_nonblocking(&mut self) -> Result<Rx, ChannelError> {
308        if let Some(message) = self.try_read_delimited_message()? {
309            return Ok(message);
310        }
311
312        self.interest.insert(Ready::READABLE);
313        Err(ChannelError::NothingRead)
314    }
315
316    /// Wait for the front buffer to be filled, and parses a message from it.
317    pub fn read_message_blocking_timeout(
318        &mut self,
319        timeout: Option<Duration>,
320    ) -> Result<Rx, ChannelError> {
321        let now = std::time::Instant::now();
322
323        // set a very small timeout, to repeat the loop often
324        self.set_timeout(Some(Duration::from_millis(10)))?;
325
326        let status = loop {
327            if let Some(timeout) = timeout {
328                if now.elapsed() >= timeout {
329                    break Err(ChannelError::TimeoutReached(timeout));
330                }
331            }
332
333            if let Some(message) = self.try_read_delimited_message()? {
334                return Ok(message);
335            }
336
337            match self.sock.read(self.front_buf.space()) {
338                Ok(0) => return Err(ChannelError::NoByteToRead),
339                Ok(bytes_read) => self.front_buf.fill(bytes_read),
340                Err(io_error) => match io_error.kind() {
341                    ErrorKind::WouldBlock => continue, // ignore 10 millisecond timeouts
342                    _ => break Err(ChannelError::Read(io_error)),
343                },
344            };
345        };
346
347        self.set_timeout(None)?;
348
349        status
350    }
351
352    /// parse a prost message from the front buffer, grow it if necessary
353    fn try_read_delimited_message(&mut self) -> Result<Option<Rx>, ChannelError> {
354        let buffer = self.front_buf.data();
355        if buffer.len() >= delimiter_size() {
356            let delimiter = buffer[..delimiter_size()]
357                .try_into()
358                .map_err(|_| ChannelError::MismatchBufferSize)?;
359            let message_len = usize::from_le_bytes(delimiter);
360
361            if buffer.len() >= message_len {
362                let message = Rx::decode(&buffer[delimiter_size()..message_len])
363                    .map_err(ChannelError::InvalidProtobufMessage)?;
364                self.front_buf.consume(message_len);
365                return Ok(Some(message));
366            }
367        }
368
369        if self.front_buf.available_space() == 0 {
370            if (self.front_buf.capacity() as u64) >= self.max_buffer_size {
371                return Err(ChannelError::BufferFull(self.front_buf.capacity()));
372            }
373            let new_size = min(
374                self.front_buf.capacity() + 5000,
375                self.max_buffer_size as usize,
376            );
377            self.front_buf.grow(new_size);
378        }
379        Ok(None)
380    }
381
382    /// Checks whether the channel is blocking or nonblocking, writes the message.
383    ///
384    /// If the channel is nonblocking, you have to flush using `channel.run()` afterwards
385    pub fn write_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
386        if self.blocking {
387            self.write_message_blocking(message)
388        } else {
389            self.write_message_nonblocking(message)
390        }
391    }
392
393    /// Writes the message in the buffer, but NOT on the socket.
394    /// you have to call channel.run() afterwards
395    fn write_message_nonblocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
396        self.write_delimited_message(message)?;
397
398        self.interest.insert(Ready::WRITABLE);
399
400        Ok(())
401    }
402
403    /// fills the back buffer with data AND writes on the socket
404    fn write_message_blocking(&mut self, message: &Tx) -> Result<(), ChannelError> {
405        self.write_delimited_message(message)?;
406
407        loop {
408            let size = self.back_buf.available_data();
409            if size == 0 {
410                break;
411            }
412
413            match self.sock.write(self.back_buf.data()) {
414                Ok(0) => return Err(ChannelError::NoByteWritten),
415                Ok(bytes_written) => {
416                    self.back_buf.consume(bytes_written);
417                }
418                Err(_) => return Ok(()), // are we sure?
419            }
420        }
421        Ok(())
422    }
423
424    /// write a message on the back buffer, using our own delimiter (the delimiter of prost
425    /// is not trustworthy since its size may change)
426    pub fn write_delimited_message(&mut self, message: &Tx) -> Result<(), ChannelError> {
427        let payload = message.encode_to_vec();
428
429        let payload_len = payload.len() + delimiter_size();
430
431        let delimiter = payload_len.to_le_bytes();
432
433        if payload_len > self.back_buf.available_space() {
434            self.back_buf.shift();
435        }
436
437        if payload_len > self.back_buf.available_space() {
438            if payload_len - self.back_buf.available_space() + self.back_buf.capacity()
439                > (self.max_buffer_size as usize)
440            {
441                return Err(ChannelError::MessageTooLarge(self.back_buf.capacity()));
442            }
443
444            let new_length =
445                payload_len - self.back_buf.available_space() + self.back_buf.capacity();
446            self.back_buf.grow(new_length);
447        }
448
449        self.back_buf
450            .write_all(&delimiter)
451            .map_err(ChannelError::Write)?;
452        self.back_buf
453            .write_all(&payload)
454            .map_err(ChannelError::Write)?;
455
456        Ok(())
457    }
458}
459
460/// the payload is prefixed with a delimiter of sizeof(usize) bytes
461pub const fn delimiter_size() -> usize {
462    std::mem::size_of::<usize>()
463}
464
465type ChannelResult<Tx, Rx> = Result<(Channel<Tx, Rx>, Channel<Rx, Tx>), ChannelError>;
466
467impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Channel<Tx, Rx> {
468    /// creates a channel pair: `(blocking_channel, nonblocking_channel)`
469    pub fn generate(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
470        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
471        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
472        let mut command_channel = Channel::new(command, buffer_size, max_buffer_size);
473        command_channel.blocking()?;
474        Ok((command_channel, proxy_channel))
475    }
476
477    /// creates a pair of nonblocking channels
478    pub fn generate_nonblocking(buffer_size: u64, max_buffer_size: u64) -> ChannelResult<Tx, Rx> {
479        let (command, proxy) = MioUnixStream::pair().map_err(ChannelError::Read)?;
480        let proxy_channel = Channel::new(proxy, buffer_size, max_buffer_size);
481        let command_channel = Channel::new(command, buffer_size, max_buffer_size);
482        Ok((command_channel, proxy_channel))
483    }
484}
485
486impl<Tx: Debug + ProstMessage + Default, Rx: Debug + ProstMessage + Default> Iterator
487    for Channel<Tx, Rx>
488{
489    type Item = Rx;
490    fn next(&mut self) -> Option<Self::Item> {
491        self.read_message().ok()
492    }
493}
494
495use mio::{Interest, Registry, Token};
496impl<Tx, Rx> Source for Channel<Tx, Rx> {
497    fn register(
498        &mut self,
499        registry: &Registry,
500        token: Token,
501        interests: Interest,
502    ) -> io::Result<()> {
503        self.sock.register(registry, token, interests)
504    }
505
506    fn reregister(
507        &mut self,
508        registry: &Registry,
509        token: Token,
510        interests: Interest,
511    ) -> io::Result<()> {
512        self.sock.reregister(registry, token, interests)
513    }
514
515    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
516        self.sock.deregister(registry)
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use std::{thread, time::Duration};
523
524    use super::*;
525
526    #[derive(Clone, PartialEq, prost::Message)]
527    pub struct ProtobufMessage {
528        #[prost(uint32, required, tag = "1")]
529        inner: u32,
530    }
531
532    fn test_channels() -> (
533        Channel<ProtobufMessage, ProtobufMessage>,
534        Channel<ProtobufMessage, ProtobufMessage>,
535    ) {
536        Channel::generate(1000, 10000).expect("could not generate blocking channels for testing")
537    }
538
539    #[test]
540    fn unblock_a_channel() {
541        let (mut blocking, _nonblocking) = test_channels();
542        assert!(blocking.nonblocking().is_ok())
543    }
544
545    #[test]
546    fn generate_blocking_and_nonblocking_channels() {
547        let (blocking_channel, nonblocking_channel) = test_channels();
548
549        assert!(blocking_channel.is_blocking());
550        assert!(!nonblocking_channel.is_blocking());
551
552        let (nonblocking_channel_1, nonblocking_channel_2): (
553            Channel<ProtobufMessage, ProtobufMessage>,
554            Channel<ProtobufMessage, ProtobufMessage>,
555        ) = Channel::generate_nonblocking(1000, 10000)
556            .expect("could not generatie nonblocking channels");
557
558        assert!(!nonblocking_channel_1.is_blocking());
559        assert!(!nonblocking_channel_2.is_blocking());
560    }
561
562    #[test]
563    fn write_and_read_message_blocking() {
564        let (mut blocking_channel, mut nonblocking_channel) = test_channels();
565
566        let message_to_send = ProtobufMessage { inner: 42 };
567
568        nonblocking_channel
569            .blocking()
570            .expect("Could not block channel");
571        nonblocking_channel
572            .write_message(&message_to_send)
573            .expect("Could not write message on channel");
574
575        trace!("we wrote a message!");
576
577        trace!("reading message..");
578        // blocking_channel.readable();
579        let message = blocking_channel
580            .read_message()
581            .expect("Could not read message on channel");
582        trace!("read message!");
583
584        assert_eq!(message, ProtobufMessage { inner: 42 });
585    }
586
587    #[test]
588    fn read_message_blocking_with_timeout_fails() {
589        let (mut reading_channel, mut writing_channel) = test_channels();
590        writing_channel.blocking().expect("Could not block channel");
591
592        trace!("reading message in a detached thread, with a timeout of 100 milliseconds...");
593        let awaiting_with_timeout = thread::spawn(move || {
594            let message =
595                reading_channel.read_message_blocking_timeout(Some(Duration::from_millis(100)));
596            trace!("read message!");
597            message
598        });
599
600        trace!("Waiting 200 milliseconds…");
601        thread::sleep(std::time::Duration::from_millis(200));
602
603        writing_channel
604            .write_message(&ProtobufMessage { inner: 200 })
605            .expect("Could not write message on channel");
606        trace!("we wrote a message that should arrive too late!");
607
608        let arrived_too_late = awaiting_with_timeout
609            .join()
610            .expect("error with receiving message from awaiting thread");
611
612        assert!(arrived_too_late.is_err());
613    }
614
615    #[test]
616    fn read_message_blocking_with_timeout_succeeds() {
617        let (mut reading_channel, mut writing_channel) = test_channels();
618        writing_channel.blocking().expect("Could not block channel");
619
620        trace!("reading message in a detached thread, with a timeout of 200 milliseconds...");
621        let awaiting_with_timeout = thread::spawn(move || {
622            let message = reading_channel
623                .read_message_blocking_timeout(Some(Duration::from_millis(200)))
624                .expect("Could not read message with timeout on blocking channel");
625            trace!("read message!");
626            message
627        });
628
629        trace!("Waiting 100 milliseconds…");
630        thread::sleep(std::time::Duration::from_millis(100));
631
632        writing_channel
633            .write_message(&ProtobufMessage { inner: 100 })
634            .expect("Could not write message on channel");
635        trace!("we wrote a message that should arrive on time!");
636
637        let arrived_on_time = awaiting_with_timeout
638            .join()
639            .expect("error with receiving message from awaiting thread");
640
641        assert_eq!(arrived_on_time, ProtobufMessage { inner: 100 });
642    }
643
644    #[test]
645    fn exhaustive_use_of_nonblocking_channels() {
646        // - two nonblocking channels A and B, identical
647        let (mut channel_a, mut channel_b) = test_channels();
648        channel_a.nonblocking().expect("Could not block channel");
649
650        // write on A
651        channel_a
652            .write_message(&ProtobufMessage { inner: 1 })
653            .expect("Could not write message on channel");
654
655        // set B as readable, normally mio tells when to, by giving events
656        channel_b.handle_events(Ready::READABLE);
657
658        // read on B
659        let should_err = channel_b.read_message();
660        assert!(should_err.is_err());
661
662        // write another message on A
663        channel_a
664            .write_message(&ProtobufMessage { inner: 2 })
665            .expect("Could not write message on channel");
666
667        // insert a handle_events Ready::writable on A
668        channel_a.handle_events(Ready::WRITABLE);
669
670        // flush A with run()
671        channel_a.run().expect("Failed to run the channel");
672
673        // maybe a thread sleep
674        thread::sleep(std::time::Duration::from_millis(100));
675
676        // receive with B using run()
677        channel_b.run().expect("Failed to run the channel");
678
679        // use read_message() twice on B, check them
680        let message_1 = channel_b
681            .read_message()
682            .expect("Could not read message on channel");
683        assert_eq!(message_1, ProtobufMessage { inner: 1 });
684
685        let message_2 = channel_b
686            .read_message()
687            .expect("Could not read message on channel");
688        assert_eq!(message_2, ProtobufMessage { inner: 2 });
689    }
690}