rust_nostr_server/
web.rs

1use crate::IncomingMessage;
2use crate::RateLimiter;
3use futures_util::stream::SplitSink;
4use futures_util::stream::StreamExt;
5use futures_util::SinkExt;
6use log::debug;
7use nostr::message::MessageHandleError;
8use nostr::types::url::form_urlencoded::Serializer;
9use nostr::ClientMessage;
10use nostr_database::DatabaseError;
11use std::fmt;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio_tungstenite::tungstenite::http::status;
15use tokio_tungstenite::WebSocketStream;
16//use std::os::macos::raw;
17use crate::HandlerResult;
18use tokio::net::{TcpListener, TcpStream};
19use tokio_tungstenite::accept_async;
20use tokio_tungstenite::tungstenite::protocol::Message;
21
22const CONNECTED: &'static str = "New WebSocket connection";
23const CLOSE: &'static str = "Received close message";
24
25#[derive(Clone, Debug)]
26pub struct WebServer {
27    addr: SocketAddr,
28    handler: IncomingMessage,
29    limiter: RateLimiter,
30}
31
32impl fmt::Display for WebServer {
33    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34        write!(f, "WebServer at {}, handler: {:?}", self.addr, self.handler)
35    }
36}
37
38impl WebServer {
39    pub async fn new(port: u16) -> Self {
40        let address = Self::create_address(port);
41        let message_handler = Self::new_handler().await;
42        let rate_limiter = Self::rate_limiter(100, Duration::from_secs(1));
43
44        debug!("WebServer created at {}", address);
45        WebServer {
46            addr: address,
47            handler: message_handler,
48            limiter: rate_limiter,
49        }
50    }
51
52    fn rate_limiter(max: usize, period: Duration) -> RateLimiter {
53        RateLimiter::new(max, period)
54    }
55
56    async fn new_handler() -> IncomingMessage {
57        IncomingMessage::new().await.unwrap_or_else(|e| {
58            eprintln!("Failed to create message handler: {}", e);
59            std::process::exit(1);
60        })
61    }
62
63    fn create_address(port: u16) -> SocketAddr {
64        format!("127.0.0.1:{}", port)
65            .parse::<SocketAddr>()
66            .unwrap_or_else(|e| {
67                eprintln!("Failed to parse address: {}", e);
68                std::process::exit(1);
69            })
70    }
71
72    pub async fn start_listening(&self) -> Result<TcpListener, std::io::Error> {
73        let listener: TcpListener = TcpListener::bind(&self.addr).await?;
74        println!("WebSocket server running at {}", self.addr);
75        Ok(listener)
76    }
77
78    pub async fn accept_connection(&self, listener: TcpListener) {
79        while let Ok((stream, _)) = listener.accept().await {
80            let this: WebServer = self.clone();
81            tokio::spawn(async move {
82                this.handle_connection(stream).await;
83            });
84        }
85    }
86
87    pub async fn run(&self) {
88        match self.start_listening().await {
89            Ok(l) => {
90                self.accept_connection(l).await;
91            }
92            Err(e) => {
93                log::error!("Failed to start listening: {}", e);
94            }
95        };
96    }
97
98    pub async fn echo_message(
99        &self,
100        write: &mut futures_util::stream::SplitSink<
101            tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
102            tokio_tungstenite::tungstenite::protocol::Message,
103        >,
104        message: &Message,
105    ) {
106        if let Err(e) = write.send(message.clone()).await {
107            log::error!("Failed to echo message: {}", e);
108        }
109    }
110
111    async fn close_connection(
112        &self,
113        write: &mut futures_util::stream::SplitSink<
114            tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
115            tokio_tungstenite::tungstenite::protocol::Message,
116        >,
117    ) {
118        let close_message = tokio_tungstenite::tungstenite::protocol::Message::Close(Some(
119            tokio_tungstenite::tungstenite::protocol::CloseFrame {
120                code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
121                reason: "".into(),
122            },
123        ));
124
125        let _ = write.send(close_message).await;
126
127        if let Err(e) = write.close().await {
128            log::error!("Failed to close WebSocket stream: {}", e);
129        }
130    }
131}
132
133impl WebServer {
134    async fn handle_connection(&self, stream: TcpStream) {
135        let mut conn = Conn::new(self).await;
136        conn.handle(stream).await;
137    }
138}
139
140#[derive(Debug, Clone)]
141struct Conn<'a> {
142    server: &'a WebServer,
143    is_verified: bool,
144}
145
146impl<'a> Conn<'a> {
147    async fn new(server: &'a WebServer) -> Self {
148        Conn {
149            server,
150            is_verified: false,
151        }
152    }
153    fn verify(&mut self) {
154        self.is_verified = true;
155    }
156
157    fn is_verified(&self) -> bool {
158        self.is_verified
159    }
160    async fn handle(&mut self, stream: TcpStream) {
161        let ws_stream: tokio_tungstenite::WebSocketStream<TcpStream> =
162            match accept_async(stream).await {
163                Ok(ws) => ws,
164                Err(e) => {
165                    log::error!("WebSocket handler failed: {}", e);
166                    return;
167                }
168            };
169        let limiter = &self.server.limiter;
170        log::debug!("{}", CONNECTED);
171        let (mut write, mut read) = ws_stream.split();
172        while let Some(message) = read.next().await {
173            if limiter.acquire().await.is_err() {
174                log::error!("Rate limit exceeded");
175                return;
176            }
177            match message {
178                Ok(msg) => match msg {
179                    Message::Text(txt) => {
180                        let certified = self.is_verified();
181                        let m = self.server.handler.to_client_message(&txt).await;
182                        let m: ClientMessage = match m {
183                            Ok(message) => message,
184                            Err(err) => {
185                                log::error!("Failed to parse message: {}", err);
186                                panic!("Failed to parse message");
187                            }
188                        };
189
190                        let results = self.server.handler.handlers(m, certified).await;
191                        let results: HandlerResult = match results {
192                            Ok(result) => result,
193                            Err(err) => {
194                                log::error!("Failed to handle message: {}", err);
195                                panic!("Failed to handle message");
196                            }
197                        };
198
199                        self.handle_result(results, &mut write).await;
200                    }
201                    Message::Binary(bin) => {
202                        println!("Received binary: {:?}", bin);
203                    }
204
205                    Message::Close(Some(close_frame)) => {
206                        log::debug!("Received close frame: {:?}", close_frame);
207                        self.server.close_connection(&mut write).await;
208                        break;
209                    }
210
211                    Message::Close(None) => {
212                        log::debug!("{}", CLOSE);
213                        self.server.close_connection(&mut write).await;
214                        break;
215                    }
216                    _ => {}
217                },
218                Err(e) => {
219                    log::error!("WebSocket handler failed: {}", e);
220                    return;
221                }
222            }
223        }
224    }
225    async fn handle_result(
226        &mut self,
227        result: HandlerResult,
228        mut write: &mut SplitSink<WebSocketStream<TcpStream>, Message>,
229    ) {
230        match result {
231            HandlerResult::String(msg) => {
232                let message: Message = Message::Text(msg);
233                self.server.echo_message(&mut write, &message).await;
234                //self.close_connection(&mut write).await;
235            }
236            HandlerResult::Strings(msgs) => {
237                for msg in msgs {
238                    let message: Message = Message::Text(msg);
239                    self.server.echo_message(&mut write, &message).await;
240                    //self.close_connection(&mut write).await;
241                }
242            }
243            HandlerResult::Close(do_close) => {
244                let message: Message = Message::Text(do_close.get_data().await.to_string());
245                self.server.echo_message(&mut write, &message).await;
246                self.server.close_connection(&mut write).await;
247            }
248            HandlerResult::Auth(do_auth, status) => {
249                let message: Message = Message::Text(do_auth.get_data().await.to_string());
250                self.server.echo_message(&mut write, &message).await;
251                if status {
252                    self.verify();
253                }
254                //self.close_connection(&mut write).await;
255            }
256            HandlerResult::Event(do_event) => {
257                let message: Message = Message::Text(do_event.get_data().await.to_string());
258                self.server.echo_message(&mut write, &message).await;
259                //self.close_connection(&mut write).await;
260            }
261            HandlerResult::Req(do_req) => {
262                let msgs: &Vec<String> = do_req.get_data().await;
263                for msg in msgs {
264                    let message: Message = Message::Text(msg.to_string());
265                    self.server.echo_message(&mut write, &message).await;
266                    //self.close_connection(&mut write).await;
267                }
268            }
269            HandlerResult::Count(do_count) => {
270                let message: Message = Message::Text(do_count.get_data().await.to_string());
271                self.server.echo_message(&mut write, &message).await;
272                //self.close_connection(&mut write).await;
273            }
274        }
275    }
276}