1#![doc = include_str!("../readme.md")]
2mod util;
3
4use std::{
5 future::Future,
6 io::{self},
7 pin::Pin,
8 sync::Arc,
9};
10
11use serde::{Deserialize, Serialize};
12#[cfg(not(feature = "async-tcp"))]
13pub use std::net::TcpListener;
14#[cfg(not(feature = "async-tcp"))]
15use std::{
16 io::{Read as _, Write as _},
17 net::TcpStream,
18};
19pub use tcp_struct_macros::{register_impl, TCPShare};
20#[cfg(feature = "async-tcp")]
21pub use tokio::net::TcpListener;
22use tokio::sync::{Mutex, Notify};
23#[cfg(feature = "async-tcp")]
24use tokio::{
25 io::{AsyncReadExt as _, AsyncWriteExt as _},
26 net::TcpStream,
27};
28use util::{take_status_code, take_str};
29
30#[derive(thiserror::Error, Debug, Deserialize, Serialize)]
31pub enum Error {
32 #[error("Buffer too short")]
33 BufferTooShort,
34 #[error("unknown function")]
35 FunctionNotFound,
36 #[error("failed to convert bytes to string")]
37 Utf8Error,
38 #[error("?")]
39 StreamError(StreamError),
40 #[error("todo: remove later")]
41 Custom(String),
42 #[error("does not match struct")]
43 ApiMisMatch(String),
44}
45
46#[derive(Deserialize, Serialize, Debug)]
47pub struct StreamError {
48 pub code: Option<i32>,
49 pub kind: String,
50}
51
52impl From<std::io::Error> for Error {
53 fn from(err: std::io::Error) -> Self {
54 let kind = err.kind().to_string();
55 let code = err.raw_os_error();
56 Error::StreamError(StreamError { code, kind })
57 }
58}
59
60pub type Result<T> = std::result::Result<T, Error>;
61
62pub fn encode<T: Serialize>(data: T) -> Result<Vec<u8>> {
63 Ok(bincode::serialize(&data).unwrap())
64}
65
66pub fn decode<'a, T>(v: &'a [u8]) -> Result<T>
67where
68 T: serde::de::Deserialize<'a>,
69{
70 Ok(bincode::deserialize(v).unwrap())
71}
72
73#[cfg(feature = "async-tcp")]
74pub async fn send_data(
75 port: u16,
76 magic_header: &str,
77 func: &str,
78 data: Vec<u8>,
79) -> Result<Vec<u8>> {
80 let mut stream = TcpStream::connect(("127.0.0.1", port)).await?;
81 let header = magic_header.as_bytes();
82 let func = func.as_bytes();
83 let mut buffer = vec![];
84 buffer.extend((header.len() as u32).to_ne_bytes());
85 buffer.extend(header);
86 buffer.extend((func.len() as u32).to_ne_bytes());
87 buffer.extend(func);
88 buffer.extend(data);
89 let length = buffer.len() as u32;
90 let mut response = vec![];
91 stream.write_all(&length.to_be_bytes()).await?;
92 stream.write_all(&buffer).await?;
93 stream.read_to_end(&mut response).await?;
94 let mut response: &[u8] = &response;
95 let status = take_status_code(&mut response)?;
96 if status == 0 {
97 Ok(response.to_vec())
98 } else {
99 let err: Result<Error> = decode(response);
100 match err {
101 Ok(err) => Err(err),
102 Err(err) => Err(err),
103 }
104 }
105}
106
107#[cfg(not(feature = "async-tcp"))]
108pub fn send_data(port: u16, magic_header: &str, func: &str, data: Vec<u8>) -> Result<Vec<u8>> {
109 let mut stream = TcpStream::connect(("127.0.0.1", port))?;
110 let header = magic_header.as_bytes();
111 let func = func.as_bytes();
112 let mut buffer = vec![];
113 buffer.extend((header.len() as u32).to_ne_bytes());
114 buffer.extend(header);
115 buffer.extend((func.len() as u32).to_ne_bytes());
116 buffer.extend(func);
117 buffer.extend(data);
118 let length = buffer.len() as u32;
119 let mut response = vec![];
120
121 stream.write_all(&length.to_be_bytes())?;
122 stream.write_all(&buffer)?;
123 stream.read_to_end(&mut response)?;
124 let mut response: &[u8] = &response;
125 let status = take_status_code(&mut response)?;
126 if status == 0 {
127 Ok(response.to_vec())
128 } else {
129 let err: Result<Error> = decode(response);
130 match err {
131 Ok(err) => Err(err),
132 Err(err) => Err(err),
133 }
134 }
135}
136
137#[cfg(feature = "async-tcp")]
138async fn receive_data(
139 stream: &mut TcpStream,
140 magic_header_server: &str,
141) -> Result<(String, Vec<u8>)> {
142 let mut length_bytes = [0; 4];
143 stream.read_exact(&mut length_bytes).await?;
144
145 let length = u32::from_be_bytes(length_bytes) as usize;
146 let mut buffer = vec![0; length];
147 #[cfg(feature = "async-tcp")]
148 stream.read_exact(&mut buffer).await?;
149
150 let mut buffer: &[u8] = &buffer;
151 let magic_header_client = take_str(&mut buffer)?;
152 let fn_name = take_str(&mut buffer)?;
153 if magic_header_client != magic_header_server && fn_name.as_str() != "stop" {
154 return Err(Error::ApiMisMatch(format!(
155 "failed to match magic header, expected: {}, got: {}",
156 magic_header_server, magic_header_client
157 )));
158 }
159 Ok((fn_name, buffer.to_vec()))
160}
161
162#[cfg(not(feature = "async-tcp"))]
163fn receive_data(stream: &mut TcpStream, magic_header_server: &str) -> Result<(String, Vec<u8>)> {
164 let mut length_bytes = [0; 4];
165 stream.read_exact(&mut length_bytes)?;
166
167 let length = u32::from_be_bytes(length_bytes) as usize;
168 let mut buffer = vec![0; length];
169 stream.read_exact(&mut buffer)?;
170
171 let mut buffer: &[u8] = &buffer;
172 let magic_header_client = take_str(&mut buffer)?;
173 if magic_header_client != magic_header_server {
174 return Err(Error::ApiMisMatch(format!(
175 "failed to match magic header, expected: {}, got: {}",
176 magic_header_server, magic_header_client
177 )));
178 }
179 let fn_name = take_str(&mut buffer)?;
180 Ok((fn_name, buffer.to_vec()))
181}
182
183type AsyncFuture<T> =
184 fn(String, Vec<u8>, Arc<Mutex<T>>) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>;
185async fn handle_client<T>(
186 mut stream: TcpStream,
187 magic_header: &str,
188 future: AsyncFuture<T>,
189 app_data: Arc<Mutex<T>>,
190) {
191 #[cfg(feature = "async-tcp")]
192 let data = receive_data(&mut stream, magic_header).await;
193 #[cfg(not(feature = "async-tcp"))]
194 let data = receive_data(&mut stream, magic_header);
195 let (func, data) = match data {
196 Ok(v) => v,
197 Err(err) => {
198 let mut response_buffer = Vec::new();
199 response_buffer.extend_from_slice(&[0, 0, 0, 1]);
200 if let Ok(err) = encode(&err) {
201 response_buffer.extend_from_slice(&err);
202 }
203 #[cfg(feature = "async-tcp")]
204 let _ = stream.write_all(&response_buffer).await;
205 #[cfg(not(feature = "async-tcp"))]
206 let _ = stream.write_all(&response_buffer);
207 return;
208 }
209 };
210
211 let response = future(func, data, app_data).await;
212
213 let mut response_buffer = Vec::new();
214 match response {
215 Ok(data) => {
216 response_buffer.extend_from_slice(&[0, 0, 0, 0]);
217 response_buffer.extend_from_slice(&data);
218 }
219 Err(err) => {
220 response_buffer.extend_from_slice(&[0, 0, 0, 1]);
221 if let Ok(err) = encode(&err) {
222 response_buffer.extend_from_slice(&err);
223 }
224 }
225 }
226 #[cfg(feature = "async-tcp")]
227 let _ = stream.write_all(&response_buffer).await;
228 #[cfg(not(feature = "async-tcp"))]
229 let _ = stream.write_all(&response_buffer);
230}
231
232pub trait Receiver<T: Send + 'static> {
233 fn request(
234 func: String,
235 data: Vec<u8>,
236 app_data: Arc<Mutex<T>>,
237 ) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>;
238
239 fn get_app_data(&self) -> Arc<Mutex<T>>;
240
241 #[allow(async_fn_in_trait)]
242 async fn start_from_listener(
243 &self,
244 listener: TcpListener,
245 magic_header: &str,
246 ) -> io::Result<()> {
247 let thread_count = 8;
248
249 let futures = Arc::new(Mutex::new(vec![]));
250 let notify = Arc::new(Notify::new());
251 for _ in 0..thread_count {
252 let futures = futures.clone();
253 let app_data = self.get_app_data();
254 let notify = notify.clone();
255 let magic_header = magic_header.to_owned();
256 tokio::spawn(async move {
257 loop {
258 notify.notified().await;
259 let item = futures.lock().await.pop();
260 if let Some(stream) = item {
261 handle_client(stream, &magic_header, Self::request, app_data.clone()).await;
262 }
263 }
264 });
265 }
266
267 loop {
268 #[cfg(feature = "async-tcp")]
269 if let Ok((stream, _)) = listener.accept().await {
270 futures.lock().await.push(stream);
271 notify.notify_one();
272 }
273 #[cfg(not(feature = "async-tcp"))]
274 if let Ok((stream, _)) = listener.accept() {
275 futures.lock().await.push(stream);
276 notify.notify_one();
277 }
278 }
279 }
280
281 #[allow(async_fn_in_trait)]
282 async fn start(&self, port: u16, magic_header: &str) -> io::Result<()> {
283 let listener = create_listener(port).await?;
284 self.start_from_listener(listener, magic_header).await
285 }
286}
287
288async fn create_listener(port: u16) -> io::Result<TcpListener> {
289 #[cfg(feature = "async-tcp")]
290 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).await?;
291 #[cfg(not(feature = "async-tcp"))]
292 let listener = TcpListener::bind(format!("127.0.0.1:{}", port))?;
293 Ok(listener)
294}
295
296pub trait Starter {
297 #[allow(async_fn_in_trait)]
298 async fn start(self, port: u16, header: &str) -> std::io::Result<()>;
299 #[allow(async_fn_in_trait)]
300 async fn start_from_listener(self, listener: TcpListener, header: &str) -> std::io::Result<()>;
301 #[allow(async_fn_in_trait)]
302 async fn start_gen<T: Starter>(
303 port: u16,
304 magic_header: &str,
305 gen: impl FnOnce() -> T,
306 ) -> std::io::Result<()> {
307 let listener = create_listener(port).await?;
308 let app_data = gen();
309 app_data.start_from_listener(listener, magic_header).await
310 }
311}