1use std::collections::HashMap;
2use std::future::Future;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures_util::{FutureExt, SinkExt, StreamExt, future::BoxFuture};
8use serde::de::DeserializeOwned;
9use serde_json::{Value, json};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, mpsc};
12use tokio::time::{MissedTickBehavior, interval, timeout};
13use tokio_tungstenite::{accept_async, tungstenite::Message};
14use uuid::Uuid;
15use validator::Validate;
16use wscall_protocol::{
17 EncryptionKind, ErrorPayload, FileAttachment, FrameCodec, PacketBody, PacketEnvelope,
18};
19
20use crate::server_types::{
21 ApiContext, ApiError, EventContext, ExceptionContext, ServerConnectionContext,
22 ServerDisconnectContext, ServerError, ServerHandle, ServerOutbound, ServerState,
23};
24
25const SERVER_IDLE_TIMEOUT: Duration = Duration::from_secs(45);
26const SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
27const SERVER_OUTBOUND_QUEUE_CAPACITY: usize = 256;
28
29type ApiHandler =
30 Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
31type Filter =
32 Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<ApiContext, ApiError>> + Send + Sync>;
33type EventHandler =
34 Arc<dyn Fn(EventContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
35type ConnectionHandler = Arc<dyn Fn(ServerConnectionContext) -> BoxFuture<'static, ()> + Send + Sync>;
36type DisconnectHandler = Arc<dyn Fn(ServerDisconnectContext) -> BoxFuture<'static, ()> + Send + Sync>;
37type ExceptionHandler =
38 Arc<dyn Fn(ExceptionContext) -> BoxFuture<'static, ErrorPayload> + Send + Sync>;
39
40struct ApiRequestInput {
41 request_id: String,
42 route: String,
43 params: Value,
44 attachments: Vec<FileAttachment>,
45 metadata: Value,
46}
47
48struct EventEmitInput {
49 event_id: String,
50 name: String,
51 data: Value,
52 attachments: Vec<FileAttachment>,
53 metadata: Value,
54}
55
56impl ServerHandle {
57 pub async fn broadcast_event(
58 &self,
59 name: impl Into<String>,
60 data: Value,
61 attachments: Vec<FileAttachment>,
62 ) -> Result<(), ApiError> {
63 let packet = PacketEnvelope::with_encryption(
64 PacketBody::EventEmit {
65 event_id: Uuid::new_v4().to_string(),
66 name: name.into(),
67 data,
68 attachments,
69 metadata: json!({ "source": "server" }),
70 expect_ack: true,
71 },
72 self.default_encryption,
73 );
74
75 let clients = self.state.clients.read().await;
76 let senders = clients.values().cloned().collect::<Vec<_>>();
77 drop(clients);
78
79 for sender in senders {
80 sender
81 .try_send(ServerOutbound::Packet(packet.clone()))
82 .map_err(|_| ApiError::internal("failed to queue broadcast event"))?;
83 }
84 Ok(())
85 }
86
87 pub async fn send_event_to(
88 &self,
89 connection_id: &str,
90 name: impl Into<String>,
91 data: Value,
92 attachments: Vec<FileAttachment>,
93 ) -> Result<(), ApiError> {
94 let packet = PacketEnvelope::with_encryption(
95 PacketBody::EventEmit {
96 event_id: Uuid::new_v4().to_string(),
97 name: name.into(),
98 data,
99 attachments,
100 metadata: json!({ "source": "server" }),
101 expect_ack: true,
102 },
103 self.default_encryption,
104 );
105
106 let clients = self.state.clients.read().await;
107 let sender = clients
108 .get(connection_id)
109 .cloned()
110 .ok_or_else(|| ApiError::not_found("target connection not found"))?;
111 drop(clients);
112 sender
113 .try_send(ServerOutbound::Packet(packet))
114 .map_err(|_| ApiError::internal("failed to queue direct event"))
115 }
116
117 pub async fn connection_count(&self) -> usize {
118 self.state.clients.read().await.len()
119 }
120}
121
122pub struct WscallServer {
123 state: Arc<ServerState>,
124 routes: HashMap<String, ApiHandler>,
125 filters: Vec<Filter>,
126 event_handlers: HashMap<String, EventHandler>,
127 connection_handlers: Vec<ConnectionHandler>,
128 disconnect_handlers: Vec<DisconnectHandler>,
129 exception_handler: Option<ExceptionHandler>,
130 codec: FrameCodec,
131 default_encryption: EncryptionKind,
132}
133
134impl Default for WscallServer {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140impl WscallServer {
141 pub fn new() -> Self {
142 Self {
143 state: Arc::new(ServerState {
144 clients: RwLock::new(HashMap::new()),
145 }),
146 routes: HashMap::new(),
147 filters: Vec::new(),
148 event_handlers: HashMap::new(),
149 connection_handlers: Vec::new(),
150 disconnect_handlers: Vec::new(),
151 exception_handler: None,
152 codec: FrameCodec::plaintext(),
153 default_encryption: EncryptionKind::None,
154 }
155 }
156
157 pub fn with_chacha20_key(mut self, key: [u8; 32]) -> Self {
158 self.codec = self.codec.clone().with_chacha20_key(key);
159 self.default_encryption = EncryptionKind::ChaCha20;
160 self
161 }
162
163 pub fn with_aes256_key(mut self, key: [u8; 32]) -> Self {
164 self.codec = self.codec.clone().with_aes256_key(key);
165 self.default_encryption = EncryptionKind::Aes256;
166 self
167 }
168
169 pub fn handle(&self) -> ServerHandle {
170 ServerHandle {
171 state: Arc::clone(&self.state),
172 default_encryption: self.default_encryption,
173 }
174 }
175
176 pub fn route<F, Fut>(&mut self, route: impl Into<String>, handler: F)
177 where
178 F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
179 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
180 {
181 let handler = Arc::new(move |ctx: ApiContext| {
182 Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
183 });
184 self.routes.insert(route.into(), handler);
185 }
186
187 pub fn typed_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
188 where
189 T: DeserializeOwned + Send + 'static,
190 F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
191 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
192 {
193 let handler = Arc::new(handler);
194 self.route(route, move |ctx| {
195 let handler = Arc::clone(&handler);
196 let params = ctx.bind::<T>();
197 async move {
198 let params = params?;
199 handler(ctx, params).await
200 }
201 });
202 }
203
204 pub fn validated_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
205 where
206 T: DeserializeOwned + Validate + Send + 'static,
207 F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
208 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
209 {
210 let handler = Arc::new(handler);
211 self.route(route, move |ctx| {
212 let handler = Arc::clone(&handler);
213 let params = ctx.bind_validated::<T>();
214 async move {
215 let params = params?;
216 handler(ctx, params).await
217 }
218 });
219 }
220
221 pub fn filter<F, Fut>(&mut self, filter: F)
222 where
223 F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
224 Fut: Future<Output = Result<ApiContext, ApiError>> + Send + 'static,
225 {
226 let filter = Arc::new(move |ctx: ApiContext| {
227 Box::pin(filter(ctx)) as BoxFuture<'static, Result<ApiContext, ApiError>>
228 });
229 self.filters.push(filter);
230 }
231
232 pub fn event_handler<F, Fut>(&mut self, name: impl Into<String>, handler: F)
233 where
234 F: Fn(EventContext) -> Fut + Send + Sync + 'static,
235 Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
236 {
237 let handler = Arc::new(move |ctx: EventContext| {
238 Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
239 });
240 self.event_handlers.insert(name.into(), handler);
241 }
242
243 pub fn on_connected<F, Fut>(&mut self, handler: F)
244 where
245 F: Fn(ServerConnectionContext) -> Fut + Send + Sync + 'static,
246 Fut: Future<Output = ()> + Send + 'static,
247 {
248 let handler = Arc::new(move |ctx: ServerConnectionContext| {
249 Box::pin(handler(ctx)) as BoxFuture<'static, ()>
250 });
251 self.connection_handlers.push(handler);
252 }
253
254 pub fn on_disconnected<F, Fut>(&mut self, handler: F)
255 where
256 F: Fn(ServerDisconnectContext) -> Fut + Send + Sync + 'static,
257 Fut: Future<Output = ()> + Send + 'static,
258 {
259 let handler = Arc::new(move |ctx: ServerDisconnectContext| {
260 Box::pin(handler(ctx)) as BoxFuture<'static, ()>
261 });
262 self.disconnect_handlers.push(handler);
263 }
264
265 pub fn exception_handler<F, Fut>(&mut self, handler: F)
266 where
267 F: Fn(ExceptionContext) -> Fut + Send + Sync + 'static,
268 Fut: Future<Output = ErrorPayload> + Send + 'static,
269 {
270 self.exception_handler = Some(Arc::new(move |ctx: ExceptionContext| {
271 Box::pin(handler(ctx)) as BoxFuture<'static, ErrorPayload>
272 }));
273 }
274
275 pub async fn listen(self, address: &str) -> Result<(), ServerError> {
276 let listener = TcpListener::bind(address).await?;
277 println!("WSCALL server listening on ws://{address}/socket");
278
279 let shared = Arc::new(self);
280 loop {
281 let (stream, peer) = listener.accept().await?;
282 let server = Arc::clone(&shared);
283 tokio::spawn(async move {
284 if let Err(error) = server.serve_connection(stream, peer).await {
285 eprintln!("connection {peer:?} failed: {error}");
286 }
287 });
288 }
289 }
290
291 async fn serve_connection(
292 self: Arc<Self>,
293 stream: TcpStream,
294 peer: std::net::SocketAddr,
295 ) -> Result<(), ServerError> {
296 let websocket = accept_async(stream).await?;
297 let connection_id = Uuid::new_v4().to_string();
298 let (mut sink, mut stream) = websocket.split();
299 let (tx, mut rx) = mpsc::channel::<ServerOutbound>(SERVER_OUTBOUND_QUEUE_CAPACITY);
300
301 self.state
302 .clients
303 .write()
304 .await
305 .insert(connection_id.clone(), tx.clone());
306
307 self.notify_connected(&connection_id, Some(peer)).await;
308
309 let codec = self.codec.clone();
310 let writer = tokio::spawn(async move {
311 while let Some(outbound) = rx.recv().await {
312 match outbound {
313 ServerOutbound::Packet(packet) => {
314 let bytes = codec.encode(&packet)?;
315 sink.send(Message::Binary(bytes)).await?;
316 }
317 ServerOutbound::Ping(payload) => {
318 sink.send(Message::Ping(payload)).await?;
319 }
320 ServerOutbound::Pong(payload) => {
321 sink.send(Message::Pong(payload)).await?;
322 }
323 ServerOutbound::Close => {
324 let _ = sink.send(Message::Close(None)).await;
325 break;
326 }
327 }
328 }
329 Ok::<(), ServerError>(())
330 });
331
332 let heartbeat_tx = tx.clone();
333 let heartbeat = tokio::spawn(async move {
334 let mut ticker = interval(SERVER_HEARTBEAT_INTERVAL);
335 ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
336 loop {
337 ticker.tick().await;
338 if heartbeat_tx
339 .send(ServerOutbound::Ping(Vec::new()))
340 .await
341 .is_err()
342 {
343 break;
344 }
345 }
346 });
347
348 let result = async {
349 self.handle()
350 .send_event_to(
351 &connection_id,
352 "system.notice",
353 json!({ "message": "connected", "connection_id": connection_id }),
354 Vec::new(),
355 )
356 .await
357 .map_err(ServerError::Api)?;
358
359 loop {
360 let next_message = timeout(SERVER_IDLE_TIMEOUT, stream.next()).await;
361 let Some(message) =
362 next_message.map_err(|_| ServerError::IdleTimeout(connection_id.clone()))?
363 else {
364 break Ok(());
365 };
366
367 match message? {
368 Message::Binary(bytes) => {
369 let packet = self.codec.decode(&bytes)?;
370 self.process_packet(&connection_id, Some(peer), packet)
371 .await?;
372 }
373 Message::Close(_) => break Ok(()),
374 Message::Ping(payload) => {
375 if tx
376 .send(ServerOutbound::Pong(payload.to_vec()))
377 .await
378 .is_err()
379 {
380 break Ok(());
381 }
382 }
383 Message::Pong(_) => {}
384 Message::Text(_) => {}
385 Message::Frame(_) => {}
386 }
387 }
388 }
389 .await;
390
391 self.state.clients.write().await.remove(&connection_id);
392 let _ = tx.send(ServerOutbound::Close).await;
393 heartbeat.abort();
394 writer.abort();
395 self.notify_disconnected(
396 &connection_id,
397 Some(peer),
398 Self::disconnect_reason(&result),
399 )
400 .await;
401 result
402 }
403
404 async fn process_packet(
405 &self,
406 connection_id: &str,
407 peer_addr: Option<std::net::SocketAddr>,
408 packet: PacketEnvelope,
409 ) -> Result<(), ServerError> {
410 match packet.body {
411 PacketBody::ApiRequest {
412 request_id,
413 route,
414 params,
415 attachments,
416 metadata,
417 } => {
418 let response = self
419 .run_api_request(
420 connection_id,
421 peer_addr,
422 ApiRequestInput {
423 request_id: request_id.clone(),
424 route,
425 params,
426 attachments,
427 metadata,
428 },
429 )
430 .await;
431 self.queue_for(connection_id, response).await?;
432 }
433 PacketBody::EventEmit {
434 event_id,
435 name,
436 data,
437 attachments,
438 metadata,
439 ..
440 } => {
441 let ack = self
442 .run_event(
443 connection_id,
444 peer_addr,
445 EventEmitInput {
446 event_id: event_id.clone(),
447 name,
448 data,
449 attachments,
450 metadata,
451 },
452 )
453 .await;
454 self.queue_for(connection_id, ack).await?;
455 }
456 PacketBody::EventAck {
457 event_id,
458 ok,
459 receipt,
460 error,
461 } => {
462 println!(
463 "received event ack from {} for {}: ok={}, receipt={}, error={:?}",
464 connection_id, event_id, ok, receipt, error
465 );
466 }
467 PacketBody::ApiResponse { .. } => {}
468 }
469 Ok(())
470 }
471
472 async fn queue_for(
473 &self,
474 connection_id: &str,
475 packet: PacketEnvelope,
476 ) -> Result<(), ServerError> {
477 let clients = self.state.clients.read().await;
478 let sender = clients
479 .get(connection_id)
480 .cloned()
481 .ok_or_else(|| ServerError::Api(ApiError::not_found("connection is closed")))?;
482 drop(clients);
483 sender
484 .try_send(ServerOutbound::Packet(packet))
485 .map_err(|error| match error {
486 tokio::sync::mpsc::error::TrySendError::Full(_) => {
487 ServerError::OutboundQueueFull(connection_id.to_string())
488 }
489 tokio::sync::mpsc::error::TrySendError::Closed(_) => {
490 ServerError::Api(ApiError::internal("failed to queue outbound packet"))
491 }
492 })
493 }
494
495 async fn run_api_request(
496 &self,
497 connection_id: &str,
498 peer_addr: Option<std::net::SocketAddr>,
499 request: ApiRequestInput,
500 ) -> PacketEnvelope {
501 let ApiRequestInput {
502 request_id,
503 route,
504 params,
505 attachments,
506 metadata,
507 } = request;
508
509 let mut ctx = ApiContext {
510 connection_id: connection_id.to_string(),
511 peer_addr,
512 request_id: request_id.clone(),
513 route: route.clone(),
514 params,
515 attachments,
516 metadata,
517 server: self.handle(),
518 };
519
520 for filter in &self.filters {
521 match filter(ctx).await {
522 Ok(next_ctx) => ctx = next_ctx,
523 Err(error) => {
524 return self
525 .api_error_packet(connection_id, Some(request_id), route, error)
526 .await;
527 }
528 }
529 }
530
531 let Some(handler) = self.routes.get(&ctx.route) else {
532 return self
533 .api_error_packet(
534 connection_id,
535 Some(request_id),
536 route,
537 ApiError::not_found("route not found"),
538 )
539 .await;
540 };
541
542 match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
543 Ok(Ok(data)) => PacketEnvelope::with_encryption(
544 PacketBody::ApiResponse {
545 request_id,
546 ok: true,
547 status: 200,
548 data,
549 error: None,
550 metadata: json!({}),
551 },
552 self.default_encryption,
553 ),
554 Ok(Err(error)) => {
555 self.api_error_packet(connection_id, Some(request_id), route, error)
556 .await
557 }
558 Err(_) => {
559 self.api_error_packet(
560 connection_id,
561 Some(request_id),
562 route,
563 ApiError::internal("handler panicked"),
564 )
565 .await
566 }
567 }
568 }
569
570 async fn run_event(
571 &self,
572 connection_id: &str,
573 peer_addr: Option<std::net::SocketAddr>,
574 event: EventEmitInput,
575 ) -> PacketEnvelope {
576 let EventEmitInput {
577 event_id,
578 name,
579 data,
580 attachments,
581 metadata,
582 } = event;
583
584 let ctx = EventContext {
585 connection_id: connection_id.to_string(),
586 peer_addr,
587 event_id: event_id.clone(),
588 name: name.clone(),
589 data,
590 attachments,
591 metadata,
592 server: self.handle(),
593 };
594
595 let Some(handler) = self.event_handlers.get(&name) else {
596 return PacketEnvelope::with_encryption(
597 PacketBody::EventAck {
598 event_id,
599 ok: false,
600 receipt: json!({}),
601 error: Some(ApiError::not_found("event handler not found").into_payload()),
602 },
603 self.default_encryption,
604 );
605 };
606
607 match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
608 Ok(Ok(receipt)) => PacketEnvelope::with_encryption(
609 PacketBody::EventAck {
610 event_id,
611 ok: true,
612 receipt,
613 error: None,
614 },
615 self.default_encryption,
616 ),
617 Ok(Err(error)) => PacketEnvelope::with_encryption(
618 PacketBody::EventAck {
619 event_id: event_id.clone(),
620 ok: false,
621 receipt: json!({}),
622 error: Some(
623 self.map_exception(ExceptionContext {
624 connection_id: connection_id.to_string(),
625 request_id: Some(event_id.clone()),
626 target: name,
627 message_kind: "event",
628 error,
629 })
630 .await,
631 ),
632 },
633 self.default_encryption,
634 ),
635 Err(_) => PacketEnvelope::with_encryption(
636 PacketBody::EventAck {
637 event_id: event_id.clone(),
638 ok: false,
639 receipt: json!({}),
640 error: Some(
641 self.map_exception(ExceptionContext {
642 connection_id: connection_id.to_string(),
643 request_id: Some(event_id.clone()),
644 target: name,
645 message_kind: "event",
646 error: ApiError::internal("event handler panicked"),
647 })
648 .await,
649 ),
650 },
651 self.default_encryption,
652 ),
653 }
654 }
655
656 async fn api_error_packet(
657 &self,
658 connection_id: &str,
659 request_id: Option<String>,
660 route: String,
661 error: ApiError,
662 ) -> PacketEnvelope {
663 let request_id = request_id.unwrap_or_else(|| Uuid::new_v4().to_string());
664 let status = error.status;
665 let payload = self
666 .map_exception(ExceptionContext {
667 connection_id: connection_id.to_string(),
668 request_id: Some(request_id.clone()),
669 target: route,
670 message_kind: "api",
671 error,
672 })
673 .await;
674
675 PacketEnvelope::with_encryption(
676 PacketBody::ApiResponse {
677 request_id,
678 ok: false,
679 status,
680 data: json!({}),
681 error: Some(payload),
682 metadata: json!({}),
683 },
684 self.default_encryption,
685 )
686 }
687
688 async fn notify_connected(
689 &self,
690 connection_id: &str,
691 peer_addr: Option<std::net::SocketAddr>,
692 ) {
693 let handlers = self.connection_handlers.clone();
694 for handler in handlers {
695 let context = ServerConnectionContext {
696 connection_id: connection_id.to_string(),
697 peer_addr,
698 server: self.handle(),
699 };
700
701 if AssertUnwindSafe(handler(context)).catch_unwind().await.is_err() {
702 eprintln!("server connected handler panicked");
703 }
704 }
705 }
706
707 async fn notify_disconnected(
708 &self,
709 connection_id: &str,
710 peer_addr: Option<std::net::SocketAddr>,
711 reason: String,
712 ) {
713 let handlers = self.disconnect_handlers.clone();
714 for handler in handlers {
715 let context = ServerDisconnectContext {
716 connection_id: connection_id.to_string(),
717 peer_addr,
718 reason: reason.clone(),
719 server: self.handle(),
720 };
721
722 if AssertUnwindSafe(handler(context)).catch_unwind().await.is_err() {
723 eprintln!("server disconnected handler panicked");
724 }
725 }
726 }
727
728 fn disconnect_reason(result: &Result<(), ServerError>) -> String {
729 match result {
730 Ok(()) => "connection closed".to_string(),
731 Err(ServerError::IdleTimeout(_)) => "idle timeout".to_string(),
732 Err(error) => error.to_string(),
733 }
734 }
735
736 async fn map_exception(&self, context: ExceptionContext) -> ErrorPayload {
737 match &self.exception_handler {
738 Some(handler) => handler(context).await,
739 None => context.error.into_payload(),
740 }
741 }
742}