Skip to main content

wscall_server/
server_runtime.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures_util::{FutureExt, SinkExt, StreamExt, future::BoxFuture};
8use serde::de::DeserializeOwned;
9use serde_json::{Value, json};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, mpsc};
12use tokio::time::{MissedTickBehavior, interval, timeout};
13use tokio_tungstenite::{accept_async, tungstenite::Message};
14use uuid::Uuid;
15use validator::Validate;
16use wscall_protocol::{
17    EncryptionKind, ErrorPayload, FileAttachment, FrameCodec, PacketBody, PacketEnvelope,
18};
19
20use crate::server_types::{
21    ApiContext, ApiError, EventContext, ExceptionContext, ServerError, ServerHandle,
22    ServerOutbound, ServerState,
23};
24
25const SERVER_IDLE_TIMEOUT: Duration = Duration::from_secs(45);
26const SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
27const SERVER_OUTBOUND_QUEUE_CAPACITY: usize = 256;
28
29type ApiHandler =
30    Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
31type Filter =
32    Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<ApiContext, ApiError>> + Send + Sync>;
33type EventHandler =
34    Arc<dyn Fn(EventContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
35type ExceptionHandler =
36    Arc<dyn Fn(ExceptionContext) -> BoxFuture<'static, ErrorPayload> + Send + Sync>;
37
38struct ApiRequestInput {
39    request_id: String,
40    route: String,
41    params: Value,
42    attachments: Vec<FileAttachment>,
43    metadata: Value,
44}
45
46struct EventEmitInput {
47    event_id: String,
48    name: String,
49    data: Value,
50    attachments: Vec<FileAttachment>,
51    metadata: Value,
52}
53
54impl ServerHandle {
55    pub async fn broadcast_event(
56        &self,
57        name: impl Into<String>,
58        data: Value,
59        attachments: Vec<FileAttachment>,
60    ) -> Result<(), ApiError> {
61        let packet = PacketEnvelope::with_encryption(
62            PacketBody::EventEmit {
63                event_id: Uuid::new_v4().to_string(),
64                name: name.into(),
65                data,
66                attachments,
67                metadata: json!({ "source": "server" }),
68                expect_ack: true,
69            },
70            self.default_encryption,
71        );
72
73        let clients = self.state.clients.read().await;
74        let senders = clients.values().cloned().collect::<Vec<_>>();
75        drop(clients);
76
77        for sender in senders {
78            sender
79                .try_send(ServerOutbound::Packet(packet.clone()))
80                .map_err(|_| ApiError::internal("failed to queue broadcast event"))?;
81        }
82        Ok(())
83    }
84
85    pub async fn send_event_to(
86        &self,
87        connection_id: &str,
88        name: impl Into<String>,
89        data: Value,
90        attachments: Vec<FileAttachment>,
91    ) -> Result<(), ApiError> {
92        let packet = PacketEnvelope::with_encryption(
93            PacketBody::EventEmit {
94                event_id: Uuid::new_v4().to_string(),
95                name: name.into(),
96                data,
97                attachments,
98                metadata: json!({ "source": "server" }),
99                expect_ack: true,
100            },
101            self.default_encryption,
102        );
103
104        let clients = self.state.clients.read().await;
105        let sender = clients
106            .get(connection_id)
107            .cloned()
108            .ok_or_else(|| ApiError::not_found("target connection not found"))?;
109        drop(clients);
110        sender
111            .try_send(ServerOutbound::Packet(packet))
112            .map_err(|_| ApiError::internal("failed to queue direct event"))
113    }
114
115    pub async fn connection_count(&self) -> usize {
116        self.state.clients.read().await.len()
117    }
118}
119
120pub struct WscallServer {
121    state: Arc<ServerState>,
122    routes: HashMap<String, ApiHandler>,
123    filters: Vec<Filter>,
124    event_handlers: HashMap<String, EventHandler>,
125    exception_handler: Option<ExceptionHandler>,
126    codec: FrameCodec,
127    default_encryption: EncryptionKind,
128}
129
130impl Default for WscallServer {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl WscallServer {
137    pub fn new() -> Self {
138        Self {
139            state: Arc::new(ServerState {
140                clients: RwLock::new(HashMap::new()),
141            }),
142            routes: HashMap::new(),
143            filters: Vec::new(),
144            event_handlers: HashMap::new(),
145            exception_handler: None,
146            codec: FrameCodec::plaintext(),
147            default_encryption: EncryptionKind::None,
148        }
149    }
150
151    pub fn with_chacha20_key(mut self, key: [u8; 32]) -> Self {
152        self.codec = self.codec.clone().with_chacha20_key(key);
153        self.default_encryption = EncryptionKind::ChaCha20;
154        self
155    }
156
157    pub fn with_aes256_key(mut self, key: [u8; 32]) -> Self {
158        self.codec = self.codec.clone().with_aes256_key(key);
159        self.default_encryption = EncryptionKind::Aes256;
160        self
161    }
162
163    pub fn handle(&self) -> ServerHandle {
164        ServerHandle {
165            state: Arc::clone(&self.state),
166            default_encryption: self.default_encryption,
167        }
168    }
169
170    pub fn route<F, Fut>(&mut self, route: impl Into<String>, handler: F)
171    where
172        F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
173        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
174    {
175        let handler = Arc::new(move |ctx: ApiContext| {
176            Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
177        });
178        self.routes.insert(route.into(), handler);
179    }
180
181    pub fn typed_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
182    where
183        T: DeserializeOwned + Send + 'static,
184        F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
185        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
186    {
187        let handler = Arc::new(handler);
188        self.route(route, move |ctx| {
189            let handler = Arc::clone(&handler);
190            let params = ctx.bind::<T>();
191            async move {
192                let params = params?;
193                handler(ctx, params).await
194            }
195        });
196    }
197
198    pub fn validated_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
199    where
200        T: DeserializeOwned + Validate + Send + 'static,
201        F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
202        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
203    {
204        let handler = Arc::new(handler);
205        self.route(route, move |ctx| {
206            let handler = Arc::clone(&handler);
207            let params = ctx.bind_validated::<T>();
208            async move {
209                let params = params?;
210                handler(ctx, params).await
211            }
212        });
213    }
214
215    pub fn filter<F, Fut>(&mut self, filter: F)
216    where
217        F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
218        Fut: Future<Output = Result<ApiContext, ApiError>> + Send + 'static,
219    {
220        let filter = Arc::new(move |ctx: ApiContext| {
221            Box::pin(filter(ctx)) as BoxFuture<'static, Result<ApiContext, ApiError>>
222        });
223        self.filters.push(filter);
224    }
225
226    pub fn event_handler<F, Fut>(&mut self, name: impl Into<String>, handler: F)
227    where
228        F: Fn(EventContext) -> Fut + Send + Sync + 'static,
229        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
230    {
231        let handler = Arc::new(move |ctx: EventContext| {
232            Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
233        });
234        self.event_handlers.insert(name.into(), handler);
235    }
236
237    pub fn exception_handler<F, Fut>(&mut self, handler: F)
238    where
239        F: Fn(ExceptionContext) -> Fut + Send + Sync + 'static,
240        Fut: Future<Output = ErrorPayload> + Send + 'static,
241    {
242        self.exception_handler = Some(Arc::new(move |ctx: ExceptionContext| {
243            Box::pin(handler(ctx)) as BoxFuture<'static, ErrorPayload>
244        }));
245    }
246
247    pub async fn listen(self, address: &str) -> Result<(), ServerError> {
248        let listener = TcpListener::bind(address).await?;
249        println!("WSCALL server listening on ws://{address}/socket");
250
251        let shared = Arc::new(self);
252        loop {
253            let (stream, peer) = listener.accept().await?;
254            let server = Arc::clone(&shared);
255            tokio::spawn(async move {
256                if let Err(error) = server.serve_connection(stream, peer).await {
257                    eprintln!("connection {peer:?} failed: {error}");
258                }
259            });
260        }
261    }
262
263    async fn serve_connection(
264        self: Arc<Self>,
265        stream: TcpStream,
266        peer: std::net::SocketAddr,
267    ) -> Result<(), ServerError> {
268        let websocket = accept_async(stream).await?;
269        let connection_id = Uuid::new_v4().to_string();
270        let (mut sink, mut stream) = websocket.split();
271        let (tx, mut rx) = mpsc::channel::<ServerOutbound>(SERVER_OUTBOUND_QUEUE_CAPACITY);
272
273        self.state
274            .clients
275            .write()
276            .await
277            .insert(connection_id.clone(), tx.clone());
278
279        let codec = self.codec.clone();
280        let writer = tokio::spawn(async move {
281            while let Some(outbound) = rx.recv().await {
282                match outbound {
283                    ServerOutbound::Packet(packet) => {
284                        let bytes = codec.encode(&packet)?;
285                        sink.send(Message::Binary(bytes)).await?;
286                    }
287                    ServerOutbound::Ping(payload) => {
288                        sink.send(Message::Ping(payload)).await?;
289                    }
290                    ServerOutbound::Pong(payload) => {
291                        sink.send(Message::Pong(payload)).await?;
292                    }
293                    ServerOutbound::Close => {
294                        let _ = sink.send(Message::Close(None)).await;
295                        break;
296                    }
297                }
298            }
299            Ok::<(), ServerError>(())
300        });
301
302        let heartbeat_tx = tx.clone();
303        let heartbeat = tokio::spawn(async move {
304            let mut ticker = interval(SERVER_HEARTBEAT_INTERVAL);
305            ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
306            loop {
307                ticker.tick().await;
308                if heartbeat_tx
309                    .send(ServerOutbound::Ping(Vec::new()))
310                    .await
311                    .is_err()
312                {
313                    break;
314                }
315            }
316        });
317
318        self.handle()
319            .send_event_to(
320                &connection_id,
321                "system.notice",
322                json!({ "message": "connected", "connection_id": connection_id }),
323                Vec::new(),
324            )
325            .await
326            .map_err(ServerError::Api)?;
327
328        let result = loop {
329            let next_message = timeout(SERVER_IDLE_TIMEOUT, stream.next()).await;
330            let Some(message) =
331                next_message.map_err(|_| ServerError::IdleTimeout(connection_id.clone()))?
332            else {
333                break Ok(());
334            };
335
336            match message? {
337                Message::Binary(bytes) => {
338                    let packet = self.codec.decode(&bytes)?;
339                    self.process_packet(&connection_id, Some(peer), packet)
340                        .await?;
341                }
342                Message::Close(_) => break Ok(()),
343                Message::Ping(payload) => {
344                    if tx
345                        .send(ServerOutbound::Pong(payload.to_vec()))
346                        .await
347                        .is_err()
348                    {
349                        break Ok(());
350                    }
351                }
352                Message::Pong(_) => {}
353                Message::Text(_) => {}
354                Message::Frame(_) => {}
355            }
356        };
357
358        self.state.clients.write().await.remove(&connection_id);
359        let _ = tx.send(ServerOutbound::Close).await;
360        heartbeat.abort();
361        writer.abort();
362        result
363    }
364
365    async fn process_packet(
366        &self,
367        connection_id: &str,
368        peer_addr: Option<std::net::SocketAddr>,
369        packet: PacketEnvelope,
370    ) -> Result<(), ServerError> {
371        match packet.body {
372            PacketBody::ApiRequest {
373                request_id,
374                route,
375                params,
376                attachments,
377                metadata,
378            } => {
379                let response = self
380                    .run_api_request(
381                        connection_id,
382                        peer_addr,
383                        ApiRequestInput {
384                            request_id: request_id.clone(),
385                            route,
386                            params,
387                            attachments,
388                            metadata,
389                        },
390                    )
391                    .await;
392                self.queue_for(connection_id, response).await?;
393            }
394            PacketBody::EventEmit {
395                event_id,
396                name,
397                data,
398                attachments,
399                metadata,
400                ..
401            } => {
402                let ack = self
403                    .run_event(
404                        connection_id,
405                        peer_addr,
406                        EventEmitInput {
407                            event_id: event_id.clone(),
408                            name,
409                            data,
410                            attachments,
411                            metadata,
412                        },
413                    )
414                    .await;
415                self.queue_for(connection_id, ack).await?;
416            }
417            PacketBody::EventAck {
418                event_id,
419                ok,
420                receipt,
421                error,
422            } => {
423                println!(
424                    "received event ack from {} for {}: ok={}, receipt={}, error={:?}",
425                    connection_id, event_id, ok, receipt, error
426                );
427            }
428            PacketBody::ApiResponse { .. } => {}
429        }
430        Ok(())
431    }
432
433    async fn queue_for(
434        &self,
435        connection_id: &str,
436        packet: PacketEnvelope,
437    ) -> Result<(), ServerError> {
438        let clients = self.state.clients.read().await;
439        let sender = clients
440            .get(connection_id)
441            .cloned()
442            .ok_or_else(|| ServerError::Api(ApiError::not_found("connection is closed")))?;
443        drop(clients);
444        sender
445            .try_send(ServerOutbound::Packet(packet))
446            .map_err(|error| match error {
447                tokio::sync::mpsc::error::TrySendError::Full(_) => {
448                    ServerError::OutboundQueueFull(connection_id.to_string())
449                }
450                tokio::sync::mpsc::error::TrySendError::Closed(_) => {
451                    ServerError::Api(ApiError::internal("failed to queue outbound packet"))
452                }
453            })
454    }
455
456    async fn run_api_request(
457        &self,
458        connection_id: &str,
459        peer_addr: Option<std::net::SocketAddr>,
460        request: ApiRequestInput,
461    ) -> PacketEnvelope {
462        let ApiRequestInput {
463            request_id,
464            route,
465            params,
466            attachments,
467            metadata,
468        } = request;
469
470        let mut ctx = ApiContext {
471            connection_id: connection_id.to_string(),
472            peer_addr,
473            request_id: request_id.clone(),
474            route: route.clone(),
475            params,
476            attachments,
477            metadata,
478            server: self.handle(),
479        };
480
481        for filter in &self.filters {
482            match filter(ctx).await {
483                Ok(next_ctx) => ctx = next_ctx,
484                Err(error) => {
485                    return self
486                        .api_error_packet(connection_id, Some(request_id), route, error)
487                        .await;
488                }
489            }
490        }
491
492        let Some(handler) = self.routes.get(&ctx.route) else {
493            return self
494                .api_error_packet(
495                    connection_id,
496                    Some(request_id),
497                    route,
498                    ApiError::not_found("route not found"),
499                )
500                .await;
501        };
502
503        match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
504            Ok(Ok(data)) => PacketEnvelope::with_encryption(
505                PacketBody::ApiResponse {
506                    request_id,
507                    ok: true,
508                    status: 200,
509                    data,
510                    error: None,
511                    metadata: json!({}),
512                },
513                self.default_encryption,
514            ),
515            Ok(Err(error)) => {
516                self.api_error_packet(connection_id, Some(request_id), route, error)
517                    .await
518            }
519            Err(_) => {
520                self.api_error_packet(
521                    connection_id,
522                    Some(request_id),
523                    route,
524                    ApiError::internal("handler panicked"),
525                )
526                .await
527            }
528        }
529    }
530
531    async fn run_event(
532        &self,
533        connection_id: &str,
534        peer_addr: Option<std::net::SocketAddr>,
535        event: EventEmitInput,
536    ) -> PacketEnvelope {
537        let EventEmitInput {
538            event_id,
539            name,
540            data,
541            attachments,
542            metadata,
543        } = event;
544
545        let ctx = EventContext {
546            connection_id: connection_id.to_string(),
547            peer_addr,
548            event_id: event_id.clone(),
549            name: name.clone(),
550            data,
551            attachments,
552            metadata,
553            server: self.handle(),
554        };
555
556        let Some(handler) = self.event_handlers.get(&name) else {
557            return PacketEnvelope::with_encryption(
558                PacketBody::EventAck {
559                    event_id,
560                    ok: false,
561                    receipt: json!({}),
562                    error: Some(ApiError::not_found("event handler not found").into_payload()),
563                },
564                self.default_encryption,
565            );
566        };
567
568        match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
569            Ok(Ok(receipt)) => PacketEnvelope::with_encryption(
570                PacketBody::EventAck {
571                    event_id,
572                    ok: true,
573                    receipt,
574                    error: None,
575                },
576                self.default_encryption,
577            ),
578            Ok(Err(error)) => PacketEnvelope::with_encryption(
579                PacketBody::EventAck {
580                    event_id: event_id.clone(),
581                    ok: false,
582                    receipt: json!({}),
583                    error: Some(
584                        self.map_exception(ExceptionContext {
585                            connection_id: connection_id.to_string(),
586                            request_id: Some(event_id.clone()),
587                            target: name,
588                            message_kind: "event",
589                            error,
590                        })
591                        .await,
592                    ),
593                },
594                self.default_encryption,
595            ),
596            Err(_) => PacketEnvelope::with_encryption(
597                PacketBody::EventAck {
598                    event_id: event_id.clone(),
599                    ok: false,
600                    receipt: json!({}),
601                    error: Some(
602                        self.map_exception(ExceptionContext {
603                            connection_id: connection_id.to_string(),
604                            request_id: Some(event_id.clone()),
605                            target: name,
606                            message_kind: "event",
607                            error: ApiError::internal("event handler panicked"),
608                        })
609                        .await,
610                    ),
611                },
612                self.default_encryption,
613            ),
614        }
615    }
616
617    async fn api_error_packet(
618        &self,
619        connection_id: &str,
620        request_id: Option<String>,
621        route: String,
622        error: ApiError,
623    ) -> PacketEnvelope {
624        let request_id = request_id.unwrap_or_else(|| Uuid::new_v4().to_string());
625        let status = error.status;
626        let payload = self
627            .map_exception(ExceptionContext {
628                connection_id: connection_id.to_string(),
629                request_id: Some(request_id.clone()),
630                target: route,
631                message_kind: "api",
632                error,
633            })
634            .await;
635
636        PacketEnvelope::with_encryption(
637            PacketBody::ApiResponse {
638                request_id,
639                ok: false,
640                status,
641                data: json!({}),
642                error: Some(payload),
643                metadata: json!({}),
644            },
645            self.default_encryption,
646        )
647    }
648
649    async fn map_exception(&self, context: ExceptionContext) -> ErrorPayload {
650        match &self.exception_handler {
651            Some(handler) => handler(context).await,
652            None => context.error.into_payload(),
653        }
654    }
655}