1use crate::IncomingMessage;
2use crate::RateLimiter;
3use futures_util::stream::SplitSink;
4use futures_util::stream::StreamExt;
5use futures_util::SinkExt;
6use log::debug;
7use nostr::message::MessageHandleError;
8use nostr::types::url::form_urlencoded::Serializer;
9use nostr::ClientMessage;
10use nostr_database::DatabaseError;
11use std::fmt;
12use std::net::SocketAddr;
13use std::time::Duration;
14use tokio_tungstenite::tungstenite::http::status;
15use tokio_tungstenite::WebSocketStream;
16use crate::HandlerResult;
18use tokio::net::{TcpListener, TcpStream};
19use tokio_tungstenite::accept_async;
20use tokio_tungstenite::tungstenite::protocol::Message;
21
22const CONNECTED: &'static str = "New WebSocket connection";
23const CLOSE: &'static str = "Received close message";
24
25#[derive(Clone, Debug)]
26pub struct WebServer {
27 addr: SocketAddr,
28 handler: IncomingMessage,
29 limiter: RateLimiter,
30}
31
32impl fmt::Display for WebServer {
33 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34 write!(f, "WebServer at {}, handler: {:?}", self.addr, self.handler)
35 }
36}
37
38impl WebServer {
39 pub async fn new(port: u16) -> Self {
40 let address = Self::create_address(port);
41 let message_handler = Self::new_handler().await;
42 let rate_limiter = Self::rate_limiter(100, Duration::from_secs(1));
43
44 debug!("WebServer created at {}", address);
45 WebServer {
46 addr: address,
47 handler: message_handler,
48 limiter: rate_limiter,
49 }
50 }
51
52 fn rate_limiter(max: usize, period: Duration) -> RateLimiter {
53 RateLimiter::new(max, period)
54 }
55
56 async fn new_handler() -> IncomingMessage {
57 IncomingMessage::new().await.unwrap_or_else(|e| {
58 eprintln!("Failed to create message handler: {}", e);
59 std::process::exit(1);
60 })
61 }
62
63 fn create_address(port: u16) -> SocketAddr {
64 format!("127.0.0.1:{}", port)
65 .parse::<SocketAddr>()
66 .unwrap_or_else(|e| {
67 eprintln!("Failed to parse address: {}", e);
68 std::process::exit(1);
69 })
70 }
71
72 pub async fn start_listening(&self) -> Result<TcpListener, std::io::Error> {
73 let listener: TcpListener = TcpListener::bind(&self.addr).await?;
74 println!("WebSocket server running at {}", self.addr);
75 Ok(listener)
76 }
77
78 pub async fn accept_connection(&self, listener: TcpListener) {
79 while let Ok((stream, _)) = listener.accept().await {
80 let this: WebServer = self.clone();
81 tokio::spawn(async move {
82 this.handle_connection(stream).await;
83 });
84 }
85 }
86
87 pub async fn run(&self) {
88 match self.start_listening().await {
89 Ok(l) => {
90 self.accept_connection(l).await;
91 }
92 Err(e) => {
93 log::error!("Failed to start listening: {}", e);
94 }
95 };
96 }
97
98 pub async fn echo_message(
99 &self,
100 write: &mut futures_util::stream::SplitSink<
101 tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
102 tokio_tungstenite::tungstenite::protocol::Message,
103 >,
104 message: &Message,
105 ) {
106 if let Err(e) = write.send(message.clone()).await {
107 log::error!("Failed to echo message: {}", e);
108 }
109 }
110
111 async fn close_connection(
112 &self,
113 write: &mut futures_util::stream::SplitSink<
114 tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>,
115 tokio_tungstenite::tungstenite::protocol::Message,
116 >,
117 ) {
118 let close_message = tokio_tungstenite::tungstenite::protocol::Message::Close(Some(
119 tokio_tungstenite::tungstenite::protocol::CloseFrame {
120 code: tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode::Normal,
121 reason: "".into(),
122 },
123 ));
124
125 let _ = write.send(close_message).await;
126
127 if let Err(e) = write.close().await {
128 log::error!("Failed to close WebSocket stream: {}", e);
129 }
130 }
131}
132
133impl WebServer {
134 async fn handle_connection(&self, stream: TcpStream) {
135 let mut conn = Conn::new(self).await;
136 conn.handle(stream).await;
137 }
138}
139
140#[derive(Debug, Clone)]
141struct Conn<'a> {
142 server: &'a WebServer,
143 is_verified: bool,
144}
145
146impl<'a> Conn<'a> {
147 async fn new(server: &'a WebServer) -> Self {
148 Conn {
149 server,
150 is_verified: false,
151 }
152 }
153 fn verify(&mut self) {
154 self.is_verified = true;
155 }
156
157 fn is_verified(&self) -> bool {
158 self.is_verified
159 }
160 async fn handle(&mut self, stream: TcpStream) {
161 let ws_stream: tokio_tungstenite::WebSocketStream<TcpStream> =
162 match accept_async(stream).await {
163 Ok(ws) => ws,
164 Err(e) => {
165 log::error!("WebSocket handler failed: {}", e);
166 return;
167 }
168 };
169 let limiter = &self.server.limiter;
170 log::debug!("{}", CONNECTED);
171 let (mut write, mut read) = ws_stream.split();
172 while let Some(message) = read.next().await {
173 if limiter.acquire().await.is_err() {
174 log::error!("Rate limit exceeded");
175 return;
176 }
177 match message {
178 Ok(msg) => match msg {
179 Message::Text(txt) => {
180 let certified = self.is_verified();
181 let m = self.server.handler.to_client_message(&txt).await;
182 let m: ClientMessage = match m {
183 Ok(message) => message,
184 Err(err) => {
185 log::error!("Failed to parse message: {}", err);
186 panic!("Failed to parse message");
187 }
188 };
189
190 let results = self.server.handler.handlers(m, certified).await;
191 let results: HandlerResult = match results {
192 Ok(result) => result,
193 Err(err) => {
194 log::error!("Failed to handle message: {}", err);
195 panic!("Failed to handle message");
196 }
197 };
198
199 self.handle_result(results, &mut write).await;
200 }
201 Message::Binary(bin) => {
202 println!("Received binary: {:?}", bin);
203 }
204
205 Message::Close(Some(close_frame)) => {
206 log::debug!("Received close frame: {:?}", close_frame);
207 self.server.close_connection(&mut write).await;
208 break;
209 }
210
211 Message::Close(None) => {
212 log::debug!("{}", CLOSE);
213 self.server.close_connection(&mut write).await;
214 break;
215 }
216 _ => {}
217 },
218 Err(e) => {
219 log::error!("WebSocket handler failed: {}", e);
220 return;
221 }
222 }
223 }
224 }
225 async fn handle_result(
226 &mut self,
227 result: HandlerResult,
228 mut write: &mut SplitSink<WebSocketStream<TcpStream>, Message>,
229 ) {
230 match result {
231 HandlerResult::String(msg) => {
232 let message: Message = Message::Text(msg);
233 self.server.echo_message(&mut write, &message).await;
234 }
236 HandlerResult::Strings(msgs) => {
237 for msg in msgs {
238 let message: Message = Message::Text(msg);
239 self.server.echo_message(&mut write, &message).await;
240 }
242 }
243 HandlerResult::Close(do_close) => {
244 let message: Message = Message::Text(do_close.get_data().await.to_string());
245 self.server.echo_message(&mut write, &message).await;
246 self.server.close_connection(&mut write).await;
247 }
248 HandlerResult::Auth(do_auth, status) => {
249 let message: Message = Message::Text(do_auth.get_data().await.to_string());
250 self.server.echo_message(&mut write, &message).await;
251 if status {
252 self.verify();
253 }
254 }
256 HandlerResult::Event(do_event) => {
257 let message: Message = Message::Text(do_event.get_data().await.to_string());
258 self.server.echo_message(&mut write, &message).await;
259 }
261 HandlerResult::Req(do_req) => {
262 let msgs: &Vec<String> = do_req.get_data().await;
263 for msg in msgs {
264 let message: Message = Message::Text(msg.to_string());
265 self.server.echo_message(&mut write, &message).await;
266 }
268 }
269 HandlerResult::Count(do_count) => {
270 let message: Message = Message::Text(do_count.get_data().await.to_string());
271 self.server.echo_message(&mut write, &message).await;
272 }
274 }
275 }
276}