socketioxide_redis/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(
3    clippy::all,
4    clippy::todo,
5    clippy::empty_enum,
6    clippy::mem_forget,
7    clippy::unused_self,
8    clippy::filter_map_next,
9    clippy::needless_continue,
10    clippy::needless_borrow,
11    clippy::match_wildcard_for_single_variants,
12    clippy::if_let_mutex,
13    clippy::await_holding_lock,
14    clippy::match_on_vec_items,
15    clippy::imprecise_flops,
16    clippy::suboptimal_flops,
17    clippy::lossy_float_literal,
18    clippy::rest_pat_in_fully_bound_structs,
19    clippy::fn_params_excessive_bools,
20    clippy::exit,
21    clippy::inefficient_to_string,
22    clippy::linkedlist,
23    clippy::macro_use_imports,
24    clippy::option_option,
25    clippy::verbose_file_reads,
26    clippy::unnested_or_patterns,
27    rust_2018_idioms,
28    rust_2024_compatibility,
29    future_incompatible,
30    nonstandard_style,
31    missing_docs
32)]
33
34//! # A redis/valkey adapter implementation for the socketioxide crate.
35//! The adapter is used to communicate with other nodes of the same application.
36//! This allows to broadcast messages to sockets connected on other servers,
37//! to get the list of rooms, to add or remove sockets from rooms, etc.
38//!
39//! To achieve this, the adapter uses a [pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/) system
40//! through Redis to communicate with other servers.
41//!
42//! The [`Driver`] abstraction allows the use of any pub/sub client.
43//! Three implementations are provided:
44//! * [`RedisDriver`](crate::drivers::redis::RedisDriver) for the [`redis`] crate with a standalone redis.
45//! * [`ClusterDriver`](crate::drivers::redis::ClusterDriver) for the [`redis`] crate with a redis cluster.
46//! * [`FredDriver`](crate::drivers::fred::FredDriver) for the [`fred`] crate with a standalone/cluster redis.
47//!
48//! When using redis clusters, the drivers employ [sharded pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/#sharded-pubsub)
49//! to distribute the load across Redis nodes.
50//!
51//! You can also implement your own driver by implementing the [`Driver`] trait.
52//!
53//! <div class="warning">
54//!     The provided driver implementations are using <code>RESP3</code> for efficiency purposes.
55//!     Make sure your redis server supports it (redis v7 and above).
56//!     If not, you can implement your own driver using the <code>RESP2</code> protocol.
57//! </div>
58//!
59//! <div class="warning">
60//!     Socketioxide-Redis is not compatible with <code>@socketio/redis-adapter</code>
61//!     and <code>@socketio/redis-emitter</code>. They use completely different protocols and
62//!     cannot be used together. Do not mix socket.io JS servers with socketioxide rust servers.
63//! </div>
64//!
65//! ## Example with the [`redis`] driver
66//! ```rust
67//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
68//! # use socketioxide_redis::{RedisAdapterCtr, RedisAdapter};
69//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
70//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
71//!     socket.join("room1");
72//!     socket.on("event", on_event);
73//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
74//! }
75//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
76//!
77//! let client = redis::Client::open("redis://127.0.0.1:6379?protocol=RESP3")?;
78//! let adapter = RedisAdapterCtr::new_with_redis(&client).await?;
79//! let (layer, io) = SocketIo::builder()
80//!     .with_adapter::<RedisAdapter<_>>(adapter)
81//!     .build_layer();
82//! Ok(())
83//! # }
84//! ```
85//!
86//!
87//! ## Example with the [`fred`] driver
88//! ```rust
89//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
90//! # use socketioxide_redis::{RedisAdapterCtr, FredAdapter};
91//! # use fred::types::RespVersion;
92//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
93//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
94//!     socket.join("room1");
95//!     socket.on("event", on_event);
96//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
97//! }
98//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
99//!
100//! let mut config = fred::prelude::Config::from_url("redis://127.0.0.1:6379?protocol=resp3")?;
101//! // We need to manually set the RESP3 version because
102//! // the fred crate does not parse the protocol query parameter.
103//! config.version = RespVersion::RESP3;
104//! let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?;
105//! let adapter = RedisAdapterCtr::new_with_fred(client).await?;
106//! let (layer, io) = SocketIo::builder()
107//!     .with_adapter::<FredAdapter<_>>(adapter)
108//!     .build_layer();
109//! Ok(())
110//! # }
111//! ```
112//!
113//!
114//! ## Example with the [`redis`] cluster driver
115//! ```rust
116//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
117//! # use socketioxide_redis::{RedisAdapterCtr, ClusterAdapter};
118//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
119//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
120//!     socket.join("room1");
121//!     socket.on("event", on_event);
122//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
123//! }
124//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
125//!
126//! // single node cluster
127//! let client = redis::cluster::ClusterClient::new(["redis://127.0.0.1:6379?protocol=resp3"])?;
128//! let adapter = RedisAdapterCtr::new_with_cluster(&client).await?;
129//!
130//! let (layer, io) = SocketIo::builder()
131//!     .with_adapter::<ClusterAdapter<_>>(adapter)
132//!     .build_layer();
133//! Ok(())
134//! # }
135//! ```
136//!
137//! Check the [`chat example`](https://github.com/Totodore/socketioxide/tree/main/examples/chat)
138//! for more complete examples.
139//!
140//! ## How does it work?
141//!
142//! An adapter is created for each created namespace and it takes a corresponding [`CoreLocalAdapter`].
143//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter`
144//! is simply a wrapper around this [`CoreLocalAdapter`].
145//!
146//! The adapter is then initialized with the [`RedisAdapter::init`] method.
147//! This will subscribe to 3 channels:
148//! * `"{prefix}-request#{namespace}#"`: A global channel to receive broadcasted requests.
149//! * `"{prefix}-request#{namespace}#{uid}#"`: A specific channel to receive requests only for this server.
150//! * `"{prefix}-response#{namespace}#{uid}#"`: A specific channel to receive responses only for this server.
151//!   Messages sent to this channel will be always in the form `[req_id, data]`. This will allow the adapter to extract the request id
152//!   and route the response to the appropriate stream before deserializing the data.
153//!
154//! All messages are encoded with msgpack.
155//!
156//! There are 7 types of requests:
157//! * Broadcast a packet to all the matching sockets.
158//! * Broadcast a packet to all the matching sockets and wait for a stream of acks.
159//! * Disconnect matching sockets.
160//! * Get all the rooms.
161//! * Add matching sockets to rooms.
162//! * Remove matching sockets to rooms.
163//! * Fetch all the remote sockets matching the options.
164//!
165//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request,
166//! and then send the acks as they are received (more details in [`RedisAdapter::broadcast_with_ack`] fn).
167//!
168//! On the other side, each time an action has to be performed on the local server, the adapter will
169//! first broadcast a request to all the servers and then perform the action locally.
170
171use std::{
172    borrow::Cow,
173    collections::HashMap,
174    fmt,
175    future::{self, Future},
176    pin::Pin,
177    sync::{Arc, Mutex},
178    task::{Context, Poll},
179    time::Duration,
180};
181
182use drivers::{ChanItem, Driver, MessageStream};
183use futures_core::Stream;
184use futures_util::StreamExt;
185use serde::{Serialize, de::DeserializeOwned};
186use socketioxide_core::adapter::remote_packet::{
187    RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
188};
189use socketioxide_core::{
190    Sid, Uid,
191    adapter::errors::{AdapterError, BroadcastError},
192    adapter::{
193        BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
194        RoomParam, SocketEmitter, Spawnable,
195    },
196    packet::Packet,
197};
198use stream::{AckStream, DropStream};
199use tokio::{sync::mpsc, time};
200
201/// Drivers are an abstraction over the pub/sub backend used by the adapter.
202/// You can use the provided implementation or implement your own.
203pub mod drivers;
204
205mod stream;
206
207/// Represent any error that might happen when using this adapter.
208#[derive(thiserror::Error)]
209pub enum Error<R: Driver> {
210    /// Redis driver error
211    #[error("driver error: {0}")]
212    Driver(R::Error),
213    /// Packet encoding error
214    #[error("packet encoding error: {0}")]
215    Decode(#[from] rmp_serde::decode::Error),
216    /// Packet decoding error
217    #[error("packet decoding error: {0}")]
218    Encode(#[from] rmp_serde::encode::Error),
219}
220
221impl<R: Driver> Error<R> {
222    fn from_driver(err: R::Error) -> Self {
223        Self::Driver(err)
224    }
225}
226impl<R: Driver> fmt::Debug for Error<R> {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        match self {
229            Self::Driver(err) => write!(f, "Driver error: {:?}", err),
230            Self::Decode(err) => write!(f, "Decode error: {:?}", err),
231            Self::Encode(err) => write!(f, "Encode error: {:?}", err),
232        }
233    }
234}
235
236impl<R: Driver> From<Error<R>> for AdapterError {
237    fn from(err: Error<R>) -> Self {
238        AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
239    }
240}
241
242/// The configuration of the [`RedisAdapter`].
243#[derive(Debug, Clone)]
244pub struct RedisAdapterConfig {
245    /// The request timeout. It is mainly used when expecting response such as when using
246    /// `broadcast_with_ack` or `rooms`. Default is 5 seconds.
247    pub request_timeout: Duration,
248
249    /// The prefix used for the channels. Default is "socket.io".
250    pub prefix: Cow<'static, str>,
251
252    /// The channel size used to receive ack responses. Default is 255.
253    ///
254    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
255    /// than you poll them with the returned stream, you might want to increase this value.
256    pub ack_response_buffer: usize,
257
258    /// The channel size used to receive messages. Default is 1024.
259    ///
260    /// If your server is under heavy load, you might want to increase this value.
261    pub stream_buffer: usize,
262}
263impl RedisAdapterConfig {
264    /// Create a new config.
265    pub fn new() -> Self {
266        Self::default()
267    }
268    /// Set the request timeout. Default is 5 seconds.
269    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
270        self.request_timeout = timeout;
271        self
272    }
273
274    /// Set the prefix used for the channels. Default is "socket.io".
275    pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
276        self.prefix = prefix.into();
277        self
278    }
279
280    /// Set the channel size used to send ack responses. Default is 255.
281    ///
282    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
283    /// than you poll them with the returned stream, you might want to increase this value.
284    pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
285        assert!(buffer > 0, "buffer size must be greater than 0");
286        self.ack_response_buffer = buffer;
287        self
288    }
289
290    /// Set the channel size used to receive messages. Default is 1024.
291    ///
292    /// If your server is under heavy load, you might want to increase this value.
293    pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
294        assert!(buffer > 0, "buffer size must be greater than 0");
295        self.stream_buffer = buffer;
296        self
297    }
298}
299
300impl Default for RedisAdapterConfig {
301    fn default() -> Self {
302        Self {
303            request_timeout: Duration::from_secs(5),
304            prefix: Cow::Borrowed("socket.io"),
305            ack_response_buffer: 255,
306            stream_buffer: 1024,
307        }
308    }
309}
310
311/// The adapter constructor. For each namespace you define, a new adapter instance is created
312/// from this constructor.
313#[derive(Debug)]
314pub struct RedisAdapterCtr<R> {
315    driver: R,
316    config: RedisAdapterConfig,
317}
318
319#[cfg(feature = "redis")]
320impl RedisAdapterCtr<drivers::redis::RedisDriver> {
321    /// Create a new adapter constructor with the [`redis`] driver and a default config.
322    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
323    pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
324        Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
325    }
326    /// Create a new adapter constructor with the [`redis`] driver and a custom config.
327    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
328    pub async fn new_with_redis_config(
329        client: &redis::Client,
330        config: RedisAdapterConfig,
331    ) -> redis::RedisResult<Self> {
332        let driver = drivers::redis::RedisDriver::new(client).await?;
333        Ok(Self::new_with_driver(driver, config))
334    }
335}
336#[cfg(feature = "redis-cluster")]
337impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
338    /// Create a new adapter constructor with the [`redis`] driver and a default config.
339    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
340    pub async fn new_with_cluster(
341        client: &redis::cluster::ClusterClient,
342    ) -> redis::RedisResult<Self> {
343        Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
344    }
345
346    /// Create a new adapter constructor with the [`redis`] driver and a default config.
347    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
348    pub async fn new_with_cluster_config(
349        client: &redis::cluster::ClusterClient,
350        config: RedisAdapterConfig,
351    ) -> redis::RedisResult<Self> {
352        let driver = drivers::redis::ClusterDriver::new(client).await?;
353        Ok(Self::new_with_driver(driver, config))
354    }
355}
356#[cfg(feature = "fred")]
357impl RedisAdapterCtr<drivers::fred::FredDriver> {
358    /// Create a new adapter constructor with the default [`fred`] driver and a default config.
359    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
360    pub async fn new_with_fred(
361        client: fred::clients::SubscriberClient,
362    ) -> fred::prelude::FredResult<Self> {
363        Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
364    }
365    /// Create a new adapter constructor with the default [`fred`] driver and a custom config.
366    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
367    pub async fn new_with_fred_config(
368        client: fred::clients::SubscriberClient,
369        config: RedisAdapterConfig,
370    ) -> fred::prelude::FredResult<Self> {
371        let driver = drivers::fred::FredDriver::new(client).await?;
372        Ok(Self::new_with_driver(driver, config))
373    }
374}
375impl<R: Driver> RedisAdapterCtr<R> {
376    /// Create a new adapter constructor with a custom redis/valkey driver and a config.
377    ///
378    /// You can implement your own driver by implementing the [`Driver`] trait with any redis/valkey client.
379    /// Check the [`drivers`] module for more information.
380    pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
381        RedisAdapterCtr { driver, config }
382    }
383}
384
385pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
386
387/// The redis adapter with the fred driver.
388#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
389#[cfg(feature = "fred")]
390pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
391
392/// The redis adapter with the redis driver.
393#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
394#[cfg(feature = "redis")]
395pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
396
397/// The redis adapter with the redis cluster driver.
398#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
399#[cfg(feature = "redis-cluster")]
400pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
401
402/// The redis adapter implementation.
403/// It is generic over the [`Driver`] used to communicate with the redis server.
404/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to
405/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates.
406pub struct CustomRedisAdapter<E, R> {
407    /// The driver used by the adapter. This is used to communicate with the redis server.
408    /// All the redis adapter instances share the same driver.
409    driver: R,
410    /// The configuration of the adapter.
411    config: RedisAdapterConfig,
412    /// A unique identifier for the adapter to identify itself in the redis server.
413    uid: Uid,
414    /// The local adapter, used to manage local rooms and socket stores.
415    local: CoreLocalAdapter<E>,
416    /// The request channel used to broadcast requests to all the servers.
417    /// format: `{prefix}-request#{path}#`.
418    req_chan: String,
419    /// A map of response handlers used to await for responses from the remote servers.
420    responses: Arc<Mutex<ResponseHandlers>>,
421}
422
423impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
424impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
425    type Error = Error<R>;
426    type State = RedisAdapterCtr<R>;
427    type AckStream = AckStream<E::AckStream>;
428    type InitRes = InitRes<R>;
429
430    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
431        let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
432        let uid = local.server_id();
433        Self {
434            local,
435            req_chan,
436            uid,
437            driver: state.driver.clone(),
438            config: state.config.clone(),
439            responses: Arc::new(Mutex::new(HashMap::new())),
440        }
441    }
442
443    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
444        let fut = async move {
445            check_ns(self.local.path())?;
446            let global_stream = self.subscribe(self.req_chan.clone()).await?;
447            let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
448            let response_chan = format!(
449                "{}-response#{}#{}#",
450                &self.config.prefix,
451                self.local.path(),
452                self.uid
453            );
454
455            let response_stream = self.subscribe(response_chan.clone()).await?;
456            let stream = futures_util::stream::select(global_stream, specific_stream);
457            let stream = futures_util::stream::select(stream, response_stream);
458            tokio::spawn(self.pipe_stream(stream, response_chan));
459            on_success();
460            Ok(())
461        };
462        InitRes(Box::pin(fut))
463    }
464
465    async fn close(&self) -> Result<(), Self::Error> {
466        let response_chan = format!(
467            "{}-response#{}#{}#",
468            &self.config.prefix,
469            self.local.path(),
470            self.uid
471        );
472        tokio::try_join!(
473            self.driver.unsubscribe(self.req_chan.clone()),
474            self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
475            self.driver.unsubscribe(response_chan)
476        )
477        .map_err(Error::from_driver)?;
478
479        Ok(())
480    }
481
482    /// Get the number of servers by getting the number of subscribers to the request channel.
483    async fn server_count(&self) -> Result<u16, Self::Error> {
484        let count = self
485            .driver
486            .num_serv(&self.req_chan)
487            .await
488            .map_err(Error::from_driver)?;
489
490        Ok(count)
491    }
492
493    /// Broadcast a packet to all the servers to send them through their sockets.
494    async fn broadcast(
495        &self,
496        packet: Packet,
497        opts: BroadcastOptions,
498    ) -> Result<(), BroadcastError> {
499        if !opts.is_local(self.uid) {
500            let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
501            self.send_req(req, opts.server_id)
502                .await
503                .map_err(AdapterError::from)?;
504        }
505
506        self.local.broadcast(packet, opts)?;
507        Ok(())
508    }
509
510    /// Broadcast a packet to all the servers to send them through their sockets.
511    ///
512    /// Returns a Stream that is a combination of the local ack stream and a remote [`MessageStream`].
513    /// Here is a specific protocol in order to know how many message the server expect to close
514    /// the stream at the right time:
515    /// * Get the number `n` of remote servers.
516    /// * Send the broadcast request.
517    /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses.
518    /// * Expect `sum(m)` broadcast counts sent by the servers.
519    ///
520    /// Example with 3 remote servers (n = 3):
521    /// ```text
522    /// +---+                   +---+                   +---+
523    /// | A |                   | B |                   | C |
524    /// +---+                   +---+                   +---+
525    ///   |                       |                       |
526    ///   |---BroadcastWithAck--->|                       |
527    ///   |---BroadcastWithAck--------------------------->|
528    ///   |                       |                       |
529    ///   |<-BroadcastAckCount(2)-|     (n = 2; m = 2)    |
530    ///   |<-BroadcastAckCount(2)-------(n = 2; m = 4)----|
531    ///   |                       |                       |
532    ///   |<----------------Ack---------------------------|
533    ///   |<----------------Ack---|                       |
534    ///   |                       |                       |
535    ///   |<----------------Ack---------------------------|
536    ///   |<----------------Ack---|                       |
537    async fn broadcast_with_ack(
538        &self,
539        packet: Packet,
540        opts: BroadcastOptions,
541        timeout: Option<Duration>,
542    ) -> Result<Self::AckStream, Self::Error> {
543        if opts.is_local(self.uid) {
544            tracing::debug!(?opts, "broadcast with ack is local");
545            let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
546            let stream = AckStream::new_local(local);
547            return Ok(stream);
548        }
549        let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
550        let req_id = req.id;
551
552        let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
553
554        let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
555        self.responses.lock().unwrap().insert(req_id, tx);
556        let remote = MessageStream::new(rx);
557
558        self.send_req(req, opts.server_id).await?;
559        let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
560
561        Ok(AckStream::new(
562            local,
563            remote,
564            self.config.request_timeout,
565            remote_serv_cnt,
566            req_id,
567            self.responses.clone(),
568        ))
569    }
570
571    async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
572        if !opts.is_local(self.uid) {
573            let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
574            self.send_req(req, opts.server_id)
575                .await
576                .map_err(AdapterError::from)?;
577        }
578        self.local
579            .disconnect_socket(opts)
580            .map_err(BroadcastError::Socket)?;
581
582        Ok(())
583    }
584
585    async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
586        if opts.is_local(self.uid) {
587            return Ok(self.local.rooms(opts).into_iter().collect());
588        }
589        let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
590        let req_id = req.id;
591
592        // First get the remote stream because redis might send
593        // the responses before subscription is done.
594        let stream = self
595            .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
596            .await?;
597        self.send_req(req, opts.server_id).await?;
598        let local = self.local.rooms(opts);
599        let rooms = stream
600            .filter_map(|item| future::ready(item.into_rooms()))
601            .fold(local, async |mut acc, item| {
602                acc.extend(item);
603                acc
604            })
605            .await;
606        Ok(Vec::from_iter(rooms))
607    }
608
609    async fn add_sockets(
610        &self,
611        opts: BroadcastOptions,
612        rooms: impl RoomParam,
613    ) -> Result<(), Self::Error> {
614        let rooms: Vec<Room> = rooms.into_room_iter().collect();
615        if !opts.is_local(self.uid) {
616            let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
617            self.send_req(req, opts.server_id).await?;
618        }
619        self.local.add_sockets(opts, rooms);
620        Ok(())
621    }
622
623    async fn del_sockets(
624        &self,
625        opts: BroadcastOptions,
626        rooms: impl RoomParam,
627    ) -> Result<(), Self::Error> {
628        let rooms: Vec<Room> = rooms.into_room_iter().collect();
629        if !opts.is_local(self.uid) {
630            let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
631            self.send_req(req, opts.server_id).await?;
632        }
633        self.local.del_sockets(opts, rooms);
634        Ok(())
635    }
636
637    async fn fetch_sockets(
638        &self,
639        opts: BroadcastOptions,
640    ) -> Result<Vec<RemoteSocketData>, Self::Error> {
641        if opts.is_local(self.uid) {
642            return Ok(self.local.fetch_sockets(opts));
643        }
644        let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
645        let req_id = req.id;
646        // First get the remote stream because redis might send
647        // the responses before subscription is done.
648        let remote = self
649            .get_res::<RemoteSocketData>(req_id, ResponseTypeId::FetchSockets, opts.server_id)
650            .await?;
651
652        self.send_req(req, opts.server_id).await?;
653        let local = self.local.fetch_sockets(opts);
654        let sockets = remote
655            .filter_map(|item| future::ready(item.into_fetch_sockets()))
656            .fold(local, async |mut acc, item| {
657                acc.extend(item);
658                acc
659            })
660            .await;
661        Ok(sockets)
662    }
663
664    fn get_local(&self) -> &CoreLocalAdapter<E> {
665        &self.local
666    }
667}
668
669/// Error that can happen when initializing the adapter.
670#[derive(thiserror::Error)]
671pub enum InitError<D: Driver> {
672    /// Driver error.
673    #[error("driver error: {0}")]
674    Driver(D::Error),
675    /// Malformed namespace path.
676    #[error("malformed namespace path, it must not contain '#'")]
677    MalformedNamespace,
678}
679impl<D: Driver> fmt::Debug for InitError<D> {
680    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681        match self {
682            Self::Driver(err) => fmt::Debug::fmt(err, f),
683            Self::MalformedNamespace => write!(f, "Malformed namespace path"),
684        }
685    }
686}
687/// The result of the init future.
688#[must_use = "futures do nothing unless you `.await` or poll them"]
689pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
690
691impl<D: Driver> Future for InitRes<D> {
692    type Output = Result<(), InitError<D>>;
693
694    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
695        self.0.as_mut().poll(cx)
696    }
697}
698impl<D: Driver> Spawnable for InitRes<D> {
699    fn spawn(self) {
700        tokio::spawn(async move {
701            if let Err(e) = self.0.await {
702                tracing::error!("error initializing adapter: {e}");
703            }
704        });
705    }
706}
707
708impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
709    /// Build a response channel for a request.
710    ///
711    /// The uid is used to identify the server that sent the request.
712    /// The req_id is used to identify the request.
713    fn get_res_chan(&self, uid: Uid) -> String {
714        let path = self.local.path();
715        let prefix = &self.config.prefix;
716        format!("{}-response#{}#{}#", prefix, path, uid)
717    }
718    /// Build a request channel for a request.
719    ///
720    /// If we know the target server id, we can build a channel specific to this server.
721    /// Otherwise, we use the default request channel that will broadcast the request to all the servers.
722    fn get_req_chan(&self, node_id: Option<Uid>) -> String {
723        match node_id {
724            Some(uid) => format!("{}{}#", self.req_chan, uid),
725            None => self.req_chan.clone(),
726        }
727    }
728
729    async fn pipe_stream(
730        self: Arc<Self>,
731        mut stream: impl Stream<Item = ChanItem> + Unpin,
732        response_chan: String,
733    ) {
734        while let Some((chan, item)) = stream.next().await {
735            if chan.starts_with(&self.req_chan) {
736                if let Err(e) = self.recv_req(item) {
737                    let ns = self.local.path();
738                    let uid = self.uid;
739                    tracing::warn!(?uid, ?ns, "request handler error: {e}");
740                }
741            } else if chan == response_chan {
742                let req_id = read_req_id(&item);
743                tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
744                let handlers = self.responses.lock().unwrap();
745                if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
746                    if let Err(e) = tx.try_send(item) {
747                        tracing::warn!("error sending response to handler: {e}");
748                    }
749                } else {
750                    tracing::warn!(?req_id, "could not find req handler");
751                }
752            } else {
753                tracing::warn!("unexpected message/channel: {chan}");
754            }
755        }
756    }
757
758    /// Handle a generic request received from the request channel.
759    fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
760        let req: RequestIn = rmp_serde::from_slice(&item)?;
761        if req.node_id == self.uid {
762            return Ok(());
763        }
764
765        tracing::trace!(?req, "handling request");
766        let Some(opts) = req.opts else {
767            tracing::warn!(?req, "request is missing options");
768            return Ok(());
769        };
770
771        match req.r#type {
772            RequestTypeIn::Broadcast(p) => self.recv_broadcast(opts, p),
773            RequestTypeIn::BroadcastWithAck(p) => {
774                self.clone()
775                    .recv_broadcast_with_ack(req.node_id, req.id, p, opts)
776            }
777            RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(opts),
778            RequestTypeIn::AllRooms => self.recv_rooms(req.node_id, req.id, opts),
779            RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(opts, rooms),
780            RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(opts, rooms),
781            RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req.node_id, req.id, opts),
782            _ => (),
783        };
784        Ok(())
785    }
786
787    fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
788        if let Err(e) = self.local.broadcast(packet, opts) {
789            let ns = self.local.path();
790            tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
791        }
792    }
793
794    fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
795        if let Err(e) = self.local.disconnect_socket(opts) {
796            let ns = self.local.path();
797            tracing::warn!(
798                ?self.uid,
799                ?ns,
800                "remote request disconnect sockets handler: {:?}",
801                e
802            );
803        }
804    }
805
806    fn recv_broadcast_with_ack(
807        self: Arc<Self>,
808        origin: Uid,
809        req_id: Sid,
810        packet: Packet,
811        opts: BroadcastOptions,
812    ) {
813        let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
814        tokio::spawn(async move {
815            let on_err = |err| {
816                let ns = self.local.path();
817                tracing::warn!(
818                    ?origin,
819                    ?ns,
820                    "remote request broadcast with ack handler errors: {:?}",
821                    err
822                );
823            };
824            // First send the count of expected acks to the server that sent the request.
825            // This is used to keep track of the number of expected acks.
826            let res = Response {
827                r#type: ResponseType::<()>::BroadcastAckCount(count),
828                node_id: self.uid,
829            };
830            if let Err(err) = self.send_res(origin, req_id, res).await {
831                on_err(err);
832                return;
833            }
834
835            // Then send the acks as they are received.
836            futures_util::pin_mut!(stream);
837            while let Some(ack) = stream.next().await {
838                let res = Response {
839                    r#type: ResponseType::BroadcastAck(ack),
840                    node_id: self.uid,
841                };
842                if let Err(err) = self.send_res(origin, req_id, res).await {
843                    on_err(err);
844                    return;
845                }
846            }
847        });
848    }
849
850    fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
851        let rooms = self.local.rooms(opts);
852        let res = Response {
853            r#type: ResponseType::<()>::AllRooms(rooms),
854            node_id: self.uid,
855        };
856        let fut = self.send_res(origin, req_id, res);
857        let ns = self.local.path().clone();
858        let uid = self.uid;
859        tokio::spawn(async move {
860            if let Err(err) = fut.await {
861                tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
862            }
863        });
864    }
865
866    fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
867        self.local.add_sockets(opts, rooms);
868    }
869
870    fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
871        self.local.del_sockets(opts, rooms);
872    }
873    fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
874        let sockets = self.local.fetch_sockets(opts);
875        let res = Response {
876            node_id: self.uid,
877            r#type: ResponseType::FetchSockets(sockets),
878        };
879        let fut = self.send_res(origin, req_id, res);
880        let ns = self.local.path().clone();
881        let uid = self.uid;
882        tokio::spawn(async move {
883            if let Err(err) = fut.await {
884                tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
885            }
886        });
887    }
888
889    async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
890        tracing::trace!(?req, "sending request");
891        let req = rmp_serde::to_vec(&req)?;
892        let chan = self.get_req_chan(target_uid);
893        self.driver
894            .publish(chan, req)
895            .await
896            .map_err(Error::from_driver)?;
897
898        Ok(())
899    }
900
901    fn send_res<D: Serialize + fmt::Debug>(
902        &self,
903        req_node_id: Uid,
904        req_id: Sid,
905        res: Response<D>,
906    ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
907        let chan = self.get_res_chan(req_node_id);
908        tracing::trace!(?res, "sending response to {}", &chan);
909        // We send the req_id separated from the response object.
910        // This allows to partially decode the response and route by the req_id
911        // before fully deserializing it.
912        let res = rmp_serde::to_vec(&(req_id, res));
913        let driver = self.driver.clone();
914        async move {
915            driver
916                .publish(chan, res?)
917                .await
918                .map_err(Error::from_driver)?;
919            Ok(())
920        }
921    }
922
923    /// Await for all the responses from the remote servers.
924    async fn get_res<D: DeserializeOwned + fmt::Debug>(
925        &self,
926        req_id: Sid,
927        response_type: ResponseTypeId,
928        target_uid: Option<Uid>,
929    ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
930        // Check for specific target node
931        let remote_serv_cnt = if target_uid.is_none() {
932            self.server_count().await?.saturating_sub(1) as usize
933        } else {
934            1
935        };
936        let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
937        self.responses.lock().unwrap().insert(req_id, tx);
938        let stream = MessageStream::new(rx)
939            .filter_map(|item| {
940                let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
941                    Ok((_, data)) => Some(data),
942                    Err(e) => {
943                        tracing::warn!("error decoding response: {e}");
944                        None
945                    }
946                };
947                future::ready(data)
948            })
949            .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
950            .take(remote_serv_cnt)
951            .take_until(time::sleep(self.config.request_timeout));
952        let stream = DropStream::new(stream, self.responses.clone(), req_id);
953        Ok(stream)
954    }
955
956    /// Little wrapper to map the error type.
957    #[inline]
958    async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
959        tracing::trace!(?pat, "subscribing to");
960        self.driver
961            .subscribe(pat, self.config.stream_buffer)
962            .await
963            .map_err(InitError::Driver)
964    }
965}
966
967/// Checks if the namespace path is valid
968/// Panics if the path is empty or contains a `#`
969fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
970    if path.is_empty() || path.contains('#') {
971        Err(InitError::MalformedNamespace)
972    } else {
973        Ok(())
974    }
975}
976
977/// Extract the request id from a data encoded as `[Sid, ...]`
978pub fn read_req_id(data: &[u8]) -> Option<Sid> {
979    use std::str::FromStr;
980    let mut rd = data;
981    let len = rmp::decode::read_array_len(&mut rd).ok()?;
982    if len < 1 {
983        return None;
984    }
985
986    let mut buff = [0u8; Sid::ZERO.as_str().len()];
987    let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?;
988    Sid::from_str(str).ok()
989}
990
991#[cfg(test)]
992mod tests {
993    use super::*;
994    use futures_util::stream::{self, FusedStream, StreamExt};
995    use socketioxide_core::{Str, Value, adapter::AckStreamItem};
996    use std::convert::Infallible;
997
998    #[derive(Clone)]
999    struct StubDriver;
1000    impl Driver for StubDriver {
1001        type Error = Infallible;
1002
1003        async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
1004            Ok(())
1005        }
1006
1007        async fn subscribe(
1008            &self,
1009            _: String,
1010            _: usize,
1011        ) -> Result<MessageStream<ChanItem>, Self::Error> {
1012            Ok(MessageStream::new_empty())
1013        }
1014
1015        async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
1016            Ok(())
1017        }
1018
1019        async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
1020            Ok(0)
1021        }
1022    }
1023    fn new_stub_ack_stream(
1024        remote: MessageStream<Vec<u8>>,
1025        timeout: Duration,
1026    ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1027        AckStream::new(
1028            stream::empty::<AckStreamItem<()>>(),
1029            remote,
1030            timeout,
1031            2,
1032            Sid::new(),
1033            Arc::new(Mutex::new(HashMap::new())),
1034        )
1035    }
1036
1037    //TODO: test weird behaviours, packets out of orders, etc
1038    #[tokio::test]
1039    async fn ack_stream() {
1040        let (tx, rx) = tokio::sync::mpsc::channel(255);
1041        let remote = MessageStream::new(rx);
1042        let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1043        let node_id = Uid::new();
1044        let req_id = Sid::new();
1045
1046        // The two servers will send 2 acks each.
1047        let ack_cnt_res = Response::<()> {
1048            node_id,
1049            r#type: ResponseType::BroadcastAckCount(2),
1050        };
1051        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1052            .unwrap();
1053        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1054            .unwrap();
1055
1056        let ack_res = Response::<String> {
1057            node_id,
1058            r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1059        };
1060        for _ in 0..4 {
1061            tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1062                .unwrap();
1063        }
1064        futures_util::pin_mut!(stream);
1065        for _ in 0..4 {
1066            assert!(stream.next().await.is_some());
1067        }
1068        assert!(stream.is_terminated());
1069    }
1070
1071    #[tokio::test]
1072    async fn ack_stream_timeout() {
1073        let (tx, rx) = tokio::sync::mpsc::channel(255);
1074        let remote = MessageStream::new(rx);
1075        let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1076        let node_id = Uid::new();
1077        let req_id = Sid::new();
1078        // There will be only one ack count and then the stream will timeout.
1079        let ack_cnt_res = Response::<()> {
1080            node_id,
1081            r#type: ResponseType::BroadcastAckCount(2),
1082        };
1083        tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1084            .unwrap();
1085
1086        futures_util::pin_mut!(stream);
1087        tokio::time::sleep(Duration::from_millis(50)).await;
1088        assert!(stream.next().await.is_none());
1089        assert!(stream.is_terminated());
1090    }
1091
1092    #[tokio::test]
1093    async fn ack_stream_drop() {
1094        let (tx, rx) = tokio::sync::mpsc::channel(255);
1095        let remote = MessageStream::new(rx);
1096        let handlers = Arc::new(Mutex::new(HashMap::new()));
1097        let id = Sid::new();
1098        handlers.lock().unwrap().insert(id, tx);
1099        let stream = AckStream::new(
1100            stream::empty::<AckStreamItem<()>>(),
1101            remote,
1102            Duration::from_secs(10),
1103            2,
1104            id,
1105            handlers.clone(),
1106        );
1107        drop(stream);
1108        assert!(handlers.lock().unwrap().is_empty(),);
1109    }
1110
1111    #[test]
1112    fn check_ns_error() {
1113        assert!(matches!(
1114            check_ns::<StubDriver>("#"),
1115            Err(InitError::MalformedNamespace)
1116        ));
1117        assert!(matches!(
1118            check_ns::<StubDriver>(""),
1119            Err(InitError::MalformedNamespace)
1120        ));
1121    }
1122}