socketioxide_mongodb/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![warn(
3    clippy::all,
4    clippy::todo,
5    clippy::empty_enum,
6    clippy::mem_forget,
7    clippy::unused_self,
8    clippy::filter_map_next,
9    clippy::needless_continue,
10    clippy::needless_borrow,
11    clippy::match_wildcard_for_single_variants,
12    clippy::if_let_mutex,
13    clippy::await_holding_lock,
14    clippy::match_on_vec_items,
15    clippy::imprecise_flops,
16    clippy::suboptimal_flops,
17    clippy::lossy_float_literal,
18    clippy::rest_pat_in_fully_bound_structs,
19    clippy::fn_params_excessive_bools,
20    clippy::exit,
21    clippy::inefficient_to_string,
22    clippy::linkedlist,
23    clippy::macro_use_imports,
24    clippy::option_option,
25    clippy::verbose_file_reads,
26    clippy::unnested_or_patterns,
27    rust_2018_idioms,
28    future_incompatible,
29    nonstandard_style,
30    missing_docs
31)]
32//! # A mongodb adapter implementation for the socketioxide crate.
33//! The adapter is used to communicate with other nodes of the same application.
34//! This allows to broadcast messages to sockets connected on other servers,
35//! to get the list of rooms, to add or remove sockets from rooms, etc.
36//!
37//! To achieve this, the adapter uses [change streams](https://www.mongodb.com/docs/manual/changeStreams/)
38//! on a collection. The message expiration process is either handled with TTL-indexes or a capped collection.
39//! If you change the message expiration strategy, make sure to first drop the collection.
40//! MongoDB doesn't support switching from capped to TTL-indexes on an existing collection.
41//!
42//! The [`Driver`] abstraction allows the use of any mongodb client.
43//! One implementation is provided:
44//! * [`MongoDbDriver`](crate::drivers::mongodb::MongoDbDriver) for the [`mongodb`] crate.
45//!
46//! You can also implement your own driver by implementing the [`Driver`] trait.
47//!
48//! <div class="warning">
49//!     The provided driver implementation is using change streams.
50//!     They are only available on replica sets and sharded clusters.
51//!     Make sure your mongodb server/cluster is configured accordingly.
52//! </div>
53//!
54//! <div class="warning">
55//!     Socketioxide-mongodb is not compatible with <code>@socketio/mongodb-adapter</code>
56//!     and <code>@socketio/mongodb-emitter</code>. They use completely different protocols and
57//!     cannot be used together. Do not mix socket.io JS servers with socketioxide rust servers.
58//! </div>
59//!
60//! ## Example with the default mongodb driver
61//! ```rust
62//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
63//! # use socketioxide_mongodb::{MongoDbAdapterCtr, MongoDbAdapter};
64//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
65//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
66//!     socket.join("room1");
67//!     socket.on("event", on_event);
68//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
69//! }
70//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
71//!
72//! const URI: &str = "mongodb://127.0.0.1:27017/?replicaSet=rs0&directConnection=true";
73//! let client = mongodb::Client::with_uri_str(URI).await?;
74//! let adapter = MongoDbAdapterCtr::new_with_mongodb(client.database("test")).await?;
75//! let (layer, io) = SocketIo::builder()
76//!     .with_adapter::<MongoDbAdapter<_>>(adapter)
77//!     .build_layer();
78//! Ok(())
79//! # }
80//! ```
81//!
82//! Check the [`chat example`](https://github.com/Totodore/socketioxide/tree/main/examples/chat)
83//! for more complete examples.
84//!
85//! ## How does it work?
86//!
87//! The [`MongoDbAdapterCtr`] is a constructor for the [`MongoDbAdapter`] which is an implementation of
88//! the [`Adapter`](https://docs.rs/socketioxide/latest/socketioxide/adapter/trait.Adapter.html) trait.
89//! The constructor takes a [`mongodb::Database`] as an argument and will configure a collection
90//! according to the chosen message expiration strategy (TTL indexes or capped collection).
91//!
92//! Then, for each namespace, an adapter is created and it takes a corresponding [`CoreLocalAdapter`].
93//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter`
94//! is simply a wrapper around this [`CoreLocalAdapter`].
95//!
96//! Once it is created the adapter is initialized with the [`MongoDbAdapter::init`] method.
97//! It will listen to changes on the event collection and emit heartbeats,
98//! messages are composed of a header (in bson) and a binary payload encoded with msgpack.
99//! Headers are used to filter and route messages to server/namespaces/event handlers.
100//!
101//! All messages are encoded with msgpack.
102//!
103//! There are 7 types of requests:
104//! * Broadcast a packet to all the matching sockets.
105//! * Broadcast a packet to all the matching sockets and wait for a stream of acks.
106//! * Disconnect matching sockets.
107//! * Get all the rooms.
108//! * Add matching sockets to rooms.
109//! * Remove matching sockets to rooms.
110//! * Fetch all the remote sockets matching the options.
111//! * Heartbeat
112//! * Initial heartbeat. When receiving a initial heartbeat all other servers reply a heartbeat immediately.
113//!
114//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request,
115//! and then send the acks as they are received (more details in [`MongoDbAdapter::broadcast_with_ack`] fn).
116//!
117//! On the other side, each time an action has to be performed on the local server, the adapter will
118//! first broadcast a request to all the servers and then perform the action locally.
119
120use std::{
121    borrow::Cow,
122    collections::HashMap,
123    fmt, future,
124    pin::Pin,
125    sync::{Arc, Mutex},
126    task::{Context, Poll},
127    time::{Duration, Instant},
128};
129
130use futures_core::{Stream, future::Future};
131use futures_util::StreamExt;
132use serde::{Serialize, de::DeserializeOwned};
133use socketioxide_core::adapter::remote_packet::{
134    RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
135};
136use socketioxide_core::{
137    Sid, Uid,
138    adapter::errors::{AdapterError, BroadcastError},
139    adapter::{
140        BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
141        RoomParam, SocketEmitter, Spawnable,
142    },
143    packet::Packet,
144};
145use stream::{AckStream, ChanStream, DropStream};
146use tokio::sync::mpsc;
147
148use drivers::{Driver, Item, ItemHeader};
149
150/// Drivers are an abstraction over the pub/sub backend used by the adapter.
151/// You can use the provided implementation or implement your own.
152pub mod drivers;
153
154mod stream;
155
156/// The configuration of the [`MongoDbAdapter`].
157#[derive(Debug, Clone)]
158pub struct MongoDbAdapterConfig {
159    /// The heartbeat timeout duration. If a remote node does not respond within this duration,
160    /// it will be considered disconnected. Default is 60 seconds.
161    pub hb_timeout: Duration,
162    /// The heartbeat interval duration. The current node will broadcast a heartbeat to the
163    /// remote nodes at this interval. Default is 10 seconds.
164    pub hb_interval: Duration,
165    /// The request timeout. When expecting a response from remote nodes, if they do not respond within
166    /// this duration, the request will be considered failed. Default is 5 seconds.
167    pub request_timeout: Duration,
168    /// The channel size used to receive ack responses. Default is 255.
169    ///
170    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
171    /// than you poll them with the returned stream, you might want to increase this value.
172    pub ack_response_buffer: usize,
173    /// The collection name used to store socket.io data. Default is "socket.io-adapter".
174    pub collection: Cow<'static, str>,
175    /// The [`MessageExpirationStrategy`] used to remove old documents.
176    /// Default is `Ttl(Duration::from_secs(60))`.
177    pub expiration_strategy: MessageExpirationStrategy,
178}
179
180/// The strategy used to remove old documents in the mongodb collection.
181/// The default mongodb driver supports both [TTL indexes](https://www.mongodb.com/docs/manual/core/index-ttl/)
182/// and [capped collections](https://www.mongodb.com/docs/manual/core/capped-collections/).
183///
184/// Prefer the [`MessageExpirationStrategy::TtlIndex`] strategy for better performance and usability.
185#[derive(Debug, Clone)]
186pub enum MessageExpirationStrategy {
187    /// Use a TTL index to expire documents after a certain duration.
188    #[cfg(feature = "ttl-index")]
189    TtlIndex(Duration),
190    /// Use a capped collection to limit the size in bytes of the collection. Older messages are removed.
191    ///
192    /// Be aware that if you send a message that is bigger than your capped collection's size,
193    /// it will be rejected and won't be broadcast.
194    CappedCollection(u64),
195}
196
197impl MongoDbAdapterConfig {
198    /// Create a new [`MongoDbAdapterConfig`] with default values.
199    pub fn new() -> Self {
200        Self {
201            hb_timeout: Duration::from_secs(60),
202            hb_interval: Duration::from_secs(10),
203            request_timeout: Duration::from_secs(5),
204            ack_response_buffer: 255,
205            collection: Cow::Borrowed("socket.io-adapter"),
206            #[cfg(feature = "ttl-index")]
207            expiration_strategy: MessageExpirationStrategy::TtlIndex(Duration::from_secs(60)),
208            #[cfg(not(feature = "ttl-index"))]
209            expiration_strategy: MessageExpirationStrategy::CappedCollection(1024 * 1024), // 1MB
210        }
211    }
212    /// The heartbeat timeout duration. If a remote node does not respond within this duration,
213    /// it will be considered disconnected. Default is 60 seconds.
214    pub fn with_hb_timeout(mut self, hb_timeout: Duration) -> Self {
215        self.hb_timeout = hb_timeout;
216        self
217    }
218    /// The heartbeat interval duration. The current node will broadcast a heartbeat to the
219    /// remote nodes at this interval. Default is 10 seconds.
220    pub fn with_hb_interval(mut self, hb_interval: Duration) -> Self {
221        self.hb_interval = hb_interval;
222        self
223    }
224    /// The request timeout. When expecting a response from remote nodes, if they do not respond within
225    /// this duration, the request will be considered failed. Default is 5 seconds.
226    pub fn with_request_timeout(mut self, request_timeout: Duration) -> Self {
227        self.request_timeout = request_timeout;
228        self
229    }
230    /// The channel size used to receive ack responses. Default is 255.
231    ///
232    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
233    /// than you poll them with the returned stream, you might want to increase this value.
234    pub fn with_ack_response_buffer(mut self, ack_response_buffer: usize) -> Self {
235        self.ack_response_buffer = ack_response_buffer;
236        self
237    }
238    /// The collection name used to store socket.io data. Default is "socket.io-adapter".
239    pub fn with_collection(mut self, collection: impl Into<Cow<'static, str>>) -> Self {
240        self.collection = collection.into();
241        self
242    }
243    /// The [`MessageExpirationStrategy`] used to remove old documents.
244    /// Default is `TtlIndex(Duration::from_secs(60))` with the `ttl-index` feature enabled.
245    /// Otherwise it is `CappedCollection(1MB)`
246    pub fn with_expiration_strategy(
247        mut self,
248        expiration_strategy: MessageExpirationStrategy,
249    ) -> Self {
250        self.expiration_strategy = expiration_strategy;
251        self
252    }
253}
254
255impl Default for MongoDbAdapterConfig {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261/// The adapter constructor. For each namespace you define, a new adapter instance is created
262/// from this constructor.
263#[derive(Debug, Clone)]
264pub struct MongoDbAdapterCtr<D> {
265    driver: D,
266    config: MongoDbAdapterConfig,
267}
268
269#[cfg(feature = "mongodb")]
270impl MongoDbAdapterCtr<drivers::mongodb::MongoDbDriver> {
271    /// Create a new adapter constructor with the [`mongodb`](drivers::mongodb) driver
272    /// and a default config.
273    pub async fn new_with_mongodb(
274        db: mongodb::Database,
275    ) -> Result<Self, drivers::mongodb::mongodb_client::error::Error> {
276        Self::new_with_mongodb_config(db, MongoDbAdapterConfig::default()).await
277    }
278    /// Create a new adapter constructor with the [`mongodb`](drivers::mongodb) driver
279    /// and a custom config.
280    pub async fn new_with_mongodb_config(
281        db: mongodb::Database,
282        config: MongoDbAdapterConfig,
283    ) -> Result<Self, drivers::mongodb::mongodb_client::error::Error> {
284        use drivers::mongodb::MongoDbDriver;
285        let driver =
286            MongoDbDriver::new(db, &config.collection, &config.expiration_strategy).await?;
287        Ok(Self { driver, config })
288    }
289}
290impl<D: Driver> MongoDbAdapterCtr<D> {
291    /// Create a new adapter constructor with a custom mongodb driver and a config.
292    ///
293    /// You can implement your own driver by implementing the [`Driver`] trait with any mongodb client.
294    /// Check the [`drivers`] module for more information.
295    pub fn new_with_driver(driver: D, config: MongoDbAdapterConfig) -> Self {
296        Self { driver, config }
297    }
298}
299
300/// Represent any error that might happen when using this adapter.
301#[derive(thiserror::Error)]
302pub enum Error<D: Driver> {
303    /// Mongo driver error
304    #[error("driver error: {0}")]
305    Driver(D::Error),
306    /// Packet encoding error
307    #[error("packet encoding error: {0}")]
308    Encode(#[from] rmp_serde::encode::Error),
309    /// Packet decoding error
310    #[error("packet decoding error: {0}")]
311    Decode(#[from] rmp_serde::decode::Error),
312}
313
314impl<R: Driver> Error<R> {
315    fn from_driver(err: R::Error) -> Self {
316        Self::Driver(err)
317    }
318}
319impl<R: Driver> fmt::Debug for Error<R> {
320    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
321        match self {
322            Self::Driver(err) => write!(f, "Driver error: {:?}", err),
323            Self::Decode(err) => write!(f, "Decode error: {:?}", err),
324            Self::Encode(err) => write!(f, "Encode error: {:?}", err),
325        }
326    }
327}
328
329impl<R: Driver> From<Error<R>> for AdapterError {
330    fn from(err: Error<R>) -> Self {
331        AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
332    }
333}
334
335pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Item>>;
336
337/// The mongodb adapter with the [mongodb](drivers::mongodb::mongodb_client) driver.
338#[cfg(feature = "mongodb")]
339pub type MongoDbAdapter<E> = CustomMongoDbAdapter<E, drivers::mongodb::MongoDbDriver>;
340
341/// The mongodb adapter implementation.
342/// It is generic over the [`Driver`] used to communicate with the mongodb server.
343/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to
344/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates.
345pub struct CustomMongoDbAdapter<E, D> {
346    /// The driver used by the adapter. This is used to communicate with the mongodb server.
347    /// All the mongodb adapter instances share the same driver.
348    driver: D,
349    /// The configuration of the adapter.
350    config: MongoDbAdapterConfig,
351    /// A unique identifier for the adapter to identify itself in the mongodb server.
352    uid: Uid,
353    /// The local adapter, used to manage local rooms and socket stores.
354    local: CoreLocalAdapter<E>,
355    /// A map of nodes liveness, with the last time remote nodes were seen alive.
356    nodes_liveness: Mutex<Vec<(Uid, std::time::Instant)>>,
357    /// A map of response handlers used to await for responses from the remote servers.
358    responses: Arc<Mutex<ResponseHandlers>>,
359}
360
361impl<E, D> DefinedAdapter for CustomMongoDbAdapter<E, D> {}
362impl<E: SocketEmitter, D: Driver> CoreAdapter<E> for CustomMongoDbAdapter<E, D> {
363    type Error = Error<D>;
364    type State = MongoDbAdapterCtr<D>;
365    type AckStream = AckStream<E::AckStream>;
366    type InitRes = InitRes<D>;
367
368    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
369        let uid = local.server_id();
370        Self {
371            local,
372            uid,
373            driver: state.driver.clone(),
374            config: state.config.clone(),
375            nodes_liveness: Mutex::new(Vec::new()),
376            responses: Arc::new(Mutex::new(HashMap::new())),
377        }
378    }
379
380    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
381        let fut = async move {
382            let stream = self.driver.watch(self.uid, self.local.path()).await?;
383            tokio::spawn(self.clone().handle_ev_stream(stream));
384            tokio::spawn(self.clone().heartbeat_job());
385
386            // Send initial heartbeat when starting.
387            self.emit_init_heartbeat().await.map_err(|e| match e {
388                Error::Driver(e) => e,
389                Error::Encode(_) | Error::Decode(_) => unreachable!(),
390            })?;
391
392            on_success();
393            Ok(())
394        };
395        InitRes(Box::pin(fut))
396    }
397
398    async fn close(&self) -> Result<(), Self::Error> {
399        Ok(())
400    }
401
402    /// Get the number of servers by iterating over the node liveness heartbeats.
403    async fn server_count(&self) -> Result<u16, Self::Error> {
404        let treshold = std::time::Instant::now() - self.config.hb_timeout;
405        let mut nodes_liveness = self.nodes_liveness.lock().unwrap();
406        nodes_liveness.retain(|(_, v)| v > &treshold);
407        Ok((nodes_liveness.len() + 1) as u16)
408    }
409
410    /// Broadcast a packet to all the servers to send them through their sockets.
411    async fn broadcast(
412        &self,
413        packet: Packet,
414        opts: BroadcastOptions,
415    ) -> Result<(), BroadcastError> {
416        if !opts.is_local(self.uid) {
417            let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
418            self.send_req(req, None).await.map_err(AdapterError::from)?;
419        }
420
421        self.local.broadcast(packet, opts)?;
422        Ok(())
423    }
424
425    /// Broadcast a packet to all the servers to send them through their sockets.
426    ///
427    /// Returns a Stream that is a combination of the local ack stream and a remote ack stream.
428    /// Here is a specific protocol in order to know how many message the server expect to close
429    /// the stream at the right time:
430    /// * Get the number `n` of remote servers.
431    /// * Send the broadcast request.
432    /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses.
433    /// * Expect `sum(m)` broadcast counts sent by the servers.
434    ///
435    /// Example with 3 remote servers (n = 3):
436    /// ```text
437    /// +---+                   +---+                   +---+
438    /// | A |                   | B |                   | C |
439    /// +---+                   +---+                   +---+
440    ///   |                       |                       |
441    ///   |---BroadcastWithAck--->|                       |
442    ///   |---BroadcastWithAck--------------------------->|
443    ///   |                       |                       |
444    ///   |<-BroadcastAckCount(2)-|     (n = 2; m = 2)    |
445    ///   |<-BroadcastAckCount(2)-------(n = 2; m = 4)----|
446    ///   |                       |                       |
447    ///   |<----------------Ack---------------------------|
448    ///   |<----------------Ack---|                       |
449    ///   |                       |                       |
450    ///   |<----------------Ack---------------------------|
451    ///   |<----------------Ack---|                       |
452    async fn broadcast_with_ack(
453        &self,
454        packet: Packet,
455        opts: BroadcastOptions,
456        timeout: Option<Duration>,
457    ) -> Result<Self::AckStream, Self::Error> {
458        if opts.is_local(self.uid) {
459            tracing::debug!(?opts, "broadcast with ack is local");
460            let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
461            let stream = AckStream::new_local(local);
462            return Ok(stream);
463        }
464        let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
465        let req_id = req.id;
466
467        let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
468        tracing::trace!(?remote_serv_cnt, "expecting acks from remote servers");
469
470        let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
471        self.responses.lock().unwrap().insert(req_id, tx);
472        self.send_req(req, None).await?;
473        let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
474
475        Ok(AckStream::new(
476            local,
477            rx,
478            self.config.request_timeout,
479            remote_serv_cnt,
480            req_id,
481            self.responses.clone(),
482        ))
483    }
484
485    async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
486        if !opts.is_local(self.uid) {
487            let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
488            self.send_req(req, None).await.map_err(AdapterError::from)?;
489        }
490        self.local
491            .disconnect_socket(opts)
492            .map_err(BroadcastError::Socket)?;
493
494        Ok(())
495    }
496
497    async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
498        if opts.is_local(self.uid) {
499            return Ok(self.local.rooms(opts).into_iter().collect());
500        }
501        let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
502        let req_id = req.id;
503
504        // First get the remote stream because mongodb might send
505        // the responses before subscription is done.
506        let stream = self
507            .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
508            .await?;
509        self.send_req(req, opts.server_id).await?;
510        let local = self.local.rooms(opts);
511        let rooms = stream
512            .filter_map(|item| future::ready(item.into_rooms()))
513            .fold(local, |mut acc, item| async move {
514                acc.extend(item);
515                acc
516            })
517            .await;
518        Ok(Vec::from_iter(rooms))
519    }
520
521    async fn add_sockets(
522        &self,
523        opts: BroadcastOptions,
524        rooms: impl RoomParam,
525    ) -> Result<(), Self::Error> {
526        let rooms: Vec<Room> = rooms.into_room_iter().collect();
527        if !opts.is_local(self.uid) {
528            let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
529            self.send_req(req, opts.server_id).await?;
530        }
531        self.local.add_sockets(opts, rooms);
532        Ok(())
533    }
534
535    async fn del_sockets(
536        &self,
537        opts: BroadcastOptions,
538        rooms: impl RoomParam,
539    ) -> Result<(), Self::Error> {
540        let rooms: Vec<Room> = rooms.into_room_iter().collect();
541        if !opts.is_local(self.uid) {
542            let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
543            self.send_req(req, opts.server_id).await?;
544        }
545        self.local.del_sockets(opts, rooms);
546        Ok(())
547    }
548
549    async fn fetch_sockets(
550        &self,
551        opts: BroadcastOptions,
552    ) -> Result<Vec<RemoteSocketData>, Self::Error> {
553        if opts.is_local(self.uid) {
554            return Ok(self.local.fetch_sockets(opts));
555        }
556        let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
557        // First get the remote stream because mongodb might send
558        // the responses before subscription is done.
559        let remote = self
560            .get_res::<RemoteSocketData>(req.id, ResponseTypeId::FetchSockets, opts.server_id)
561            .await?;
562
563        self.send_req(req, opts.server_id).await?;
564        let local = self.local.fetch_sockets(opts);
565        let sockets = remote
566            .filter_map(|item| future::ready(item.into_fetch_sockets()))
567            .fold(local, |mut acc, item| async move {
568                acc.extend(item);
569                acc
570            })
571            .await;
572        Ok(sockets)
573    }
574
575    fn get_local(&self) -> &CoreLocalAdapter<E> {
576        &self.local
577    }
578}
579
580impl<E: SocketEmitter, D: Driver> CustomMongoDbAdapter<E, D> {
581    async fn heartbeat_job(self: Arc<Self>) -> Result<(), Error<D>> {
582        let mut interval = tokio::time::interval(self.config.hb_interval);
583        interval.tick().await; // first tick yields immediately
584        loop {
585            interval.tick().await;
586            self.emit_heartbeat(None).await?;
587        }
588    }
589
590    async fn handle_ev_stream(
591        self: Arc<Self>,
592        mut stream: impl Stream<Item = Result<Item, D::Error>> + Unpin,
593    ) {
594        while let Some(item) = stream.next().await {
595            match item {
596                Ok(Item {
597                    header: ItemHeader::Req { target, .. },
598                    data,
599                    ..
600                }) if target.is_none_or(|id| id == self.uid) => {
601                    tracing::debug!(?target, "request header");
602                    if let Err(e) = self.recv_req(data).await {
603                        tracing::warn!("error receiving request from driver: {e}");
604                    }
605                }
606                Ok(Item {
607                    header: ItemHeader::Req { target, .. },
608                    ..
609                }) => {
610                    tracing::debug!(
611                        ?target,
612                        "receiving request which is not for us, skipping..."
613                    );
614                }
615                Ok(
616                    item @ Item {
617                        header: ItemHeader::Res { request, .. },
618                        ..
619                    },
620                ) => {
621                    tracing::trace!(?request, "received response");
622                    let handlers = self.responses.lock().unwrap();
623                    if let Some(tx) = handlers.get(&request) {
624                        if let Err(e) = tx.try_send(item) {
625                            tracing::warn!("error sending response to handler: {e}");
626                        }
627                    } else {
628                        tracing::warn!(?request, ?handlers, "could not find req handler");
629                    }
630                }
631                Err(e) => {
632                    tracing::warn!("error receiving event from driver: {e}");
633                }
634            }
635        }
636    }
637
638    async fn recv_req(self: &Arc<Self>, req: Vec<u8>) -> Result<(), Error<D>> {
639        let req = rmp_serde::from_slice::<RequestIn>(&req)?;
640        tracing::trace!(?req, "incoming request");
641        match (req.r#type, req.opts) {
642            (RequestTypeIn::Broadcast(p), Some(opts)) => self.recv_broadcast(opts, p),
643            (RequestTypeIn::BroadcastWithAck(p), Some(opts)) => self
644                .clone()
645                .recv_broadcast_with_ack(req.node_id, req.id, p, opts),
646            (RequestTypeIn::DisconnectSockets, Some(opts)) => self.recv_disconnect_sockets(opts),
647            (RequestTypeIn::AllRooms, Some(opts)) => self.recv_rooms(req.node_id, req.id, opts),
648            (RequestTypeIn::AddSockets(rooms), Some(opts)) => self.recv_add_sockets(opts, rooms),
649            (RequestTypeIn::DelSockets(rooms), Some(opts)) => self.recv_del_sockets(opts, rooms),
650            (RequestTypeIn::FetchSockets, Some(opts)) => {
651                self.recv_fetch_sockets(req.node_id, req.id, opts)
652            }
653            req_type @ (RequestTypeIn::Heartbeat | RequestTypeIn::InitHeartbeat, _) => {
654                self.recv_heartbeat(req_type.0, req.node_id)
655            }
656            _ => (),
657        }
658        Ok(())
659    }
660
661    fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
662        tracing::trace!(?opts, "incoming broadcast");
663        if let Err(e) = self.local.broadcast(packet, opts) {
664            let ns = self.local.path();
665            tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
666        }
667    }
668
669    fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
670        if let Err(e) = self.local.disconnect_socket(opts) {
671            let ns = self.local.path();
672            tracing::warn!(
673                ?self.uid,
674                ?ns,
675                "remote request disconnect sockets handler: {:?}",
676                e
677            );
678        }
679    }
680
681    fn recv_broadcast_with_ack(
682        self: Arc<Self>,
683        origin: Uid,
684        req_id: Sid,
685        packet: Packet,
686        opts: BroadcastOptions,
687    ) {
688        let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
689        tokio::spawn(async move {
690            let on_err = |err| {
691                let ns = self.local.path();
692                tracing::warn!(
693                    ?self.uid,
694                    ?ns,
695                    "remote request broadcast with ack handler errors: {:?}",
696                    err
697                );
698            };
699            // First send the count of expected acks to the server that sent the request.
700            // This is used to keep track of the number of expected acks.
701            let res = Response {
702                r#type: ResponseType::<()>::BroadcastAckCount(count),
703                node_id: self.uid,
704            };
705            if let Err(err) = self.send_res(req_id, origin, res).await {
706                on_err(err);
707                return;
708            }
709
710            // Then send the acks as they are received.
711            futures_util::pin_mut!(stream);
712            while let Some(ack) = stream.next().await {
713                let res = Response {
714                    r#type: ResponseType::BroadcastAck(ack),
715                    node_id: self.uid,
716                };
717                if let Err(err) = self.send_res(req_id, origin, res).await {
718                    on_err(err);
719                    return;
720                }
721            }
722        });
723    }
724
725    fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
726        let rooms = self.local.rooms(opts);
727        let res = Response {
728            r#type: ResponseType::<()>::AllRooms(rooms),
729            node_id: self.uid,
730        };
731        let fut = self.send_res(req_id, origin, res);
732        let ns = self.local.path().clone();
733        let uid = self.uid;
734        tokio::spawn(async move {
735            if let Err(err) = fut.await {
736                tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
737            }
738        });
739    }
740
741    fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
742        self.local.add_sockets(opts, rooms);
743    }
744
745    fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
746        self.local.del_sockets(opts, rooms);
747    }
748    fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
749        let sockets = self.local.fetch_sockets(opts);
750        let res = Response {
751            node_id: self.uid,
752            r#type: ResponseType::FetchSockets(sockets),
753        };
754        let fut = self.send_res(req_id, origin, res);
755        let ns = self.local.path().clone();
756        let uid = self.uid;
757        tokio::spawn(async move {
758            if let Err(err) = fut.await {
759                tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
760            }
761        });
762    }
763
764    /// Receive a heartbeat from a remote node.
765    /// It might be a FirstHeartbeat packet, in which case we are re-emitting a heartbeat to the remote node.
766    fn recv_heartbeat(self: &Arc<Self>, req_type: RequestTypeIn, origin: Uid) {
767        tracing::debug!(?req_type, "{:?} received", req_type);
768        let mut node_liveness = self.nodes_liveness.lock().unwrap();
769        // Even with a FirstHeartbeat packet we first consume the node liveness to
770        // ensure that the node is not already in the list.
771        for (id, liveness) in node_liveness.iter_mut() {
772            if *id == origin {
773                *liveness = Instant::now();
774                return;
775            }
776        }
777
778        node_liveness.push((origin, Instant::now()));
779
780        if matches!(req_type, RequestTypeIn::InitHeartbeat) {
781            tracing::debug!(
782                ?origin,
783                "initial heartbeat detected, saying hello to the new node"
784            );
785
786            let this = self.clone();
787            tokio::spawn(async move {
788                if let Err(err) = this.emit_heartbeat(Some(origin)).await {
789                    tracing::warn!(
790                        "could not re-emit heartbeat after new node detection: {:?}",
791                        err
792                    );
793                }
794            });
795        }
796    }
797
798    /// Send a request to a specific target node or broadcast it to all nodes if no target is specified.
799    async fn send_req(&self, req: RequestOut<'_>, target: Option<Uid>) -> Result<(), Error<D>> {
800        tracing::trace!(?req, "sending request");
801        let head = ItemHeader::Req { target };
802        let req = self.new_packet(head, &req)?;
803        self.driver.emit(&req).await.map_err(Error::from_driver)?;
804        Ok(())
805    }
806
807    /// Send a response to the node that sent the request.
808    fn send_res<T: Serialize + fmt::Debug>(
809        &self,
810        req_id: Sid,
811        req_origin: Uid,
812        res: Response<T>,
813    ) -> impl Future<Output = Result<(), Error<D>>> + Send + 'static {
814        tracing::trace!(?res, "sending response for {req_id} req to {req_origin}");
815        let driver = self.driver.clone();
816        let head = ItemHeader::Res {
817            request: req_id,
818            target: req_origin,
819        };
820        let res = self.new_packet(head, &res);
821
822        async move {
823            driver.emit(&res?).await.map_err(Error::from_driver)?;
824            Ok(())
825        }
826    }
827
828    /// Await for all the responses from the remote servers.
829    /// If the target node is specified, only await for the response from that node.
830    async fn get_res<T: DeserializeOwned + fmt::Debug>(
831        &self,
832        req_id: Sid,
833        response_type: ResponseTypeId,
834        target: Option<Uid>,
835    ) -> Result<impl Stream<Item = Response<T>>, Error<D>> {
836        // Check for specific target node
837        let remote_serv_cnt = if target.is_none() {
838            self.server_count().await?.saturating_sub(1) as usize
839        } else {
840            1
841        };
842        let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
843        self.responses.lock().unwrap().insert(req_id, tx);
844        let stream = ChanStream::new(rx)
845            .filter_map(|Item { header, data, .. }| {
846                let data = match rmp_serde::from_slice::<Response<T>>(&data) {
847                    Ok(data) => Some(data),
848                    Err(e) => {
849                        tracing::warn!(header = ?header, "error decoding response: {e}");
850                        None
851                    }
852                };
853                future::ready(data)
854            })
855            .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
856            .take(remote_serv_cnt)
857            .take_until(tokio::time::sleep(self.config.request_timeout));
858        let stream = DropStream::new(stream, self.responses.clone(), req_id);
859        Ok(stream)
860    }
861
862    /// Emit a heartbeat to the specified target node or broadcast to all nodes.
863    async fn emit_heartbeat(&self, target: Option<Uid>) -> Result<(), Error<D>> {
864        // Send heartbeat when starting.
865        self.send_req(
866            RequestOut::new_empty(self.uid, RequestTypeOut::Heartbeat),
867            target,
868        )
869        .await
870    }
871
872    /// Emit an initial heartbeat to all nodes.
873    async fn emit_init_heartbeat(&self) -> Result<(), Error<D>> {
874        // Send initial heartbeat when starting.
875        self.send_req(
876            RequestOut::new_empty(self.uid, RequestTypeOut::InitHeartbeat),
877            None,
878        )
879        .await
880    }
881    fn new_packet(&self, head: ItemHeader, data: &impl Serialize) -> Result<Item, Error<D>> {
882        let ns = &self.local.path();
883        let uid = self.uid;
884        match self.config.expiration_strategy {
885            #[cfg(feature = "ttl-index")]
886            MessageExpirationStrategy::TtlIndex(_) => Ok(Item::new_ttl(head, data, uid, ns)?),
887            MessageExpirationStrategy::CappedCollection(_) => Ok(Item::new(head, data, uid, ns)?),
888        }
889    }
890}
891
892/// The result of the init future.
893#[must_use = "futures do nothing unless you `.await` or poll them"]
894pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), D::Error>>);
895
896impl<D: Driver> Future for InitRes<D> {
897    type Output = Result<(), D::Error>;
898
899    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
900        self.0.as_mut().poll(cx)
901    }
902}
903impl<D: Driver> Spawnable for InitRes<D> {
904    fn spawn(self) {
905        tokio::spawn(async move {
906            if let Err(e) = self.0.await {
907                tracing::error!("error initializing adapter: {e}");
908            }
909        });
910    }
911}