Skip to main content

wscall_server/
server_runtime.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::panic::AssertUnwindSafe;
4use std::sync::Arc;
5use std::time::Duration;
6
7use futures_util::{FutureExt, SinkExt, StreamExt, future::BoxFuture};
8use serde::de::DeserializeOwned;
9use serde_json::{Value, json};
10use tokio::net::{TcpListener, TcpStream};
11use tokio::sync::{RwLock, mpsc};
12use tokio::time::{MissedTickBehavior, interval, timeout};
13use tokio_tungstenite::{accept_async, tungstenite::Message};
14use uuid::Uuid;
15use validator::Validate;
16use wscall_protocol::{
17    EncryptionKind, ErrorPayload, FileAttachment, FrameCodec, PacketBody, PacketEnvelope,
18};
19
20use crate::server_types::{
21    ApiContext, ApiError, EventContext, ExceptionContext, ServerConnectionContext,
22    ServerDisconnectContext, ServerError, ServerHandle, ServerOutbound, ServerState,
23};
24
25const SERVER_IDLE_TIMEOUT: Duration = Duration::from_secs(45);
26const SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
27const SERVER_OUTBOUND_QUEUE_CAPACITY: usize = 256;
28
29type ApiHandler =
30    Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
31type Filter =
32    Arc<dyn Fn(ApiContext) -> BoxFuture<'static, Result<ApiContext, ApiError>> + Send + Sync>;
33type EventHandler =
34    Arc<dyn Fn(EventContext) -> BoxFuture<'static, Result<Value, ApiError>> + Send + Sync>;
35type ConnectionHandler = Arc<dyn Fn(ServerConnectionContext) -> BoxFuture<'static, ()> + Send + Sync>;
36type DisconnectHandler = Arc<dyn Fn(ServerDisconnectContext) -> BoxFuture<'static, ()> + Send + Sync>;
37type ExceptionHandler =
38    Arc<dyn Fn(ExceptionContext) -> BoxFuture<'static, ErrorPayload> + Send + Sync>;
39
40struct ApiRequestInput {
41    request_id: String,
42    route: String,
43    params: Value,
44    attachments: Vec<FileAttachment>,
45    metadata: Value,
46}
47
48struct EventEmitInput {
49    event_id: String,
50    name: String,
51    data: Value,
52    attachments: Vec<FileAttachment>,
53    metadata: Value,
54}
55
56impl ServerHandle {
57    pub async fn broadcast_event(
58        &self,
59        name: impl Into<String>,
60        data: Value,
61        attachments: Vec<FileAttachment>,
62    ) -> Result<(), ApiError> {
63        let packet = PacketEnvelope::with_encryption(
64            PacketBody::EventEmit {
65                event_id: Uuid::new_v4().to_string(),
66                name: name.into(),
67                data,
68                attachments,
69                metadata: json!({ "source": "server" }),
70                expect_ack: true,
71            },
72            self.default_encryption,
73        );
74
75        let clients = self.state.clients.read().await;
76        let senders = clients.values().cloned().collect::<Vec<_>>();
77        drop(clients);
78
79        for sender in senders {
80            sender
81                .try_send(ServerOutbound::Packet(packet.clone()))
82                .map_err(|_| ApiError::internal("failed to queue broadcast event"))?;
83        }
84        Ok(())
85    }
86
87    pub async fn send_event_to(
88        &self,
89        connection_id: &str,
90        name: impl Into<String>,
91        data: Value,
92        attachments: Vec<FileAttachment>,
93    ) -> Result<(), ApiError> {
94        let packet = PacketEnvelope::with_encryption(
95            PacketBody::EventEmit {
96                event_id: Uuid::new_v4().to_string(),
97                name: name.into(),
98                data,
99                attachments,
100                metadata: json!({ "source": "server" }),
101                expect_ack: true,
102            },
103            self.default_encryption,
104        );
105
106        let clients = self.state.clients.read().await;
107        let sender = clients
108            .get(connection_id)
109            .cloned()
110            .ok_or_else(|| ApiError::not_found("target connection not found"))?;
111        drop(clients);
112        sender
113            .try_send(ServerOutbound::Packet(packet))
114            .map_err(|_| ApiError::internal("failed to queue direct event"))
115    }
116
117    pub async fn connection_count(&self) -> usize {
118        self.state.clients.read().await.len()
119    }
120}
121
122pub struct WscallServer {
123    state: Arc<ServerState>,
124    routes: HashMap<String, ApiHandler>,
125    filters: Vec<Filter>,
126    event_handlers: HashMap<String, EventHandler>,
127    connection_handlers: Vec<ConnectionHandler>,
128    disconnect_handlers: Vec<DisconnectHandler>,
129    exception_handler: Option<ExceptionHandler>,
130    codec: FrameCodec,
131    default_encryption: EncryptionKind,
132}
133
134impl Default for WscallServer {
135    fn default() -> Self {
136        Self::new()
137    }
138}
139
140impl WscallServer {
141    pub fn new() -> Self {
142        Self {
143            state: Arc::new(ServerState {
144                clients: RwLock::new(HashMap::new()),
145            }),
146            routes: HashMap::new(),
147            filters: Vec::new(),
148            event_handlers: HashMap::new(),
149            connection_handlers: Vec::new(),
150            disconnect_handlers: Vec::new(),
151            exception_handler: None,
152            codec: FrameCodec::plaintext(),
153            default_encryption: EncryptionKind::None,
154        }
155    }
156
157    pub fn with_chacha20_key(mut self, key: [u8; 32]) -> Self {
158        self.codec = self.codec.clone().with_chacha20_key(key);
159        self.default_encryption = EncryptionKind::ChaCha20;
160        self
161    }
162
163    pub fn with_aes256_key(mut self, key: [u8; 32]) -> Self {
164        self.codec = self.codec.clone().with_aes256_key(key);
165        self.default_encryption = EncryptionKind::Aes256;
166        self
167    }
168
169    pub fn handle(&self) -> ServerHandle {
170        ServerHandle {
171            state: Arc::clone(&self.state),
172            default_encryption: self.default_encryption,
173        }
174    }
175
176    pub fn route<F, Fut>(&mut self, route: impl Into<String>, handler: F)
177    where
178        F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
179        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
180    {
181        let handler = Arc::new(move |ctx: ApiContext| {
182            Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
183        });
184        self.routes.insert(route.into(), handler);
185    }
186
187    pub fn typed_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
188    where
189        T: DeserializeOwned + Send + 'static,
190        F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
191        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
192    {
193        let handler = Arc::new(handler);
194        self.route(route, move |ctx| {
195            let handler = Arc::clone(&handler);
196            let params = ctx.bind::<T>();
197            async move {
198                let params = params?;
199                handler(ctx, params).await
200            }
201        });
202    }
203
204    pub fn validated_route<T, F, Fut>(&mut self, route: impl Into<String>, handler: F)
205    where
206        T: DeserializeOwned + Validate + Send + 'static,
207        F: Fn(ApiContext, T) -> Fut + Send + Sync + 'static,
208        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
209    {
210        let handler = Arc::new(handler);
211        self.route(route, move |ctx| {
212            let handler = Arc::clone(&handler);
213            let params = ctx.bind_validated::<T>();
214            async move {
215                let params = params?;
216                handler(ctx, params).await
217            }
218        });
219    }
220
221    pub fn filter<F, Fut>(&mut self, filter: F)
222    where
223        F: Fn(ApiContext) -> Fut + Send + Sync + 'static,
224        Fut: Future<Output = Result<ApiContext, ApiError>> + Send + 'static,
225    {
226        let filter = Arc::new(move |ctx: ApiContext| {
227            Box::pin(filter(ctx)) as BoxFuture<'static, Result<ApiContext, ApiError>>
228        });
229        self.filters.push(filter);
230    }
231
232    pub fn event_handler<F, Fut>(&mut self, name: impl Into<String>, handler: F)
233    where
234        F: Fn(EventContext) -> Fut + Send + Sync + 'static,
235        Fut: Future<Output = Result<Value, ApiError>> + Send + 'static,
236    {
237        let handler = Arc::new(move |ctx: EventContext| {
238            Box::pin(handler(ctx)) as BoxFuture<'static, Result<Value, ApiError>>
239        });
240        self.event_handlers.insert(name.into(), handler);
241    }
242
243    pub fn on_connected<F, Fut>(&mut self, handler: F)
244    where
245        F: Fn(ServerConnectionContext) -> Fut + Send + Sync + 'static,
246        Fut: Future<Output = ()> + Send + 'static,
247    {
248        let handler = Arc::new(move |ctx: ServerConnectionContext| {
249            Box::pin(handler(ctx)) as BoxFuture<'static, ()>
250        });
251        self.connection_handlers.push(handler);
252    }
253
254    pub fn on_disconnected<F, Fut>(&mut self, handler: F)
255    where
256        F: Fn(ServerDisconnectContext) -> Fut + Send + Sync + 'static,
257        Fut: Future<Output = ()> + Send + 'static,
258    {
259        let handler = Arc::new(move |ctx: ServerDisconnectContext| {
260            Box::pin(handler(ctx)) as BoxFuture<'static, ()>
261        });
262        self.disconnect_handlers.push(handler);
263    }
264
265    pub fn exception_handler<F, Fut>(&mut self, handler: F)
266    where
267        F: Fn(ExceptionContext) -> Fut + Send + Sync + 'static,
268        Fut: Future<Output = ErrorPayload> + Send + 'static,
269    {
270        self.exception_handler = Some(Arc::new(move |ctx: ExceptionContext| {
271            Box::pin(handler(ctx)) as BoxFuture<'static, ErrorPayload>
272        }));
273    }
274
275    pub async fn listen(self, address: &str) -> Result<(), ServerError> {
276        let listener = TcpListener::bind(address).await?;
277        println!("WSCALL server listening on ws://{address}/socket");
278
279        let shared = Arc::new(self);
280        loop {
281            let (stream, peer) = listener.accept().await?;
282            let server = Arc::clone(&shared);
283            tokio::spawn(async move {
284                if let Err(error) = server.serve_connection(stream, peer).await {
285                    eprintln!("connection {peer:?} failed: {error}");
286                }
287            });
288        }
289    }
290
291    async fn serve_connection(
292        self: Arc<Self>,
293        stream: TcpStream,
294        peer: std::net::SocketAddr,
295    ) -> Result<(), ServerError> {
296        let websocket = accept_async(stream).await?;
297        let connection_id = Uuid::new_v4().to_string();
298        let (mut sink, mut stream) = websocket.split();
299        let (tx, mut rx) = mpsc::channel::<ServerOutbound>(SERVER_OUTBOUND_QUEUE_CAPACITY);
300
301        self.state
302            .clients
303            .write()
304            .await
305            .insert(connection_id.clone(), tx.clone());
306
307        self.notify_connected(&connection_id, Some(peer)).await;
308
309        let codec = self.codec.clone();
310        let writer = tokio::spawn(async move {
311            while let Some(outbound) = rx.recv().await {
312                match outbound {
313                    ServerOutbound::Packet(packet) => {
314                        let bytes = codec.encode(&packet)?;
315                        sink.send(Message::Binary(bytes)).await?;
316                    }
317                    ServerOutbound::Ping(payload) => {
318                        sink.send(Message::Ping(payload)).await?;
319                    }
320                    ServerOutbound::Pong(payload) => {
321                        sink.send(Message::Pong(payload)).await?;
322                    }
323                    ServerOutbound::Close => {
324                        let _ = sink.send(Message::Close(None)).await;
325                        break;
326                    }
327                }
328            }
329            Ok::<(), ServerError>(())
330        });
331
332        let heartbeat_tx = tx.clone();
333        let heartbeat = tokio::spawn(async move {
334            let mut ticker = interval(SERVER_HEARTBEAT_INTERVAL);
335            ticker.set_missed_tick_behavior(MissedTickBehavior::Delay);
336            loop {
337                ticker.tick().await;
338                if heartbeat_tx
339                    .send(ServerOutbound::Ping(Vec::new()))
340                    .await
341                    .is_err()
342                {
343                    break;
344                }
345            }
346        });
347
348        let result = async {
349            self.handle()
350                .send_event_to(
351                    &connection_id,
352                    "system.notice",
353                    json!({ "message": "connected", "connection_id": connection_id }),
354                    Vec::new(),
355                )
356                .await
357                .map_err(ServerError::Api)?;
358
359            loop {
360                let next_message = timeout(SERVER_IDLE_TIMEOUT, stream.next()).await;
361                let Some(message) =
362                    next_message.map_err(|_| ServerError::IdleTimeout(connection_id.clone()))?
363                else {
364                    break Ok(());
365                };
366
367                match message? {
368                    Message::Binary(bytes) => {
369                        let packet = self.codec.decode(&bytes)?;
370                        self.process_packet(&connection_id, Some(peer), packet)
371                            .await?;
372                    }
373                    Message::Close(_) => break Ok(()),
374                    Message::Ping(payload) => {
375                        if tx
376                            .send(ServerOutbound::Pong(payload.to_vec()))
377                            .await
378                            .is_err()
379                        {
380                            break Ok(());
381                        }
382                    }
383                    Message::Pong(_) => {}
384                    Message::Text(_) => {}
385                    Message::Frame(_) => {}
386                }
387            }
388        }
389        .await;
390
391        self.state.clients.write().await.remove(&connection_id);
392        let _ = tx.send(ServerOutbound::Close).await;
393        heartbeat.abort();
394        writer.abort();
395        self.notify_disconnected(
396            &connection_id,
397            Some(peer),
398            Self::disconnect_reason(&result),
399        )
400        .await;
401        result
402    }
403
404    async fn process_packet(
405        &self,
406        connection_id: &str,
407        peer_addr: Option<std::net::SocketAddr>,
408        packet: PacketEnvelope,
409    ) -> Result<(), ServerError> {
410        match packet.body {
411            PacketBody::ApiRequest {
412                request_id,
413                route,
414                params,
415                attachments,
416                metadata,
417            } => {
418                let response = self
419                    .run_api_request(
420                        connection_id,
421                        peer_addr,
422                        ApiRequestInput {
423                            request_id: request_id.clone(),
424                            route,
425                            params,
426                            attachments,
427                            metadata,
428                        },
429                    )
430                    .await;
431                self.queue_for(connection_id, response).await?;
432            }
433            PacketBody::EventEmit {
434                event_id,
435                name,
436                data,
437                attachments,
438                metadata,
439                ..
440            } => {
441                let ack = self
442                    .run_event(
443                        connection_id,
444                        peer_addr,
445                        EventEmitInput {
446                            event_id: event_id.clone(),
447                            name,
448                            data,
449                            attachments,
450                            metadata,
451                        },
452                    )
453                    .await;
454                self.queue_for(connection_id, ack).await?;
455            }
456            PacketBody::EventAck {
457                event_id,
458                ok,
459                receipt,
460                error,
461            } => {
462                println!(
463                    "received event ack from {} for {}: ok={}, receipt={}, error={:?}",
464                    connection_id, event_id, ok, receipt, error
465                );
466            }
467            PacketBody::ApiResponse { .. } => {}
468        }
469        Ok(())
470    }
471
472    async fn queue_for(
473        &self,
474        connection_id: &str,
475        packet: PacketEnvelope,
476    ) -> Result<(), ServerError> {
477        let clients = self.state.clients.read().await;
478        let sender = clients
479            .get(connection_id)
480            .cloned()
481            .ok_or_else(|| ServerError::Api(ApiError::not_found("connection is closed")))?;
482        drop(clients);
483        sender
484            .try_send(ServerOutbound::Packet(packet))
485            .map_err(|error| match error {
486                tokio::sync::mpsc::error::TrySendError::Full(_) => {
487                    ServerError::OutboundQueueFull(connection_id.to_string())
488                }
489                tokio::sync::mpsc::error::TrySendError::Closed(_) => {
490                    ServerError::Api(ApiError::internal("failed to queue outbound packet"))
491                }
492            })
493    }
494
495    async fn run_api_request(
496        &self,
497        connection_id: &str,
498        peer_addr: Option<std::net::SocketAddr>,
499        request: ApiRequestInput,
500    ) -> PacketEnvelope {
501        let ApiRequestInput {
502            request_id,
503            route,
504            params,
505            attachments,
506            metadata,
507        } = request;
508
509        let mut ctx = ApiContext {
510            connection_id: connection_id.to_string(),
511            peer_addr,
512            request_id: request_id.clone(),
513            route: route.clone(),
514            params,
515            attachments,
516            metadata,
517            server: self.handle(),
518        };
519
520        for filter in &self.filters {
521            match filter(ctx).await {
522                Ok(next_ctx) => ctx = next_ctx,
523                Err(error) => {
524                    return self
525                        .api_error_packet(connection_id, Some(request_id), route, error)
526                        .await;
527                }
528            }
529        }
530
531        let Some(handler) = self.routes.get(&ctx.route) else {
532            return self
533                .api_error_packet(
534                    connection_id,
535                    Some(request_id),
536                    route,
537                    ApiError::not_found("route not found"),
538                )
539                .await;
540        };
541
542        match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
543            Ok(Ok(data)) => PacketEnvelope::with_encryption(
544                PacketBody::ApiResponse {
545                    request_id,
546                    ok: true,
547                    status: 200,
548                    data,
549                    error: None,
550                    metadata: json!({}),
551                },
552                self.default_encryption,
553            ),
554            Ok(Err(error)) => {
555                self.api_error_packet(connection_id, Some(request_id), route, error)
556                    .await
557            }
558            Err(_) => {
559                self.api_error_packet(
560                    connection_id,
561                    Some(request_id),
562                    route,
563                    ApiError::internal("handler panicked"),
564                )
565                .await
566            }
567        }
568    }
569
570    async fn run_event(
571        &self,
572        connection_id: &str,
573        peer_addr: Option<std::net::SocketAddr>,
574        event: EventEmitInput,
575    ) -> PacketEnvelope {
576        let EventEmitInput {
577            event_id,
578            name,
579            data,
580            attachments,
581            metadata,
582        } = event;
583
584        let ctx = EventContext {
585            connection_id: connection_id.to_string(),
586            peer_addr,
587            event_id: event_id.clone(),
588            name: name.clone(),
589            data,
590            attachments,
591            metadata,
592            server: self.handle(),
593        };
594
595        let Some(handler) = self.event_handlers.get(&name) else {
596            return PacketEnvelope::with_encryption(
597                PacketBody::EventAck {
598                    event_id,
599                    ok: false,
600                    receipt: json!({}),
601                    error: Some(ApiError::not_found("event handler not found").into_payload()),
602                },
603                self.default_encryption,
604            );
605        };
606
607        match AssertUnwindSafe(handler(ctx)).catch_unwind().await {
608            Ok(Ok(receipt)) => PacketEnvelope::with_encryption(
609                PacketBody::EventAck {
610                    event_id,
611                    ok: true,
612                    receipt,
613                    error: None,
614                },
615                self.default_encryption,
616            ),
617            Ok(Err(error)) => PacketEnvelope::with_encryption(
618                PacketBody::EventAck {
619                    event_id: event_id.clone(),
620                    ok: false,
621                    receipt: json!({}),
622                    error: Some(
623                        self.map_exception(ExceptionContext {
624                            connection_id: connection_id.to_string(),
625                            request_id: Some(event_id.clone()),
626                            target: name,
627                            message_kind: "event",
628                            error,
629                        })
630                        .await,
631                    ),
632                },
633                self.default_encryption,
634            ),
635            Err(_) => PacketEnvelope::with_encryption(
636                PacketBody::EventAck {
637                    event_id: event_id.clone(),
638                    ok: false,
639                    receipt: json!({}),
640                    error: Some(
641                        self.map_exception(ExceptionContext {
642                            connection_id: connection_id.to_string(),
643                            request_id: Some(event_id.clone()),
644                            target: name,
645                            message_kind: "event",
646                            error: ApiError::internal("event handler panicked"),
647                        })
648                        .await,
649                    ),
650                },
651                self.default_encryption,
652            ),
653        }
654    }
655
656    async fn api_error_packet(
657        &self,
658        connection_id: &str,
659        request_id: Option<String>,
660        route: String,
661        error: ApiError,
662    ) -> PacketEnvelope {
663        let request_id = request_id.unwrap_or_else(|| Uuid::new_v4().to_string());
664        let status = error.status;
665        let payload = self
666            .map_exception(ExceptionContext {
667                connection_id: connection_id.to_string(),
668                request_id: Some(request_id.clone()),
669                target: route,
670                message_kind: "api",
671                error,
672            })
673            .await;
674
675        PacketEnvelope::with_encryption(
676            PacketBody::ApiResponse {
677                request_id,
678                ok: false,
679                status,
680                data: json!({}),
681                error: Some(payload),
682                metadata: json!({}),
683            },
684            self.default_encryption,
685        )
686    }
687
688    async fn notify_connected(
689        &self,
690        connection_id: &str,
691        peer_addr: Option<std::net::SocketAddr>,
692    ) {
693        let handlers = self.connection_handlers.clone();
694        for handler in handlers {
695            let context = ServerConnectionContext {
696                connection_id: connection_id.to_string(),
697                peer_addr,
698                server: self.handle(),
699            };
700
701            if AssertUnwindSafe(handler(context)).catch_unwind().await.is_err() {
702                eprintln!("server connected handler panicked");
703            }
704        }
705    }
706
707    async fn notify_disconnected(
708        &self,
709        connection_id: &str,
710        peer_addr: Option<std::net::SocketAddr>,
711        reason: String,
712    ) {
713        let handlers = self.disconnect_handlers.clone();
714        for handler in handlers {
715            let context = ServerDisconnectContext {
716                connection_id: connection_id.to_string(),
717                peer_addr,
718                reason: reason.clone(),
719                server: self.handle(),
720            };
721
722            if AssertUnwindSafe(handler(context)).catch_unwind().await.is_err() {
723                eprintln!("server disconnected handler panicked");
724            }
725        }
726    }
727
728    fn disconnect_reason(result: &Result<(), ServerError>) -> String {
729        match result {
730            Ok(()) => "connection closed".to_string(),
731            Err(ServerError::IdleTimeout(_)) => "idle timeout".to_string(),
732            Err(error) => error.to_string(),
733        }
734    }
735
736    async fn map_exception(&self, context: ExceptionContext) -> ErrorPayload {
737        match &self.exception_handler {
738            Some(handler) => handler(context).await,
739            None => context.error.into_payload(),
740        }
741    }
742}