tcp_struct/
lib.rs

1#![doc = include_str!("../readme.md")]
2mod util;
3
4use std::{
5    future::Future,
6    io::{self},
7    pin::Pin,
8    sync::Arc,
9};
10
11use serde::{Deserialize, Serialize};
12#[cfg(not(feature = "async-tcp"))]
13pub use std::net::TcpListener;
14#[cfg(not(feature = "async-tcp"))]
15use std::{
16    io::{Read as _, Write as _},
17    net::TcpStream,
18};
19pub use tcp_struct_macros::{register_impl, TCPShare};
20#[cfg(feature = "async-tcp")]
21pub use tokio::net::TcpListener;
22use tokio::sync::{Mutex, Notify};
23#[cfg(feature = "async-tcp")]
24use tokio::{
25    io::{AsyncReadExt as _, AsyncWriteExt as _},
26    net::TcpStream,
27};
28use util::{take_status_code, take_str};
29
30#[derive(thiserror::Error, Debug, Deserialize, Serialize)]
31pub enum Error {
32    #[error("Buffer too short")]
33    BufferTooShort,
34    #[error("unknown function")]
35    FunctionNotFound,
36    #[error("failed to convert bytes to string")]
37    Utf8Error,
38    #[error("?")]
39    StreamError(StreamError),
40    #[error("todo: remove later")]
41    Custom(String),
42    #[error("does not match struct")]
43    ApiMisMatch(String),
44}
45
46#[derive(Deserialize, Serialize, Debug)]
47pub struct StreamError {
48    pub code: Option<i32>,
49    pub kind: String,
50}
51
52impl From<std::io::Error> for Error {
53    fn from(err: std::io::Error) -> Self {
54        let kind = err.kind().to_string();
55        let code = err.raw_os_error();
56        Error::StreamError(StreamError { code, kind })
57    }
58}
59
60pub type Result<T> = std::result::Result<T, Error>;
61
62pub fn encode<T: Serialize>(data: T) -> Result<Vec<u8>> {
63    Ok(bincode::serialize(&data).unwrap())
64}
65
66pub fn decode<'a, T>(v: &'a [u8]) -> Result<T>
67where
68    T: serde::de::Deserialize<'a>,
69{
70    Ok(bincode::deserialize(v).unwrap())
71}
72
73#[cfg(feature = "async-tcp")]
74pub async fn send_data(
75    port: u16,
76    magic_header: &str,
77    func: &str,
78    data: Vec<u8>,
79) -> Result<Vec<u8>> {
80    let mut stream = TcpStream::connect(("127.0.0.1", port)).await?;
81    let header = magic_header.as_bytes();
82    let func = func.as_bytes();
83    let mut buffer = vec![];
84    buffer.extend((header.len() as u32).to_ne_bytes());
85    buffer.extend(header);
86    buffer.extend((func.len() as u32).to_ne_bytes());
87    buffer.extend(func);
88    buffer.extend(data);
89    let length = buffer.len() as u32;
90    let mut response = vec![];
91    stream.write_all(&length.to_be_bytes()).await?;
92    stream.write_all(&buffer).await?;
93    stream.read_to_end(&mut response).await?;
94    let mut response: &[u8] = &response;
95    let status = take_status_code(&mut response)?;
96    if status == 0 {
97        Ok(response.to_vec())
98    } else {
99        let err: Result<Error> = decode(response);
100        match err {
101            Ok(err) => Err(err),
102            Err(err) => Err(err),
103        }
104    }
105}
106
107#[cfg(not(feature = "async-tcp"))]
108pub fn send_data(port: u16, magic_header: &str, func: &str, data: Vec<u8>) -> Result<Vec<u8>> {
109    let mut stream = TcpStream::connect(("127.0.0.1", port))?;
110    let header = magic_header.as_bytes();
111    let func = func.as_bytes();
112    let mut buffer = vec![];
113    buffer.extend((header.len() as u32).to_ne_bytes());
114    buffer.extend(header);
115    buffer.extend((func.len() as u32).to_ne_bytes());
116    buffer.extend(func);
117    buffer.extend(data);
118    let length = buffer.len() as u32;
119    let mut response = vec![];
120
121    stream.write_all(&length.to_be_bytes())?;
122    stream.write_all(&buffer)?;
123    stream.read_to_end(&mut response)?;
124    let mut response: &[u8] = &response;
125    let status = take_status_code(&mut response)?;
126    if status == 0 {
127        Ok(response.to_vec())
128    } else {
129        let err: Result<Error> = decode(response);
130        match err {
131            Ok(err) => Err(err),
132            Err(err) => Err(err),
133        }
134    }
135}
136
137#[cfg(feature = "async-tcp")]
138async fn receive_data(
139    stream: &mut TcpStream,
140    magic_header_server: &str,
141) -> Result<(String, Vec<u8>)> {
142    let mut length_bytes = [0; 4];
143    stream.read_exact(&mut length_bytes).await?;
144
145    let length = u32::from_be_bytes(length_bytes) as usize;
146    let mut buffer = vec![0; length];
147    #[cfg(feature = "async-tcp")]
148    stream.read_exact(&mut buffer).await?;
149
150    let mut buffer: &[u8] = &buffer;
151    let magic_header_client = take_str(&mut buffer)?;
152    let fn_name = take_str(&mut buffer)?;
153    if magic_header_client != magic_header_server && fn_name.as_str() != "stop" {
154        return Err(Error::ApiMisMatch(format!(
155            "failed to match magic header, expected: {}, got: {}",
156            magic_header_server, magic_header_client
157        )));
158    }
159    Ok((fn_name, buffer.to_vec()))
160}
161
162#[cfg(not(feature = "async-tcp"))]
163fn receive_data(stream: &mut TcpStream, magic_header_server: &str) -> Result<(String, Vec<u8>)> {
164    let mut length_bytes = [0; 4];
165    stream.read_exact(&mut length_bytes)?;
166
167    let length = u32::from_be_bytes(length_bytes) as usize;
168    let mut buffer = vec![0; length];
169    stream.read_exact(&mut buffer)?;
170
171    let mut buffer: &[u8] = &buffer;
172    let magic_header_client = take_str(&mut buffer)?;
173    if magic_header_client != magic_header_server {
174        return Err(Error::ApiMisMatch(format!(
175            "failed to match magic header, expected: {}, got: {}",
176            magic_header_server, magic_header_client
177        )));
178    }
179    let fn_name = take_str(&mut buffer)?;
180    Ok((fn_name, buffer.to_vec()))
181}
182
183type AsyncFuture<T> =
184    fn(String, Vec<u8>, Arc<Mutex<T>>) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>;
185async fn handle_client<T>(
186    mut stream: TcpStream,
187    magic_header: &str,
188    future: AsyncFuture<T>,
189    app_data: Arc<Mutex<T>>,
190) {
191    #[cfg(feature = "async-tcp")]
192    let data = receive_data(&mut stream, magic_header).await;
193    #[cfg(not(feature = "async-tcp"))]
194    let data = receive_data(&mut stream, magic_header);
195    let (func, data) = match data {
196        Ok(v) => v,
197        Err(err) => {
198            let mut response_buffer = Vec::new();
199            response_buffer.extend_from_slice(&[0, 0, 0, 1]);
200            if let Ok(err) = encode(&err) {
201                response_buffer.extend_from_slice(&err);
202            }
203            #[cfg(feature = "async-tcp")]
204            let _ = stream.write_all(&response_buffer).await;
205            #[cfg(not(feature = "async-tcp"))]
206            let _ = stream.write_all(&response_buffer);
207            return;
208        }
209    };
210
211    let response = future(func, data, app_data).await;
212
213    let mut response_buffer = Vec::new();
214    match response {
215        Ok(data) => {
216            response_buffer.extend_from_slice(&[0, 0, 0, 0]);
217            response_buffer.extend_from_slice(&data);
218        }
219        Err(err) => {
220            response_buffer.extend_from_slice(&[0, 0, 0, 1]);
221            if let Ok(err) = encode(&err) {
222                response_buffer.extend_from_slice(&err);
223            }
224        }
225    }
226    #[cfg(feature = "async-tcp")]
227    let _ = stream.write_all(&response_buffer).await;
228    #[cfg(not(feature = "async-tcp"))]
229    let _ = stream.write_all(&response_buffer);
230}
231
232pub trait Receiver<T: Send + 'static> {
233    fn request(
234        func: String,
235        data: Vec<u8>,
236        app_data: Arc<Mutex<T>>,
237    ) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>;
238
239    fn get_app_data(&self) -> Arc<Mutex<T>>;
240
241    #[allow(async_fn_in_trait)]
242    async fn start_from_listener(
243        &self,
244        listener: TcpListener,
245        magic_header: &str,
246    ) -> io::Result<()> {
247        let thread_count = 8;
248
249        let futures = Arc::new(Mutex::new(vec![]));
250        let notify = Arc::new(Notify::new());
251        for _ in 0..thread_count {
252            let futures = futures.clone();
253            let app_data = self.get_app_data();
254            let notify = notify.clone();
255            let magic_header = magic_header.to_owned();
256            tokio::spawn(async move {
257                loop {
258                    notify.notified().await;
259                    let item = futures.lock().await.pop();
260                    if let Some(stream) = item {
261                        handle_client(stream, &magic_header, Self::request, app_data.clone()).await;
262                    }
263                }
264            });
265        }
266
267        loop {
268            #[cfg(feature = "async-tcp")]
269            if let Ok((stream, _)) = listener.accept().await {
270                futures.lock().await.push(stream);
271                notify.notify_one();
272            }
273            #[cfg(not(feature = "async-tcp"))]
274            if let Ok((stream, _)) = listener.accept() {
275                futures.lock().await.push(stream);
276                notify.notify_one();
277            }
278        }
279    }
280
281    #[allow(async_fn_in_trait)]
282    async fn start(&self, port: u16, magic_header: &str) -> io::Result<()> {
283        let listener = create_listener(port).await?;
284        self.start_from_listener(listener, magic_header).await
285    }
286}
287
288async fn create_listener(port: u16) -> io::Result<TcpListener> {
289    #[cfg(feature = "async-tcp")]
290    let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).await?;
291    #[cfg(not(feature = "async-tcp"))]
292    let listener = TcpListener::bind(format!("127.0.0.1:{}", port))?;
293    Ok(listener)
294}
295
296pub trait Starter {
297    #[allow(async_fn_in_trait)]
298    async fn start(self, port: u16, header: &str) -> std::io::Result<()>;
299    #[allow(async_fn_in_trait)]
300    async fn start_from_listener(self, listener: TcpListener, header: &str) -> std::io::Result<()>;
301    #[allow(async_fn_in_trait)]
302    async fn start_gen<T: Starter>(
303        port: u16,
304        magic_header: &str,
305        gen: impl FnOnce() -> T,
306    ) -> std::io::Result<()> {
307        let listener = create_listener(port).await?;
308        let app_data = gen();
309        app_data.start_from_listener(listener, magic_header).await
310    }
311}