umbral_socket/stream/
server.rs

1use std::collections::HashMap;
2use std::io::Result;
3use std::path::Path;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use futures::future::BoxFuture;
8use tokio::io::AsyncReadExt;
9use tokio::io::AsyncWriteExt;
10use tokio::net::UnixListener;
11use tokio::net::UnixStream;
12
13type Handler<S> = Arc<dyn Fn(Arc<S>, Bytes) -> BoxFuture<'static, Result<Bytes>> + Send + Sync>;
14
15pub struct UmbralServer<S> {
16    state: Arc<S>,
17    handlers: HashMap<String, Handler<S>>,
18}
19
20impl<S: Send + Sync + 'static> UmbralServer<S> {
21    pub fn new(state: S) -> Self {
22        Self {
23            state: Arc::new(state),
24            handlers: HashMap::new(),
25        }
26    }
27
28    pub fn route<F, Fut>(mut self, method: &str, handler: F) -> Self
29    where
30        F: Fn(Arc<S>, Bytes) -> Fut + Send + Sync + 'static,
31        Fut: futures::Future<Output = Result<Bytes>> + Send + 'static,
32    {
33        let handler_arc: Handler<S> =
34            Arc::new(move |state, payload| Box::pin(handler(state, payload)));
35        self.handlers.insert(method.to_string(), handler_arc);
36        self
37    }
38
39    pub async fn run(self, socket: &str) -> Result<()> {
40        let path = Path::new(socket);
41        if path.exists() {
42            tokio::fs::remove_file(path).await?;
43        }
44        let listener = UnixListener::bind(path)?;
45        let server_arc = Arc::new(self);
46        println!("Umbral Server listening on \"{}\"", socket);
47        loop {
48            let (stream, _) = listener.accept().await?;
49            let server_clone = server_arc.clone();
50            tokio::spawn(async move {
51                if let Err(e) = server_clone.handle_connection(stream).await {
52                    eprintln!("Error processing connection: {}", e);
53                }
54            });
55        }
56    }
57
58    async fn handle_connection(&self, mut stream: UnixStream) -> Result<()> {
59        loop {
60            let mut buffer = [0; 1024];
61            let n = match stream.read(&mut buffer).await {
62                Ok(0) => return Ok(()),
63                Ok(n) => n,
64                Err(e) => return Err(e),
65            };
66            let message = String::from_utf8_lossy(&buffer[..n]);
67
68            let response = if let Some((method, payload)) = message.trim().split_once("[%]") {
69                if let Some(handler) = self.handlers.get(method) {
70                    let state_clone = self.state.clone();
71                    let payload_bytes = Bytes::from(payload.as_bytes().to_vec());
72                    handler(state_clone, payload_bytes).await
73                } else {
74                    Ok(Bytes::from_static(b"METHOD NOT FOUND"))
75                }
76            } else {
77                Ok(Bytes::from_static(b"INVALID PROTOCOL"))
78            };
79
80            match response {
81                Ok(response_bytes) => {
82                    let len = response_bytes.len() as u32;
83                    stream.write_all(&len.to_be_bytes()).await?;
84                    stream.write_all(&response_bytes).await?;
85                }
86                Err(e) => {
87                    let err_msg = Bytes::from(format!("Handler error: {}", e));
88                    let len = err_msg.len() as u32;
89                    stream.write_all(&len.to_be_bytes()).await?;
90                    stream.write_all(&err_msg).await?;
91                }
92            }
93        }
94    }
95}