rustdtp/server.rs
1//! Protocol server implementation.
2
3use super::command_channel::*;
4use super::timeout::*;
5use crate::crypto::*;
6use crate::error::{Error, Result};
7use crate::util::*;
8use rsa::pkcs8::EncodePublicKey;
9use serde::de::DeserializeOwned;
10use serde::ser::Serialize;
11use std::collections::HashMap;
12use std::future::Future;
13use std::marker::PhantomData;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::sync::Arc;
17use tokio::io::{AsyncReadExt, AsyncWriteExt};
18use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
19use tokio::sync::mpsc::{channel, Receiver, Sender};
20use tokio::task::JoinHandle;
21
22/// Configuration for a server's event callbacks.
23///
24/// # Events
25///
26/// There are four events for which callbacks can be registered:
27///
28/// - `connect`
29/// - `disconnect`
30/// - `receive`
31/// - `stop`
32///
33/// All callbacks are optional, and can be registered for any combination of
34/// these events. Note that each callback must be provided as a function or
35/// closure returning a thread-safe future. The future will be awaited by the
36/// runtime.
37///
38/// # Example
39///
40/// ```no_run
41/// # use rustdtp::prelude::*;
42///
43/// # #[tokio::main]
44/// # async fn main() {
45/// let server = Server::builder()
46/// .sending::<usize>()
47/// .receiving::<String>()
48/// .with_event_callbacks(
49/// ServerEventCallbacks::new()
50/// .on_connect(move |client_id| async move {
51/// // some async operation...
52/// println!("Client with ID {} connected", client_id);
53/// })
54/// .on_disconnect(move |client_id| async move {
55/// // some async operation...
56/// println!("Client with ID {} disconnected", client_id);
57/// })
58/// .on_receive(move |client_id, data| async move {
59/// // some async operation...
60/// println!("Received data from client with ID {}: {}", client_id, data);
61/// })
62/// .on_stop(move || async move {
63/// // some async operation...
64/// println!("Server closed");
65/// })
66/// )
67/// .start(("0.0.0.0", 0))
68/// .await
69/// .unwrap();
70/// # }
71/// ```
72#[allow(clippy::type_complexity)]
73#[must_use = "event callbacks do nothing unless you configure them for a server"]
74pub struct ServerEventCallbacks<R>
75where
76 R: DeserializeOwned + 'static,
77{
78 /// The `connect` event callback.
79 connect: Option<Arc<dyn Fn(usize) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
80 /// The `disconnect` event callback.
81 disconnect:
82 Option<Arc<dyn Fn(usize) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
83 /// The `receive` event callback.
84 receive:
85 Option<Arc<dyn Fn(usize, R) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
86 /// The `stop` event callback.
87 stop: Option<Arc<dyn Fn() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>>,
88}
89
90impl<R> ServerEventCallbacks<R>
91where
92 R: DeserializeOwned + 'static,
93{
94 /// Creates a new server event callbacks configuration with all callbacks
95 /// empty.
96 pub const fn new() -> Self {
97 Self {
98 connect: None,
99 disconnect: None,
100 receive: None,
101 stop: None,
102 }
103 }
104
105 /// Registers a callback on the `connect` event.
106 pub fn on_connect<C, F>(mut self, callback: C) -> Self
107 where
108 C: Fn(usize) -> F + Send + Sync + 'static,
109 F: Future<Output = ()> + Send + 'static,
110 {
111 self.connect = Some(Arc::new(move |client_id| Box::pin((callback)(client_id))));
112 self
113 }
114
115 /// Registers a callback on the `disconnect` event.
116 pub fn on_disconnect<C, F>(mut self, callback: C) -> Self
117 where
118 C: Fn(usize) -> F + Send + Sync + 'static,
119 F: Future<Output = ()> + Send + 'static,
120 {
121 self.disconnect = Some(Arc::new(move |client_id| Box::pin((callback)(client_id))));
122 self
123 }
124
125 /// Registers a callback on the `receive` event.
126 pub fn on_receive<C, F>(mut self, callback: C) -> Self
127 where
128 C: Fn(usize, R) -> F + Send + Sync + 'static,
129 F: Future<Output = ()> + Send + 'static,
130 {
131 self.receive = Some(Arc::new(move |client_id, data| {
132 Box::pin((callback)(client_id, data))
133 }));
134 self
135 }
136
137 /// Registers a callback on the `stop` event.
138 pub fn on_stop<C, F>(mut self, callback: C) -> Self
139 where
140 C: Fn() -> F + Send + Sync + 'static,
141 F: Future<Output = ()> + Send + 'static,
142 {
143 self.stop = Some(Arc::new(move || Box::pin((callback)())));
144 self
145 }
146}
147
148impl<R> Default for ServerEventCallbacks<R>
149where
150 R: DeserializeOwned + 'static,
151{
152 fn default() -> Self {
153 Self::new()
154 }
155}
156
157/// An event handling trait for the server.
158///
159/// # Events
160///
161/// There are four events for which methods can be implemented:
162///
163/// - `connect`
164/// - `disconnect`
165/// - `receive`
166/// - `stop`
167///
168/// All method implementations are optional, and can be registered for any
169/// combination of these events. Note that the type that implements the trait
170/// must be `Send + Sync`, and that all event method futures must be `Send`.
171///
172/// # Example
173///
174/// ```no_run
175/// # use rustdtp::prelude::*;
176///
177/// # #[tokio::main]
178/// # async fn main() {
179/// struct MyServerHandler;
180///
181/// impl ServerEventHandler<String> for MyServerHandler {
182/// async fn on_connect(&self, client_id: usize) {
183/// // some async operation...
184/// println!("Client with ID {} connected", client_id);
185/// }
186///
187/// async fn on_disconnect(&self, client_id: usize) {
188/// // some async operation...
189/// println!("Client with ID {} disconnected", client_id);
190/// }
191///
192/// async fn on_receive(&self, client_id: usize, data: String) {
193/// // some async operation...
194/// println!("Received data from client with ID {}: {}", client_id, data);
195/// }
196///
197/// async fn on_stop(&self) {
198/// // some async operation...
199/// println!("Server closed");
200/// }
201/// }
202///
203/// let server = Server::builder()
204/// .sending::<usize>()
205/// .receiving::<String>()
206/// .with_event_handler(MyServerHandler)
207/// .start(("0.0.0.0", 0))
208/// .await
209/// .unwrap();
210/// # }
211/// ```
212pub trait ServerEventHandler<R>
213where
214 Self: Send + Sync,
215 R: DeserializeOwned + 'static,
216{
217 /// Handles the `connect` event.
218 #[allow(unused_variables)]
219 fn on_connect(&self, client_id: usize) -> impl Future<Output = ()> + Send {
220 async {}
221 }
222
223 /// Handles the `disconnect` event.
224 #[allow(unused_variables)]
225 fn on_disconnect(&self, client_id: usize) -> impl Future<Output = ()> + Send {
226 async {}
227 }
228
229 /// Handles the `receive` event.
230 #[allow(unused_variables)]
231 fn on_receive(&self, client_id: usize, data: R) -> impl Future<Output = ()> + Send {
232 async {}
233 }
234
235 /// Handles the `stop` event.
236 fn on_stop(&self) -> impl Future<Output = ()> + Send {
237 async {}
238 }
239}
240
241/// Unknown server sending type.
242pub struct ServerSendingUnknown;
243
244/// Known server sending type, stored as the type parameter `S`.
245pub struct ServerSending<S>(PhantomData<fn() -> S>)
246where
247 S: Serialize + 'static;
248
249/// A server sending marker trait.
250trait ServerSendingConfig {}
251
252impl ServerSendingConfig for ServerSendingUnknown {}
253
254impl<S> ServerSendingConfig for ServerSending<S> where S: Serialize + 'static {}
255
256/// Unknown server receiving type.
257pub struct ServerReceivingUnknown;
258
259/// Known server receiving type, stored as the type parameter `R`.
260pub struct ServerReceiving<R>(PhantomData<fn() -> R>)
261where
262 R: DeserializeOwned + 'static;
263
264/// A server receiving marker trait.
265trait ServerReceivingConfig {}
266
267impl ServerReceivingConfig for ServerReceivingUnknown {}
268
269impl<R> ServerReceivingConfig for ServerReceiving<R> where R: DeserializeOwned + 'static {}
270
271/// Unknown server event reporting type.
272pub struct ServerEventReportingUnknown;
273
274/// Known server event reporting type, stored as the type parameter `E`.
275pub struct ServerEventReporting<E>(E);
276
277/// Server event reporting via callbacks.
278pub struct ServerEventReportingCallbacks<R>(ServerEventCallbacks<R>)
279where
280 R: DeserializeOwned + 'static;
281
282/// Server event reporting via an event handler.
283pub struct ServerEventReportingHandler<R, H>
284where
285 R: DeserializeOwned + 'static,
286 H: ServerEventHandler<R>,
287{
288 /// The event handler instance.
289 handler: H,
290 /// Phantom `R` owner.
291 phantom_receive: PhantomData<fn() -> R>,
292}
293
294/// Server event reporting via a channel.
295pub struct ServerEventReportingChannel;
296
297/// A server event reporting marker trait.
298trait ServerEventReportingConfig {}
299
300impl ServerEventReportingConfig for ServerEventReportingUnknown {}
301
302impl<R> ServerEventReportingConfig for ServerEventReporting<ServerEventReportingCallbacks<R>> where
303 R: DeserializeOwned + 'static
304{
305}
306
307impl<R, H> ServerEventReportingConfig for ServerEventReporting<ServerEventReportingHandler<R, H>>
308where
309 R: DeserializeOwned + 'static,
310 H: ServerEventHandler<R>,
311{
312}
313
314impl ServerEventReportingConfig for ServerEventReporting<ServerEventReportingChannel> {}
315
316/// A builder for the [`Server`].
317///
318/// An instance of this can be constructed using `ServerBuilder::new()` or
319/// `Server::builder()`. The configuration information exists primarily at the
320/// type-level, so it is impossible to misconfigure this.
321///
322/// This method of configuration is technically not necessary, but it is far
323/// clearer and more explicit than simply configuring the `Server` type. Plus,
324/// it provides additional ways of detecting events.
325///
326/// # Configuration
327///
328/// To configure the server, first provide the types that will be sent and
329/// received through the server using the `.sending::<...>()` and
330/// `.receiving::<...>()` methods. Then specify the way in which events will
331/// be detected. There are three methods of receiving events:
332///
333/// - via callback functions (`.with_event_callbacks(...)`)
334/// - via implementation of a handler trait (`.with_event_handler(...)`)
335/// - via a channel (`.with_event_channel()`)
336///
337/// The channel method is the most versatile, hence why it's the `Server`'s
338/// default implementation. The other methods are provided to support a
339/// greater variety of program architectures.
340///
341/// Once configured, the `.start(...)` method, which is effectively identical
342/// to the `Server::start(...)` method, can be called to start the server.
343///
344/// # Example
345///
346/// ```no_run
347/// # use rustdtp::prelude::*;
348///
349/// # #[tokio::main]
350/// # async fn main() {
351/// let (server, server_events) = Server::builder()
352/// .sending::<usize>()
353/// .receiving::<String>()
354/// .with_event_channel()
355/// .start(("0.0.0.0", 0))
356/// .await
357/// .unwrap();
358/// # }
359/// ```
360#[allow(private_bounds)]
361#[must_use = "server builders do nothing unless `start` is called"]
362pub struct ServerBuilder<SC, RC, EC>
363where
364 SC: ServerSendingConfig,
365 RC: ServerReceivingConfig,
366 EC: ServerEventReportingConfig,
367{
368 /// Phantom marker for `SC` and `RC`.
369 marker: PhantomData<fn() -> (SC, RC)>,
370 /// The event reporting configuration.
371 event_reporting: EC,
372}
373
374impl ServerBuilder<ServerSendingUnknown, ServerReceivingUnknown, ServerEventReportingUnknown> {
375 /// Creates a new server builder.
376 pub const fn new() -> Self {
377 Self {
378 marker: PhantomData,
379 event_reporting: ServerEventReportingUnknown,
380 }
381 }
382}
383
384impl Default
385 for ServerBuilder<ServerSendingUnknown, ServerReceivingUnknown, ServerEventReportingUnknown>
386{
387 fn default() -> Self {
388 Self::new()
389 }
390}
391
392#[allow(private_bounds)]
393impl<RC, EC> ServerBuilder<ServerSendingUnknown, RC, EC>
394where
395 RC: ServerReceivingConfig,
396 EC: ServerEventReportingConfig,
397{
398 /// Configures the type of data the server intends to send to clients.
399 pub fn sending<S>(self) -> ServerBuilder<ServerSending<S>, RC, EC>
400 where
401 S: Serialize + 'static,
402 {
403 ServerBuilder {
404 marker: PhantomData,
405 event_reporting: self.event_reporting,
406 }
407 }
408}
409
410#[allow(private_bounds)]
411impl<SC, EC> ServerBuilder<SC, ServerReceivingUnknown, EC>
412where
413 SC: ServerSendingConfig,
414 EC: ServerEventReportingConfig,
415{
416 /// Configures the type of data the server intends to receive from
417 /// clients.
418 pub fn receiving<R>(self) -> ServerBuilder<SC, ServerReceiving<R>, EC>
419 where
420 R: DeserializeOwned + 'static,
421 {
422 ServerBuilder {
423 marker: PhantomData,
424 event_reporting: self.event_reporting,
425 }
426 }
427}
428
429impl<S, R> ServerBuilder<ServerSending<S>, ServerReceiving<R>, ServerEventReportingUnknown>
430where
431 S: Serialize + 'static,
432 R: DeserializeOwned + 'static,
433{
434 /// Configures the server to receive events via callbacks.
435 ///
436 /// Using callbacks is typically considered an anti-pattern in Rust, so
437 /// this should only be used if it makes sense in the context of the
438 /// design of the code utilizing this API.
439 ///
440 /// See [`ServerEventCallbacks`] for more information and examples.
441 pub fn with_event_callbacks(
442 self,
443 callbacks: ServerEventCallbacks<R>,
444 ) -> ServerBuilder<
445 ServerSending<S>,
446 ServerReceiving<R>,
447 ServerEventReporting<ServerEventReportingCallbacks<R>>,
448 >
449 where
450 R: DeserializeOwned + 'static,
451 {
452 ServerBuilder {
453 marker: PhantomData,
454 event_reporting: ServerEventReporting(ServerEventReportingCallbacks(callbacks)),
455 }
456 }
457
458 /// Configures the server to receive events via a trait implementation.
459 ///
460 /// This provides an approach to event handling that closely aligns with
461 /// object-oriented practices.
462 ///
463 /// See [`ServerEventHandler`] for more information and examples.
464 pub fn with_event_handler<H>(
465 self,
466 handler: H,
467 ) -> ServerBuilder<
468 ServerSending<S>,
469 ServerReceiving<R>,
470 ServerEventReporting<ServerEventReportingHandler<R, H>>,
471 >
472 where
473 H: ServerEventHandler<R>,
474 {
475 ServerBuilder {
476 marker: PhantomData,
477 event_reporting: ServerEventReporting(ServerEventReportingHandler {
478 handler,
479 phantom_receive: PhantomData,
480 }),
481 }
482 }
483
484 /// Configures the server to receive events via a channel.
485 ///
486 /// This is the most versatile event handling strategy. In fact, all other
487 /// event handling options use this implementation under the hood.
488 /// Because of its flexibility, this will typically be the desired
489 /// approach.
490 pub fn with_event_channel(
491 self,
492 ) -> ServerBuilder<
493 ServerSending<S>,
494 ServerReceiving<R>,
495 ServerEventReporting<ServerEventReportingChannel>,
496 > {
497 ServerBuilder {
498 marker: PhantomData,
499 event_reporting: ServerEventReporting(ServerEventReportingChannel),
500 }
501 }
502}
503
504impl<S, R>
505 ServerBuilder<
506 ServerSending<S>,
507 ServerReceiving<R>,
508 ServerEventReporting<ServerEventReportingCallbacks<R>>,
509 >
510where
511 S: Serialize + 'static,
512 R: DeserializeOwned + 'static,
513{
514 /// Starts the server. This is effectively identical to [`Server::start`].
515 ///
516 /// # Errors
517 ///
518 /// The set of errors that can occur are identical to that of
519 /// [`Server::start`].
520 #[allow(clippy::future_not_send)]
521 pub async fn start<A>(self, addr: A) -> Result<ServerHandle<S>>
522 where
523 A: ToSocketAddrs,
524 {
525 let (server, mut server_events) = Server::<S, R>::start(addr).await?;
526 let callbacks = self.event_reporting.0 .0;
527
528 tokio::spawn(async move {
529 while let Ok(event) = server_events.next_raw().await {
530 match event {
531 ServerEventRawSafe::Connect { client_id } => {
532 if let Some(ref connect) = callbacks.connect {
533 let connect = Arc::clone(connect);
534 tokio::spawn(async move {
535 (*connect)(client_id).await;
536 });
537 }
538 }
539 ServerEventRawSafe::Disconnect { client_id } => {
540 if let Some(ref disconnect) = callbacks.disconnect {
541 let disconnect = Arc::clone(disconnect);
542 tokio::spawn(async move {
543 (*disconnect)(client_id).await;
544 });
545 }
546 }
547 ServerEventRawSafe::Receive { client_id, data } => {
548 if let Some(ref receive) = callbacks.receive {
549 let receive = Arc::clone(receive);
550 tokio::spawn(async move {
551 let data = data.deserialize();
552 (*receive)(client_id, data).await;
553 });
554 }
555 }
556 ServerEventRawSafe::Stop => {
557 if let Some(ref stop) = callbacks.stop {
558 let stop = Arc::clone(stop);
559 tokio::spawn(async move {
560 (*stop)().await;
561 });
562 }
563 }
564 }
565 }
566 });
567
568 Ok(server)
569 }
570}
571
572impl<S, R, H>
573 ServerBuilder<
574 ServerSending<S>,
575 ServerReceiving<R>,
576 ServerEventReporting<ServerEventReportingHandler<R, H>>,
577 >
578where
579 S: Serialize + 'static,
580 R: DeserializeOwned + 'static,
581 H: ServerEventHandler<R> + 'static,
582{
583 /// Starts the server. This is effectively identical to [`Server::start`].
584 ///
585 /// # Errors
586 ///
587 /// The set of errors that can occur are identical to that of
588 /// [`Server::start`].
589 #[allow(clippy::future_not_send)]
590 pub async fn start<A>(self, addr: A) -> Result<ServerHandle<S>>
591 where
592 A: ToSocketAddrs,
593 {
594 let (server, mut server_events) = Server::<S, R>::start(addr).await?;
595 let handler = Arc::new(self.event_reporting.0.handler);
596
597 tokio::spawn(async move {
598 while let Ok(event) = server_events.next_raw().await {
599 match event {
600 ServerEventRawSafe::Connect { client_id } => {
601 let handler = Arc::clone(&handler);
602 tokio::spawn(async move {
603 handler.on_connect(client_id).await;
604 });
605 }
606 ServerEventRawSafe::Disconnect { client_id } => {
607 let handler = Arc::clone(&handler);
608 tokio::spawn(async move {
609 handler.on_disconnect(client_id).await;
610 });
611 }
612 ServerEventRawSafe::Receive { client_id, data } => {
613 let handler = Arc::clone(&handler);
614 tokio::spawn(async move {
615 let data = data.deserialize();
616 handler.on_receive(client_id, data).await;
617 });
618 }
619 ServerEventRawSafe::Stop => {
620 let handler = Arc::clone(&handler);
621 tokio::spawn(async move {
622 handler.on_stop().await;
623 });
624 }
625 }
626 }
627 });
628
629 Ok(server)
630 }
631}
632
633impl<S, R>
634 ServerBuilder<
635 ServerSending<S>,
636 ServerReceiving<R>,
637 ServerEventReporting<ServerEventReportingChannel>,
638 >
639where
640 S: Serialize + 'static,
641 R: DeserializeOwned + 'static,
642{
643 /// Starts the server. This is effectively identical to [`Server::start`].
644 ///
645 /// # Errors
646 ///
647 /// The set of errors that can occur are identical to that of
648 /// [`Server::start`].
649 #[allow(clippy::future_not_send)]
650 pub async fn start<A>(self, addr: A) -> Result<(ServerHandle<S>, ServerEventStream<R>)>
651 where
652 A: ToSocketAddrs,
653 {
654 Server::<S, R>::start(addr).await
655 }
656}
657
658/// A command sent from the server handle to the background server task.
659pub enum ServerCommand {
660 /// Stop the server.
661 Stop,
662 /// Send data to a client.
663 Send {
664 /// The ID of the client to send the data to.
665 client_id: usize,
666 /// The data to send.
667 data: Vec<u8>,
668 },
669 /// Send data to all clients.
670 SendAll {
671 /// The data to send.
672 data: Vec<u8>,
673 },
674 /// Get the local server address.
675 GetAddr,
676 /// Get the address of a client.
677 GetClientAddr {
678 /// The ID of the client.
679 client_id: usize,
680 },
681 /// Disconnect a client from the server.
682 RemoveClient {
683 /// The ID of the client.
684 client_id: usize,
685 },
686}
687
688/// The return value of a command executed on the background server task.
689pub enum ServerCommandReturn {
690 /// Stop return value.
691 Stop(Result<()>),
692 /// Sent data return value.
693 Send(Result<()>),
694 /// Sent data to all return value.
695 SendAll(Result<()>),
696 /// Local server address return value.
697 GetAddr(Result<SocketAddr>),
698 /// Client address return value.
699 GetClientAddr(Result<SocketAddr>),
700 /// Disconnect client return value.
701 RemoveClient(Result<()>),
702}
703
704/// A command sent from the server background task to a client background task.
705pub enum ServerClientCommand {
706 /// Send data to the client.
707 Send {
708 /// The serialized data to send.
709 data: Arc<[u8]>,
710 },
711 /// Get the address of the client.
712 GetAddr,
713 /// Disconnect the client.
714 Remove,
715}
716
717/// The return value of a command executed on a client background task.
718pub enum ServerClientCommandReturn {
719 /// Send data return value.
720 Send(Result<()>),
721 /// Client address return value.
722 GetAddr(Result<SocketAddr>),
723 /// Disconnect client return value.
724 Remove(Result<()>),
725}
726
727/// An event from the server.
728///
729/// ```no_run
730/// use rustdtp::prelude::*;
731///
732/// #[tokio::main]
733/// async fn main() {
734/// // Create the server
735/// let (mut server, mut server_events) = Server::builder()
736/// .sending::<()>()
737/// .receiving::<String>()
738/// .with_event_channel()
739/// .start(("0.0.0.0", 0))
740/// .await
741/// .unwrap();
742///
743/// // Iterate over events
744/// while let Ok(event) = server_events.next().await {
745/// match event {
746/// ServerEvent::Connect { client_id } => {
747/// println!("Client with ID {} connected", client_id);
748/// }
749/// ServerEvent::Disconnect { client_id } => {
750/// println!("Client with ID {} disconnected", client_id);
751/// }
752/// ServerEvent::Receive { client_id, data } => {
753/// println!("Client with ID {} sent: {}", client_id, data);
754/// }
755/// ServerEvent::Stop => {
756/// // No more events will be sent, and the loop will end
757/// println!("Server closed");
758/// }
759/// }
760/// }
761/// }
762/// ```
763#[derive(Debug, Clone)]
764pub enum ServerEvent<R>
765where
766 R: DeserializeOwned + 'static,
767{
768 /// A client connected.
769 Connect {
770 /// The ID of the client that connected.
771 client_id: usize,
772 },
773 /// A client disconnected.
774 Disconnect {
775 /// The ID of the client that disconnected.
776 client_id: usize,
777 },
778 /// Data received from a client.
779 Receive {
780 /// The ID of the client that sent the data.
781 client_id: usize,
782 /// The data itself.
783 data: R,
784 },
785 /// Server stopped.
786 Stop,
787}
788
789/// Identical to `ServerEvent`, but with the received data in serialized form.
790#[derive(Debug, Clone)]
791enum ServerEventRaw {
792 /// A client connected.
793 Connect {
794 /// The ID of the client that connected.
795 client_id: usize,
796 },
797 /// A client disconnected.
798 Disconnect {
799 /// The ID of the client that disconnected.
800 client_id: usize,
801 },
802 /// Data received from a client.
803 Receive {
804 /// The ID of the client that sent the data.
805 client_id: usize,
806 /// The data itself.
807 data: Vec<u8>,
808 },
809 /// Server stopped.
810 Stop,
811}
812
813impl ServerEventRaw {
814 /// Deserializes this instance into a `ServerEvent`.
815 fn deserialize<R>(&self) -> Result<ServerEvent<R>>
816 where
817 R: DeserializeOwned + 'static,
818 {
819 match self {
820 Self::Connect { client_id } => Ok(ServerEvent::Connect {
821 client_id: *client_id,
822 }),
823 Self::Disconnect { client_id } => Ok(ServerEvent::Disconnect {
824 client_id: *client_id,
825 }),
826 Self::Receive { client_id, data } => {
827 Ok(
828 serde_json::from_slice(data).map(|data| ServerEvent::Receive {
829 client_id: *client_id,
830 data,
831 })?,
832 )
833 }
834 Self::Stop => Ok(ServerEvent::Stop),
835 }
836 }
837}
838
839/// The serialized data component of a server receive event. The data is
840/// guaranteed to be deserializable into an instance of `R`.
841#[derive(Debug, Clone)]
842struct ServerEventRawSafeData<R>
843where
844 R: DeserializeOwned + 'static,
845{
846 /// The raw data.
847 data: Vec<u8>,
848 /// Phantom marker for `R`.
849 marker: PhantomData<fn() -> R>,
850}
851
852/// Identical to `ServerEventRaw`, but with the guarantee that the data can be
853/// deserialized into an instance of `R`.
854#[derive(Debug, Clone)]
855enum ServerEventRawSafe<R>
856where
857 R: DeserializeOwned + 'static,
858{
859 /// A client connected.
860 Connect {
861 /// The ID of the client that connected.
862 client_id: usize,
863 },
864 /// A client disconnected.
865 Disconnect {
866 /// The ID of the client that disconnected.
867 client_id: usize,
868 },
869 /// Data received from a client.
870 Receive {
871 /// The ID of the client that sent the data.
872 client_id: usize,
873 /// The data itself.
874 data: ServerEventRawSafeData<R>,
875 },
876 /// Server stopped.
877 Stop,
878}
879
880impl<R> TryFrom<ServerEventRaw> for ServerEventRawSafe<R>
881where
882 R: DeserializeOwned + 'static,
883{
884 type Error = Error;
885
886 fn try_from(value: ServerEventRaw) -> std::result::Result<Self, Self::Error> {
887 value.deserialize::<R>()?;
888
889 Ok(match value {
890 ServerEventRaw::Connect { client_id } => Self::Connect { client_id },
891 ServerEventRaw::Disconnect { client_id } => Self::Disconnect { client_id },
892 ServerEventRaw::Receive { client_id, data } => Self::Receive {
893 client_id,
894 data: ServerEventRawSafeData {
895 data,
896 marker: PhantomData,
897 },
898 },
899 ServerEventRaw::Stop => Self::Stop,
900 })
901 }
902}
903
904impl<R> ServerEventRawSafeData<R>
905where
906 R: DeserializeOwned + 'static,
907{
908 /// Deserialize the raw data into an instance of `R`. This is guaranteed to
909 /// succeed.
910 fn deserialize(&self) -> R {
911 serde_json::from_slice(&self.data).unwrap()
912 }
913}
914
915impl<R> ServerEventRawSafe<R>
916where
917 R: DeserializeOwned + 'static,
918{
919 /// Deserializes this instance into a `ServerEvent`.
920 #[allow(dead_code)]
921 fn deserialize(&self) -> ServerEvent<R> {
922 match self {
923 Self::Connect { client_id } => ServerEvent::Connect {
924 client_id: *client_id,
925 },
926 Self::Disconnect { client_id } => ServerEvent::Disconnect {
927 client_id: *client_id,
928 },
929 Self::Receive { client_id, data } => ServerEvent::Receive {
930 client_id: *client_id,
931 data: data.deserialize(),
932 },
933 Self::Stop => ServerEvent::Stop,
934 }
935 }
936}
937
938/// An asynchronous stream of server events.
939pub struct ServerEventStream<R>
940where
941 R: DeserializeOwned + 'static,
942{
943 /// The event receiver channel.
944 event_receiver: Receiver<ServerEventRaw>,
945 /// Phantom marker for `R`.
946 marker: PhantomData<fn() -> R>,
947}
948
949impl<R> ServerEventStream<R>
950where
951 R: DeserializeOwned + 'static,
952{
953 /// Consumes and returns the next value in the stream.
954 ///
955 /// # Errors
956 ///
957 /// This will return an error if the stream is closed, or if there was an
958 /// error while deserializing data received.
959 pub async fn next(&mut self) -> Result<ServerEvent<R>> {
960 match self.event_receiver.recv().await {
961 Some(serialized_event) => serialized_event.deserialize(),
962 None => Err(Error::ConnectionClosed),
963 }
964 }
965
966 /// Identical to `next`, but doesn't deserialize the event. It does,
967 /// however, validate that the event can be deserialized without error.
968 async fn next_raw(&mut self) -> Result<ServerEventRawSafe<R>> {
969 match self.event_receiver.recv().await {
970 Some(serialized_event) => serialized_event.try_into(),
971 None => Err(Error::ConnectionClosed),
972 }
973 }
974}
975
976/// A handle to the server.
977pub struct ServerHandle<S>
978where
979 S: Serialize + 'static,
980{
981 /// The channel through which commands can be sent to the background task.
982 server_command_sender: CommandChannelSender<ServerCommand, ServerCommandReturn>,
983 /// The handle to the background task.
984 server_task_handle: JoinHandle<Result<()>>,
985 /// Phantom marker for `S`.
986 marker: PhantomData<fn() -> S>,
987}
988
989impl<S> ServerHandle<S>
990where
991 S: Serialize + 'static,
992{
993 /// Stop the server, disconnect all clients, and shut down all network
994 /// interfaces.
995 ///
996 /// Returns a result of the error variant if an error occurred while
997 /// disconnecting clients.
998 ///
999 /// ```no_run
1000 /// use rustdtp::prelude::*;
1001 ///
1002 /// #[tokio::main]
1003 /// async fn main() {
1004 /// // Create the server
1005 /// let (mut server, mut server_events) = Server::builder()
1006 /// .sending::<()>()
1007 /// .receiving::<String>()
1008 /// .with_event_channel()
1009 /// .start(("0.0.0.0", 0))
1010 /// .await
1011 /// .unwrap();
1012 ///
1013 /// // Wait for events until a client requests the server be stopped
1014 /// while let Ok(event) = server_events.next().await {
1015 /// match event {
1016 /// // Stop the server when a client requests it be stopped
1017 /// ServerEvent::Receive { client_id, data } => {
1018 /// if data.as_str() == "Stop the server!" {
1019 /// println!("Server stop requested");
1020 /// server.stop().await.unwrap();
1021 /// break;
1022 /// }
1023 /// }
1024 /// _ => {} // Do nothing for other events
1025 /// }
1026 /// }
1027 ///
1028 /// // The last event should be a stop event
1029 /// assert!(matches!(server_events.next().await.unwrap(), ServerEvent::Stop));
1030 /// }
1031 /// ```
1032 ///
1033 /// # Errors
1034 ///
1035 /// This will return an error if the server socket has already closed, or if
1036 /// the underlying server loop returned an error.
1037 #[allow(clippy::missing_panics_doc)]
1038 pub async fn stop(mut self) -> Result<()> {
1039 let value = self
1040 .server_command_sender
1041 .send_command(ServerCommand::Stop)
1042 .await?;
1043 // `unwrap` is allowed, as an error is returned only when the underlying
1044 // task panics, which it never should
1045 self.server_task_handle.await.unwrap()?;
1046 unwrap_enum!(value, ServerCommandReturn::Stop)
1047 }
1048
1049 /// Send data to a client.
1050 ///
1051 /// - `client_id`: the ID of the client to send the data to.
1052 /// - `data`: the data to send.
1053 ///
1054 /// Returns a result of the error variant if an error occurred while
1055 /// sending.
1056 ///
1057 /// ```no_run
1058 /// use rustdtp::prelude::*;
1059 ///
1060 /// #[tokio::main]
1061 /// async fn main() {
1062 /// // Create the server
1063 /// let (mut server, mut server_events) = Server::builder()
1064 /// .sending::<String>()
1065 /// .receiving::<()>()
1066 /// .with_event_channel()
1067 /// .start(("0.0.0.0", 0))
1068 /// .await
1069 /// .unwrap();
1070 ///
1071 /// // Iterate over events
1072 /// while let Ok(event) = server_events.next().await {
1073 /// match event {
1074 /// // When a client connects, send a greeting
1075 /// ServerEvent::Connect { client_id } => {
1076 /// server.send(client_id, format!("Hello, client {}!", client_id)).await.unwrap();
1077 /// }
1078 /// _ => {} // Do nothing for other events
1079 /// }
1080 /// }
1081 /// }
1082 /// ```
1083 ///
1084 /// # Errors
1085 ///
1086 /// This will return an error if the server socket has closed, or if data
1087 /// serialization fails.
1088 #[allow(clippy::future_not_send)]
1089 pub async fn send(&mut self, client_id: usize, data: S) -> Result<()> {
1090 let data_serialized = serde_json::to_vec(&data)?;
1091 let value = self
1092 .server_command_sender
1093 .send_command(ServerCommand::Send {
1094 client_id,
1095 data: data_serialized,
1096 })
1097 .await?;
1098 unwrap_enum!(value, ServerCommandReturn::Send)
1099 }
1100
1101 /// Send data to all clients.
1102 ///
1103 /// - `data`: the data to send.
1104 ///
1105 /// Returns a result of the error variant if an error occurred while
1106 /// sending.
1107 ///
1108 /// ```no_run
1109 /// use rustdtp::prelude::*;
1110 ///
1111 /// #[tokio::main]
1112 /// async fn main() {
1113 /// // Create the server
1114 /// let (mut server, mut server_events) = Server::builder()
1115 /// .sending::<String>()
1116 /// .receiving::<()>()
1117 /// .with_event_channel()
1118 /// .start(("0.0.0.0", 0))
1119 /// .await
1120 /// .unwrap();
1121 ///
1122 /// // Iterate over events
1123 /// while let Ok(event) = server_events.next().await {
1124 /// match event {
1125 /// // When a client connects, notify all clients
1126 /// ServerEvent::Connect { client_id } => {
1127 /// server.send_all(format!("A new client with ID {} has joined!", client_id)).await.unwrap();
1128 /// }
1129 /// _ => {} // Do nothing for other events
1130 /// }
1131 /// }
1132 /// }
1133 /// ```
1134 ///
1135 /// # Errors
1136 ///
1137 /// This will return an error if the server socket has closed, or if data
1138 /// serialization fails.
1139 #[allow(clippy::future_not_send)]
1140 pub async fn send_all(&mut self, data: S) -> Result<()> {
1141 let data_serialized = serde_json::to_vec(&data)?;
1142 let value = self
1143 .server_command_sender
1144 .send_command(ServerCommand::SendAll {
1145 data: data_serialized,
1146 })
1147 .await?;
1148 unwrap_enum!(value, ServerCommandReturn::SendAll)
1149 }
1150
1151 /// Get the address the server is listening on.
1152 ///
1153 /// Returns a result containing the address the server is listening on, or
1154 /// the error variant if an error occurred.
1155 ///
1156 /// ```no_run
1157 /// use rustdtp::prelude::*;
1158 ///
1159 /// #[tokio::main]
1160 /// async fn main() {
1161 /// // Create the server
1162 /// let (mut server, mut server_events) = Server::builder()
1163 /// .sending::<()>()
1164 /// .receiving::<()>()
1165 /// .with_event_channel()
1166 /// .start(("0.0.0.0", 0))
1167 /// .await
1168 /// .unwrap();
1169 ///
1170 /// // Get the server address
1171 /// let addr = server.get_addr().await.unwrap();
1172 /// println!("Server listening on {}", addr);
1173 /// }
1174 /// ```
1175 ///
1176 /// # Errors
1177 ///
1178 /// This will return an error if the server socket has closed.
1179 pub async fn get_addr(&mut self) -> Result<SocketAddr> {
1180 let value = self
1181 .server_command_sender
1182 .send_command(ServerCommand::GetAddr)
1183 .await?;
1184 unwrap_enum!(value, ServerCommandReturn::GetAddr)
1185 }
1186
1187 /// Get the address of a connected client.
1188 ///
1189 /// - `client_id`: the ID of the client.
1190 ///
1191 /// Returns a result containing the address of the client, or the error
1192 /// variant if the client ID is invalid.
1193 ///
1194 /// ```no_run
1195 /// use rustdtp::prelude::*;
1196 ///
1197 /// #[tokio::main]
1198 /// async fn main() {
1199 /// // Create the server
1200 /// let (mut server, mut server_events) = Server::builder()
1201 /// .sending::<()>()
1202 /// .receiving::<()>()
1203 /// .with_event_channel()
1204 /// .start(("0.0.0.0", 0))
1205 /// .await
1206 /// .unwrap();
1207 ///
1208 /// // Iterate over events
1209 /// while let Ok(event) = server_events.next().await {
1210 /// match event {
1211 /// // When a client connects, get their address
1212 /// ServerEvent::Connect { client_id } => {
1213 /// let addr = server.get_client_addr(client_id).await.unwrap();
1214 /// println!("Client with ID {} connected from {}", client_id, addr);
1215 /// }
1216 /// _ => {} // Do nothing for other events
1217 /// }
1218 /// }
1219 /// }
1220 /// ```
1221 ///
1222 /// # Errors
1223 ///
1224 /// This will return an error if the server socket has closed, or if the
1225 /// client ID is invalid.
1226 pub async fn get_client_addr(&mut self, client_id: usize) -> Result<SocketAddr> {
1227 let value = self
1228 .server_command_sender
1229 .send_command(ServerCommand::GetClientAddr { client_id })
1230 .await?;
1231 unwrap_enum!(value, ServerCommandReturn::GetClientAddr)
1232 }
1233
1234 /// Disconnect a client from the server.
1235 ///
1236 /// - `client_id`: the ID of the client.
1237 ///
1238 /// Returns a result of the error variant if an error occurred while
1239 /// disconnecting the client, or if the client ID is invalid.
1240 ///
1241 /// ```no_run
1242 /// use rustdtp::prelude::*;
1243 ///
1244 /// #[tokio::main]
1245 /// async fn main() {
1246 /// // Create the server
1247 /// let (mut server, mut server_events) = Server::builder()
1248 /// .sending::<String>()
1249 /// .receiving::<i32>()
1250 /// .with_event_channel()
1251 /// .start(("0.0.0.0", 0))
1252 /// .await
1253 /// .unwrap();
1254 ///
1255 /// // Iterate over events
1256 /// while let Ok(event) = server_events.next().await {
1257 /// match event {
1258 /// // Disconnect a client if they send an even number
1259 /// ServerEvent::Receive { client_id, data } => {
1260 /// if data % 2 == 0 {
1261 /// println!("Disconnecting client with ID {}", client_id);
1262 /// server.send(client_id, "Even numbers are not allowed".to_owned()).await.unwrap();
1263 /// server.remove_client(client_id).await.unwrap();
1264 /// }
1265 /// }
1266 /// _ => {} // Do nothing for other events
1267 /// }
1268 /// }
1269 ///
1270 /// // The last event should be a stop event
1271 /// assert!(matches!(server_events.next().await.unwrap(), ServerEvent::Stop));
1272 /// }
1273 /// ```
1274 ///
1275 /// # Errors
1276 ///
1277 /// This will return an error if the server socket has closed, or if the
1278 /// client ID is invalid.
1279 pub async fn remove_client(&mut self, client_id: usize) -> Result<()> {
1280 let value = self
1281 .server_command_sender
1282 .send_command(ServerCommand::RemoveClient { client_id })
1283 .await?;
1284 unwrap_enum!(value, ServerCommandReturn::RemoveClient)
1285 }
1286}
1287
1288/// A socket server.
1289///
1290/// The server takes two generic parameters:
1291///
1292/// - `S`: the type of data that will be **sent** to clients.
1293/// - `R`: the type of data that will be **received** from clients.
1294///
1295/// Both types must be serializable in order to be sent through the socket. When
1296/// creating clients, the types should be swapped, since the server's send type will be the client's receive type and vice versa.
1297///
1298/// ```no_run
1299/// use rustdtp::prelude::*;
1300///
1301/// #[tokio::main]
1302/// async fn main() {
1303/// // Create a server that receives strings and returns the length of each string
1304/// let (mut server, mut server_events) = Server::builder()
1305/// .sending::<usize>()
1306/// .receiving::<String>()
1307/// .with_event_channel()
1308/// .start(("0.0.0.0", 0))
1309/// .await
1310/// .unwrap();
1311///
1312/// // Iterate over events
1313/// while let Ok(event) = server_events.next().await {
1314/// match event {
1315/// ServerEvent::Connect { client_id } => {
1316/// println!("Client with ID {} connected", client_id);
1317/// }
1318/// ServerEvent::Disconnect { client_id } => {
1319/// println!("Client with ID {} disconnected", client_id);
1320/// }
1321/// ServerEvent::Receive { client_id, data } => {
1322/// // Send back the length of the string
1323/// server.send(client_id, data.len()).await.unwrap();
1324/// }
1325/// ServerEvent::Stop => {
1326/// // No more events will be sent, and the loop will end
1327/// println!("Server closed");
1328/// }
1329/// }
1330/// }
1331/// }
1332/// ```
1333pub struct Server<S, R>
1334where
1335 S: Serialize + 'static,
1336 R: DeserializeOwned + 'static,
1337{
1338 /// Phantom marker for `S` and `R`.
1339 marker: PhantomData<fn() -> (S, R)>,
1340}
1341
1342impl Server<(), ()> {
1343 /// Constructs a server builder. Use this for a clearer, more explicit,
1344 /// and more featureful server configuration. See [`ServerBuilder`] for
1345 /// more information.
1346 pub const fn builder(
1347 ) -> ServerBuilder<ServerSendingUnknown, ServerReceivingUnknown, ServerEventReportingUnknown>
1348 {
1349 ServerBuilder::new()
1350 }
1351}
1352
1353impl<S, R> Server<S, R>
1354where
1355 S: Serialize + 'static,
1356 R: DeserializeOwned + 'static,
1357{
1358 /// Start a socket server.
1359 ///
1360 /// - `addr`: the address for the server to listen on.
1361 ///
1362 /// Returns a result containing a handle to the server and a channel from
1363 /// which to receive server events, or the error variant if an error
1364 /// occurred while starting the server.
1365 ///
1366 /// ```no_run
1367 /// use rustdtp::prelude::*;
1368 ///
1369 /// #[tokio::main]
1370 /// async fn main() {
1371 /// let (mut server, mut server_events) = Server::builder()
1372 /// .sending::<()>()
1373 /// .receiving::<()>()
1374 /// .with_event_channel()
1375 /// .start(("0.0.0.0", 0))
1376 /// .await
1377 /// .unwrap();
1378 /// }
1379 /// ```
1380 ///
1381 /// Neither the server handle nor the event receiver should be dropped until
1382 /// the server has been stopped. Prematurely dropping either one can cause
1383 /// unintended behavior.
1384 ///
1385 /// # Errors
1386 ///
1387 /// This will return an error if a TCP listener cannot be bound to the
1388 /// provided address.
1389 #[allow(clippy::future_not_send)]
1390 pub async fn start<A>(addr: A) -> Result<(ServerHandle<S>, ServerEventStream<R>)>
1391 where
1392 A: ToSocketAddrs,
1393 {
1394 // Server TCP listener
1395 let listener = TcpListener::bind(addr).await?;
1396 // Channels for sending commands from the server handle to the background server task
1397 let (server_command_sender, server_command_receiver) = command_channel();
1398 // Channels for sending event notifications from the background server task
1399 let (server_event_sender, server_event_receiver) = channel(CHANNEL_BUFFER_SIZE);
1400
1401 // Start the background server task, saving the join handle for when the server is stopped
1402 let server_task_handle = tokio::spawn(server_handler(
1403 listener,
1404 server_event_sender,
1405 server_command_receiver,
1406 ));
1407
1408 // Create a handle for the server
1409 let server_handle = ServerHandle {
1410 server_command_sender,
1411 server_task_handle,
1412 marker: PhantomData,
1413 };
1414
1415 // Create an event stream for the server
1416 let server_event_stream = ServerEventStream {
1417 event_receiver: server_event_receiver,
1418 marker: PhantomData,
1419 };
1420
1421 Ok((server_handle, server_event_stream))
1422 }
1423}
1424
1425/// The server client loop. Handles received data and commands.
1426#[allow(clippy::too_many_lines)]
1427async fn server_client_loop(
1428 client_id: usize,
1429 mut socket: TcpStream,
1430 server_client_event_sender: Sender<ServerEventRaw>,
1431 mut client_command_receiver: CommandChannelReceiver<
1432 ServerClientCommand,
1433 ServerClientCommandReturn,
1434 >,
1435) -> Result<()> {
1436 // Generate RSA keys
1437 let (rsa_pub, rsa_priv) = rsa_keys().await?;
1438 // Convert the RSA public key into a string...
1439 let rsa_pub_str = rsa_pub
1440 .to_public_key_pem(rsa::pkcs1::LineEnding::LF)
1441 .map_err(|_| Error::InvalidRsaKeyEncoding)?;
1442 // ...and then into bytes
1443 let rsa_pub_bytes = rsa_pub_str.as_bytes();
1444 // Create the buffer containing the RSA public key and its size
1445 let mut rsa_pub_buffer = encode_message_size(rsa_pub_bytes.len()).to_vec();
1446 // Extend the buffer with the RSA public key bytes
1447 rsa_pub_buffer.extend(rsa_pub_bytes);
1448 // Send the RSA public key to the client
1449 socket.write_all(&rsa_pub_buffer).await?;
1450 // Flush the stream
1451 socket.flush().await?;
1452
1453 // Buffer in which to receive the size portion of the AES key
1454 let mut aes_key_size_buffer = [0; LEN_SIZE];
1455 // Read the AES key from the client
1456 handshake_timeout! {
1457 socket.read_exact(&mut aes_key_size_buffer[..])
1458 }??;
1459
1460 // Decode the size portion of the AES key
1461 let aes_key_size = decode_message_size(&aes_key_size_buffer);
1462 // Initialize the buffer for the AES key
1463 let mut aes_key_buffer = vec![0; aes_key_size];
1464
1465 // Read the AES key portion from the client socket, returning an error if
1466 // the socket could not be read
1467 data_read_timeout! {
1468 socket.read_exact(&mut aes_key_buffer[..])
1469 }??;
1470
1471 // Decrypt the AES key
1472 let aes_key_decrypted = rsa_decrypt(rsa_priv, aes_key_buffer.into()).await?;
1473
1474 // Assert that the AES key is the correct size
1475 let aes_key: [u8; AES_KEY_SIZE] = aes_key_decrypted
1476 .try_into()
1477 .map_err(|_| Error::InvalidAesKeySize)?;
1478
1479 // Buffer in which to receive the size portion of a message
1480 let mut size_buffer = [0; LEN_SIZE];
1481
1482 // Client loop
1483 loop {
1484 // Await messages from the client
1485 // and commands from the background server task
1486 tokio::select! {
1487 // Read the size portion from the client socket
1488 read_value = socket.read(&mut size_buffer[..]) => {
1489 // Return an error if the socket could not be read
1490 let n_size = read_value?;
1491
1492 // If there were no bytes read, or if there were fewer bytes
1493 // read than there should have been, close the socket
1494 if n_size != LEN_SIZE {
1495 socket.shutdown().await?;
1496 break;
1497 };
1498
1499 // Decode the size portion of the message
1500 let encrypted_data_size = decode_message_size(&size_buffer);
1501 // Initialize the buffer for the data portion of the message
1502 let mut encrypted_data_buffer = vec![0; encrypted_data_size];
1503
1504 // Read the data portion from the client socket, returning an
1505 // error if the socket could not be read
1506 let n_data = data_read_timeout! {
1507 socket.read_exact(&mut encrypted_data_buffer[..])
1508 }??;
1509
1510 // If there were no bytes read, or if there were fewer bytes
1511 // read than there should have been, close the socket
1512 if n_data != encrypted_data_size {
1513 socket.shutdown().await?;
1514 break;
1515 }
1516
1517 // Decrypt the data
1518 let data_serialized = aes_decrypt(aes_key, encrypted_data_buffer.into()).await?;
1519
1520 // Send an event to note that a piece of data has been received from
1521 // a client
1522 if let Err(_e) = server_client_event_sender.send(ServerEventRaw::Receive { client_id, data: data_serialized }).await {
1523 // Sending failed, disconnect the client
1524 socket.shutdown().await?;
1525 break;
1526 }
1527 }
1528 // Process a command sent to the client
1529 client_command_value = client_command_receiver.recv_command() => {
1530 // Handle the command, or lack thereof if the channel is closed
1531 match client_command_value {
1532 Ok(client_command) => {
1533 // Process the command
1534 match client_command {
1535 ServerClientCommand::Send { data } => {
1536 let value = 'val: {
1537 // Encrypt the serialized data
1538 let encrypted_data_buffer = break_on_err!(aes_encrypt(aes_key, data).await, 'val);
1539 // Encode the message size to a buffer
1540 let size_buffer = encode_message_size(encrypted_data_buffer.len());
1541
1542 // Initialize the message buffer
1543 let mut buffer = vec![];
1544 // Extend the buffer to contain the payload
1545 // size
1546 buffer.extend_from_slice(&size_buffer);
1547 // Extend the buffer to contain the payload
1548 // data
1549 buffer.extend(&encrypted_data_buffer);
1550
1551 // Write the data to the client socket
1552 break_on_err!(socket.write_all(&buffer).await, 'val);
1553 // Flush the stream
1554 break_on_err!(socket.flush().await, 'val);
1555
1556 Ok(())
1557 };
1558
1559 let error_occurred = value.is_err();
1560
1561 // Return the status of the send operation
1562 if let Err(_e) = client_command_receiver.command_return(ServerClientCommandReturn::Send(value)).await {
1563 // Channel is closed, disconnect the client
1564 socket.shutdown().await?;
1565 break;
1566 }
1567
1568 // If the send failed, disconnect the client
1569 if error_occurred {
1570 socket.shutdown().await?;
1571 break;
1572 }
1573 },
1574 ServerClientCommand::GetAddr => {
1575 // Get the client socket's address
1576 let addr = socket.peer_addr();
1577
1578 // Return the address
1579 if let Err(_e) = client_command_receiver.command_return(ServerClientCommandReturn::GetAddr(addr.map_err(Into::into))).await {
1580 // Channel is closed, disconnect the client
1581 socket.shutdown().await?;
1582 break;
1583 }
1584 },
1585 ServerClientCommand::Remove => {
1586 // Disconnect the client
1587 let value = socket.shutdown().await;
1588
1589 // Return the status of the remove operation,
1590 // ignoring failures, since a failure indicates
1591 // that the client has probably already
1592 // disconnected
1593 _ = client_command_receiver.command_return(ServerClientCommandReturn::Remove(value.map_err(Into::into))).await;
1594
1595 // Break the client loop
1596 break;
1597 },
1598 }
1599 },
1600 Err(_e) => {
1601 // Channel is closed, disconnect the client
1602 socket.shutdown().await?;
1603 break;
1604 },
1605 }
1606 }
1607 }
1608 }
1609
1610 Ok(())
1611}
1612
1613/// Starts a server client loop in the background.
1614fn server_client_handler(
1615 client_id: usize,
1616 socket: TcpStream,
1617 server_client_event_sender: Sender<ServerEventRaw>,
1618 client_cleanup_sender: Sender<usize>,
1619) -> (
1620 CommandChannelSender<ServerClientCommand, ServerClientCommandReturn>,
1621 JoinHandle<Result<()>>,
1622) {
1623 // Channels for sending commands from the background server task to a background client task
1624 let (client_command_sender, client_command_receiver) = command_channel();
1625
1626 // Start a background client task, saving the join handle for when the
1627 // server is stopped
1628 let client_task_handle = tokio::spawn(async move {
1629 let res = server_client_loop(
1630 client_id,
1631 socket,
1632 server_client_event_sender,
1633 client_command_receiver,
1634 )
1635 .await;
1636
1637 // Tell the server to clean up after the client, ignoring failures,
1638 // since a failure indicates that the server has probably closed
1639 _ = client_cleanup_sender.send(client_id).await;
1640
1641 res
1642 });
1643
1644 (client_command_sender, client_task_handle)
1645}
1646
1647/// The server loop. Handles incoming connections and commands.
1648#[allow(clippy::too_many_lines)]
1649async fn server_loop(
1650 listener: TcpListener,
1651 server_event_sender: Sender<ServerEventRaw>,
1652 mut server_command_receiver: CommandChannelReceiver<ServerCommand, ServerCommandReturn>,
1653 client_command_senders: &mut HashMap<
1654 usize,
1655 CommandChannelSender<ServerClientCommand, ServerClientCommandReturn>,
1656 >,
1657 client_join_handles: &mut HashMap<usize, JoinHandle<Result<()>>>,
1658) -> Result<()> {
1659 // ID assigned to the next client
1660 let mut next_client_id = 0usize;
1661 // Channel for indicating that a client needs to be cleaned up after
1662 let (server_client_cleanup_sender, mut server_client_cleanup_receiver) =
1663 channel::<usize>(CHANNEL_BUFFER_SIZE);
1664
1665 // Server loop
1666 loop {
1667 // Await new clients connecting,
1668 // commands from the server handle,
1669 // and notifications of clients disconnecting
1670 tokio::select! {
1671 // Accept a connecting client
1672 accept_value = listener.accept() => {
1673 // Get the client socket, exiting if an error occurs
1674 let (socket, _) = accept_value?;
1675 // New client ID
1676 let client_id = next_client_id;
1677 // Increment next client ID
1678 next_client_id += 1;
1679 // Clone the event sender so the background client tasks can
1680 // send events
1681 let server_client_event_sender = server_event_sender.clone();
1682 // Clone the client cleanup sender to the background client
1683 // tasks can be cleaned up properly
1684 let client_cleanup_sender = server_client_cleanup_sender.clone();
1685
1686 // Handle the new connection
1687 let (client_command_sender, client_task_handle) = server_client_handler(client_id, socket, server_client_event_sender, client_cleanup_sender);
1688 // Keep track of client command senders
1689 client_command_senders.insert(client_id, client_command_sender);
1690 // Keep track of client task handles
1691 client_join_handles.insert(client_id, client_task_handle);
1692
1693 // Send an event to note that a client has connected
1694 // successfully
1695 if let Err(_e) = server_event_sender
1696 .send(ServerEventRaw::Connect { client_id })
1697 .await
1698 {
1699 // Server is probably closed
1700 break;
1701 }
1702 },
1703 // Process a command from the server handle
1704 command_value = server_command_receiver.recv_command() => {
1705 // Handle the command, or lack thereof if the channel is closed
1706 match command_value {
1707 Ok(command) => {
1708 match command {
1709 ServerCommand::Stop => {
1710 // If a command fails to send, the server has
1711 // already closed, and the error can be ignored.
1712 // It should be noted that this is not where the
1713 // stop method actually returns its `Result`.
1714 // This immediately returns with an `Ok` status.
1715 // The real return value is the `Result`
1716 // returned from the server task join handle.
1717 _ = server_command_receiver.command_return(ServerCommandReturn::Stop(Ok(()))).await;
1718
1719 // Break the server loop, the clients will be
1720 // disconnected before the task ends
1721 break;
1722 },
1723 ServerCommand::Send { client_id, data } => {
1724 let value = match client_command_senders.get_mut(&client_id) {
1725 Some(client_command_sender) => {
1726 // Turn `Vec<u8>` into `Arc<[u8]>`,
1727 // making it more easily shareable
1728 let shareable_data = Arc::<[u8]>::from(data);
1729
1730 match client_command_sender.send_command(ServerClientCommand::Send { data: shareable_data }).await {
1731 Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::Send),
1732 Err(_e) => {
1733 // The channel is closed, and
1734 // the client has probably been
1735 // disconnected, so the error
1736 // can be ignored
1737 Ok(())
1738 },
1739 }
1740 },
1741 None => Err(Error::InvalidClientId(client_id)),
1742 };
1743
1744 // If a command fails to send, the client has probably disconnected,
1745 // and the error can be ignored
1746 _ = server_command_receiver.command_return(ServerCommandReturn::Send(value)).await;
1747 },
1748 ServerCommand::SendAll { data } => {
1749 let value = {
1750 // Turn `Vec<u8>` into `Arc<[u8]>`, making
1751 // it more easily shareable
1752 let shareable_data = Arc::<[u8]>::from(data);
1753
1754 let send_futures = client_command_senders.iter_mut().map(|(_client_id, client_command_sender)| async {
1755 match client_command_sender.send_command(ServerClientCommand::Send { data: Arc::clone(&shareable_data) }).await {
1756 Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::Send),
1757 Err(_e) => {
1758 // The channel is closed, and
1759 // the client has probably been
1760 // disconnected, so the error
1761 // can be ignored
1762 Ok(())
1763 }
1764 }
1765 });
1766
1767 let resolved = futures::future::join_all(send_futures).await;
1768 resolved.into_iter().collect::<Result<Vec<_>>>().map(|_| ())
1769 };
1770
1771 // If a command fails to send, the client has
1772 // probably disconnected, and the error can be
1773 // ignored
1774 _ = server_command_receiver.command_return(ServerCommandReturn::SendAll(value)).await;
1775 },
1776 ServerCommand::GetAddr => {
1777 // Get the server listener's address
1778 let addr = listener.local_addr();
1779
1780 // If a command fails to send, the client has
1781 // probably disconnected, and the error can be
1782 // ignored
1783 _ = server_command_receiver.command_return(ServerCommandReturn::GetAddr(addr.map_err(Into::into))).await;
1784 },
1785 ServerCommand::GetClientAddr { client_id } => {
1786 let value = match client_command_senders.get_mut(&client_id) {
1787 Some(client_command_sender) => match client_command_sender.send_command(ServerClientCommand::GetAddr).await {
1788 Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::GetAddr),
1789 Err(_e) => {
1790 // The channel is closed, and the
1791 // client has probably been
1792 // disconnected, so the error can be
1793 // treated as an invalid client
1794 // error
1795 Err(Error::InvalidClientId(client_id))
1796 },
1797 },
1798 None => Err(Error::InvalidClientId(client_id)),
1799 };
1800
1801 // If a command fails to send, the client has
1802 // probably disconnected, and the error can be
1803 // ignored
1804 _ = server_command_receiver.command_return(ServerCommandReturn::GetClientAddr(value)).await;
1805 },
1806 ServerCommand::RemoveClient { client_id } => {
1807 let value = match client_command_senders.get_mut(&client_id) {
1808 Some(client_command_sender) => match client_command_sender.send_command(ServerClientCommand::Remove).await {
1809 Ok(return_value) => unwrap_enum!(return_value, ServerClientCommandReturn::Remove),
1810 Err(_e) => {
1811 // The channel is closed, and the
1812 // client has probably been
1813 // disconnected, so the error can be
1814 // ignored
1815 Ok(())
1816 },
1817 },
1818 None => Err(Error::InvalidClientId(client_id)),
1819 };
1820
1821 // If a command fails to send, the client has
1822 // probably disconnected already, and the error
1823 // can be ignored
1824 _ = server_command_receiver.command_return(ServerCommandReturn::RemoveClient(value)).await;
1825 },
1826 }
1827 },
1828 Err(_e) => {
1829 // Server is probably closed, exit
1830 break;
1831 },
1832 }
1833 }
1834 // Clean up after a disconnecting client
1835 disconnecting_client_id = server_client_cleanup_receiver.recv() => {
1836 match disconnecting_client_id {
1837 Some(client_id) => {
1838 // Remove the client's command sender, which will be
1839 // dropped after this block ends
1840 client_command_senders.remove(&client_id);
1841
1842 // Remove the client's join handle
1843 if let Some(handle) = client_join_handles.remove(&client_id) {
1844 // Join the client's handle
1845 if let Err(e) = handle.await.unwrap() {
1846 if cfg!(test) {
1847 // If testing, fail
1848 Err(e)?;
1849 } else {
1850 // If not testing, ignore client handler
1851 // errors
1852 }
1853 }
1854 }
1855
1856 // Send an event to note that a client has disconnected
1857 if let Err(_e) = server_event_sender.send(ServerEventRaw::Disconnect { client_id }).await {
1858 // Server is probably closed, exit
1859 break;
1860 }
1861 },
1862 None => {
1863 // Server is probably closed, exit
1864 break;
1865 },
1866 }
1867 }
1868 }
1869 }
1870
1871 Ok(())
1872}
1873
1874/// Starts the server loop task in the background.
1875async fn server_handler(
1876 listener: TcpListener,
1877 server_event_sender: Sender<ServerEventRaw>,
1878 server_command_receiver: CommandChannelReceiver<ServerCommand, ServerCommandReturn>,
1879) -> Result<()> {
1880 // Collection of channels for sending commands from the background server
1881 // task to a background client task
1882 let mut client_command_senders: HashMap<
1883 usize,
1884 CommandChannelSender<ServerClientCommand, ServerClientCommandReturn>,
1885 > = HashMap::new();
1886 // Background client task join handles
1887 let mut client_join_handles: HashMap<usize, JoinHandle<Result<()>>> = HashMap::new();
1888
1889 // Wrap server loop in a block to catch all exit scenarios
1890 let server_exit = server_loop(
1891 listener,
1892 server_event_sender.clone(),
1893 server_command_receiver,
1894 &mut client_command_senders,
1895 &mut client_join_handles,
1896 )
1897 .await;
1898
1899 // Send a remove command to all clients
1900 futures::future::join_all(client_command_senders.into_values().map(
1901 |mut client_command_sender| async move {
1902 // If a command fails to send, the client has probably disconnected
1903 // already, and the error can be ignored
1904 _ = client_command_sender
1905 .send_command(ServerClientCommand::Remove)
1906 .await;
1907 },
1908 ))
1909 .await;
1910
1911 // Join all background client tasks before exiting
1912 futures::future::join_all(client_join_handles.into_values().map(|handle| async move {
1913 if let Err(e) = handle.await.unwrap() {
1914 if cfg!(test) {
1915 // If testing, fail
1916 Err(e)?;
1917 } else {
1918 // If not testing, ignore client handler errors
1919 }
1920 }
1921
1922 Ok(())
1923 }))
1924 .await
1925 .into_iter()
1926 .collect::<Result<Vec<_>>>()?;
1927
1928 // Send a stop event, ignoring send errors
1929 _ = server_event_sender.send(ServerEventRaw::Stop).await;
1930
1931 // Return server loop result
1932 server_exit
1933}