Skip to main content

socketioxide_redis/
lib.rs

1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4//! # A redis/valkey adapter implementation for the socketioxide crate.
5//! The adapter is used to communicate with other nodes of the same application.
6//! This allows to broadcast messages to sockets connected on other servers,
7//! to get the list of rooms, to add or remove sockets from rooms, etc.
8//!
9//! To achieve this, the adapter uses a [pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/) system
10//! through Redis to communicate with other servers.
11//!
12//! The [`Driver`] abstraction allows the use of any pub/sub client.
13//! Three implementations are provided:
14//! * [`RedisDriver`](crate::drivers::redis::RedisDriver) for the [`redis`] crate with a standalone redis.
15//! * [`ClusterDriver`](crate::drivers::redis::ClusterDriver) for the [`redis`] crate with a redis cluster.
16//! * [`FredDriver`](crate::drivers::fred::FredDriver) for the [`fred`] crate with a standalone/cluster redis.
17//!
18//! When using redis clusters, the drivers employ [sharded pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/#sharded-pubsub)
19//! to distribute the load across Redis nodes.
20//!
21//! You can also implement your own driver by implementing the [`Driver`] trait.
22//!
23//! <div class="warning">
24//!     The provided driver implementations are using <code>RESP3</code> for efficiency purposes.
25//!     Make sure your redis server supports it (redis v7 and above).
26//!     If not, you can implement your own driver using the <code>RESP2</code> protocol.
27//! </div>
28//!
29//! <div class="warning">
30//!     Socketioxide-Redis is not compatible with <code>@socketio/redis-adapter</code>
31//!     and <code>@socketio/redis-emitter</code>. They use completely different protocols and
32//!     cannot be used together. Do not mix socket.io JS servers with socketioxide rust servers.
33//! </div>
34//!
35//! ## Example with the [`redis`] driver
36//! ```rust
37//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
38//! # use socketioxide_redis::{RedisAdapterCtr, RedisAdapter};
39//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
40//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
41//!     socket.join("room1");
42//!     socket.on("event", on_event);
43//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
44//! }
45//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
46//!
47//! let client = redis::Client::open("redis://127.0.0.1:6379?protocol=RESP3")?;
48//! let adapter = RedisAdapterCtr::new_with_redis(&client).await?;
49//! let (layer, io) = SocketIo::builder()
50//!     .with_adapter::<RedisAdapter<_>>(adapter)
51//!     .build_layer();
52//! Ok(())
53//! # }
54//! ```
55//!
56//!
57//! ## Example with the [`fred`] driver
58//! ```rust
59//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
60//! # use socketioxide_redis::{RedisAdapterCtr, FredAdapter};
61//! # use fred::types::RespVersion;
62//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
63//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
64//!     socket.join("room1");
65//!     socket.on("event", on_event);
66//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
67//! }
68//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
69//!
70//! let mut config = fred::prelude::Config::from_url("redis://127.0.0.1:6379?protocol=resp3")?;
71//! // We need to manually set the RESP3 version because
72//! // the fred crate does not parse the protocol query parameter.
73//! config.version = RespVersion::RESP3;
74//! let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?;
75//! let adapter = RedisAdapterCtr::new_with_fred(client).await?;
76//! let (layer, io) = SocketIo::builder()
77//!     .with_adapter::<FredAdapter<_>>(adapter)
78//!     .build_layer();
79//! Ok(())
80//! # }
81//! ```
82//!
83//!
84//! ## Example with the [`redis`] cluster driver
85//! ```rust
86//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
87//! # use socketioxide_redis::{RedisAdapterCtr, ClusterAdapter};
88//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
89//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
90//!     socket.join("room1");
91//!     socket.on("event", on_event);
92//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
93//! }
94//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
95//!
96//! // single node cluster
97//! let client = redis::cluster::ClusterClient::new(["redis://127.0.0.1:6379?protocol=resp3"])?;
98//! let adapter = RedisAdapterCtr::new_with_cluster(&client).await?;
99//!
100//! let (layer, io) = SocketIo::builder()
101//!     .with_adapter::<ClusterAdapter<_>>(adapter)
102//!     .build_layer();
103//! Ok(())
104//! # }
105//! ```
106//!
107//! Check the [`chat example`](https://github.com/Totodore/socketioxide/tree/main/examples/chat)
108//! for more complete examples.
109//!
110//! ## How does it work?
111//!
112//! An adapter is created for each created namespace and it takes a corresponding [`CoreLocalAdapter`].
113//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter`
114//! is simply a wrapper around this [`CoreLocalAdapter`].
115//!
116//! The adapter is then initialized with the [`RedisAdapter::init`] method.
117//! This will subscribe to 3 channels:
118//! * `"{prefix}-request#{namespace}#"`: A global channel to receive broadcasted requests.
119//! * `"{prefix}-request#{namespace}#{uid}#"`: A specific channel to receive requests only for this server.
120//! * `"{prefix}-response#{namespace}#{uid}#"`: A specific channel to receive responses only for this server.
121//!   Messages sent to this channel will be always in the form `[req_id, data]`. This will allow the adapter to extract the request id
122//!   and route the response to the appropriate stream before deserializing the data.
123//!
124//! All messages are encoded with msgpack.
125//!
126//! There are 7 types of requests:
127//! * Broadcast a packet to all the matching sockets.
128//! * Broadcast a packet to all the matching sockets and wait for a stream of acks.
129//! * Disconnect matching sockets.
130//! * Get all the rooms.
131//! * Add matching sockets to rooms.
132//! * Remove matching sockets to rooms.
133//! * Fetch all the remote sockets matching the options.
134//!
135//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request,
136//! and then send the acks as they are received (more details in [`RedisAdapter::broadcast_with_ack`] fn).
137//!
138//! On the other side, each time an action has to be performed on the local server, the adapter will
139//! first broadcast a request to all the servers and then perform the action locally.
140
141use std::{
142    borrow::Cow,
143    collections::HashMap,
144    fmt,
145    future::{self, Future},
146    pin::Pin,
147    sync::{Arc, Mutex},
148    task::{Context, Poll},
149    time::Duration,
150};
151
152use drivers::{ChanItem, Driver, MessageStream};
153use futures_core::Stream;
154use futures_util::StreamExt;
155use serde::{Serialize, de::DeserializeOwned};
156use socketioxide_core::adapter::remote_packet::{
157    RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
158};
159use socketioxide_core::{
160    Sid, Uid,
161    adapter::errors::{AdapterError, BroadcastError},
162    adapter::{
163        BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
164        RoomParam, SocketEmitter, Spawnable,
165    },
166    packet::Packet,
167};
168use stream::{AckStream, DropStream};
169use tokio::{sync::mpsc, time};
170
171/// Drivers are an abstraction over the pub/sub backend used by the adapter.
172/// You can use the provided implementation or implement your own.
173pub mod drivers;
174
175mod stream;
176
177/// Represent any error that might happen when using this adapter.
178#[derive(thiserror::Error)]
179pub enum Error<R: Driver> {
180    /// Redis driver error
181    #[error("driver error: {0}")]
182    Driver(R::Error),
183    /// Packet encoding error
184    #[error("packet encoding error: {0}")]
185    Decode(#[from] rmp_serde::decode::Error),
186    /// Packet decoding error
187    #[error("packet decoding error: {0}")]
188    Encode(#[from] rmp_serde::encode::Error),
189}
190
191impl<R: Driver> Error<R> {
192    fn from_driver(err: R::Error) -> Self {
193        Self::Driver(err)
194    }
195}
196impl<R: Driver> fmt::Debug for Error<R> {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        match self {
199            Self::Driver(err) => write!(f, "Driver error: {err:?}"),
200            Self::Decode(err) => write!(f, "Decode error: {err:?}"),
201            Self::Encode(err) => write!(f, "Encode error: {err:?}"),
202        }
203    }
204}
205
206impl<R: Driver> From<Error<R>> for AdapterError {
207    fn from(err: Error<R>) -> Self {
208        AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
209    }
210}
211
212/// The configuration of the [`RedisAdapter`].
213#[derive(Debug, Clone)]
214pub struct RedisAdapterConfig {
215    /// The request timeout. It is mainly used when expecting response such as when using
216    /// `broadcast_with_ack` or `rooms`. Default is 5 seconds.
217    pub request_timeout: Duration,
218
219    /// The prefix used for the channels. Default is "socket.io".
220    pub prefix: Cow<'static, str>,
221
222    /// The channel size used to receive ack responses. Default is 255.
223    ///
224    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
225    /// than you poll them with the returned stream, you might want to increase this value.
226    pub ack_response_buffer: usize,
227
228    /// The channel size used to receive messages. Default is 1024.
229    ///
230    /// If your server is under heavy load, you might want to increase this value.
231    pub stream_buffer: usize,
232}
233impl RedisAdapterConfig {
234    /// Create a new config.
235    pub fn new() -> Self {
236        Self::default()
237    }
238    /// Set the request timeout. Default is 5 seconds.
239    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
240        self.request_timeout = timeout;
241        self
242    }
243
244    /// Set the prefix used for the channels. Default is "socket.io".
245    pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
246        self.prefix = prefix.into();
247        self
248    }
249
250    /// Set the channel size used to send ack responses. Default is 255.
251    ///
252    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
253    /// than you poll them with the returned stream, you might want to increase this value.
254    pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
255        assert!(buffer > 0, "buffer size must be greater than 0");
256        self.ack_response_buffer = buffer;
257        self
258    }
259
260    /// Set the channel size used to receive messages. Default is 1024.
261    ///
262    /// If your server is under heavy load, you might want to increase this value.
263    pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
264        assert!(buffer > 0, "buffer size must be greater than 0");
265        self.stream_buffer = buffer;
266        self
267    }
268}
269
270impl Default for RedisAdapterConfig {
271    fn default() -> Self {
272        Self {
273            request_timeout: Duration::from_secs(5),
274            prefix: Cow::Borrowed("socket.io"),
275            ack_response_buffer: 255,
276            stream_buffer: 1024,
277        }
278    }
279}
280
281/// The adapter constructor. For each namespace you define, a new adapter instance is created
282/// from this constructor.
283#[derive(Debug)]
284pub struct RedisAdapterCtr<R> {
285    driver: R,
286    config: RedisAdapterConfig,
287}
288
289#[cfg(feature = "redis")]
290impl RedisAdapterCtr<drivers::redis::RedisDriver> {
291    /// Create a new adapter constructor with the [`redis`] driver and a default config.
292    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
293    pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
294        Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
295    }
296    /// Create a new adapter constructor with the [`redis`] driver and a custom config.
297    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
298    pub async fn new_with_redis_config(
299        client: &redis::Client,
300        config: RedisAdapterConfig,
301    ) -> redis::RedisResult<Self> {
302        let driver = drivers::redis::RedisDriver::new(client).await?;
303        Ok(Self::new_with_driver(driver, config))
304    }
305}
306#[cfg(feature = "redis-cluster")]
307impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
308    /// Create a new adapter constructor with the [`redis`] driver and a default config.
309    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
310    pub async fn new_with_cluster(
311        client: &redis::cluster::ClusterClient,
312    ) -> redis::RedisResult<Self> {
313        Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
314    }
315
316    /// Create a new adapter constructor with the [`redis`] driver and a default config.
317    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
318    pub async fn new_with_cluster_config(
319        client: &redis::cluster::ClusterClient,
320        config: RedisAdapterConfig,
321    ) -> redis::RedisResult<Self> {
322        let driver = drivers::redis::ClusterDriver::new(client).await?;
323        Ok(Self::new_with_driver(driver, config))
324    }
325}
326#[cfg(feature = "fred")]
327impl RedisAdapterCtr<drivers::fred::FredDriver> {
328    /// Create a new adapter constructor with the default [`fred`] driver and a default config.
329    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
330    pub async fn new_with_fred(
331        client: fred::clients::SubscriberClient,
332    ) -> fred::prelude::FredResult<Self> {
333        Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
334    }
335    /// Create a new adapter constructor with the default [`fred`] driver and a custom config.
336    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
337    pub async fn new_with_fred_config(
338        client: fred::clients::SubscriberClient,
339        config: RedisAdapterConfig,
340    ) -> fred::prelude::FredResult<Self> {
341        let driver = drivers::fred::FredDriver::new(client).await?;
342        Ok(Self::new_with_driver(driver, config))
343    }
344}
345impl<R: Driver> RedisAdapterCtr<R> {
346    /// Create a new adapter constructor with a custom redis/valkey driver and a config.
347    ///
348    /// You can implement your own driver by implementing the [`Driver`] trait with any redis/valkey client.
349    /// Check the [`drivers`] module for more information.
350    pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
351        RedisAdapterCtr { driver, config }
352    }
353}
354
355pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
356
357/// The redis adapter with the fred driver.
358#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
359#[cfg(feature = "fred")]
360pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
361
362/// The redis adapter with the redis driver.
363#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
364#[cfg(feature = "redis")]
365pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
366
367/// The redis adapter with the redis cluster driver.
368#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
369#[cfg(feature = "redis-cluster")]
370pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
371
372/// The redis adapter implementation.
373/// It is generic over the [`Driver`] used to communicate with the redis server.
374/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to
375/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates.
376pub struct CustomRedisAdapter<E, R> {
377    /// The driver used by the adapter. This is used to communicate with the redis server.
378    /// All the redis adapter instances share the same driver.
379    driver: R,
380    /// The configuration of the adapter.
381    config: RedisAdapterConfig,
382    /// A unique identifier for the adapter to identify itself in the redis server.
383    uid: Uid,
384    /// The local adapter, used to manage local rooms and socket stores.
385    local: CoreLocalAdapter<E>,
386    /// The request channel used to broadcast requests to all the servers.
387    /// format: `{prefix}-request#{path}#`.
388    req_chan: String,
389    /// A map of response handlers used to await for responses from the remote servers.
390    responses: Arc<Mutex<ResponseHandlers>>,
391}
392
393impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
394impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
395    type Error = Error<R>;
396    type State = RedisAdapterCtr<R>;
397    type AckStream = AckStream<E::AckStream>;
398    type InitRes = InitRes<R>;
399
400    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
401        let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
402        let uid = local.server_id();
403        Self {
404            local,
405            req_chan,
406            uid,
407            driver: state.driver.clone(),
408            config: state.config.clone(),
409            responses: Arc::new(Mutex::new(HashMap::new())),
410        }
411    }
412
413    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
414        let fut = async move {
415            check_ns(self.local.path())?;
416            let global_stream = self.subscribe(self.req_chan.clone()).await?;
417            let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
418            let response_chan = format!(
419                "{}-response#{}#{}#",
420                &self.config.prefix,
421                self.local.path(),
422                self.uid
423            );
424
425            let response_stream = self.subscribe(response_chan.clone()).await?;
426            let stream = futures_util::stream::select(global_stream, specific_stream);
427            let stream = futures_util::stream::select(stream, response_stream);
428            tokio::spawn(self.pipe_stream(stream, response_chan));
429            on_success();
430            Ok(())
431        };
432        InitRes(Box::pin(fut))
433    }
434
435    async fn close(&self) -> Result<(), Self::Error> {
436        let response_chan = format!(
437            "{}-response#{}#{}#",
438            &self.config.prefix,
439            self.local.path(),
440            self.uid
441        );
442        tokio::try_join!(
443            self.driver.unsubscribe(self.req_chan.clone()),
444            self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
445            self.driver.unsubscribe(response_chan)
446        )
447        .map_err(Error::from_driver)?;
448
449        Ok(())
450    }
451
452    /// Get the number of servers by getting the number of subscribers to the request channel.
453    async fn server_count(&self) -> Result<u16, Self::Error> {
454        let count = self
455            .driver
456            .num_serv(&self.req_chan)
457            .await
458            .map_err(Error::from_driver)?;
459
460        Ok(count)
461    }
462
463    /// Broadcast a packet to all the servers to send them through their sockets.
464    async fn broadcast(
465        &self,
466        packet: Packet,
467        opts: BroadcastOptions,
468    ) -> Result<(), BroadcastError> {
469        if !opts.is_local(self.uid) {
470            let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
471            self.send_req(req, opts.server_id)
472                .await
473                .map_err(AdapterError::from)?;
474        }
475
476        self.local.broadcast(packet, opts)?;
477        Ok(())
478    }
479
480    /// Broadcast a packet to all the servers to send them through their sockets.
481    ///
482    /// Returns a Stream that is a combination of the local ack stream and a remote [`MessageStream`].
483    /// Here is a specific protocol in order to know how many message the server expect to close
484    /// the stream at the right time:
485    /// * Get the number `n` of remote servers.
486    /// * Send the broadcast request.
487    /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses.
488    /// * Expect `sum(m)` broadcast counts sent by the servers.
489    ///
490    /// Example with 3 remote servers (n = 3):
491    /// ```text
492    /// +---+                   +---+                   +---+
493    /// | A |                   | B |                   | C |
494    /// +---+                   +---+                   +---+
495    ///   |                       |                       |
496    ///   |---BroadcastWithAck--->|                       |
497    ///   |---BroadcastWithAck--------------------------->|
498    ///   |                       |                       |
499    ///   |<-BroadcastAckCount(2)-|     (n = 2; m = 2)    |
500    ///   |<-BroadcastAckCount(2)-------(n = 2; m = 4)----|
501    ///   |                       |                       |
502    ///   |<----------------Ack---------------------------|
503    ///   |<----------------Ack---|                       |
504    ///   |                       |                       |
505    ///   |<----------------Ack---------------------------|
506    ///   |<----------------Ack---|                       |
507    async fn broadcast_with_ack(
508        &self,
509        packet: Packet,
510        opts: BroadcastOptions,
511        timeout: Option<Duration>,
512    ) -> Result<Self::AckStream, Self::Error> {
513        if opts.is_local(self.uid) {
514            tracing::debug!(?opts, "broadcast with ack is local");
515            let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
516            let stream = AckStream::new_local(local);
517            return Ok(stream);
518        }
519        let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
520        let req_id = req.id;
521
522        let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
523
524        let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
525        self.responses.lock().unwrap().insert(req_id, tx);
526        let remote = MessageStream::new(rx);
527
528        self.send_req(req, opts.server_id).await?;
529        let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
530
531        // we wait for the configured ack timeout + the adapter request timeout
532        let timeout = self
533            .config
534            .request_timeout
535            .saturating_add(timeout.unwrap_or(self.local.ack_timeout()));
536
537        Ok(AckStream::new(
538            local,
539            remote,
540            timeout,
541            remote_serv_cnt,
542            req_id,
543            self.responses.clone(),
544        ))
545    }
546
547    async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
548        if !opts.is_local(self.uid) {
549            let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
550            self.send_req(req, opts.server_id)
551                .await
552                .map_err(AdapterError::from)?;
553        }
554        self.local
555            .disconnect_socket(opts)
556            .map_err(BroadcastError::Socket)?;
557
558        Ok(())
559    }
560
561    async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
562        if opts.is_local(self.uid) {
563            return Ok(self.local.rooms(opts).into_iter().collect());
564        }
565        let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
566        let req_id = req.id;
567
568        // First get the remote stream because redis might send
569        // the responses before subscription is done.
570        let stream = self
571            .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
572            .await?;
573        self.send_req(req, opts.server_id).await?;
574        let local = self.local.rooms(opts);
575        let rooms = stream
576            .filter_map(|item| future::ready(item.into_rooms()))
577            .fold(local, async |mut acc, item| {
578                acc.extend(item);
579                acc
580            })
581            .await;
582        Ok(Vec::from_iter(rooms))
583    }
584
585    async fn add_sockets(
586        &self,
587        opts: BroadcastOptions,
588        rooms: impl RoomParam,
589    ) -> Result<(), Self::Error> {
590        let rooms: Vec<Room> = rooms.into_room_iter().collect();
591        if !opts.is_local(self.uid) {
592            let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
593            self.send_req(req, opts.server_id).await?;
594        }
595        self.local.add_sockets(opts, rooms);
596        Ok(())
597    }
598
599    async fn del_sockets(
600        &self,
601        opts: BroadcastOptions,
602        rooms: impl RoomParam,
603    ) -> Result<(), Self::Error> {
604        let rooms: Vec<Room> = rooms.into_room_iter().collect();
605        if !opts.is_local(self.uid) {
606            let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
607            self.send_req(req, opts.server_id).await?;
608        }
609        self.local.del_sockets(opts, rooms);
610        Ok(())
611    }
612
613    async fn fetch_sockets(
614        &self,
615        opts: BroadcastOptions,
616    ) -> Result<Vec<RemoteSocketData>, Self::Error> {
617        if opts.is_local(self.uid) {
618            return Ok(self.local.fetch_sockets(opts));
619        }
620        let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
621        let req_id = req.id;
622        // First get the remote stream because redis might send
623        // the responses before subscription is done.
624        let remote = self
625            .get_res::<RemoteSocketData>(req_id, ResponseTypeId::FetchSockets, opts.server_id)
626            .await?;
627
628        self.send_req(req, opts.server_id).await?;
629        let local = self.local.fetch_sockets(opts);
630        let sockets = remote
631            .filter_map(|item| future::ready(item.into_fetch_sockets()))
632            .fold(local, async |mut acc, item| {
633                acc.extend(item);
634                acc
635            })
636            .await;
637        Ok(sockets)
638    }
639
640    fn get_local(&self) -> &CoreLocalAdapter<E> {
641        &self.local
642    }
643}
644
645/// Error that can happen when initializing the adapter.
646#[derive(thiserror::Error)]
647pub enum InitError<D: Driver> {
648    /// Driver error.
649    #[error("driver error: {0}")]
650    Driver(D::Error),
651    /// Malformed namespace path.
652    #[error("malformed namespace path, it must not contain '#'")]
653    MalformedNamespace,
654}
655impl<D: Driver> fmt::Debug for InitError<D> {
656    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
657        match self {
658            Self::Driver(err) => fmt::Debug::fmt(err, f),
659            Self::MalformedNamespace => write!(f, "Malformed namespace path"),
660        }
661    }
662}
663/// The result of the init future.
664#[must_use = "futures do nothing unless you `.await` or poll them"]
665pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
666
667impl<D: Driver> Future for InitRes<D> {
668    type Output = Result<(), InitError<D>>;
669
670    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
671        self.0.as_mut().poll(cx)
672    }
673}
674impl<D: Driver> Spawnable for InitRes<D> {
675    fn spawn(self) {
676        tokio::spawn(async move {
677            if let Err(e) = self.0.await {
678                tracing::error!("error initializing adapter: {e}");
679            }
680        });
681    }
682}
683
684impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
685    /// Build a response channel for a request.
686    ///
687    /// The uid is used to identify the server that sent the request.
688    /// The req_id is used to identify the request.
689    fn get_res_chan(&self, uid: Uid) -> String {
690        let path = self.local.path();
691        let prefix = &self.config.prefix;
692        format!("{prefix}-response#{path}#{uid}#")
693    }
694    /// Build a request channel for a request.
695    ///
696    /// If we know the target server id, we can build a channel specific to this server.
697    /// Otherwise, we use the default request channel that will broadcast the request to all the servers.
698    fn get_req_chan(&self, node_id: Option<Uid>) -> String {
699        match node_id {
700            Some(uid) => format!("{}{}#", self.req_chan, uid),
701            None => self.req_chan.clone(),
702        }
703    }
704
705    async fn pipe_stream(
706        self: Arc<Self>,
707        mut stream: impl Stream<Item = ChanItem> + Unpin,
708        response_chan: String,
709    ) {
710        while let Some((chan, item)) = stream.next().await {
711            if chan.starts_with(&self.req_chan) {
712                if let Err(e) = self.recv_req(item) {
713                    let ns = self.local.path();
714                    let uid = self.uid;
715                    tracing::warn!(?uid, ?ns, "request handler error: {e}");
716                }
717            } else if chan == response_chan {
718                let req_id = read_req_id(&item);
719                tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
720                let handlers = self.responses.lock().unwrap();
721                if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
722                    if let Err(e) = tx.try_send(item) {
723                        tracing::warn!("error sending response to handler: {e}");
724                    }
725                } else {
726                    tracing::warn!(?req_id, "could not find req handler");
727                }
728            } else {
729                tracing::warn!("unexpected message/channel: {chan}");
730            }
731        }
732    }
733
734    /// Handle a generic request received from the request channel.
735    fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
736        let req: RequestIn = rmp_serde::from_slice(&item)?;
737        if req.node_id == self.uid {
738            return Ok(());
739        }
740
741        tracing::trace!(?req, "handling request");
742        let Some(opts) = req.opts else {
743            tracing::warn!(?req, "request is missing options");
744            return Ok(());
745        };
746
747        match req.r#type {
748            RequestTypeIn::Broadcast(p) => self.recv_broadcast(opts, p),
749            RequestTypeIn::BroadcastWithAck(p) => {
750                self.clone()
751                    .recv_broadcast_with_ack(req.node_id, req.id, p, opts)
752            }
753            RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(opts),
754            RequestTypeIn::AllRooms => self.recv_rooms(req.node_id, req.id, opts),
755            RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(opts, rooms),
756            RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(opts, rooms),
757            RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req.node_id, req.id, opts),
758            _ => (),
759        };
760        Ok(())
761    }
762
763    fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
764        if let Err(e) = self.local.broadcast(packet, opts) {
765            let ns = self.local.path();
766            tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
767        }
768    }
769
770    fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
771        if let Err(e) = self.local.disconnect_socket(opts) {
772            let ns = self.local.path();
773            tracing::warn!(
774                ?self.uid,
775                ?ns,
776                "remote request disconnect sockets handler: {:?}",
777                e
778            );
779        }
780    }
781
782    fn recv_broadcast_with_ack(
783        self: Arc<Self>,
784        origin: Uid,
785        req_id: Sid,
786        packet: Packet,
787        opts: BroadcastOptions,
788    ) {
789        let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
790        tokio::spawn(async move {
791            let on_err = |err| {
792                let ns = self.local.path();
793                tracing::warn!(
794                    ?origin,
795                    ?ns,
796                    "remote request broadcast with ack handler errors: {:?}",
797                    err
798                );
799            };
800            // First send the count of expected acks to the server that sent the request.
801            // This is used to keep track of the number of expected acks.
802            let res = Response {
803                r#type: ResponseType::<()>::BroadcastAckCount(count),
804                node_id: self.uid,
805            };
806            if let Err(err) = self.send_res(origin, req_id, res).await {
807                on_err(err);
808                return;
809            }
810
811            // Then send the acks as they are received.
812            futures_util::pin_mut!(stream);
813            while let Some(ack) = stream.next().await {
814                let res = Response {
815                    r#type: ResponseType::BroadcastAck(ack),
816                    node_id: self.uid,
817                };
818                if let Err(err) = self.send_res(origin, req_id, res).await {
819                    on_err(err);
820                    return;
821                }
822            }
823        });
824    }
825
826    fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
827        let rooms = self.local.rooms(opts);
828        let res = Response {
829            r#type: ResponseType::<()>::AllRooms(rooms),
830            node_id: self.uid,
831        };
832        let fut = self.send_res(origin, req_id, res);
833        let ns = self.local.path().clone();
834        let uid = self.uid;
835        tokio::spawn(async move {
836            if let Err(err) = fut.await {
837                tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
838            }
839        });
840    }
841
842    fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
843        self.local.add_sockets(opts, rooms);
844    }
845
846    fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
847        self.local.del_sockets(opts, rooms);
848    }
849    fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
850        let sockets = self.local.fetch_sockets(opts);
851        let res = Response {
852            node_id: self.uid,
853            r#type: ResponseType::FetchSockets(sockets),
854        };
855        let fut = self.send_res(origin, req_id, res);
856        let ns = self.local.path().clone();
857        let uid = self.uid;
858        tokio::spawn(async move {
859            if let Err(err) = fut.await {
860                tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
861            }
862        });
863    }
864
865    async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
866        tracing::trace!(?req, "sending request");
867        let req = rmp_serde::to_vec(&req)?;
868        let chan = self.get_req_chan(target_uid);
869        self.driver
870            .publish(chan, req)
871            .await
872            .map_err(Error::from_driver)?;
873
874        Ok(())
875    }
876
877    fn send_res<D: Serialize + fmt::Debug>(
878        &self,
879        req_node_id: Uid,
880        req_id: Sid,
881        res: Response<D>,
882    ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
883        let chan = self.get_res_chan(req_node_id);
884        tracing::trace!(?res, "sending response to {}", &chan);
885        // We send the req_id separated from the response object.
886        // This allows to partially decode the response and route by the req_id
887        // before fully deserializing it.
888        let res = rmp_serde::to_vec(&(req_id, res));
889        let driver = self.driver.clone();
890        async move {
891            driver
892                .publish(chan, res?)
893                .await
894                .map_err(Error::from_driver)?;
895            Ok(())
896        }
897    }
898
899    /// Await for all the responses from the remote servers.
900    async fn get_res<D: DeserializeOwned + fmt::Debug>(
901        &self,
902        req_id: Sid,
903        response_type: ResponseTypeId,
904        target_uid: Option<Uid>,
905    ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
906        // Check for specific target node
907        let remote_serv_cnt = if target_uid.is_none() {
908            self.server_count().await?.saturating_sub(1) as usize
909        } else {
910            1
911        };
912        let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
913        self.responses.lock().unwrap().insert(req_id, tx);
914        let stream = MessageStream::new(rx)
915            .filter_map(|item| {
916                let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
917                    Ok((_, data)) => Some(data),
918                    Err(e) => {
919                        tracing::warn!("error decoding response: {e}");
920                        None
921                    }
922                };
923                future::ready(data)
924            })
925            .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
926            .take(remote_serv_cnt)
927            .take_until(time::sleep(self.config.request_timeout));
928        let stream = DropStream::new(stream, self.responses.clone(), req_id);
929        Ok(stream)
930    }
931
932    /// Little wrapper to map the error type.
933    #[inline]
934    async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
935        tracing::trace!(?pat, "subscribing to");
936        self.driver
937            .subscribe(pat, self.config.stream_buffer)
938            .await
939            .map_err(InitError::Driver)
940    }
941}
942
943/// Checks if the namespace path is valid
944/// Panics if the path is empty or contains a `#`
945fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
946    if path.is_empty() || path.contains('#') {
947        Err(InitError::MalformedNamespace)
948    } else {
949        Ok(())
950    }
951}
952
953/// Extract the request id from a data encoded as `[Sid, ...]`
954pub fn read_req_id(data: &[u8]) -> Option<Sid> {
955    use std::str::FromStr;
956    let mut rd = data;
957    let len = rmp::decode::read_array_len(&mut rd).ok()?;
958    if len < 1 {
959        return None;
960    }
961
962    let mut buff = [0u8; Sid::ZERO.as_str().len()];
963    let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?;
964    Sid::from_str(str).ok()
965}
966
967#[cfg(test)]
968mod tests {
969    use super::*;
970    use futures_util::stream::{self, FusedStream, StreamExt};
971    use socketioxide_core::{Str, Value, adapter::AckStreamItem};
972    use std::convert::Infallible;
973
974    #[derive(Clone)]
975    struct StubDriver;
976    impl Driver for StubDriver {
977        type Error = Infallible;
978
979        async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
980            Ok(())
981        }
982
983        async fn subscribe(
984            &self,
985            _: String,
986            _: usize,
987        ) -> Result<MessageStream<ChanItem>, Self::Error> {
988            Ok(MessageStream::new_empty())
989        }
990
991        async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
992            Ok(())
993        }
994
995        async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
996            Ok(0)
997        }
998    }
999    fn new_stub_ack_stream(
1000        remote: MessageStream<Vec<u8>>,
1001        timeout: Duration,
1002    ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1003        AckStream::new(
1004            stream::empty::<AckStreamItem<()>>(),
1005            remote,
1006            timeout,
1007            2,
1008            Sid::new(),
1009            Arc::new(Mutex::new(HashMap::new())),
1010        )
1011    }
1012
1013    //TODO: test weird behaviours, packets out of orders, etc
1014    #[tokio::test]
1015    async fn ack_stream() {
1016        let (tx, rx) = tokio::sync::mpsc::channel(255);
1017        let remote = MessageStream::new(rx);
1018        let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1019        let node_id = Uid::new();
1020        let req_id = Sid::new();
1021
1022        // The two servers will send 2 acks each.
1023        let ack_cnt_res = Response::<()> {
1024            node_id,
1025            r#type: ResponseType::BroadcastAckCount(2),
1026        };
1027        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1028            .unwrap();
1029        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1030            .unwrap();
1031
1032        let ack_res = Response::<String> {
1033            node_id,
1034            r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1035        };
1036        for _ in 0..4 {
1037            tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1038                .unwrap();
1039        }
1040        futures_util::pin_mut!(stream);
1041        for _ in 0..4 {
1042            assert!(stream.next().await.is_some());
1043        }
1044        assert!(stream.is_terminated());
1045    }
1046
1047    #[tokio::test]
1048    async fn ack_stream_timeout() {
1049        let (tx, rx) = tokio::sync::mpsc::channel(255);
1050        let remote = MessageStream::new(rx);
1051        let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1052        let node_id = Uid::new();
1053        let req_id = Sid::new();
1054        // There will be only one ack count and then the stream will timeout.
1055        let ack_cnt_res = Response::<()> {
1056            node_id,
1057            r#type: ResponseType::BroadcastAckCount(2),
1058        };
1059        tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1060            .unwrap();
1061
1062        futures_util::pin_mut!(stream);
1063        tokio::time::sleep(Duration::from_millis(50)).await;
1064        assert!(stream.next().await.is_none());
1065        assert!(stream.is_terminated());
1066    }
1067
1068    #[tokio::test]
1069    async fn ack_stream_drop() {
1070        let (tx, rx) = tokio::sync::mpsc::channel(255);
1071        let remote = MessageStream::new(rx);
1072        let handlers = Arc::new(Mutex::new(HashMap::new()));
1073        let id = Sid::new();
1074        handlers.lock().unwrap().insert(id, tx);
1075        let stream = AckStream::new(
1076            stream::empty::<AckStreamItem<()>>(),
1077            remote,
1078            Duration::from_secs(10),
1079            2,
1080            id,
1081            handlers.clone(),
1082        );
1083        drop(stream);
1084        assert!(handlers.lock().unwrap().is_empty(),);
1085    }
1086
1087    #[test]
1088    fn check_ns_error() {
1089        assert!(matches!(
1090            check_ns::<StubDriver>("#"),
1091            Err(InitError::MalformedNamespace)
1092        ));
1093        assert!(matches!(
1094            check_ns::<StubDriver>(""),
1095            Err(InitError::MalformedNamespace)
1096        ));
1097    }
1098}