1use std::io::{self, Read, Write};
2use std::marker::PhantomData;
3use std::net::{Shutdown, SocketAddr};
4use std::sync::{Mutex, MutexGuard};
5
6use bincode::Result;
7use byteorder::{NetworkEndian, ReadBytesExt, WriteBytesExt};
8use serde::{de::DeserializeOwned, Serialize};
9use socket2::{Domain, Socket, Type};
10
11pub struct SocketClient<Req, Res>
12where
13 Req: Serialize,
14 Res: DeserializeOwned,
15{
16 buffer: Mutex<Buffer>,
17 stream: Socket,
18
19 _request: PhantomData<Req>,
20 _response: PhantomData<Res>,
21}
22
23impl<Req, Res> SocketClient<Req, Res>
24where
25 Req: Serialize,
26 Res: DeserializeOwned,
27{
28 pub fn try_new(addr: SocketAddr) -> io::Result<Self> {
29 let domain = match addr {
30 SocketAddr::V4(_) => Domain::ipv4(),
31 SocketAddr::V6(_) => Domain::ipv6(),
32 };
33
34 let socket = Socket::new(domain, Type::stream(), None)?;
35 socket.connect(&addr.into())?;
36
37 Ok(Self {
38 buffer: Mutex::default(),
39 stream: socket,
40
41 _request: PhantomData::default(),
42 _response: PhantomData::default(),
43 })
44 }
45
46 pub(crate) fn try_from_stream(stream: Socket) -> io::Result<Self> {
47 stream.set_nonblocking(true)?;
48
49 Ok(Self {
50 buffer: Mutex::default(),
51 stream,
52
53 _request: PhantomData::default(),
54 _response: PhantomData::default(),
55 })
56 }
57}
58
59impl<Req, Res> SocketClient<Req, Res>
60where
61 Req: Serialize,
62 Res: DeserializeOwned,
63{
64 pub fn request(&self, request: &Req) -> Result<Res> {
65 let mut buffer = self.buffer.lock().unwrap();
66
67 let stream = &mut &self.stream;
68
69 buffer.data.clear();
70 bincode::serialize_into(&mut buffer.data, request)?;
71
72 let size = buffer.data.len() as u64;
73 assert_ne!(size, 0, "Message must have one or more bytes.");
74
75 stream.write_u64::<NetworkEndian>(size)?;
76 stream.write_all(&buffer.data)?;
77
78 bincode::deserialize_from(stream)
79 }
80
81 pub(crate) fn response<F>(&self, handler: F) -> Result<SocketStatus>
82 where
83 F: FnMut(Res) -> Req,
84 {
85 let mut buffer = self.buffer.lock().unwrap();
86
87 let stream = &mut &self.stream;
88
89 if buffer.size.is_some() {
90 fill_buffer_and_handle(buffer, stream, handler)
91 } else {
92 let mut buf = [0; 8];
93 match stream.peek(&mut buf) {
94 Ok(8) => {
95 stream.read_exact(&mut buf)?;
96
97 let size = buf.as_ref().read_u64::<NetworkEndian>()? as usize;
98 if size == 0 {
99 return Ok(SocketStatus::Closed);
100 }
101
102 buffer.offset = 0;
103 buffer.size = Some(size);
104 buffer.data.resize(size, 0);
105 fill_buffer_and_handle(buffer, stream, handler)
106 }
107 Ok(_) => Ok(SocketStatus::Alive),
108 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(SocketStatus::Alive),
109 Err(error) => Err(error.into()),
110 }
111 }
112 }
113
114 fn stop(&self) -> io::Result<()> {
115 let stream = &mut &self.stream;
116
117 stream.write_u64::<NetworkEndian>(0)?;
118 stream.shutdown(Shutdown::Read)?;
119 Ok(())
120 }
121}
122
123fn fill_buffer_and_handle<Res, Req, F>(
124 mut buffer: MutexGuard<Buffer>,
125 stream: &mut &Socket,
126 mut handler: F,
127) -> Result<SocketStatus>
128where
129 Req: Serialize,
130 Res: DeserializeOwned,
131 F: FnMut(Res) -> Req,
132{
133 let offset = buffer.offset;
134
135 buffer.offset += match stream.read(&mut buffer.data[offset..]) {
136 Ok(size) => size,
137 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(SocketStatus::Alive),
138 Err(error) => return Err(error.into()),
139 };
140 if buffer.offset >= buffer.size.unwrap() {
141 let request = bincode::deserialize_from(&buffer.data[..])?;
142 buffer.size = None;
143
144 bincode::serialize_into(stream, &handler(request))?;
145 Ok(SocketStatus::Alive)
146 } else {
147 Ok(SocketStatus::Alive)
148 }
149}
150
151impl<Req, Res> Drop for SocketClient<Req, Res>
152where
153 Req: Serialize,
154 Res: DeserializeOwned,
155{
156 fn drop(&mut self) {
157 match self.stop() {
158 Ok(()) => (),
159 Err(ref e) if e.kind() == io::ErrorKind::NotConnected => (),
160 Err(error) => panic!(error),
161 }
162 }
163}
164
165pub enum SocketStatus {
166 Alive,
167 Closed,
168}
169
170#[derive(Default)]
171struct Buffer {
172 data: Vec<u8>,
173 offset: usize,
174 size: Option<usize>,
175}