tcp_channel/
receiver.rs

1use std::io::{BufReader, Read};
2use std::net::{TcpListener, TcpStream, ToSocketAddrs};
3use std::marker::PhantomData;
4
5use bincode::Config;
6use byteorder::ReadBytesExt;
7use serde::de::DeserializeOwned;
8
9use crate::{ChannelRecv, Endian, BigEndian, RecvError};
10
11pub const DEFAULT_MAX_SIZE: usize = 64 * 0x100_000;
12
13/// The receiving side of a channel.
14pub struct Receiver<T: DeserializeOwned, E: Endian, R: Read = BufReader<TcpStream>> {
15    reader: R,
16    config: Config,
17    max_size: usize,
18    _marker: PhantomData<(T, E)>,
19
20    // This buffer is used for storing the currently read bytes in case the stream is nonblocking.
21    // Otherwise, bincode would deserialize only the currently read bytes.
22    buffer: Vec<u8>,
23
24    bytes_read: usize,
25    bytes_to_read: usize,
26}
27
28/// A more convenient way of initializing receivers.
29pub struct ReceiverBuilder;
30
31pub struct TypedReceiverBuilder<T, R, E> {
32    _marker: PhantomData<(T, R, E)>,
33    max_size: usize,
34}
35impl ReceiverBuilder {
36    /// Begin building a new, buffered channel.
37    pub fn new() -> TypedReceiverBuilder<(), BufReader<TcpStream>, BigEndian> {
38        Self::buffered()
39    }
40    /// Begin building a new, buffered channel.
41    pub fn buffered() -> TypedReceiverBuilder<(), BufReader<TcpStream>, BigEndian> {
42        TypedReceiverBuilder {
43            _marker: PhantomData,
44            max_size: DEFAULT_MAX_SIZE,
45        }
46    }
47    /// Begin building a new, non-buffered channel.
48    pub fn realtime() -> TypedReceiverBuilder<(), TcpStream, BigEndian> {
49        TypedReceiverBuilder {
50            _marker: PhantomData,
51            max_size: DEFAULT_MAX_SIZE,
52        }
53    }
54}
55impl<T, R, E> TypedReceiverBuilder<T, R, E> {
56    /// Specify the type to send.
57    pub fn with_type<U: DeserializeOwned>(self) -> TypedReceiverBuilder<U, R, E> {
58        TypedReceiverBuilder {
59            _marker: PhantomData,
60            max_size: self.max_size,
61        }
62    }
63    /// Specify the underlying reader type.
64    pub fn with_reader<S: Read>(self) -> TypedReceiverBuilder<T, S, E> {
65        TypedReceiverBuilder {
66            _marker: PhantomData,
67            max_size: self.max_size,
68        }
69    }
70    /// Specify the endianness.
71    pub fn with_endianness<F: Endian>(self) -> TypedReceiverBuilder<T, R, F> {
72        TypedReceiverBuilder {
73            _marker: PhantomData,
74            max_size: self.max_size,
75        }
76    }
77    /// Specify the max size to be allocated when receiving.
78    pub fn with_max_size(self, max_size: usize) -> Self {
79        Self {
80            _marker: PhantomData,
81            max_size,
82        }
83    }
84}
85impl<T: DeserializeOwned, R: Read, E: Endian> TypedReceiverBuilder<T, R, E> {
86    /// Initialize the receiver with the current variables.
87    pub fn build(self, reader: R) -> Receiver<T, E, R> {
88        Receiver {
89            _marker: PhantomData,
90            reader,
91            config: E::config(),
92            max_size: self.max_size,
93            buffer: Vec::new(),
94            bytes_read: 0,
95            bytes_to_read: 0,
96        }
97    }
98}
99impl<T: DeserializeOwned, E: Endian> TypedReceiverBuilder<T, BufReader<TcpStream>, E> {
100    /// Listen for a sender, binding the listener to the specified address.
101    pub fn listen_once<A: ToSocketAddrs>(self, address: A) -> std::io::Result<Receiver<T, E, BufReader<TcpStream>>> {
102        let listener = TcpListener::bind(address)?;
103
104        let (stream, _) = listener.accept()?;
105
106        Ok(Receiver {
107            config: E::config(),
108            _marker: PhantomData,
109            reader: BufReader::new(stream),
110            max_size: self.max_size,
111            buffer: Vec::new(),
112            bytes_read: 0,
113            bytes_to_read: 0,
114        })
115    }
116}
117impl<T: DeserializeOwned, E: Endian> TypedReceiverBuilder<T, TcpStream, E> {
118    /// Listen for a sender, binding the listener to the specified address.
119    pub fn listen_once<A: ToSocketAddrs>(self, address: A) -> std::io::Result<Receiver<T, E, TcpStream>> {
120        let listener = TcpListener::bind(address)?;
121
122        let (stream, _) = listener.accept()?;
123
124        Ok(Receiver {
125            config: E::config(),
126            _marker: PhantomData,
127            reader: stream,
128            max_size: self.max_size,
129            buffer: Vec::new(),
130            bytes_read: 0,
131            bytes_to_read: 0,
132        })
133    }
134}
135
136impl<T: DeserializeOwned, E: Endian, R: Read> ChannelRecv<T> for Receiver<T, E, R> {
137    type Error = RecvError;
138
139    fn recv(&mut self) -> Result<T, RecvError> {
140        if self.bytes_to_read == 0 {
141            let length = self.reader.read_u64::<E>()? as usize;
142            if length > self.max_size {
143                return Err(RecvError::TooLarge(length))
144            }
145
146            if self.buffer.len() < length {
147                self.buffer.extend(std::iter::repeat(0).take(length - self.buffer.len()));
148            }
149
150            self.bytes_to_read = length;
151            self.bytes_read = 0;
152        }
153
154        loop {
155            match self.reader.read(&mut self.buffer[self.bytes_read..self.bytes_to_read]) {
156                Ok(size) => {
157                    self.bytes_read += size;
158                    if self.bytes_read >= self.bytes_to_read {
159                        let length = self.bytes_to_read;
160                        self.bytes_to_read = 0;
161                        return Ok(self.config.deserialize(&self.buffer[0..length])?)
162                    }
163                },
164                Err(error) => return Err(error.into()),
165            }
166        }
167
168    }
169}