socketioxide_redis/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(
3    clippy::all,
4    clippy::todo,
5    clippy::empty_enum,
6    clippy::mem_forget,
7    clippy::unused_self,
8    clippy::filter_map_next,
9    clippy::needless_continue,
10    clippy::needless_borrow,
11    clippy::match_wildcard_for_single_variants,
12    clippy::if_let_mutex,
13    clippy::await_holding_lock,
14    clippy::match_on_vec_items,
15    clippy::imprecise_flops,
16    clippy::suboptimal_flops,
17    clippy::lossy_float_literal,
18    clippy::rest_pat_in_fully_bound_structs,
19    clippy::fn_params_excessive_bools,
20    clippy::exit,
21    clippy::inefficient_to_string,
22    clippy::linkedlist,
23    clippy::macro_use_imports,
24    clippy::option_option,
25    clippy::verbose_file_reads,
26    clippy::unnested_or_patterns,
27    rust_2018_idioms,
28    future_incompatible,
29    nonstandard_style,
30    missing_docs
31)]
32
33//! # A redis/valkey adapter implementation for the socketioxide crate.
34//! The adapter is used to communicate with other nodes of the same application.
35//! This allows to broadcast messages to sockets connected on other servers,
36//! to get the list of rooms, to add or remove sockets from rooms, etc.
37//!
38//! To achieve this, the adapter uses a [pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/) system
39//! through Redis to communicate with other servers.
40//!
41//! The [`Driver`] abstraction allows the use of any pub/sub client.
42//! Three implementations are provided:
43//! * [`RedisDriver`](crate::drivers::redis::RedisDriver) for the [`redis`] crate with a standalone redis.
44//! * [`ClusterDriver`](crate::drivers::redis::ClusterDriver) for the [`redis`] crate with a redis cluster.
45//! * [`FredDriver`](crate::drivers::fred::FredDriver) for the [`fred`] crate with a standalone/cluster redis.
46//!
47//! When using redis clusters, the drivers employ [sharded pub/sub](https://redis.io/docs/latest/develop/interact/pubsub/#sharded-pubsub)
48//! to distribute the load across Redis nodes.
49//!
50//! You can also implement your own driver by implementing the [`Driver`] trait.
51//!
52//! <div class="warning">
53//!     The provided driver implementations are using <code>RESP3</code> for efficiency purposes.
54//!     Make sure your redis server supports it (redis v7 and above).
55//!     If not, you can implement your own driver using the <code>RESP2</code> protocol.
56//! </div>
57//!
58//! <div class="warning">
59//!     Socketioxide-Redis is not compatible with <code>@socketio/redis-adapter</code>
60//!     and <code>@socketio/redis-emitter</code>. They use completely different protocols and
61//!     cannot be used together. Do not mix socket.io JS servers with socketioxide rust servers.
62//! </div>
63//!
64//! ## Example with the [`redis`] driver
65//! ```rust
66//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
67//! # use socketioxide_redis::{RedisAdapterCtr, RedisAdapter};
68//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
69//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
70//!     socket.join("room1");
71//!     socket.on("event", on_event);
72//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
73//! }
74//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
75//!
76//! let client = redis::Client::open("redis://127.0.0.1:6379?protocol=RESP3")?;
77//! let adapter = RedisAdapterCtr::new_with_redis(&client).await?;
78//! let (layer, io) = SocketIo::builder()
79//!     .with_adapter::<RedisAdapter<_>>(adapter)
80//!     .build_layer();
81//! Ok(())
82//! # }
83//! ```
84//!
85//!
86//! ## Example with the [`fred`] driver
87//! ```rust
88//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
89//! # use socketioxide_redis::{RedisAdapterCtr, FredAdapter};
90//! # use fred::types::RespVersion;
91//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
92//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
93//!     socket.join("room1");
94//!     socket.on("event", on_event);
95//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
96//! }
97//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
98//!
99//! let mut config = fred::prelude::Config::from_url("redis://127.0.0.1:6379?protocol=resp3")?;
100//! // We need to manually set the RESP3 version because
101//! // the fred crate does not parse the protocol query parameter.
102//! config.version = RespVersion::RESP3;
103//! let client = fred::prelude::Builder::from_config(config).build_subscriber_client()?;
104//! let adapter = RedisAdapterCtr::new_with_fred(client).await?;
105//! let (layer, io) = SocketIo::builder()
106//!     .with_adapter::<FredAdapter<_>>(adapter)
107//!     .build_layer();
108//! Ok(())
109//! # }
110//! ```
111//!
112//!
113//! ## Example with the [`redis`] cluster driver
114//! ```rust
115//! # use socketioxide::{SocketIo, extract::{SocketRef, Data}, adapter::Adapter};
116//! # use socketioxide_redis::{RedisAdapterCtr, ClusterAdapter};
117//! # async fn doc_main() -> Result<(), Box<dyn std::error::Error>> {
118//! async fn on_connect<A: Adapter>(socket: SocketRef<A>) {
119//!     socket.join("room1");
120//!     socket.on("event", on_event);
121//!     let _ = socket.broadcast().emit("hello", "world").await.ok();
122//! }
123//! async fn on_event<A: Adapter>(socket: SocketRef<A>, Data(data): Data<String>) {}
124//!
125//! // single node cluster
126//! let client = redis::cluster::ClusterClient::new(["redis://127.0.0.1:6379?protocol=resp3"])?;
127//! let adapter = RedisAdapterCtr::new_with_cluster(&client).await?;
128//!
129//! let (layer, io) = SocketIo::builder()
130//!     .with_adapter::<ClusterAdapter<_>>(adapter)
131//!     .build_layer();
132//! Ok(())
133//! # }
134//! ```
135//!
136//! Check the [`chat example`](https://github.com/Totodore/socketioxide/tree/main/examples/chat)
137//! for more complete examples.
138//!
139//! ## How does it work?
140//!
141//! An adapter is created for each created namespace and it takes a corresponding [`CoreLocalAdapter`].
142//! The [`CoreLocalAdapter`] allows to manage the local rooms and local sockets. The default `LocalAdapter`
143//! is simply a wrapper around this [`CoreLocalAdapter`].
144//!
145//! The adapter is then initialized with the [`RedisAdapter::init`] method.
146//! This will subscribe to 3 channels:
147//! * `"{prefix}-request#{namespace}#"`: A global channel to receive broadcasted requests.
148//! * `"{prefix}-request#{namespace}#{uid}#"`: A specific channel to receive requests only for this server.
149//! * `"{prefix}-response#{namespace}#{uid}#"`: A specific channel to receive responses only for this server.
150//!     Messages sent to this channel will be always in the form `[req_id, data]`. This will allow the adapter to extract the request id
151//!     and route the response to the approriate stream before deserializing the data.
152//!
153//! All messages are encoded with msgpack.
154//!
155//! There are 7 types of requests:
156//! * Broadcast a packet to all the matching sockets.
157//! * Broadcast a packet to all the matching sockets and wait for a stream of acks.
158//! * Disconnect matching sockets.
159//! * Get all the rooms.
160//! * Add matching sockets to rooms.
161//! * Remove matching sockets to rooms.
162//! * Fetch all the remote sockets matching the options.
163//!
164//! For ack streams, the adapter will first send a `BroadcastAckCount` response to the server that sent the request,
165//! and then send the acks as they are received (more details in [`RedisAdapter::broadcast_with_ack`] fn).
166//!
167//! On the other side, each time an action has to be performed on the local server, the adapter will
168//! first broadcast a request to all the servers and then perform the action locally.
169
170use std::{
171    borrow::Cow,
172    collections::HashMap,
173    fmt,
174    future::{self, Future},
175    pin::Pin,
176    sync::{Arc, Mutex},
177    task::{Context, Poll},
178    time::Duration,
179};
180
181use drivers::{ChanItem, Driver, MessageStream};
182use futures_core::Stream;
183use futures_util::StreamExt;
184use request::{
185    read_req_id, RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType,
186};
187use serde::{de::DeserializeOwned, Serialize};
188use socketioxide_core::{
189    adapter::{
190        BroadcastFlags, BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter,
191        RemoteSocketData, Room, RoomParam, SocketEmitter, Spawnable,
192    },
193    errors::{AdapterError, BroadcastError},
194    packet::Packet,
195    Sid, Uid,
196};
197use stream::{AckStream, DropStream};
198use tokio::{sync::mpsc, time};
199
200/// Drivers are an abstraction over the pub/sub backend used by the adapter.
201/// You can use the provided implementation or implement your own.
202pub mod drivers;
203
204mod request;
205mod stream;
206
207/// Represent any error that might happen when using this adapter.
208#[derive(thiserror::Error)]
209pub enum Error<R: Driver> {
210    /// Redis driver error
211    #[error("driver error: {0}")]
212    Driver(R::Error),
213    /// Packet encoding error
214    #[error("packet encoding error: {0}")]
215    Decode(#[from] rmp_serde::decode::Error),
216    /// Packet decoding error
217    #[error("packet decoding error: {0}")]
218    Encode(#[from] rmp_serde::encode::Error),
219}
220
221impl<R: Driver> Error<R> {
222    fn from_driver(err: R::Error) -> Self {
223        Self::Driver(err)
224    }
225}
226impl<R: Driver> fmt::Debug for Error<R> {
227    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228        match self {
229            Self::Driver(err) => write!(f, "Driver error: {:?}", err),
230            Self::Decode(err) => write!(f, "Decode error: {:?}", err),
231            Self::Encode(err) => write!(f, "Encode error: {:?}", err),
232        }
233    }
234}
235
236impl<R: Driver> From<Error<R>> for AdapterError {
237    fn from(err: Error<R>) -> Self {
238        AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
239    }
240}
241
242/// The configuration of the [`RedisAdapter`].
243#[derive(Debug, Clone)]
244pub struct RedisAdapterConfig {
245    /// The request timeout. It is mainly used when expecting response such as when using
246    /// `broadcast_with_ack` or `rooms`. Default is 5 seconds.
247    pub request_timeout: Duration,
248
249    /// The prefix used for the channels. Default is "socket.io".
250    pub prefix: Cow<'static, str>,
251
252    /// The channel size used to receive ack responses. Default is 255.
253    ///
254    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
255    /// than you poll them with the returned stream, you might want to increase this value.
256    pub ack_response_buffer: usize,
257
258    /// The channel size used to receive messages. Default is 1024.
259    ///
260    /// If your server is under heavy load, you might want to increase this value.
261    pub stream_buffer: usize,
262}
263impl RedisAdapterConfig {
264    /// Create a new config.
265    pub fn new() -> Self {
266        Self::default()
267    }
268    /// Set the request timeout. Default is 5 seconds.
269    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
270        self.request_timeout = timeout;
271        self
272    }
273
274    /// Set the prefix used for the channels. Default is "socket.io".
275    pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
276        self.prefix = prefix.into();
277        self
278    }
279
280    /// Set the channel size used to send ack responses. Default is 255.
281    ///
282    /// If you have a lot of servers/sockets and that you may miss acknowledgement because they arrive faster
283    /// than you poll them with the returned stream, you might want to increase this value.
284    pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
285        assert!(buffer > 0, "buffer size must be greater than 0");
286        self.ack_response_buffer = buffer;
287        self
288    }
289
290    /// Set the channel size used to receive messages. Default is 1024.
291    ///
292    /// If your server is under heavy load, you might want to increase this value.
293    pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
294        assert!(buffer > 0, "buffer size must be greater than 0");
295        self.stream_buffer = buffer;
296        self
297    }
298}
299
300impl Default for RedisAdapterConfig {
301    fn default() -> Self {
302        Self {
303            request_timeout: Duration::from_secs(5),
304            prefix: Cow::Borrowed("socket.io"),
305            ack_response_buffer: 255,
306            stream_buffer: 1024,
307        }
308    }
309}
310
311/// The adapter constructor. For each namespace you define, a new adapter instance is created
312/// from this constructor.
313#[derive(Debug)]
314pub struct RedisAdapterCtr<R> {
315    driver: R,
316    config: RedisAdapterConfig,
317}
318
319#[cfg(feature = "redis")]
320impl RedisAdapterCtr<drivers::redis::RedisDriver> {
321    /// Create a new adapter constructor with the [`redis`] driver and a default config.
322    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
323    pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
324        Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
325    }
326    /// Create a new adapter constructor with the [`redis`] driver and a custom config.
327    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
328    pub async fn new_with_redis_config(
329        client: &redis::Client,
330        config: RedisAdapterConfig,
331    ) -> redis::RedisResult<Self> {
332        let driver = drivers::redis::RedisDriver::new(client).await?;
333        Ok(Self::new_with_driver(driver, config))
334    }
335}
336#[cfg(feature = "redis-cluster")]
337impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
338    /// Create a new adapter constructor with the [`redis`] driver and a default config.
339    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
340    pub async fn new_with_cluster(
341        client: &redis::cluster::ClusterClient,
342    ) -> redis::RedisResult<Self> {
343        Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
344    }
345
346    /// Create a new adapter constructor with the [`redis`] driver and a default config.
347    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
348    pub async fn new_with_cluster_config(
349        client: &redis::cluster::ClusterClient,
350        config: RedisAdapterConfig,
351    ) -> redis::RedisResult<Self> {
352        let driver = drivers::redis::ClusterDriver::new(client).await?;
353        Ok(Self::new_with_driver(driver, config))
354    }
355}
356#[cfg(feature = "fred")]
357impl RedisAdapterCtr<drivers::fred::FredDriver> {
358    /// Create a new adapter constructor with the default [`fred`] driver and a default config.
359    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
360    pub async fn new_with_fred(
361        client: fred::clients::SubscriberClient,
362    ) -> fred::prelude::FredResult<Self> {
363        Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
364    }
365    /// Create a new adapter constructor with the default [`fred`] driver and a custom config.
366    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
367    pub async fn new_with_fred_config(
368        client: fred::clients::SubscriberClient,
369        config: RedisAdapterConfig,
370    ) -> fred::prelude::FredResult<Self> {
371        let driver = drivers::fred::FredDriver::new(client).await?;
372        Ok(Self::new_with_driver(driver, config))
373    }
374}
375impl<R: Driver> RedisAdapterCtr<R> {
376    /// Create a new adapter constructor with a custom redis/valkey driver and a config.
377    ///
378    /// You can implement your own driver by implementing the [`Driver`] trait with any redis/valkey client.
379    /// Check the [`drivers`] module for more information.
380    pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
381        RedisAdapterCtr { driver, config }
382    }
383}
384
385pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
386
387/// The redis adapter with the fred driver.
388#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
389#[cfg(feature = "fred")]
390pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
391
392/// The redis adapter with the redis driver.
393#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
394#[cfg(feature = "redis")]
395pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
396
397/// The redis adapter with the redis cluster driver.
398#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
399#[cfg(feature = "redis-cluster")]
400pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
401
402/// The redis adapter implementation.
403/// It is generic over the [`Driver`] used to communicate with the redis server.
404/// And over the [`SocketEmitter`] used to communicate with the local server. This allows to
405/// avoid cyclic dependencies between the adapter, `socketioxide-core` and `socketioxide` crates.
406pub struct CustomRedisAdapter<E, R> {
407    /// The driver used by the adapter. This is used to communicate with the redis server.
408    /// All the redis adapter instances share the same driver.
409    driver: R,
410    /// The configuration of the adapter.
411    config: RedisAdapterConfig,
412    /// A unique identifier for the adapter to identify itself in the redis server.
413    uid: Uid,
414    /// The local adapter, used to manage local rooms and socket stores.
415    local: CoreLocalAdapter<E>,
416    /// The request channel used to broadcast requests to all the servers.
417    /// format: `{prefix}-request#{path}#`.
418    req_chan: String,
419    /// A map of response handlers used to await for responses from the remote servers.
420    responses: Arc<Mutex<ResponseHandlers>>,
421}
422
423impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
424impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
425    type Error = Error<R>;
426    type State = RedisAdapterCtr<R>;
427    type AckStream = AckStream<E::AckStream>;
428    type InitRes = InitRes<R>;
429
430    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
431        let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
432        let uid = local.server_id();
433        Self {
434            local,
435            req_chan,
436            uid,
437            driver: state.driver.clone(),
438            config: state.config.clone(),
439            responses: Arc::new(Mutex::new(HashMap::new())),
440        }
441    }
442
443    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
444        let fut = async move {
445            check_ns(self.local.path())?;
446            let global_stream = self.subscribe(self.req_chan.clone()).await?;
447            let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
448            let response_chan = format!(
449                "{}-response#{}#{}#",
450                &self.config.prefix,
451                self.local.path(),
452                self.uid
453            );
454
455            let response_stream = self.subscribe(response_chan.clone()).await?;
456            let stream = futures_util::stream::select(global_stream, specific_stream);
457            let stream = futures_util::stream::select(stream, response_stream);
458            tokio::spawn(self.pipe_stream(stream, response_chan));
459            on_success();
460            Ok(())
461        };
462        InitRes(Box::pin(fut))
463    }
464
465    async fn close(&self) -> Result<(), Self::Error> {
466        let response_chan = format!(
467            "{}-response#{}#{}#",
468            &self.config.prefix,
469            self.local.path(),
470            self.uid
471        );
472        tokio::try_join!(
473            self.driver.unsubscribe(self.req_chan.clone()),
474            self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
475            self.driver.unsubscribe(response_chan)
476        )
477        .map_err(Error::from_driver)?;
478
479        Ok(())
480    }
481
482    /// Get the number of servers by getting the number of subscribers to the request channel.
483    async fn server_count(&self) -> Result<u16, Self::Error> {
484        let count = self
485            .driver
486            .num_serv(&self.req_chan)
487            .await
488            .map_err(Error::from_driver)?;
489
490        Ok(count)
491    }
492
493    /// Broadcast a packet to all the servers to send them through their sockets.
494    async fn broadcast(
495        &self,
496        packet: Packet,
497        opts: BroadcastOptions,
498    ) -> Result<(), BroadcastError> {
499        if !is_local_op(self.uid, &opts) {
500            let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
501            self.send_req(req, opts.server_id)
502                .await
503                .map_err(AdapterError::from)?;
504        }
505
506        self.local.broadcast(packet, opts)?;
507        Ok(())
508    }
509
510    /// Broadcast a packet to all the servers to send them through their sockets.
511    ///
512    /// Returns a Stream that is a combination of the local ack stream and a remote [`MessageStream`].
513    /// Here is a specific protocol in order to know how many message the server expect to close
514    /// the stream at the right time:
515    /// * Get the number `n` of remote servers.
516    /// * Send the broadcast request.
517    /// * Expect `n` `BroadcastAckCount` response in the stream to know the number `m` of expected ack responses.
518    /// * Expect `sum(m)` broadcast counts sent by the servers.
519    ///
520    /// Example with 3 remote servers (n = 3):
521    /// ```text
522    /// +---+                   +---+                   +---+
523    /// | A |                   | B |                   | C |
524    /// +---+                   +---+                   +---+
525    ///   |                       |                       |
526    ///   |---BroadcastWithAck--->|                       |
527    ///   |---BroadcastWithAck--------------------------->|
528    ///   |                       |                       |
529    ///   |<-BroadcastAckCount(2)-|     (n = 2; m = 2)    |
530    ///   |<-BroadcastAckCount(2)-------(n = 2; m = 4)----|
531    ///   |                       |                       |
532    ///   |<----------------Ack---------------------------|
533    ///   |<----------------Ack---|                       |
534    ///   |                       |                       |
535    ///   |<----------------Ack---------------------------|
536    ///   |<----------------Ack---|                       |
537    async fn broadcast_with_ack(
538        &self,
539        packet: Packet,
540        opts: BroadcastOptions,
541        timeout: Option<Duration>,
542    ) -> Result<Self::AckStream, Self::Error> {
543        if is_local_op(self.uid, &opts) {
544            tracing::debug!(?opts, "broadcast with ack is local");
545            let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
546            let stream = AckStream::new_local(local);
547            return Ok(stream);
548        }
549        let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
550        let req_id = req.id;
551
552        let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
553
554        let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
555        self.responses.lock().unwrap().insert(req_id, tx);
556        let remote = MessageStream::new(rx);
557
558        self.send_req(req, opts.server_id).await?;
559        let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
560
561        Ok(AckStream::new(
562            local,
563            remote,
564            self.config.request_timeout,
565            remote_serv_cnt,
566            req_id,
567            self.responses.clone(),
568        ))
569    }
570
571    async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
572        if !is_local_op(self.uid, &opts) {
573            let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
574            self.send_req(req, opts.server_id)
575                .await
576                .map_err(AdapterError::from)?;
577        }
578        self.local
579            .disconnect_socket(opts)
580            .map_err(BroadcastError::Socket)?;
581
582        Ok(())
583    }
584
585    async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
586        const PACKET_IDX: u8 = 2;
587
588        if is_local_op(self.uid, &opts) {
589            return Ok(self.local.rooms(opts).into_iter().collect());
590        }
591        let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
592        let req_id = req.id;
593
594        // First get the remote stream because redis might send
595        // the responses before subscription is done.
596        let stream = self
597            .get_res::<()>(req_id, PACKET_IDX, opts.server_id)
598            .await?;
599        self.send_req(req, opts.server_id).await?;
600        let local = self.local.rooms(opts);
601        let rooms = stream
602            .filter_map(|item| future::ready(item.into_rooms()))
603            .fold(local, |mut acc, item| async move {
604                acc.extend(item);
605                acc
606            })
607            .await;
608        Ok(Vec::from_iter(rooms))
609    }
610
611    async fn add_sockets(
612        &self,
613        opts: BroadcastOptions,
614        rooms: impl RoomParam,
615    ) -> Result<(), Self::Error> {
616        let rooms: Vec<Room> = rooms.into_room_iter().collect();
617        if !is_local_op(self.uid, &opts) {
618            let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
619            self.send_req(req, opts.server_id).await?;
620        }
621        self.local.add_sockets(opts, rooms);
622        Ok(())
623    }
624
625    async fn del_sockets(
626        &self,
627        opts: BroadcastOptions,
628        rooms: impl RoomParam,
629    ) -> Result<(), Self::Error> {
630        let rooms: Vec<Room> = rooms.into_room_iter().collect();
631        if !is_local_op(self.uid, &opts) {
632            let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
633            self.send_req(req, opts.server_id).await?;
634        }
635        self.local.del_sockets(opts, rooms);
636        Ok(())
637    }
638
639    async fn fetch_sockets(
640        &self,
641        opts: BroadcastOptions,
642    ) -> Result<Vec<RemoteSocketData>, Self::Error> {
643        if is_local_op(self.uid, &opts) {
644            return Ok(self.local.fetch_sockets(opts));
645        }
646        const PACKET_IDX: u8 = 3;
647        let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
648        let req_id = req.id;
649        // First get the remote stream because redis might send
650        // the responses before subscription is done.
651        let remote = self
652            .get_res::<RemoteSocketData>(req_id, PACKET_IDX, opts.server_id)
653            .await?;
654
655        self.send_req(req, opts.server_id).await?;
656        let local = self.local.fetch_sockets(opts);
657        let sockets = remote
658            .filter_map(|item| future::ready(item.into_fetch_sockets()))
659            .fold(local, |mut acc, item| async move {
660                acc.extend(item);
661                acc
662            })
663            .await;
664        Ok(sockets)
665    }
666
667    fn get_local(&self) -> &CoreLocalAdapter<E> {
668        &self.local
669    }
670}
671
672/// Error that can happen when initializing the adapter.
673#[derive(thiserror::Error)]
674pub enum InitError<D: Driver> {
675    /// Driver error.
676    #[error("driver error: {0}")]
677    Driver(D::Error),
678    /// Malformed namespace path.
679    #[error("malformed namespace path, it must not contain '#'")]
680    MalformedNamespace,
681}
682impl<D: Driver> fmt::Debug for InitError<D> {
683    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
684        match self {
685            Self::Driver(err) => fmt::Debug::fmt(err, f),
686            Self::MalformedNamespace => write!(f, "Malformed namespace path"),
687        }
688    }
689}
690/// The result of the init future.
691#[must_use = "futures do nothing unless you `.await` or poll them"]
692pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
693
694impl<D: Driver> Future for InitRes<D> {
695    type Output = Result<(), InitError<D>>;
696
697    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
698        self.0.as_mut().poll(cx)
699    }
700}
701impl<D: Driver> Spawnable for InitRes<D> {
702    fn spawn(self) {
703        tokio::spawn(async move {
704            if let Err(e) = self.0.await {
705                tracing::error!("error initializing adapter: {e}");
706            }
707        });
708    }
709}
710
711impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
712    /// Build a response channel for a request.
713    ///
714    /// The uid is used to identify the server that sent the request.
715    /// The req_id is used to identify the request.
716    fn get_res_chan(&self, uid: Uid) -> String {
717        let path = self.local.path();
718        let prefix = &self.config.prefix;
719        format!("{}-response#{}#{}#", prefix, path, uid)
720    }
721    /// Build a request channel for a request.
722    ///
723    /// If we know the target server id, we can build a channel specific to this server.
724    /// Otherwise, we use the default request channel that will broadcast the request to all the servers.
725    fn get_req_chan(&self, node_id: Option<Uid>) -> String {
726        match node_id {
727            Some(uid) => format!("{}{}#", self.req_chan, uid),
728            None => self.req_chan.clone(),
729        }
730    }
731
732    async fn pipe_stream(
733        self: Arc<Self>,
734        mut stream: impl Stream<Item = ChanItem> + Unpin,
735        response_chan: String,
736    ) {
737        while let Some((chan, item)) = stream.next().await {
738            if chan.starts_with(&self.req_chan) {
739                if let Err(e) = self.recv_req(item) {
740                    let ns = self.local.path();
741                    let uid = self.uid;
742                    tracing::warn!(?uid, ?ns, "request handler error: {e}");
743                }
744            } else if chan == response_chan {
745                let req_id = read_req_id(&item);
746                tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
747                let handlers = self.responses.lock().unwrap();
748                if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
749                    if let Err(e) = tx.try_send(item) {
750                        tracing::warn!("error sending response to handler: {e}");
751                    }
752                } else {
753                    tracing::warn!(?req_id, "could not find req handler");
754                }
755            } else {
756                tracing::warn!("unexpected message/channel: {chan}");
757            }
758        }
759    }
760
761    /// Handle a generic request received from the request channel.
762    fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
763        let req: RequestIn = rmp_serde::from_slice(&item)?;
764        if req.node_id == self.uid {
765            return Ok(());
766        }
767
768        tracing::trace!(?req, "handling request");
769
770        match req.r#type {
771            RequestTypeIn::Broadcast(p) => self.recv_broadcast(req.opts, p),
772            RequestTypeIn::BroadcastWithAck(_) => self.clone().recv_broadcast_with_ack(req),
773            RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(req),
774            RequestTypeIn::AllRooms => self.recv_rooms(req),
775            RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(req.opts, rooms),
776            RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(req.opts, rooms),
777            RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req),
778        };
779        Ok(())
780    }
781
782    fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
783        if let Err(e) = self.local.broadcast(packet, opts) {
784            let ns = self.local.path();
785            tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
786        }
787    }
788
789    fn recv_disconnect_sockets(&self, req: RequestIn) {
790        if let Err(e) = self.local.disconnect_socket(req.opts) {
791            let ns = self.local.path();
792            tracing::warn!(
793                ?self.uid,
794                ?ns,
795                "remote request disconnect sockets handler: {:?}",
796                e
797            );
798        }
799    }
800
801    fn recv_broadcast_with_ack(self: Arc<Self>, req: RequestIn) {
802        let packet = match req.r#type {
803            RequestTypeIn::BroadcastWithAck(p) => p,
804            _ => unreachable!(),
805        };
806        let (stream, count) = self.local.broadcast_with_ack(packet, req.opts, None);
807        tokio::spawn(async move {
808            let on_err = |err| {
809                let ns = self.local.path();
810                tracing::warn!(
811                    ?self.uid,
812                    ?ns,
813                    "remote request broadcast with ack handler errors: {:?}",
814                    err
815                );
816            };
817            // First send the count of expected acks to the server that sent the request.
818            // This is used to keep track of the number of expected acks.
819            let res = Response {
820                r#type: ResponseType::<()>::BroadcastAckCount(count),
821                node_id: self.uid,
822            };
823            if let Err(err) = self.send_res(req.node_id, req.id, res).await {
824                on_err(err);
825                return;
826            }
827
828            // Then send the acks as they are received.
829            futures_util::pin_mut!(stream);
830            while let Some(ack) = stream.next().await {
831                let res = Response {
832                    r#type: ResponseType::BroadcastAck(ack),
833                    node_id: self.uid,
834                };
835                if let Err(err) = self.send_res(req.node_id, req.id, res).await {
836                    on_err(err);
837                    return;
838                }
839            }
840        });
841    }
842
843    fn recv_rooms(&self, req: RequestIn) {
844        let rooms = self.local.rooms(req.opts);
845        let res = Response {
846            r#type: ResponseType::<()>::AllRooms(rooms),
847            node_id: self.uid,
848        };
849        let fut = self.send_res(req.node_id, req.id, res);
850        let ns = self.local.path().clone();
851        let uid = self.uid;
852        tokio::spawn(async move {
853            if let Err(err) = fut.await {
854                tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
855            }
856        });
857    }
858
859    fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
860        self.local.add_sockets(opts, rooms);
861    }
862
863    fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
864        self.local.del_sockets(opts, rooms);
865    }
866    fn recv_fetch_sockets(&self, req: RequestIn) {
867        let sockets = self.local.fetch_sockets(req.opts);
868        let res = Response {
869            node_id: self.uid,
870            r#type: ResponseType::FetchSockets(sockets),
871        };
872        let fut = self.send_res(req.node_id, req.id, res);
873        let ns = self.local.path().clone();
874        let uid = self.uid;
875        tokio::spawn(async move {
876            if let Err(err) = fut.await {
877                tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
878            }
879        });
880    }
881
882    async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
883        tracing::trace!(?req, "sending request");
884        let req = rmp_serde::to_vec(&req)?;
885        let chan = self.get_req_chan(target_uid);
886        self.driver
887            .publish(chan, req)
888            .await
889            .map_err(Error::from_driver)?;
890
891        Ok(())
892    }
893
894    fn send_res<D: Serialize + fmt::Debug>(
895        &self,
896        req_node_id: Uid,
897        req_id: Sid,
898        res: Response<D>,
899    ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
900        let chan = self.get_res_chan(req_node_id);
901        tracing::trace!(?res, "sending response to {}", &chan);
902        // We send the req_id separated from the response object.
903        // This allows to partially decode the response and route by the req_id
904        // before fully deserializing it.
905        let res = rmp_serde::to_vec(&(req_id, res));
906        let driver = self.driver.clone();
907        async move {
908            driver
909                .publish(chan, res?)
910                .await
911                .map_err(Error::from_driver)?;
912            Ok(())
913        }
914    }
915
916    /// Await for all the responses from the remote servers.
917    async fn get_res<D: DeserializeOwned + fmt::Debug>(
918        &self,
919        req_id: Sid,
920        response_idx: u8,
921        target_uid: Option<Uid>,
922    ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
923        // Check for specific target node
924        let remote_serv_cnt = if target_uid.is_none() {
925            self.server_count().await?.saturating_sub(1) as usize
926        } else {
927            1
928        };
929        let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
930        self.responses.lock().unwrap().insert(req_id, tx);
931        let stream = MessageStream::new(rx)
932            .filter_map(|item| {
933                let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
934                    Ok((_, data)) => Some(data),
935                    Err(e) => {
936                        tracing::warn!("error decoding response: {e}");
937                        None
938                    }
939                };
940                future::ready(data)
941            })
942            .filter(move |item| future::ready(item.r#type.to_u8() == response_idx))
943            .take(remote_serv_cnt)
944            .take_until(time::sleep(self.config.request_timeout));
945        let stream = DropStream::new(stream, self.responses.clone(), req_id);
946        Ok(stream)
947    }
948
949    /// Little wrapper to map the error type.
950    #[inline]
951    async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
952        tracing::trace!(?pat, "subscribing to");
953        self.driver
954            .subscribe(pat, self.config.stream_buffer)
955            .await
956            .map_err(InitError::Driver)
957    }
958}
959
960/// A local operator is either something that is flagged as local or a request that should be specifically
961/// sent to the current server.
962#[inline]
963fn is_local_op(uid: Uid, opts: &BroadcastOptions) -> bool {
964    if opts.has_flag(BroadcastFlags::Local)
965        || (!opts.has_flag(BroadcastFlags::Broadcast)
966            && opts.server_id == Some(uid)
967            && opts.rooms.is_empty()
968            && opts.sid.is_some())
969    {
970        tracing::debug!(?opts, "operation is local");
971        true
972    } else {
973        false
974    }
975}
976
977/// Checks if the namespace path is valid
978/// Panics if the path is empty or contains a `#`
979fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
980    if path.is_empty() || path.contains('#') {
981        Err(InitError::MalformedNamespace)
982    } else {
983        Ok(())
984    }
985}
986
987#[cfg(test)]
988mod tests {
989    use super::*;
990    use futures_util::stream::{self, FusedStream, StreamExt};
991    use socketioxide_core::{adapter::AckStreamItem, Str, Value};
992    use std::convert::Infallible;
993
994    #[derive(Clone)]
995    struct StubDriver;
996    impl Driver for StubDriver {
997        type Error = Infallible;
998
999        async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
1000            Ok(())
1001        }
1002
1003        async fn subscribe(
1004            &self,
1005            _: String,
1006            _: usize,
1007        ) -> Result<MessageStream<ChanItem>, Self::Error> {
1008            Ok(MessageStream::new_empty())
1009        }
1010
1011        async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
1012            Ok(())
1013        }
1014
1015        async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
1016            Ok(0)
1017        }
1018    }
1019    fn new_stub_ack_stream(
1020        remote: MessageStream<Vec<u8>>,
1021        timeout: Duration,
1022    ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1023        AckStream::new(
1024            stream::empty::<AckStreamItem<()>>(),
1025            remote,
1026            timeout,
1027            2,
1028            Sid::new(),
1029            Arc::new(Mutex::new(HashMap::new())),
1030        )
1031    }
1032
1033    //TODO: test weird behaviours, packets out of orders, etc
1034    #[tokio::test]
1035    async fn ack_stream() {
1036        let (tx, rx) = tokio::sync::mpsc::channel(255);
1037        let remote = MessageStream::new(rx);
1038        let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1039        let node_id = Uid::new();
1040        let req_id = Sid::new();
1041
1042        // The two servers will send 2 acks each.
1043        let ack_cnt_res = Response::<()> {
1044            node_id,
1045            r#type: ResponseType::BroadcastAckCount(2),
1046        };
1047        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1048            .unwrap();
1049        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1050            .unwrap();
1051
1052        let ack_res = Response::<String> {
1053            node_id,
1054            r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1055        };
1056        for _ in 0..4 {
1057            tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1058                .unwrap();
1059        }
1060        futures_util::pin_mut!(stream);
1061        for _ in 0..4 {
1062            assert!(stream.next().await.is_some());
1063        }
1064        assert!(stream.is_terminated());
1065    }
1066
1067    #[tokio::test]
1068    async fn ack_stream_timeout() {
1069        let (tx, rx) = tokio::sync::mpsc::channel(255);
1070        let remote = MessageStream::new(rx);
1071        let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1072        let node_id = Uid::new();
1073        let req_id = Sid::new();
1074        // There will be only one ack count and then the stream will timeout.
1075        let ack_cnt_res = Response::<()> {
1076            node_id,
1077            r#type: ResponseType::BroadcastAckCount(2),
1078        };
1079        tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1080            .unwrap();
1081
1082        futures_util::pin_mut!(stream);
1083        tokio::time::sleep(Duration::from_millis(50)).await;
1084        assert!(stream.next().await.is_none());
1085        assert!(stream.is_terminated());
1086    }
1087
1088    #[tokio::test]
1089    async fn ack_stream_drop() {
1090        let (tx, rx) = tokio::sync::mpsc::channel(255);
1091        let remote = MessageStream::new(rx);
1092        let handlers = Arc::new(Mutex::new(HashMap::new()));
1093        let id = Sid::new();
1094        handlers.lock().unwrap().insert(id, tx);
1095        let stream = AckStream::new(
1096            stream::empty::<AckStreamItem<()>>(),
1097            remote,
1098            Duration::from_secs(10),
1099            2,
1100            id,
1101            handlers.clone(),
1102        );
1103        drop(stream);
1104        assert!(handlers.lock().unwrap().is_empty(),);
1105    }
1106
1107    #[test]
1108    fn test_is_local_op() {
1109        let server_id = Uid::new();
1110        let remote = RemoteSocketData {
1111            id: Sid::new(),
1112            server_id,
1113            ns: "/".into(),
1114        };
1115        let opts = BroadcastOptions::new_remote(&remote);
1116        assert!(is_local_op(server_id, &opts));
1117        assert!(!is_local_op(Uid::new(), &opts));
1118        let opts = BroadcastOptions::new(Sid::new());
1119        assert!(!is_local_op(Uid::new(), &opts));
1120    }
1121
1122    #[test]
1123    fn check_ns_error() {
1124        assert!(matches!(
1125            check_ns::<StubDriver>("#"),
1126            Err(InitError::MalformedNamespace)
1127        ));
1128        assert!(matches!(
1129            check_ns::<StubDriver>(""),
1130            Err(InitError::MalformedNamespace)
1131        ));
1132    }
1133}