simple_socket/
client.rs

1use std::io::{self, Read, Write};
2use std::marker::PhantomData;
3use std::net::{Shutdown, SocketAddr};
4use std::sync::{Mutex, MutexGuard};
5
6use bincode::Result;
7use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
8use serde::{de::DeserializeOwned, Serialize};
9use socket2::{Domain, Socket, Type};
10
11pub struct SocketClient<Req, Res>
12where
13    Req: Serialize,
14    Res: DeserializeOwned,
15{
16    buffer: Mutex<Buffer>,
17    stream: Socket,
18
19    _request: PhantomData<Req>,
20    _response: PhantomData<Res>,
21}
22
23impl<Req, Res> SocketClient<Req, Res>
24where
25    Req: Serialize,
26    Res: DeserializeOwned,
27{
28    pub fn try_new(addr: SocketAddr) -> io::Result<Self> {
29        let domain = match addr {
30            SocketAddr::V4(_) => Domain::ipv4(),
31            SocketAddr::V6(_) => Domain::ipv6(),
32        };
33
34        let socket = Socket::new(domain, Type::stream(), None)?;
35        socket.connect(&addr.into())?;
36
37        Ok(Self {
38            buffer: Mutex::default(),
39            stream: socket,
40
41            _request: PhantomData::default(),
42            _response: PhantomData::default(),
43        })
44    }
45
46    pub(crate) fn try_from_stream(stream: Socket) -> io::Result<Self> {
47        stream.set_nonblocking(true)?;
48
49        Ok(Self {
50            buffer: Mutex::default(),
51            stream,
52
53            _request: PhantomData::default(),
54            _response: PhantomData::default(),
55        })
56    }
57}
58
59impl<Req, Res> SocketClient<Req, Res>
60where
61    Req: Serialize,
62    Res: DeserializeOwned,
63{
64    pub fn request(&self, request: &Req) -> Result<Res> {
65        let mut buffer = self.buffer.lock().unwrap();
66
67        let stream = &mut &self.stream;
68
69        buffer.data.clear();
70        bincode::serialize_into(&mut buffer.data, request)?;
71
72        let size = buffer.data.len() as u64;
73        assert_ne!(size, 0, "Message must have one or more bytes.");
74
75        stream.write_u64::<NetworkEndian>(size)?;
76        stream.write_all(&buffer.data)?;
77
78        bincode::deserialize_from(stream)
79    }
80
81    pub(crate) fn response<F>(&self, handler: F) -> Result<SocketStatus>
82    where
83        F: FnMut(Res) -> Req,
84    {
85        let mut buffer = self.buffer.lock().unwrap();
86
87        let stream = &mut &self.stream;
88
89        if buffer.size.is_some() {
90            fill_buffer_and_handle(buffer, stream, handler)
91        } else {
92            let mut buf = [0; 8];
93            match stream.peek(&mut buf) {
94                Ok(8) => {
95                    stream.read_exact(&mut buf)?;
96
97                    let size = buf.as_ref().read_u64::<NetworkEndian>()? as usize;
98                    if size == 0 {
99                        return Ok(SocketStatus::Closed);
100                    }
101
102                    buffer.offset = 0;
103                    buffer.size = Some(size);
104                    buffer.data.resize(size, 0);
105                    fill_buffer_and_handle(buffer, stream, handler)
106                }
107                Ok(_) => Ok(SocketStatus::Alive),
108                Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(SocketStatus::Alive),
109                Err(error) => Err(error.into()),
110            }
111        }
112    }
113
114    fn stop(&self) -> io::Result<()> {
115        let stream = &mut &self.stream;
116
117        stream.write_u64::<NetworkEndian>(0)?;
118        stream.shutdown(Shutdown::Read)?;
119        Ok(())
120    }
121}
122
123fn fill_buffer_and_handle<Res, Req, F>(
124    mut buffer: MutexGuard<Buffer>,
125    stream: &mut &Socket,
126    mut handler: F,
127) -> Result<SocketStatus>
128where
129    Req: Serialize,
130    Res: DeserializeOwned,
131    F: FnMut(Res) -> Req,
132{
133    let offset = buffer.offset;
134
135    buffer.offset += match stream.read(&mut buffer.data[offset..]) {
136        Ok(size) => size,
137        Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(SocketStatus::Alive),
138        Err(error) => return Err(error.into()),
139    };
140    if buffer.offset >= buffer.size.unwrap() {
141        let request = bincode::deserialize_from(&buffer.data[..])?;
142        buffer.size = None;
143
144        bincode::serialize_into(stream, &handler(request))?;
145        Ok(SocketStatus::Alive)
146    } else {
147        Ok(SocketStatus::Alive)
148    }
149}
150
151impl<Req, Res> Drop for SocketClient<Req, Res>
152where
153    Req: Serialize,
154    Res: DeserializeOwned,
155{
156    fn drop(&mut self) {
157        match self.stop() {
158            Ok(()) => (),
159            Err(ref e) if e.kind() == io::ErrorKind::NotConnected => (),
160            Err(error) => panic!(error),
161        }
162    }
163}
164
165pub enum SocketStatus {
166    Alive,
167    Closed,
168}
169
170#[derive(Default)]
171struct Buffer {
172    data: Vec<u8>,
173    offset: usize,
174    size: Option<usize>,
175}