rustdtp/
server.rs

1//! Protocol server implementation.
2
3use super::command_channel::*;
4use super::timeout::*;
5use crate::crypto::*;
6use crate::error::{Error, Result};
7use crate::util::*;
8use rsa::pkcs8::EncodePublicKey;
9use serde::de::DeserializeOwned;
10use serde::ser::Serialize;
11use std::collections::HashMap;
12use std::future::Future;
13use std::marker::PhantomData;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::sync::Arc;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
19use tokio::sync::mpsc::{channel, Receiver, Sender};
20use tokio::task::JoinHandle;
21
22/// Configuration for a server's event callbacks.
23///
24/// # Events
25///
26/// There are four events for which callbacks can be registered:
27///
28///  - `connect`
29///  - `disconnect`
30///  - `receive`
31///  - `stop`
32///
33/// All callbacks are optional, and can be registered for any combination of
34/// these events. Note that each callback must be provided as a function or
35/// closure returning a thread-safe future. The future will be awaited by the
36/// runtime.
37///
38/// # Example
39///
40/// ```no_run
41/// # use rustdtp::prelude::*;
42///
43/// # #[tokio::main]
44/// # async fn main() {
45/// let server = Server::builder()
46///     .sending::<usize>()
47///     .receiving::<String>()
48///     .with_event_callbacks(
49///         ServerEventCallbacks::new()
50///             .on_connect(move |client_id| async move {
51///                 // some async operation...
52///                 println!("Client with ID {} connected", client_id);
53///             })
54///             .on_disconnect(move |client_id| async move {
55///                 // some async operation...
56///                 println!("Client with ID {} disconnected", client_id);
57///             })
58///             .on_receive(move |client_id, data| async move {
59///                 // some async operation...
60///                 println!("Received data from client with ID {}: {}", client_id, data);
61///             })
62///             .on_stop(move || async move {
63///                 // some async operation...
64///                 println!("Server closed");
65///             })
66///     )
67///     .start(("0.0.0.0", 0))
68///     .await
69///     .unwrap();
70/// # }
71/// ```
72#[allow(clippy::type_complexity)]
73#[must_use = "event callbacks do nothing unless you configure them for a server"]
74pub struct ServerEventCallbacks<R>
75where
76    R: DeserializeOwned + 'static,
77{
78    /// The `connect` event callback.
79    connect: Option<Arc<dyn Fn(usize) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
80    /// The `disconnect` event callback.
81    disconnect:
82        Option<Arc<dyn Fn(usize) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
83    /// The `receive` event callback.
84    receive:
85        Option<Arc<dyn Fn(usize, R) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
86    /// The `stop` event callback.
87    stop: Option<Arc<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
88}
89
90impl<R> ServerEventCallbacks<R>
91where
92    R: DeserializeOwned + 'static,
93{
94    /// Creates a new server event callbacks configuration with all callbacks
95    /// empty.
96    pub const fn new() -> Self {
97        Self {
98            connect: None,
99            disconnect: None,
100            receive: None,
101            stop: None,
102        }
103    }
104
105    /// Registers a callback on the `connect` event.
106    pub fn on_connect<C, F>(mut self, callback: C) -> Self
107    where
108        C: Fn(usize) -> F + Send + Sync + 'static,
109        F: Future<Output = ()> + Send + 'static,
110    {
111        self.connect = Some(Arc::new(move |client_id| Box::pin((callback)(client_id))));
112        self
113    }
114
115    /// Registers a callback on the `disconnect` event.
116    pub fn on_disconnect<C, F>(mut self, callback: C) -> Self
117    where
118        C: Fn(usize) -> F + Send + Sync + 'static,
119        F: Future<Output = ()> + Send + 'static,
120    {
121        self.disconnect = Some(Arc::new(move |client_id| Box::pin((callback)(client_id))));
122        self
123    }
124
125    /// Registers a callback on the `receive` event.
126    pub fn on_receive<C, F>(mut self, callback: C) -> Self
127    where
128        C: Fn(usize, R) -> F + Send + Sync + 'static,
129        F: Future<Output = ()> + Send + 'static,
130    {
131        self.receive = Some(Arc::new(move |client_id, data| {
132            Box::pin((callback)(client_id, data))
133        }));
134        self
135    }
136
137    /// Registers a callback on the `stop` event.
138    pub fn on_stop<C, F>(mut self, callback: C) -> Self
139    where
140        C: Fn() -> F + Send + Sync + 'static,
141        F: Future<Output = ()> + Send + 'static,
142    {
143        self.stop = Some(Arc::new(move || Box::pin((callback)())));
144        self
145    }
146}
147
148impl<R> Default for ServerEventCallbacks<R>
149where
150    R: DeserializeOwned + 'static,
151{
152    fn default() -> Self {
153        Self::new()
154    }
155}
156
157/// An event handling trait for the server.
158///
159/// # Events
160///
161/// There are four events for which methods can be implemented:
162///
163///  - `connect`
164///  - `disconnect`
165///  - `receive`
166///  - `stop`
167///
168/// All method implementations are optional, and can be registered for any
169/// combination of these events. Note that the type that implements the trait
170/// must be `Send + Sync`, and that all event method futures must be `Send`.
171///
172/// # Example
173///
174/// ```no_run
175/// # use rustdtp::prelude::*;
176///
177/// # #[tokio::main]
178/// # async fn main() {
179/// struct MyServerHandler;
180///
181/// impl ServerEventHandler<String> for MyServerHandler {
182///     async fn on_connect(&self, client_id: usize) {
183///         // some async operation...
184///         println!("Client with ID {} connected", client_id);
185///     }
186///
187///     async fn on_disconnect(&self, client_id: usize) {
188///         // some async operation...
189///         println!("Client with ID {} disconnected", client_id);
190///     }
191///
192///     async fn on_receive(&self, client_id: usize, data: String) {
193///         // some async operation...
194///         println!("Received data from client with ID {}: {}", client_id, data);
195///     }
196///
197///     async fn on_stop(&self) {
198///         // some async operation...
199///         println!("Server closed");
200///     }
201/// }
202///
203/// let server = Server::builder()
204///     .sending::<usize>()
205///     .receiving::<String>()
206///     .with_event_handler(MyServerHandler)
207///     .start(("0.0.0.0", 0))
208///     .await
209///     .unwrap();
210/// # }
211/// ```
212pub trait ServerEventHandler<R>
213where
214    Self: Send + Sync,
215    R: DeserializeOwned + 'static,
216{
217    /// Handles the `connect` event.
218    #[allow(unused_variables)]
219    fn on_connect(&self, client_id: usize) -> impl Future<Output = ()> + Send {
220        async {}
221    }
222
223    /// Handles the `disconnect` event.
224    #[allow(unused_variables)]
225    fn on_disconnect(&self, client_id: usize) -> impl Future<Output = ()> + Send {
226        async {}
227    }
228
229    /// Handles the `receive` event.
230    #[allow(unused_variables)]
231    fn on_receive(&self, client_id: usize, data: R) -> impl Future<Output = ()> + Send {
232        async {}
233    }
234
235    /// Handles the `stop` event.
236    fn on_stop(&self) -> impl Future<Output = ()> + Send {
237        async {}
238    }
239}
240
241/// Unknown server sending type.
242pub struct ServerSendingUnknown;
243
244/// Known server sending type, stored as the type parameter `S`.
245pub struct ServerSending<S>(PhantomData<fn() -> S>)
246where
247    S: Serialize + 'static;
248
249/// A server sending marker trait.
250trait ServerSendingConfig {}
251
252impl ServerSendingConfig for ServerSendingUnknown {}
253
254impl<S> ServerSendingConfig for ServerSending<S> where S: Serialize + 'static {}
255
256/// Unknown server receiving type.
257pub struct ServerReceivingUnknown;
258
259/// Known server receiving type, stored as the type parameter `R`.
260pub struct ServerReceiving<R>(PhantomData<fn() -> R>)
261where
262    R: DeserializeOwned + 'static;
263
264/// A server receiving marker trait.
265trait ServerReceivingConfig {}
266
267impl ServerReceivingConfig for ServerReceivingUnknown {}
268
269impl<R> ServerReceivingConfig for ServerReceiving<R> where R: DeserializeOwned + 'static {}
270
271/// Unknown server event reporting type.
272pub struct ServerEventReportingUnknown;
273
274/// Known server event reporting type, stored as the type parameter `E`.
275pub struct ServerEventReporting<E>(E);
276
277/// Server event reporting via callbacks.
278pub struct ServerEventReportingCallbacks<R>(ServerEventCallbacks<R>)
279where
280    R: DeserializeOwned + 'static;
281
282/// Server event reporting via an event handler.
283pub struct ServerEventReportingHandler<R, H>
284where
285    R: DeserializeOwned + 'static,
286    H: ServerEventHandler<R>,
287{
288    /// The event handler instance.
289    handler: H,
290    /// Phantom `R` owner.
291    phantom_receive: PhantomData<fn() -> R>,
292}
293
294/// Server event reporting via a channel.
295pub struct ServerEventReportingChannel;
296
297/// A server event reporting marker trait.
298trait ServerEventReportingConfig {}
299
300impl ServerEventReportingConfig for ServerEventReportingUnknown {}
301
302impl<R> ServerEventReportingConfig for ServerEventReporting<ServerEventReportingCallbacks<R>> where
303    R: DeserializeOwned + 'static
304{
305}
306
307impl<R, H> ServerEventReportingConfig for ServerEventReporting<ServerEventReportingHandler<R, H>>
308where
309    R: DeserializeOwned + 'static,
310    H: ServerEventHandler<R>,
311{
312}
313
314impl ServerEventReportingConfig for ServerEventReporting<ServerEventReportingChannel> {}
315
316/// A builder for the [`Server`].
317///
318/// An instance of this can be constructed using `ServerBuilder::new()` or
319/// `Server::builder()`. The configuration information exists primarily at the
320/// type-level, so it is impossible to misconfigure this.
321///
322/// This method of configuration is technically not necessary, but it is far
323/// clearer and more explicit than simply configuring the `Server` type. Plus,
324/// it provides additional ways of detecting events.
325///
326/// # Configuration
327///
328/// To configure the server, first provide the types that will be sent and
329/// received through the server using the `.sending::<...>()` and
330/// `.receiving::<...>()` methods. Then specify the way in which events will
331/// be detected. There are three methods of receiving events:
332///
333/// - via callback functions (`.with_event_callbacks(...)`)
334/// - via implementation of a handler trait (`.with_event_handler(...)`)
335/// - via a channel (`.with_event_channel()`)
336///
337/// The channel method is the most versatile, hence why it's the `Server`'s
338/// default implementation. The other methods are provided to support a
339/// greater variety of program architectures.
340///
341/// Once configured, the `.start(...)` method, which is effectively identical
342/// to the `Server::start(...)` method, can be called to start the server.
343///
344/// # Example
345///
346/// ```no_run
347/// # use rustdtp::prelude::*;
348///
349/// # #[tokio::main]
350/// # async fn main() {
351/// let (server, server_events) = Server::builder()
352///     .sending::<usize>()
353///     .receiving::<String>()
354///     .with_event_channel()
355///     .start(("0.0.0.0", 0))
356///     .await
357///     .unwrap();
358/// # }
359/// ```
360#[allow(private_bounds)]
361#[must_use = "server builders do nothing unless `start` is called"]
362pub struct ServerBuilder<SC, RC, EC>
363where
364    SC: ServerSendingConfig,
365    RC: ServerReceivingConfig,
366    EC: ServerEventReportingConfig,
367{
368    /// Phantom marker for `SC` and `RC`.
369    marker: PhantomData<fn() -> (SC, RC)>,
370    /// The event reporting configuration.
371    event_reporting: EC,
372}
373
374impl ServerBuilder<ServerSendingUnknown, ServerReceivingUnknown, ServerEventReportingUnknown> {
375    /// Creates a new server builder.
376    pub const fn new() -> Self {
377        Self {
378            marker: PhantomData,
379            event_reporting: ServerEventReportingUnknown,
380        }
381    }
382}
383
384impl Default
385    for ServerBuilder<ServerSendingUnknown, ServerReceivingUnknown, ServerEventReportingUnknown>
386{
387    fn default() -> Self {
388        Self::new()
389    }
390}
391
392#[allow(private_bounds)]
393impl<RC, EC> ServerBuilder<ServerSendingUnknown, RC, EC>
394where
395    RC: ServerReceivingConfig,
396    EC: ServerEventReportingConfig,
397{
398    /// Configures the type of data the server intends to send to clients.
399    pub fn sending<S>(self) -> ServerBuilder<ServerSending<S>, RC, EC>
400    where
401        S: Serialize + 'static,
402    {
403        ServerBuilder {
404            marker: PhantomData,
405            event_reporting: self.event_reporting,
406        }
407    }
408}
409
410#[allow(private_bounds)]
411impl<SC, EC> ServerBuilder<SC, ServerReceivingUnknown, EC>
412where
413    SC: ServerSendingConfig,
414    EC: ServerEventReportingConfig,
415{
416    /// Configures the type of data the server intends to receive from
417    /// clients.
418    pub fn receiving<R>(self) -> ServerBuilder<SC, ServerReceiving<R>, EC>
419    where
420        R: DeserializeOwned + 'static,
421    {
422        ServerBuilder {
423            marker: PhantomData,
424            event_reporting: self.event_reporting,
425        }
426    }
427}
428
429impl<S, R> ServerBuilder<ServerSending<S>, ServerReceiving<R>, ServerEventReportingUnknown>
430where
431    S: Serialize + 'static,
432    R: DeserializeOwned + 'static,
433{
434    /// Configures the server to receive events via callbacks.
435    ///
436    /// Using callbacks is typically considered an anti-pattern in Rust, so
437    /// this should only be used if it makes sense in the context of the
438    /// design of the code utilizing this API.
439    ///
440    /// See [`ServerEventCallbacks`] for more information and examples.
441    pub fn with_event_callbacks(
442        self,
443        callbacks: ServerEventCallbacks<R>,
444    ) -> ServerBuilder<
445        ServerSending<S>,
446        ServerReceiving<R>,
447        ServerEventReporting<ServerEventReportingCallbacks<R>>,
448    >
449    where
450        R: DeserializeOwned + 'static,
451    {
452        ServerBuilder {
453            marker: PhantomData,
454            event_reporting: ServerEventReporting(ServerEventReportingCallbacks(callbacks)),
455        }
456    }
457
458    /// Configures the server to receive events via a trait implementation.
459    ///
460    /// This provides an approach to event handling that closely aligns with
461    /// object-oriented practices.
462    ///
463    /// See [`ServerEventHandler`] for more information and examples.
464    pub fn with_event_handler<H>(
465        self,
466        handler: H,
467    ) -> ServerBuilder<
468        ServerSending<S>,
469        ServerReceiving<R>,
470        ServerEventReporting<ServerEventReportingHandler<R, H>>,
471    >
472    where
473        H: ServerEventHandler<R>,
474    {
475        ServerBuilder {
476            marker: PhantomData,
477            event_reporting: ServerEventReporting(ServerEventReportingHandler {
478                handler,
479                phantom_receive: PhantomData,
480            }),
481        }
482    }
483
484    /// Configures the server to receive events via a channel.
485    ///
486    /// This is the most versatile event handling strategy. In fact, all other
487    /// event handling options use this implementation under the hood.
488    /// Because of its flexibility, this will typically be the desired
489    /// approach.
490    pub fn with_event_channel(
491        self,
492    ) -> ServerBuilder<
493        ServerSending<S>,
494        ServerReceiving<R>,
495        ServerEventReporting<ServerEventReportingChannel>,
496    > {
497        ServerBuilder {
498            marker: PhantomData,
499            event_reporting: ServerEventReporting(ServerEventReportingChannel),
500        }
501    }
502}
503
504impl<S, R>
505    ServerBuilder<
506        ServerSending<S>,
507        ServerReceiving<R>,
508        ServerEventReporting<ServerEventReportingCallbacks<R>>,
509    >
510where
511    S: Serialize + 'static,
512    R: DeserializeOwned + 'static,
513{
514    /// Starts the server. This is effectively identical to [`Server::start`].
515    ///
516    /// # Errors
517    ///
518    /// The set of errors that can occur are identical to that of
519    /// [`Server::start`].
520    #[allow(clippy::future_not_send)]
521    pub async fn start<A>(self, addr: A) -> Result<ServerHandle<S>>
522    where
523        A: ToSocketAddrs,
524    {
525        let (server, mut server_events) = Server::<S, R>::start(addr).await?;
526        let callbacks = self.event_reporting.0 .0;
527
528        tokio::spawn(async move {
529            while let Ok(event) = server_events.next_raw().await {
530                match event {
531                    ServerEventRawSafe::Connect { client_id } => {
532                        if let Some(ref connect) = callbacks.connect {
533                            let connect = Arc::clone(connect);
534                            tokio::spawn(async move {
535                                (*connect)(client_id).await;
536                            });
537                        }
538                    }
539                    ServerEventRawSafe::Disconnect { client_id } => {
540                        if let Some(ref disconnect) = callbacks.disconnect {
541                            let disconnect = Arc::clone(disconnect);
542                            tokio::spawn(async move {
543                                (*disconnect)(client_id).await;
544                            });
545                        }
546                    }
547                    ServerEventRawSafe::Receive { client_id, data } => {
548                        if let Some(ref receive) = callbacks.receive {
549                            let receive = Arc::clone(receive);
550                            tokio::spawn(async move {
551                                let data = data.deserialize();
552                                (*receive)(client_id, data).await;
553                            });
554                        }
555                    }
556                    ServerEventRawSafe::Stop => {
557                        if let Some(ref stop) = callbacks.stop {
558                            let stop = Arc::clone(stop);
559                            tokio::spawn(async move {
560                                (*stop)().await;
561                            });
562                        }
563                    }
564                }
565            }
566        });
567
568        Ok(server)
569    }
570}
571
572impl<S, R, H>
573    ServerBuilder<
574        ServerSending<S>,
575        ServerReceiving<R>,
576        ServerEventReporting<ServerEventReportingHandler<R, H>>,
577    >
578where
579    S: Serialize + 'static,
580    R: DeserializeOwned + 'static,
581    H: ServerEventHandler<R> + 'static,
582{
583    /// Starts the server. This is effectively identical to [`Server::start`].
584    ///
585    /// # Errors
586    ///
587    /// The set of errors that can occur are identical to that of
588    /// [`Server::start`].
589    #[allow(clippy::future_not_send)]
590    pub async fn start<A>(self, addr: A) -> Result<ServerHandle<S>>
591    where
592        A: ToSocketAddrs,
593    {
594        let (server, mut server_events) = Server::<S, R>::start(addr).await?;
595        let handler = Arc::new(self.event_reporting.0.handler);
596
597        tokio::spawn(async move {
598            while let Ok(event) = server_events.next_raw().await {
599                match event {
600                    ServerEventRawSafe::Connect { client_id } => {
601                        let handler = Arc::clone(&handler);
602                        tokio::spawn(async move {
603                            handler.on_connect(client_id).await;
604                        });
605                    }
606                    ServerEventRawSafe::Disconnect { client_id } => {
607                        let handler = Arc::clone(&handler);
608                        tokio::spawn(async move {
609                            handler.on_disconnect(client_id).await;
610                        });
611                    }
612                    ServerEventRawSafe::Receive { client_id, data } => {
613                        let handler = Arc::clone(&handler);
614                        tokio::spawn(async move {
615                            let data = data.deserialize();
616                            handler.on_receive(client_id, data).await;
617                        });
618                    }
619                    ServerEventRawSafe::Stop => {
620                        let handler = Arc::clone(&handler);
621                        tokio::spawn(async move {
622                            handler.on_stop().await;
623                        });
624                    }
625                }
626            }
627        });
628
629        Ok(server)
630    }
631}
632
633impl<S, R>
634    ServerBuilder<
635        ServerSending<S>,
636        ServerReceiving<R>,
637        ServerEventReporting<ServerEventReportingChannel>,
638    >
639where
640    S: Serialize + 'static,
641    R: DeserializeOwned + 'static,
642{
643    /// Starts the server. This is effectively identical to [`Server::start`].
644    ///
645    /// # Errors
646    ///
647    /// The set of errors that can occur are identical to that of
648    /// [`Server::start`].
649    #[allow(clippy::future_not_send)]
650    pub async fn start<A>(self, addr: A) -> Result<(ServerHandle<S>, ServerEventStream<R>)>
651    where
652        A: ToSocketAddrs,
653    {
654        Server::<S, R>::start(addr).await
655    }
656}
657
658/// A command sent from the server handle to the background server task.
659pub enum ServerCommand {
660    /// Stop the server.
661    Stop,
662    /// Send data to a client.
663    Send {
664        /// The ID of the client to send the data to.
665        client_id: usize,
666        /// The data to send.
667        data: Vec<u8>,
668    },
669    /// Send data to all clients.
670    SendAll {
671        /// The data to send.
672        data: Vec<u8>,
673    },
674    /// Get the local server address.
675    GetAddr,
676    /// Get the address of a client.
677    GetClientAddr {
678        /// The ID of the client.
679        client_id: usize,
680    },
681    /// Disconnect a client from the server.
682    RemoveClient {
683        /// The ID of the client.
684        client_id: usize,
685    },
686}
687
688/// The return value of a command executed on the background server task.
689pub enum ServerCommandReturn {
690    /// Stop return value.
691    Stop(Result<()>),
692    /// Sent data return value.
693    Send(Result<()>),
694    /// Sent data to all return value.
695    SendAll(Result<()>),
696    /// Local server address return value.
697    GetAddr(Result<SocketAddr>),
698    /// Client address return value.
699    GetClientAddr(Result<SocketAddr>),
700    /// Disconnect client return value.
701    RemoveClient(Result<()>),
702}
703
704/// A command sent from the server background task to a client background task.
705pub enum ServerClientCommand {
706    /// Send data to the client.
707    Send {
708        /// The serialized data to send.
709        data: Arc<[u8]>,
710    },
711    /// Get the address of the client.
712    GetAddr,
713    /// Disconnect the client.
714    Remove,
715}
716
717/// The return value of a command executed on a client background task.
718pub enum ServerClientCommandReturn {
719    /// Send data return value.
720    Send(Result<()>),
721    /// Client address return value.
722    GetAddr(Result<SocketAddr>),
723    /// Disconnect client return value.
724    Remove(Result<()>),
725}
726
727/// An event from the server.
728///
729/// ```no_run
730/// use rustdtp::prelude::*;
731///
732/// #[tokio::main]
733/// async fn main() {
734///     // Create the server
735///     let (mut server, mut server_events) = Server::builder()
736///         .sending::<()>()
737///         .receiving::<String>()
738///         .with_event_channel()
739///         .start(("0.0.0.0", 0))
740///         .await
741///         .unwrap();
742///
743///     // Iterate over events
744///     while let Ok(event) = server_events.next().await {
745///         match event {
746///             ServerEvent::Connect { client_id } => {
747///                 println!("Client with ID {} connected", client_id);
748///             }
749///             ServerEvent::Disconnect { client_id } => {
750///                 println!("Client with ID {} disconnected", client_id);
751///             }
752///             ServerEvent::Receive { client_id, data } => {
753///                 println!("Client with ID {} sent: {}", client_id, data);
754///             }
755///             ServerEvent::Stop => {
756///                 // No more events will be sent, and the loop will end
757///                 println!("Server closed");
758///             }
759///         }
760///     }
761/// }
762/// ```
763#[derive(Debug, Clone)]
764pub enum ServerEvent<R>
765where
766    R: DeserializeOwned + 'static,
767{
768    /// A client connected.
769    Connect {
770        /// The ID of the client that connected.
771        client_id: usize,
772    },
773    /// A client disconnected.
774    Disconnect {
775        /// The ID of the client that disconnected.
776        client_id: usize,
777    },
778    /// Data received from a client.
779    Receive {
780        /// The ID of the client that sent the data.
781        client_id: usize,
782        /// The data itself.
783        data: R,
784    },
785    /// Server stopped.
786    Stop,
787}
788
789/// Identical to `ServerEvent`, but with the received data in serialized form.
790#[derive(Debug, Clone)]
791enum ServerEventRaw {
792    /// A client connected.
793    Connect {
794        /// The ID of the client that connected.
795        client_id: usize,
796    },
797    /// A client disconnected.
798    Disconnect {
799        /// The ID of the client that disconnected.
800        client_id: usize,
801    },
802    /// Data received from a client.
803    Receive {
804        /// The ID of the client that sent the data.
805        client_id: usize,
806        /// The data itself.
807        data: Vec<u8>,
808    },
809    /// Server stopped.
810    Stop,
811}
812
813impl ServerEventRaw {
814    /// Deserializes this instance into a `ServerEvent`.
815    fn deserialize<R>(&self) -> Result<ServerEvent<R>>
816    where
817        R: DeserializeOwned + 'static,
818    {
819        match self {
820            Self::Connect { client_id } => Ok(ServerEvent::Connect {
821                client_id: *client_id,
822            }),
823            Self::Disconnect { client_id } => Ok(ServerEvent::Disconnect {
824                client_id: *client_id,
825            }),
826            Self::Receive { client_id, data } => {
827                Ok(
828                    serde_json::from_slice(data).map(|data| ServerEvent::Receive {
829                        client_id: *client_id,
830                        data,
831                    })?,
832                )
833            }
834            Self::Stop => Ok(ServerEvent::Stop),
835        }
836    }
837}
838
839/// The serialized data component of a server receive event. The data is
840/// guaranteed to be deserializable into an instance of `R`.
841#[derive(Debug, Clone)]
842struct ServerEventRawSafeData<R>
843where
844    R: DeserializeOwned + 'static,
845{
846    /// The raw data.
847    data: Vec<u8>,
848    /// Phantom marker for `R`.
849    marker: PhantomData<fn() -> R>,
850}
851
852/// Identical to `ServerEventRaw`, but with the guarantee that the data can be
853/// deserialized into an instance of `R`.
854#[derive(Debug, Clone)]
855enum ServerEventRawSafe<R>
856where
857    R: DeserializeOwned + 'static,
858{
859    /// A client connected.
860    Connect {
861        /// The ID of the client that connected.
862        client_id: usize,
863    },
864    /// A client disconnected.
865    Disconnect {
866        /// The ID of the client that disconnected.
867        client_id: usize,
868    },
869    /// Data received from a client.
870    Receive {
871        /// The ID of the client that sent the data.
872        client_id: usize,
873        /// The data itself.
874        data: ServerEventRawSafeData<R>,
875    },
876    /// Server stopped.
877    Stop,
878}
879
880impl<R> TryFrom<ServerEventRaw> for ServerEventRawSafe<R>
881where
882    R: DeserializeOwned + 'static,
883{
884    type Error = Error;
885
886    fn try_from(value: ServerEventRaw) -> std::result::Result<Self, Self::Error> {
887        value.deserialize::<R>()?;
888
889        Ok(match value {
890            ServerEventRaw::Connect { client_id } => Self::Connect { client_id },
891            ServerEventRaw::Disconnect { client_id } => Self::Disconnect { client_id },
892            ServerEventRaw::Receive { client_id, data } => Self::Receive {
893                client_id,
894                data: ServerEventRawSafeData {
895                    data,
896                    marker: PhantomData,
897                },
898            },
899            ServerEventRaw::Stop => Self::Stop,
900        })
901    }
902}
903
904impl<R> ServerEventRawSafeData<R>
905where
906    R: DeserializeOwned + 'static,
907{
908    /// Deserialize the raw data into an instance of `R`. This is guaranteed to
909    /// succeed.
910    fn deserialize(&self) -> R {
911        serde_json::from_slice(&self.data).unwrap()
912    }
913}
914
915impl<R> ServerEventRawSafe<R>
916where
917    R: DeserializeOwned + 'static,
918{
919    /// Deserializes this instance into a `ServerEvent`.
920    #[allow(dead_code)]
921    fn deserialize(&self) -> ServerEvent<R> {
922        match self {
923            Self::Connect { client_id } => ServerEvent::Connect {
924                client_id: *client_id,
925            },
926            Self::Disconnect { client_id } => ServerEvent::Disconnect {
927                client_id: *client_id,
928            },
929            Self::Receive { client_id, data } => ServerEvent::Receive {
930                client_id: *client_id,
931                data: data.deserialize(),
932            },
933            Self::Stop => ServerEvent::Stop,
934        }
935    }
936}
937
938/// An asynchronous stream of server events.
939pub struct ServerEventStream<R>
940where
941    R: DeserializeOwned + 'static,
942{
943    /// The event receiver channel.
944    event_receiver: Receiver<ServerEventRaw>,
945    /// Phantom marker for `R`.
946    marker: PhantomData<fn() -> R>,
947}
948
949impl<R> ServerEventStream<R>
950where
951    R: DeserializeOwned + 'static,
952{
953    /// Consumes and returns the next value in the stream.
954    ///
955    /// # Errors
956    ///
957    /// This will return an error if the stream is closed, or if there was an
958    /// error while deserializing data received.
959    pub async fn next(&mut self) -> Result<ServerEvent<R>> {
960        match self.event_receiver.recv().await {
961            Some(serialized_event) => serialized_event.deserialize(),
962            None => Err(Error::ConnectionClosed),
963        }
964    }
965
966    /// Identical to `next`, but doesn't deserialize the event. It does,
967    /// however, validate that the event can be deserialized without error.
968    async fn next_raw(&mut self) -> Result<ServerEventRawSafe<R>> {
969        match self.event_receiver.recv().await {
970            Some(serialized_event) => serialized_event.try_into(),
971            None => Err(Error::ConnectionClosed),
972        }
973    }
974}
975
976/// A handle to the server.
977pub struct ServerHandle<S>
978where
979    S: Serialize + 'static,
980{
981    /// The channel through which commands can be sent to the background task.
982    server_command_sender: CommandChannelSender<ServerCommand, ServerCommandReturn>,
983    /// The handle to the background task.
984    server_task_handle: JoinHandle<Result<()>>,
985    /// Phantom marker for `S`.
986    marker: PhantomData<fn() -> S>,
987}
988
989impl<S> ServerHandle<S>
990where
991    S: Serialize + 'static,
992{
993    /// Stop the server, disconnect all clients, and shut down all network
994    /// interfaces.
995    ///
996    /// Returns a result of the error variant if an error occurred while
997    /// disconnecting clients.
998    ///
999    /// ```no_run
1000    /// use rustdtp::prelude::*;
1001    ///
1002    /// #[tokio::main]
1003    /// async fn main() {
1004    ///     // Create the server
1005    ///     let (mut server, mut server_events) = Server::builder()
1006    ///         .sending::<()>()
1007    ///         .receiving::<String>()
1008    ///         .with_event_channel()
1009    ///         .start(("0.0.0.0", 0))
1010    ///         .await
1011    ///         .unwrap();
1012    ///
1013    ///     // Wait for events until a client requests the server be stopped
1014    ///     while let Ok(event) = server_events.next().await {
1015    ///         match event {
1016    ///             // Stop the server when a client requests it be stopped
1017    ///             ServerEvent::Receive { client_id, data } => {
1018    ///                 if data.as_str() == "Stop the server!" {
1019    ///                     println!("Server stop requested");
1020    ///                     server.stop().await.unwrap();
1021    ///                     break;
1022    ///                 }
1023    ///             }
1024    ///             _ => {}  // Do nothing for other events
1025    ///         }
1026    ///     }
1027    ///
1028    ///     // The last event should be a stop event
1029    ///     assert!(matches!(server_events.next().await.unwrap(), ServerEvent::Stop));
1030    /// }
1031    /// ```
1032    ///
1033    /// # Errors
1034    ///
1035    /// This will return an error if the server socket has already closed, or if
1036    /// the underlying server loop returned an error.
1037    #[allow(clippy::missing_panics_doc)]
1038    pub async fn stop(mut self) -> Result<()> {
1039        let value = self
1040            .server_command_sender
1041            .send_command(ServerCommand::Stop)
1042            .await?;
1043        // `unwrap` is allowed, as an error is returned only when the underlying
1044        // task panics, which it never should
1045        self.server_task_handle.await.unwrap()?;
1046        unwrap_enum!(value, ServerCommandReturn::Stop)
1047    }
1048
1049    /// Send data to a client.
1050    ///
1051    /// - `client_id`: the ID of the client to send the data to.
1052    /// - `data`: the data to send.
1053    ///
1054    /// Returns a result of the error variant if an error occurred while
1055    /// sending.
1056    ///
1057    /// ```no_run
1058    /// use rustdtp::prelude::*;
1059    ///
1060    /// #[tokio::main]
1061    /// async fn main() {
1062    ///     // Create the server
1063    ///     let (mut server, mut server_events) = Server::builder()
1064    ///         .sending::<String>()
1065    ///         .receiving::<()>()
1066    ///         .with_event_channel()
1067    ///         .start(("0.0.0.0", 0))
1068    ///         .await
1069    ///         .unwrap();
1070    ///
1071    ///     // Iterate over events
1072    ///     while let Ok(event) = server_events.next().await {
1073    ///         match event {
1074    ///             // When a client connects, send a greeting
1075    ///             ServerEvent::Connect { client_id } => {
1076    ///                 server.send(client_id, format!("Hello, client {}!", client_id)).await.unwrap();
1077    ///             }
1078    ///             _ => {}  // Do nothing for other events
1079    ///         }
1080    ///     }
1081    /// }
1082    /// ```
1083    ///
1084    /// # Errors
1085    ///
1086    /// This will return an error if the server socket has closed, or if data
1087    /// serialization fails.
1088    #[allow(clippy::future_not_send)]
1089    pub async fn send(&mut self, client_id: usize, data: S) -> Result<()> {
1090        let data_serialized = serde_json::to_vec(&data)?;
1091        let value = self
1092            .server_command_sender
1093            .send_command(ServerCommand::Send {
1094                client_id,
1095                data: data_serialized,
1096            })
1097            .await?;
1098        unwrap_enum!(value, ServerCommandReturn::Send)
1099    }
1100
1101    /// Send data to all clients.
1102    ///
1103    /// - `data`: the data to send.
1104    ///
1105    /// Returns a result of the error variant if an error occurred while
1106    /// sending.
1107    ///
1108    /// ```no_run
1109    /// use rustdtp::prelude::*;
1110    ///
1111    /// #[tokio::main]
1112    /// async fn main() {
1113    ///     // Create the server
1114    ///     let (mut server, mut server_events) = Server::builder()
1115    ///         .sending::<String>()
1116    ///         .receiving::<()>()
1117    ///         .with_event_channel()
1118    ///         .start(("0.0.0.0", 0))
1119    ///         .await
1120    ///         .unwrap();
1121    ///
1122    ///     // Iterate over events
1123    ///     while let Ok(event) = server_events.next().await {
1124    ///         match event {
1125    ///             // When a client connects, notify all clients
1126    ///             ServerEvent::Connect { client_id } => {
1127    ///                 server.send_all(format!("A new client with ID {} has joined!", client_id)).await.unwrap();
1128    ///             }
1129    ///             _ => {}  // Do nothing for other events
1130    ///         }
1131    ///     }
1132    /// }
1133    /// ```
1134    ///
1135    /// # Errors
1136    ///
1137    /// This will return an error if the server socket has closed, or if data
1138    /// serialization fails.
1139    #[allow(clippy::future_not_send)]
1140    pub async fn send_all(&mut self, data: S) -> Result<()> {
1141        let data_serialized = serde_json::to_vec(&data)?;
1142        let value = self
1143            .server_command_sender
1144            .send_command(ServerCommand::SendAll {
1145                data: data_serialized,
1146            })
1147            .await?;
1148        unwrap_enum!(value, ServerCommandReturn::SendAll)
1149    }
1150
1151    /// Get the address the server is listening on.
1152    ///
1153    /// Returns a result containing the address the server is listening on, or
1154    /// the error variant if an error occurred.
1155    ///
1156    /// ```no_run
1157    /// use rustdtp::prelude::*;
1158    ///
1159    /// #[tokio::main]
1160    /// async fn main() {
1161    ///     // Create the server
1162    ///     let (mut server, mut server_events) = Server::builder()
1163    ///         .sending::<()>()
1164    ///         .receiving::<()>()
1165    ///         .with_event_channel()
1166    ///         .start(("0.0.0.0", 0))
1167    ///         .await
1168    ///         .unwrap();
1169    ///
1170    ///     // Get the server address
1171    ///     let addr = server.get_addr().await.unwrap();
1172    ///     println!("Server listening on {}", addr);
1173    /// }
1174    /// ```
1175    ///
1176    /// # Errors
1177    ///
1178    /// This will return an error if the server socket has closed.
1179    pub async fn get_addr(&mut self) -> Result<SocketAddr> {
1180        let value = self
1181            .server_command_sender
1182            .send_command(ServerCommand::GetAddr)
1183            .await?;
1184        unwrap_enum!(value, ServerCommandReturn::GetAddr)
1185    }
1186
1187    /// Get the address of a connected client.
1188    ///
1189    /// - `client_id`: the ID of the client.
1190    ///
1191    /// Returns a result containing the address of the client, or the error
1192    /// variant if the client ID is invalid.
1193    ///
1194    /// ```no_run
1195    /// use rustdtp::prelude::*;
1196    ///
1197    /// #[tokio::main]
1198    /// async fn main() {
1199    ///     // Create the server
1200    ///     let (mut server, mut server_events) = Server::builder()
1201    ///         .sending::<()>()
1202    ///         .receiving::<()>()
1203    ///         .with_event_channel()
1204    ///         .start(("0.0.0.0", 0))
1205    ///         .await
1206    ///         .unwrap();
1207    ///
1208    ///     // Iterate over events
1209    ///     while let Ok(event) = server_events.next().await {
1210    ///         match event {
1211    ///             // When a client connects, get their address
1212    ///             ServerEvent::Connect { client_id } => {
1213    ///                 let addr = server.get_client_addr(client_id).await.unwrap();
1214    ///                 println!("Client with ID {} connected from {}", client_id, addr);
1215    ///             }
1216    ///             _ => {}  // Do nothing for other events
1217    ///         }
1218    ///     }
1219    /// }
1220    /// ```
1221    ///
1222    /// # Errors
1223    ///
1224    /// This will return an error if the server socket has closed, or if the
1225    /// client ID is invalid.
1226    pub async fn get_client_addr(&mut self, client_id: usize) -> Result<SocketAddr> {
1227        let value = self
1228            .server_command_sender
1229            .send_command(ServerCommand::GetClientAddr { client_id })
1230            .await?;
1231        unwrap_enum!(value, ServerCommandReturn::GetClientAddr)
1232    }
1233
1234    /// Disconnect a client from the server.
1235    ///
1236    /// - `client_id`: the ID of the client.
1237    ///
1238    /// Returns a result of the error variant if an error occurred while
1239    /// disconnecting the client, or if the client ID is invalid.
1240    ///
1241    /// ```no_run
1242    /// use rustdtp::prelude::*;
1243    ///
1244    /// #[tokio::main]
1245    /// async fn main() {
1246    ///     // Create the server
1247    ///     let (mut server, mut server_events) = Server::builder()
1248    ///         .sending::<String>()
1249    ///         .receiving::<i32>()
1250    ///         .with_event_channel()
1251    ///         .start(("0.0.0.0", 0))
1252    ///         .await
1253    ///         .unwrap();
1254    ///
1255    ///     // Iterate over events
1256    ///     while let Ok(event) = server_events.next().await {
1257    ///         match event {
1258    ///             // Disconnect a client if they send an even number
1259    ///             ServerEvent::Receive { client_id, data } => {
1260    ///                 if data % 2 == 0 {
1261    ///                     println!("Disconnecting client with ID {}", client_id);
1262    ///                     server.send(client_id, "Even numbers are not allowed".to_owned()).await.unwrap();
1263    ///                     server.remove_client(client_id).await.unwrap();
1264    ///                 }
1265    ///             }
1266    ///             _ => {}  // Do nothing for other events
1267    ///         }
1268    ///     }
1269    ///
1270    ///     // The last event should be a stop event
1271    ///     assert!(matches!(server_events.next().await.unwrap(), ServerEvent::Stop));
1272    /// }
1273    /// ```
1274    ///
1275    /// # Errors
1276    ///
1277    /// This will return an error if the server socket has closed, or if the
1278    /// client ID is invalid.
1279    pub async fn remove_client(&mut self, client_id: usize) -> Result<()> {
1280        let value = self
1281            .server_command_sender
1282            .send_command(ServerCommand::RemoveClient { client_id })
1283            .await?;
1284        unwrap_enum!(value, ServerCommandReturn::RemoveClient)
1285    }
1286}
1287
1288/// A socket server.
1289///
1290/// The server takes two generic parameters:
1291///
1292/// - `S`: the type of data that will be **sent** to clients.
1293/// - `R`: the type of data that will be **received** from clients.
1294///
1295/// Both types must be serializable in order to be sent through the socket. When
1296/// creating clients, the types should be swapped, since the server's send type will be the client's receive type and vice versa.
1297///
1298/// ```no_run
1299/// use rustdtp::prelude::*;
1300///
1301/// #[tokio::main]
1302/// async fn main() {
1303///     // Create a server that receives strings and returns the length of each string
1304///     let (mut server, mut server_events) = Server::builder()
1305///         .sending::<usize>()
1306///         .receiving::<String>()
1307///         .with_event_channel()
1308///         .start(("0.0.0.0", 0))
1309///         .await
1310///         .unwrap();
1311///
1312///     // Iterate over events
1313///     while let Ok(event) = server_events.next().await {
1314///         match event {
1315///             ServerEvent::Connect { client_id } => {
1316///                 println!("Client with ID {} connected", client_id);
1317///             }
1318///             ServerEvent::Disconnect { client_id } => {
1319///                 println!("Client with ID {} disconnected", client_id);
1320///             }
1321///             ServerEvent::Receive { client_id, data } => {
1322///                 // Send back the length of the string
1323///                 server.send(client_id, data.len()).await.unwrap();
1324///             }
1325///             ServerEvent::Stop => {
1326///                 // No more events will be sent, and the loop will end
1327///                 println!("Server closed");
1328///             }
1329///         }
1330///     }
1331/// }
1332/// ```
1333pub struct Server<S, R>
1334where
1335    S: Serialize + 'static,
1336    R: DeserializeOwned + 'static,
1337{
1338    /// Phantom marker for `S` and `R`.
1339    marker: PhantomData<fn() -> (S, R)>,
1340}
1341
1342impl Server<(), ()> {
1343    /// Constructs a server builder. Use this for a clearer, more explicit,
1344    /// and more featureful server configuration. See [`ServerBuilder`] for
1345    /// more information.
1346    pub const fn builder(
1347    ) -> ServerBuilder<ServerSendingUnknown, ServerReceivingUnknown, ServerEventReportingUnknown>
1348    {
1349        ServerBuilder::new()
1350    }
1351}
1352
1353impl<S, R> Server<S, R>
1354where
1355    S: Serialize + 'static,
1356    R: DeserializeOwned + 'static,
1357{
1358    /// Start a socket server.
1359    ///
1360    /// - `addr`: the address for the server to listen on.
1361    ///
1362    /// Returns a result containing a handle to the server and a channel from
1363    /// which to receive server events, or the error variant if an error
1364    /// occurred while starting the server.
1365    ///
1366    /// ```no_run
1367    /// use rustdtp::prelude::*;
1368    ///
1369    /// #[tokio::main]
1370    /// async fn main() {
1371    ///     let (mut server, mut server_events) = Server::builder()
1372    ///         .sending::<()>()
1373    ///         .receiving::<()>()
1374    ///         .with_event_channel()
1375    ///         .start(("0.0.0.0", 0))
1376    ///         .await
1377    ///         .unwrap();
1378    /// }
1379    /// ```
1380    ///
1381    /// Neither the server handle nor the event receiver should be dropped until
1382    /// the server has been stopped. Prematurely dropping either one can cause
1383    /// unintended behavior.
1384    ///
1385    /// # Errors
1386    ///
1387    /// This will return an error if a TCP listener cannot be bound to the
1388    /// provided address.
1389    #[allow(clippy::future_not_send)]
1390    pub async fn start<A>(addr: A) -> Result<(ServerHandle<S>, ServerEventStream<R>)>
1391    where
1392        A: ToSocketAddrs,
1393    {
1394        // Server TCP listener
1395        let listener = TcpListener::bind(addr).await?;
1396        // Channels for sending commands from the server handle to the background server task
1397        let (server_command_sender, server_command_receiver) = command_channel();
1398        // Channels for sending event notifications from the background server task
1399        let (server_event_sender, server_event_receiver) = channel(CHANNEL_BUFFER_SIZE);
1400
1401        // Start the background server task, saving the join handle for when the server is stopped
1402        let server_task_handle = tokio::spawn(server_handler(
1403            listener,
1404            server_event_sender,
1405            server_command_receiver,
1406        ));
1407
1408        // Create a handle for the server
1409        let server_handle = ServerHandle {
1410            server_command_sender,
1411            server_task_handle,
1412            marker: PhantomData,
1413        };
1414
1415        // Create an event stream for the server
1416        let server_event_stream = ServerEventStream {
1417            event_receiver: server_event_receiver,
1418            marker: PhantomData,
1419        };
1420
1421        Ok((server_handle, server_event_stream))
1422    }
1423}
1424
1425/// The server client loop. Handles received data and commands.
1426#[allow(clippy::too_many_lines)]
1427async fn server_client_loop(
1428    client_id: usize,
1429    mut socket: TcpStream,
1430    server_client_event_sender: Sender<ServerEventRaw>,
1431    mut client_command_receiver: CommandChannelReceiver<
1432        ServerClientCommand,
1433        ServerClientCommandReturn,
1434    >,
1435) -> Result<()> {
1436    // Generate RSA keys
1437    let (rsa_pub, rsa_priv) = rsa_keys().await?;
1438    // Convert the RSA public key into a string...
1439    let rsa_pub_str = rsa_pub
1440        .to_public_key_pem(rsa::pkcs1::LineEnding::LF)
1441        .map_err(|_| Error::InvalidRsaKeyEncoding)?;
1442    // ...and then into bytes
1443    let rsa_pub_bytes = rsa_pub_str.as_bytes();
1444    // Create the buffer containing the RSA public key and its size
1445    let mut rsa_pub_buffer = encode_message_size(rsa_pub_bytes.len()).to_vec();
1446    // Extend the buffer with the RSA public key bytes
1447    rsa_pub_buffer.extend(rsa_pub_bytes);
1448    // Send the RSA public key to the client
1449    socket.write_all(&rsa_pub_buffer).await?;
1450    // Flush the stream
1451    socket.flush().await?;
1452
1453    // Buffer in which to receive the size portion of the AES key
1454    let mut aes_key_size_buffer = [0; LEN_SIZE];
1455    // Read the AES key from the client
1456    handshake_timeout! {
1457        socket.read_exact(&mut aes_key_size_buffer[..])
1458    }??;
1459
1460    // Decode the size portion of the AES key
1461    let aes_key_size = decode_message_size(&aes_key_size_buffer);
1462    // Initialize the buffer for the AES key
1463    let mut aes_key_buffer = vec![0; aes_key_size];
1464
1465    // Read the AES key portion from the client socket, returning an error if
1466    // the socket could not be read
1467    data_read_timeout! {
1468        socket.read_exact(&mut aes_key_buffer[..])
1469    }??;
1470
1471    // Decrypt the AES key
1472    let aes_key_decrypted = rsa_decrypt(rsa_priv, aes_key_buffer.into()).await?;
1473
1474    // Assert that the AES key is the correct size
1475    let aes_key: [u8; AES_KEY_SIZE] = aes_key_decrypted
1476        .try_into()
1477        .map_err(|_| Error::InvalidAesKeySize)?;
1478
1479    // Buffer in which to receive the size portion of a message
1480    let mut size_buffer = [0; LEN_SIZE];
1481
1482    // Client loop
1483    loop {
1484        // Await messages from the client
1485        // and commands from the background server task
1486        tokio::select! {
1487            // Read the size portion from the client socket
1488            read_value = socket.read(&mut size_buffer[..]) => {
1489                // Return an error if the socket could not be read
1490                let n_size = read_value?;
1491
1492                // If there were no bytes read, or if there were fewer bytes
1493                // read than there should have been, close the socket
1494                if n_size != LEN_SIZE {
1495                    socket.shutdown().await?;
1496                    break;
1497                };
1498
1499                // Decode the size portion of the message
1500                let encrypted_data_size = decode_message_size(&size_buffer);
1501                // Initialize the buffer for the data portion of the message
1502                let mut encrypted_data_buffer = vec![0; encrypted_data_size];
1503
1504                // Read the data portion from the client socket, returning an
1505                // error if the socket could not be read
1506                let n_data = data_read_timeout! {
1507                    socket.read_exact(&mut encrypted_data_buffer[..])
1508                }??;
1509
1510                // If there were no bytes read, or if there were fewer bytes
1511                // read than there should have been, close the socket
1512                if n_data != encrypted_data_size {
1513                    socket.shutdown().await?;
1514                    break;
1515                }
1516
1517                // Decrypt the data
1518                let data_serialized = aes_decrypt(aes_key, encrypted_data_buffer.into()).await?;
1519
1520                // Send an event to note that a piece of data has been received from
1521                // a client
1522                if let Err(_e) = server_client_event_sender.send(ServerEventRaw::Receive { client_id, data: data_serialized }).await {
1523                    // Sending failed, disconnect the client
1524                    socket.shutdown().await?;
1525                    break;
1526                }
1527            }
1528            // Process a command sent to the client
1529            client_command_value = client_command_receiver.recv_command() => {
1530                // Handle the command, or lack thereof if the channel is closed
1531                match client_command_value {
1532                    Ok(client_command) => {
1533                        // Process the command
1534                        match client_command {
1535                            ServerClientCommand::Send { data } => {
1536                                let value = 'val: {
1537                                    // Encrypt the serialized data
1538                                    let encrypted_data_buffer = break_on_err!(aes_encrypt(aes_key, data).await, 'val);
1539                                    // Encode the message size to a buffer
1540                                    let size_buffer = encode_message_size(encrypted_data_buffer.len());
1541
1542                                    // Initialize the message buffer
1543                                    let mut buffer = vec![];
1544                                    // Extend the buffer to contain the payload
1545                                    // size
1546                                    buffer.extend_from_slice(&size_buffer);
1547                                    // Extend the buffer to contain the payload
1548                                    // data
1549                                    buffer.extend(&encrypted_data_buffer);
1550
1551                                    // Write the data to the client socket
1552                                    break_on_err!(socket.write_all(&buffer).await, 'val);
1553                                    // Flush the stream
1554                                    break_on_err!(socket.flush().await, 'val);
1555
1556                                    Ok(())
1557                                };
1558
1559                                let error_occurred = value.is_err();
1560
1561                                // Return the status of the send operation
1562                                if let Err(_e) = client_command_receiver.command_return(ServerClientCommandReturn::Send(value)).await {
1563                                    // Channel is closed, disconnect the client
1564                                    socket.shutdown().await?;
1565                                    break;
1566                                }
1567
1568                                // If the send failed, disconnect the client
1569                                if error_occurred {
1570                                    socket.shutdown().await?;
1571                                    break;
1572                                }
1573                            },
1574                            ServerClientCommand::GetAddr => {
1575                                // Get the client socket's address
1576                                let addr = socket.peer_addr();
1577
1578                                // Return the address
1579                                if let Err(_e) = client_command_receiver.command_return(ServerClientCommandReturn::GetAddr(addr.map_err(Into::into))).await {
1580                                    // Channel is closed, disconnect the client
1581                                    socket.shutdown().await?;
1582                                    break;
1583                                }
1584                            },
1585                            ServerClientCommand::Remove => {
1586                                // Disconnect the client
1587                                let value = socket.shutdown().await;
1588
1589                                // Return the status of the remove operation,
1590                                // ignoring failures, since a failure indicates
1591                                // that the client has probably already
1592                                // disconnected
1593                                _ = client_command_receiver.command_return(ServerClientCommandReturn::Remove(value.map_err(Into::into))).await;
1594
1595                                // Break the client loop
1596                                break;
1597                            },
1598                        }
1599                    },
1600                    Err(_e) => {
1601                        // Channel is closed, disconnect the client
1602                        socket.shutdown().await?;
1603                        break;
1604                    },
1605                }
1606            }
1607        }
1608    }
1609
1610    Ok(())
1611}
1612
1613/// Starts a server client loop in the background.
1614fn server_client_handler(
1615    client_id: usize,
1616    socket: TcpStream,
1617    server_client_event_sender: Sender<ServerEventRaw>,
1618    client_cleanup_sender: Sender<usize>,
1619) -> (
1620    CommandChannelSender<ServerClientCommand, ServerClientCommandReturn>,
1621    JoinHandle<Result<()>>,
1622) {
1623    // Channels for sending commands from the background server task to a background client task
1624    let (client_command_sender, client_command_receiver) = command_channel();
1625
1626    // Start a background client task, saving the join handle for when the
1627    // server is stopped
1628    let client_task_handle = tokio::spawn(async move {
1629        let res = server_client_loop(
1630            client_id,
1631            socket,
1632            server_client_event_sender,
1633            client_command_receiver,
1634        )
1635        .await;
1636
1637        // Tell the server to clean up after the client, ignoring failures,
1638        // since a failure indicates that the server has probably closed
1639        _ = client_cleanup_sender.send(client_id).await;
1640
1641        res
1642    });
1643
1644    (client_command_sender, client_task_handle)
1645}
1646
1647/// The server loop. Handles incoming connections and commands.
1648#[allow(clippy::too_many_lines)]
1649async fn server_loop(
1650    listener: TcpListener,
1651    server_event_sender: Sender<ServerEventRaw>,
1652    mut server_command_receiver: CommandChannelReceiver<ServerCommand, ServerCommandReturn>,
1653    client_command_senders: &mut HashMap<
1654        usize,
1655        CommandChannelSender<ServerClientCommand, ServerClientCommandReturn>,
1656    >,
1657    client_join_handles: &mut HashMap<usize, JoinHandle<Result<()>>>,
1658) -> Result<()> {
1659    // ID assigned to the next client
1660    let mut next_client_id = 0usize;
1661    // Channel for indicating that a client needs to be cleaned up after
1662    let (server_client_cleanup_sender, mut server_client_cleanup_receiver) =
1663        channel::<usize>(CHANNEL_BUFFER_SIZE);
1664
1665    // Server loop
1666    loop {
1667        // Await new clients connecting,
1668        // commands from the server handle,
1669        // and notifications of clients disconnecting
1670        tokio::select! {
1671            // Accept a connecting client
1672            accept_value = listener.accept() => {
1673                // Get the client socket, exiting if an error occurs
1674                let (socket, _) = accept_value?;
1675                // New client ID
1676                let client_id = next_client_id;
1677                // Increment next client ID
1678                next_client_id += 1;
1679                // Clone the event sender so the background client tasks can
1680                // send events
1681                let server_client_event_sender = server_event_sender.clone();
1682                // Clone the client cleanup sender to the background client
1683                // tasks can be cleaned up properly
1684                let client_cleanup_sender = server_client_cleanup_sender.clone();
1685
1686                // Handle the new connection
1687                let (client_command_sender, client_task_handle) = server_client_handler(client_id, socket, server_client_event_sender, client_cleanup_sender);
1688                // Keep track of client command senders
1689                client_command_senders.insert(client_id, client_command_sender);
1690                // Keep track of client task handles
1691                client_join_handles.insert(client_id, client_task_handle);
1692
1693                // Send an event to note that a client has connected
1694                // successfully
1695                if let Err(_e) = server_event_sender
1696                    .send(ServerEventRaw::Connect { client_id })
1697                    .await
1698                {
1699                    // Server is probably closed
1700                    break;
1701                }
1702            },
1703            // Process a command from the server handle
1704            command_value = server_command_receiver.recv_command() => {
1705                // Handle the command, or lack thereof if the channel is closed
1706                match command_value {
1707                    Ok(command) => {
1708                        match command {
1709                            ServerCommand::Stop => {
1710                                // If a command fails to send, the server has
1711                                // already closed, and the error can be ignored.
1712                                // It should be noted that this is not where the
1713                                // stop method actually returns its `Result`.
1714                                // This immediately returns with an `Ok` status.
1715                                // The real return value is the `Result`
1716                                // returned from the server task join handle.
1717                                _ = server_command_receiver.command_return(ServerCommandReturn::Stop(Ok(()))).await;
1718
1719                                // Break the server loop, the clients will be
1720                                // disconnected before the task ends
1721                                break;
1722                            },
1723                            ServerCommand::Send { client_id, data } => {
1724                                let value = match client_command_senders.get_mut(&client_id) {
1725                                    Some(client_command_sender) => {
1726                                        // Turn `Vec<u8>` into `Arc<[u8]>`,
1727                                        // making it more easily shareable
1728                                        let shareable_data = Arc::<[u8]>::from(data);
1729
1730                                        match client_command_sender.send_command(ServerClientCommand::Send { data: shareable_data }).await {
1731                                            Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::Send),
1732                                            Err(_e) => {
1733                                                // The channel is closed, and
1734                                                // the client has probably been
1735                                                // disconnected, so the error
1736                                                // can be ignored
1737                                                Ok(())
1738                                            },
1739                                        }
1740                                    },
1741                                    None => Err(Error::InvalidClientId(client_id)),
1742                                };
1743
1744                                // If a command fails to send, the client has probably disconnected,
1745                                // and the error can be ignored
1746                                _ = server_command_receiver.command_return(ServerCommandReturn::Send(value)).await;
1747                            },
1748                            ServerCommand::SendAll { data } => {
1749                                let value = {
1750                                    // Turn `Vec<u8>` into `Arc<[u8]>`, making
1751                                    // it more easily shareable
1752                                    let shareable_data = Arc::<[u8]>::from(data);
1753
1754                                    let send_futures = client_command_senders.iter_mut().map(|(_client_id, client_command_sender)| async {
1755                                        match client_command_sender.send_command(ServerClientCommand::Send { data: Arc::clone(&shareable_data) }).await {
1756                                            Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::Send),
1757                                            Err(_e) => {
1758                                                // The channel is closed, and
1759                                                // the client has probably been
1760                                                // disconnected, so the error
1761                                                // can be ignored
1762                                                Ok(())
1763                                            }
1764                                        }
1765                                    });
1766
1767                                    let resolved = futures::future::join_all(send_futures).await;
1768                                    resolved.into_iter().collect::<Result<Vec<_>>>().map(|_| ())
1769                                };
1770
1771                                // If a command fails to send, the client has
1772                                // probably disconnected, and the error can be
1773                                // ignored
1774                                _ = server_command_receiver.command_return(ServerCommandReturn::SendAll(value)).await;
1775                            },
1776                            ServerCommand::GetAddr => {
1777                                // Get the server listener's address
1778                                let addr = listener.local_addr();
1779
1780                                // If a command fails to send, the client has
1781                                // probably disconnected, and the error can be
1782                                // ignored
1783                                _ = server_command_receiver.command_return(ServerCommandReturn::GetAddr(addr.map_err(Into::into))).await;
1784                            },
1785                            ServerCommand::GetClientAddr { client_id } => {
1786                                let value = match client_command_senders.get_mut(&client_id) {
1787                                    Some(client_command_sender) => match client_command_sender.send_command(ServerClientCommand::GetAddr).await {
1788                                        Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::GetAddr),
1789                                        Err(_e) => {
1790                                            // The channel is closed, and the
1791                                            // client has probably been
1792                                            // disconnected, so the error can be
1793                                            // treated as an invalid client
1794                                            // error
1795                                            Err(Error::InvalidClientId(client_id))
1796                                        },
1797                                    },
1798                                    None => Err(Error::InvalidClientId(client_id)),
1799                                };
1800
1801                                // If a command fails to send, the client has
1802                                // probably disconnected, and the error can be
1803                                // ignored
1804                                _ = server_command_receiver.command_return(ServerCommandReturn::GetClientAddr(value)).await;
1805                            },
1806                            ServerCommand::RemoveClient { client_id } => {
1807                                let value = match client_command_senders.get_mut(&client_id) {
1808                                    Some(client_command_sender) => match client_command_sender.send_command(ServerClientCommand::Remove).await {
1809                                        Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::Remove),
1810                                        Err(_e) => {
1811                                            // The channel is closed, and the
1812                                            // client has probably been
1813                                            // disconnected, so the error can be
1814                                            // ignored
1815                                            Ok(())
1816                                        },
1817                                    },
1818                                    None => Err(Error::InvalidClientId(client_id)),
1819                                };
1820
1821                                // If a command fails to send, the client has
1822                                // probably disconnected already, and the error
1823                                // can be ignored
1824                                _ = server_command_receiver.command_return(ServerCommandReturn::RemoveClient(value)).await;
1825                            },
1826                        }
1827                    },
1828                    Err(_e) => {
1829                        // Server is probably closed, exit
1830                        break;
1831                    },
1832                }
1833            }
1834            // Clean up after a disconnecting client
1835            disconnecting_client_id = server_client_cleanup_receiver.recv() => {
1836                match disconnecting_client_id {
1837                    Some(client_id) => {
1838                        // Remove the client's command sender, which will be
1839                        // dropped after this block ends
1840                        client_command_senders.remove(&client_id);
1841
1842                        // Remove the client's join handle
1843                        if let Some(handle) = client_join_handles.remove(&client_id) {
1844                            // Join the client's handle
1845                            if let Err(e) = handle.await.unwrap() {
1846                                if cfg!(test) {
1847                                    // If testing, fail
1848                                    Err(e)?;
1849                                } else {
1850                                    // If not testing, ignore client handler
1851                                    // errors
1852                                }
1853                            }
1854                        }
1855
1856                        // Send an event to note that a client has disconnected
1857                        if let Err(_e) = server_event_sender.send(ServerEventRaw::Disconnect { client_id }).await {
1858                            // Server is probably closed, exit
1859                            break;
1860                        }
1861                    },
1862                    None => {
1863                        // Server is probably closed, exit
1864                        break;
1865                    },
1866                }
1867            }
1868        }
1869    }
1870
1871    Ok(())
1872}
1873
1874/// Starts the server loop task in the background.
1875async fn server_handler(
1876    listener: TcpListener,
1877    server_event_sender: Sender<ServerEventRaw>,
1878    server_command_receiver: CommandChannelReceiver<ServerCommand, ServerCommandReturn>,
1879) -> Result<()> {
1880    // Collection of channels for sending commands from the background server
1881    // task to a background client task
1882    let mut client_command_senders: HashMap<
1883        usize,
1884        CommandChannelSender<ServerClientCommand, ServerClientCommandReturn>,
1885    > = HashMap::new();
1886    // Background client task join handles
1887    let mut client_join_handles: HashMap<usize, JoinHandle<Result<()>>> = HashMap::new();
1888
1889    // Wrap server loop in a block to catch all exit scenarios
1890    let server_exit = server_loop(
1891        listener,
1892        server_event_sender.clone(),
1893        server_command_receiver,
1894        &mut client_command_senders,
1895        &mut client_join_handles,
1896    )
1897    .await;
1898
1899    // Send a remove command to all clients
1900    futures::future::join_all(client_command_senders.into_values().map(
1901        |mut client_command_sender| async move {
1902            // If a command fails to send, the client has probably disconnected
1903            // already, and the error can be ignored
1904            _ = client_command_sender
1905                .send_command(ServerClientCommand::Remove)
1906                .await;
1907        },
1908    ))
1909    .await;
1910
1911    // Join all background client tasks before exiting
1912    futures::future::join_all(client_join_handles.into_values().map(|handle| async move {
1913        if let Err(e) = handle.await.unwrap() {
1914            if cfg!(test) {
1915                // If testing, fail
1916                Err(e)?;
1917            } else {
1918                // If not testing, ignore client handler errors
1919            }
1920        }
1921
1922        Ok(())
1923    }))
1924    .await
1925    .into_iter()
1926    .collect::<Result<Vec<_>>>()?;
1927
1928    // Send a stop event, ignoring send errors
1929    _ = server_event_sender.send(ServerEventRaw::Stop).await;
1930
1931    // Return server loop result
1932    server_exit
1933}