1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(
3    clippy::all,
4    clippy::todo,
5    clippy::empty_enum,
6    clippy::mem_forget,
7    clippy::unused_self,
8    clippy::filter_map_next,
9    clippy::needless_continue,
10    clippy::needless_borrow,
11    clippy::match_wildcard_for_single_variants,
12    clippy::if_let_mutex,
13    clippy::await_holding_lock,
14    clippy::imprecise_flops,
15    clippy::suboptimal_flops,
16    clippy::lossy_float_literal,
17    clippy::rest_pat_in_fully_bound_structs,
18    clippy::fn_params_excessive_bools,
19    clippy::exit,
20    clippy::inefficient_to_string,
21    clippy::linkedlist,
22    clippy::macro_use_imports,
23    clippy::option_option,
24    clippy::verbose_file_reads,
25    clippy::unnested_or_patterns,
26    rust_2018_idioms,
27    rust_2024_compatibility,
28    future_incompatible,
29    nonstandard_style,
30    missing_docs
31)]
32
33use std::{
171    borrow::Cow,
172    collections::HashMap,
173    fmt,
174    future::{self, Future},
175    pin::Pin,
176    sync::{Arc, Mutex},
177    task::{Context, Poll},
178    time::Duration,
179};
180
181use drivers::{ChanItem, Driver, MessageStream};
182use futures_core::Stream;
183use futures_util::StreamExt;
184use serde::{Serialize, de::DeserializeOwned};
185use socketioxide_core::adapter::remote_packet::{
186    RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
187};
188use socketioxide_core::{
189    Sid, Uid,
190    adapter::errors::{AdapterError, BroadcastError},
191    adapter::{
192        BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
193        RoomParam, SocketEmitter, Spawnable,
194    },
195    packet::Packet,
196};
197use stream::{AckStream, DropStream};
198use tokio::{sync::mpsc, time};
199
200pub mod drivers;
203
204mod stream;
205
206#[derive(thiserror::Error)]
208pub enum Error<R: Driver> {
209    #[error("driver error: {0}")]
211    Driver(R::Error),
212    #[error("packet encoding error: {0}")]
214    Decode(#[from] rmp_serde::decode::Error),
215    #[error("packet decoding error: {0}")]
217    Encode(#[from] rmp_serde::encode::Error),
218}
219
220impl<R: Driver> Error<R> {
221    fn from_driver(err: R::Error) -> Self {
222        Self::Driver(err)
223    }
224}
225impl<R: Driver> fmt::Debug for Error<R> {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        match self {
228            Self::Driver(err) => write!(f, "Driver error: {err:?}"),
229            Self::Decode(err) => write!(f, "Decode error: {err:?}"),
230            Self::Encode(err) => write!(f, "Encode error: {err:?}"),
231        }
232    }
233}
234
235impl<R: Driver> From<Error<R>> for AdapterError {
236    fn from(err: Error<R>) -> Self {
237        AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
238    }
239}
240
241#[derive(Debug, Clone)]
243pub struct RedisAdapterConfig {
244    pub request_timeout: Duration,
247
248    pub prefix: Cow<'static, str>,
250
251    pub ack_response_buffer: usize,
256
257    pub stream_buffer: usize,
261}
262impl RedisAdapterConfig {
263    pub fn new() -> Self {
265        Self::default()
266    }
267    pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
269        self.request_timeout = timeout;
270        self
271    }
272
273    pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
275        self.prefix = prefix.into();
276        self
277    }
278
279    pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
284        assert!(buffer > 0, "buffer size must be greater than 0");
285        self.ack_response_buffer = buffer;
286        self
287    }
288
289    pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
293        assert!(buffer > 0, "buffer size must be greater than 0");
294        self.stream_buffer = buffer;
295        self
296    }
297}
298
299impl Default for RedisAdapterConfig {
300    fn default() -> Self {
301        Self {
302            request_timeout: Duration::from_secs(5),
303            prefix: Cow::Borrowed("socket.io"),
304            ack_response_buffer: 255,
305            stream_buffer: 1024,
306        }
307    }
308}
309
310#[derive(Debug)]
313pub struct RedisAdapterCtr<R> {
314    driver: R,
315    config: RedisAdapterConfig,
316}
317
318#[cfg(feature = "redis")]
319impl RedisAdapterCtr<drivers::redis::RedisDriver> {
320    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
322    pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
323        Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
324    }
325    #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
327    pub async fn new_with_redis_config(
328        client: &redis::Client,
329        config: RedisAdapterConfig,
330    ) -> redis::RedisResult<Self> {
331        let driver = drivers::redis::RedisDriver::new(client).await?;
332        Ok(Self::new_with_driver(driver, config))
333    }
334}
335#[cfg(feature = "redis-cluster")]
336impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
337    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
339    pub async fn new_with_cluster(
340        client: &redis::cluster::ClusterClient,
341    ) -> redis::RedisResult<Self> {
342        Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
343    }
344
345    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
347    pub async fn new_with_cluster_config(
348        client: &redis::cluster::ClusterClient,
349        config: RedisAdapterConfig,
350    ) -> redis::RedisResult<Self> {
351        let driver = drivers::redis::ClusterDriver::new(client).await?;
352        Ok(Self::new_with_driver(driver, config))
353    }
354}
355#[cfg(feature = "fred")]
356impl RedisAdapterCtr<drivers::fred::FredDriver> {
357    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
359    pub async fn new_with_fred(
360        client: fred::clients::SubscriberClient,
361    ) -> fred::prelude::FredResult<Self> {
362        Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
363    }
364    #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
366    pub async fn new_with_fred_config(
367        client: fred::clients::SubscriberClient,
368        config: RedisAdapterConfig,
369    ) -> fred::prelude::FredResult<Self> {
370        let driver = drivers::fred::FredDriver::new(client).await?;
371        Ok(Self::new_with_driver(driver, config))
372    }
373}
374impl<R: Driver> RedisAdapterCtr<R> {
375    pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
380        RedisAdapterCtr { driver, config }
381    }
382}
383
384pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
385
386#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
388#[cfg(feature = "fred")]
389pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
390
391#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
393#[cfg(feature = "redis")]
394pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
395
396#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
398#[cfg(feature = "redis-cluster")]
399pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
400
401pub struct CustomRedisAdapter<E, R> {
406    driver: R,
409    config: RedisAdapterConfig,
411    uid: Uid,
413    local: CoreLocalAdapter<E>,
415    req_chan: String,
418    responses: Arc<Mutex<ResponseHandlers>>,
420}
421
422impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
423impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
424    type Error = Error<R>;
425    type State = RedisAdapterCtr<R>;
426    type AckStream = AckStream<E::AckStream>;
427    type InitRes = InitRes<R>;
428
429    fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
430        let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
431        let uid = local.server_id();
432        Self {
433            local,
434            req_chan,
435            uid,
436            driver: state.driver.clone(),
437            config: state.config.clone(),
438            responses: Arc::new(Mutex::new(HashMap::new())),
439        }
440    }
441
442    fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
443        let fut = async move {
444            check_ns(self.local.path())?;
445            let global_stream = self.subscribe(self.req_chan.clone()).await?;
446            let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
447            let response_chan = format!(
448                "{}-response#{}#{}#",
449                &self.config.prefix,
450                self.local.path(),
451                self.uid
452            );
453
454            let response_stream = self.subscribe(response_chan.clone()).await?;
455            let stream = futures_util::stream::select(global_stream, specific_stream);
456            let stream = futures_util::stream::select(stream, response_stream);
457            tokio::spawn(self.pipe_stream(stream, response_chan));
458            on_success();
459            Ok(())
460        };
461        InitRes(Box::pin(fut))
462    }
463
464    async fn close(&self) -> Result<(), Self::Error> {
465        let response_chan = format!(
466            "{}-response#{}#{}#",
467            &self.config.prefix,
468            self.local.path(),
469            self.uid
470        );
471        tokio::try_join!(
472            self.driver.unsubscribe(self.req_chan.clone()),
473            self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
474            self.driver.unsubscribe(response_chan)
475        )
476        .map_err(Error::from_driver)?;
477
478        Ok(())
479    }
480
481    async fn server_count(&self) -> Result<u16, Self::Error> {
483        let count = self
484            .driver
485            .num_serv(&self.req_chan)
486            .await
487            .map_err(Error::from_driver)?;
488
489        Ok(count)
490    }
491
492    async fn broadcast(
494        &self,
495        packet: Packet,
496        opts: BroadcastOptions,
497    ) -> Result<(), BroadcastError> {
498        if !opts.is_local(self.uid) {
499            let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
500            self.send_req(req, opts.server_id)
501                .await
502                .map_err(AdapterError::from)?;
503        }
504
505        self.local.broadcast(packet, opts)?;
506        Ok(())
507    }
508
509    async fn broadcast_with_ack(
537        &self,
538        packet: Packet,
539        opts: BroadcastOptions,
540        timeout: Option<Duration>,
541    ) -> Result<Self::AckStream, Self::Error> {
542        if opts.is_local(self.uid) {
543            tracing::debug!(?opts, "broadcast with ack is local");
544            let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
545            let stream = AckStream::new_local(local);
546            return Ok(stream);
547        }
548        let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
549        let req_id = req.id;
550
551        let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
552
553        let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
554        self.responses.lock().unwrap().insert(req_id, tx);
555        let remote = MessageStream::new(rx);
556
557        self.send_req(req, opts.server_id).await?;
558        let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
559
560        Ok(AckStream::new(
561            local,
562            remote,
563            self.config.request_timeout,
564            remote_serv_cnt,
565            req_id,
566            self.responses.clone(),
567        ))
568    }
569
570    async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
571        if !opts.is_local(self.uid) {
572            let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
573            self.send_req(req, opts.server_id)
574                .await
575                .map_err(AdapterError::from)?;
576        }
577        self.local
578            .disconnect_socket(opts)
579            .map_err(BroadcastError::Socket)?;
580
581        Ok(())
582    }
583
584    async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
585        if opts.is_local(self.uid) {
586            return Ok(self.local.rooms(opts).into_iter().collect());
587        }
588        let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
589        let req_id = req.id;
590
591        let stream = self
594            .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
595            .await?;
596        self.send_req(req, opts.server_id).await?;
597        let local = self.local.rooms(opts);
598        let rooms = stream
599            .filter_map(|item| future::ready(item.into_rooms()))
600            .fold(local, async |mut acc, item| {
601                acc.extend(item);
602                acc
603            })
604            .await;
605        Ok(Vec::from_iter(rooms))
606    }
607
608    async fn add_sockets(
609        &self,
610        opts: BroadcastOptions,
611        rooms: impl RoomParam,
612    ) -> Result<(), Self::Error> {
613        let rooms: Vec<Room> = rooms.into_room_iter().collect();
614        if !opts.is_local(self.uid) {
615            let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
616            self.send_req(req, opts.server_id).await?;
617        }
618        self.local.add_sockets(opts, rooms);
619        Ok(())
620    }
621
622    async fn del_sockets(
623        &self,
624        opts: BroadcastOptions,
625        rooms: impl RoomParam,
626    ) -> Result<(), Self::Error> {
627        let rooms: Vec<Room> = rooms.into_room_iter().collect();
628        if !opts.is_local(self.uid) {
629            let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
630            self.send_req(req, opts.server_id).await?;
631        }
632        self.local.del_sockets(opts, rooms);
633        Ok(())
634    }
635
636    async fn fetch_sockets(
637        &self,
638        opts: BroadcastOptions,
639    ) -> Result<Vec<RemoteSocketData>, Self::Error> {
640        if opts.is_local(self.uid) {
641            return Ok(self.local.fetch_sockets(opts));
642        }
643        let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
644        let req_id = req.id;
645        let remote = self
648            .get_res::<RemoteSocketData>(req_id, ResponseTypeId::FetchSockets, opts.server_id)
649            .await?;
650
651        self.send_req(req, opts.server_id).await?;
652        let local = self.local.fetch_sockets(opts);
653        let sockets = remote
654            .filter_map(|item| future::ready(item.into_fetch_sockets()))
655            .fold(local, async |mut acc, item| {
656                acc.extend(item);
657                acc
658            })
659            .await;
660        Ok(sockets)
661    }
662
663    fn get_local(&self) -> &CoreLocalAdapter<E> {
664        &self.local
665    }
666}
667
668#[derive(thiserror::Error)]
670pub enum InitError<D: Driver> {
671    #[error("driver error: {0}")]
673    Driver(D::Error),
674    #[error("malformed namespace path, it must not contain '#'")]
676    MalformedNamespace,
677}
678impl<D: Driver> fmt::Debug for InitError<D> {
679    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
680        match self {
681            Self::Driver(err) => fmt::Debug::fmt(err, f),
682            Self::MalformedNamespace => write!(f, "Malformed namespace path"),
683        }
684    }
685}
686#[must_use = "futures do nothing unless you `.await` or poll them"]
688pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
689
690impl<D: Driver> Future for InitRes<D> {
691    type Output = Result<(), InitError<D>>;
692
693    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
694        self.0.as_mut().poll(cx)
695    }
696}
697impl<D: Driver> Spawnable for InitRes<D> {
698    fn spawn(self) {
699        tokio::spawn(async move {
700            if let Err(e) = self.0.await {
701                tracing::error!("error initializing adapter: {e}");
702            }
703        });
704    }
705}
706
707impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
708    fn get_res_chan(&self, uid: Uid) -> String {
713        let path = self.local.path();
714        let prefix = &self.config.prefix;
715        format!("{prefix}-response#{path}#{uid}#")
716    }
717    fn get_req_chan(&self, node_id: Option<Uid>) -> String {
722        match node_id {
723            Some(uid) => format!("{}{}#", self.req_chan, uid),
724            None => self.req_chan.clone(),
725        }
726    }
727
728    async fn pipe_stream(
729        self: Arc<Self>,
730        mut stream: impl Stream<Item = ChanItem> + Unpin,
731        response_chan: String,
732    ) {
733        while let Some((chan, item)) = stream.next().await {
734            if chan.starts_with(&self.req_chan) {
735                if let Err(e) = self.recv_req(item) {
736                    let ns = self.local.path();
737                    let uid = self.uid;
738                    tracing::warn!(?uid, ?ns, "request handler error: {e}");
739                }
740            } else if chan == response_chan {
741                let req_id = read_req_id(&item);
742                tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
743                let handlers = self.responses.lock().unwrap();
744                if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
745                    if let Err(e) = tx.try_send(item) {
746                        tracing::warn!("error sending response to handler: {e}");
747                    }
748                } else {
749                    tracing::warn!(?req_id, "could not find req handler");
750                }
751            } else {
752                tracing::warn!("unexpected message/channel: {chan}");
753            }
754        }
755    }
756
757    fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
759        let req: RequestIn = rmp_serde::from_slice(&item)?;
760        if req.node_id == self.uid {
761            return Ok(());
762        }
763
764        tracing::trace!(?req, "handling request");
765        let Some(opts) = req.opts else {
766            tracing::warn!(?req, "request is missing options");
767            return Ok(());
768        };
769
770        match req.r#type {
771            RequestTypeIn::Broadcast(p) => self.recv_broadcast(opts, p),
772            RequestTypeIn::BroadcastWithAck(p) => {
773                self.clone()
774                    .recv_broadcast_with_ack(req.node_id, req.id, p, opts)
775            }
776            RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(opts),
777            RequestTypeIn::AllRooms => self.recv_rooms(req.node_id, req.id, opts),
778            RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(opts, rooms),
779            RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(opts, rooms),
780            RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req.node_id, req.id, opts),
781            _ => (),
782        };
783        Ok(())
784    }
785
786    fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
787        if let Err(e) = self.local.broadcast(packet, opts) {
788            let ns = self.local.path();
789            tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
790        }
791    }
792
793    fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
794        if let Err(e) = self.local.disconnect_socket(opts) {
795            let ns = self.local.path();
796            tracing::warn!(
797                ?self.uid,
798                ?ns,
799                "remote request disconnect sockets handler: {:?}",
800                e
801            );
802        }
803    }
804
805    fn recv_broadcast_with_ack(
806        self: Arc<Self>,
807        origin: Uid,
808        req_id: Sid,
809        packet: Packet,
810        opts: BroadcastOptions,
811    ) {
812        let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
813        tokio::spawn(async move {
814            let on_err = |err| {
815                let ns = self.local.path();
816                tracing::warn!(
817                    ?origin,
818                    ?ns,
819                    "remote request broadcast with ack handler errors: {:?}",
820                    err
821                );
822            };
823            let res = Response {
826                r#type: ResponseType::<()>::BroadcastAckCount(count),
827                node_id: self.uid,
828            };
829            if let Err(err) = self.send_res(origin, req_id, res).await {
830                on_err(err);
831                return;
832            }
833
834            futures_util::pin_mut!(stream);
836            while let Some(ack) = stream.next().await {
837                let res = Response {
838                    r#type: ResponseType::BroadcastAck(ack),
839                    node_id: self.uid,
840                };
841                if let Err(err) = self.send_res(origin, req_id, res).await {
842                    on_err(err);
843                    return;
844                }
845            }
846        });
847    }
848
849    fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
850        let rooms = self.local.rooms(opts);
851        let res = Response {
852            r#type: ResponseType::<()>::AllRooms(rooms),
853            node_id: self.uid,
854        };
855        let fut = self.send_res(origin, req_id, res);
856        let ns = self.local.path().clone();
857        let uid = self.uid;
858        tokio::spawn(async move {
859            if let Err(err) = fut.await {
860                tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
861            }
862        });
863    }
864
865    fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
866        self.local.add_sockets(opts, rooms);
867    }
868
869    fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
870        self.local.del_sockets(opts, rooms);
871    }
872    fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
873        let sockets = self.local.fetch_sockets(opts);
874        let res = Response {
875            node_id: self.uid,
876            r#type: ResponseType::FetchSockets(sockets),
877        };
878        let fut = self.send_res(origin, req_id, res);
879        let ns = self.local.path().clone();
880        let uid = self.uid;
881        tokio::spawn(async move {
882            if let Err(err) = fut.await {
883                tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
884            }
885        });
886    }
887
888    async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
889        tracing::trace!(?req, "sending request");
890        let req = rmp_serde::to_vec(&req)?;
891        let chan = self.get_req_chan(target_uid);
892        self.driver
893            .publish(chan, req)
894            .await
895            .map_err(Error::from_driver)?;
896
897        Ok(())
898    }
899
900    fn send_res<D: Serialize + fmt::Debug>(
901        &self,
902        req_node_id: Uid,
903        req_id: Sid,
904        res: Response<D>,
905    ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
906        let chan = self.get_res_chan(req_node_id);
907        tracing::trace!(?res, "sending response to {}", &chan);
908        let res = rmp_serde::to_vec(&(req_id, res));
912        let driver = self.driver.clone();
913        async move {
914            driver
915                .publish(chan, res?)
916                .await
917                .map_err(Error::from_driver)?;
918            Ok(())
919        }
920    }
921
922    async fn get_res<D: DeserializeOwned + fmt::Debug>(
924        &self,
925        req_id: Sid,
926        response_type: ResponseTypeId,
927        target_uid: Option<Uid>,
928    ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
929        let remote_serv_cnt = if target_uid.is_none() {
931            self.server_count().await?.saturating_sub(1) as usize
932        } else {
933            1
934        };
935        let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
936        self.responses.lock().unwrap().insert(req_id, tx);
937        let stream = MessageStream::new(rx)
938            .filter_map(|item| {
939                let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
940                    Ok((_, data)) => Some(data),
941                    Err(e) => {
942                        tracing::warn!("error decoding response: {e}");
943                        None
944                    }
945                };
946                future::ready(data)
947            })
948            .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
949            .take(remote_serv_cnt)
950            .take_until(time::sleep(self.config.request_timeout));
951        let stream = DropStream::new(stream, self.responses.clone(), req_id);
952        Ok(stream)
953    }
954
955    #[inline]
957    async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
958        tracing::trace!(?pat, "subscribing to");
959        self.driver
960            .subscribe(pat, self.config.stream_buffer)
961            .await
962            .map_err(InitError::Driver)
963    }
964}
965
966fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
969    if path.is_empty() || path.contains('#') {
970        Err(InitError::MalformedNamespace)
971    } else {
972        Ok(())
973    }
974}
975
976pub fn read_req_id(data: &[u8]) -> Option<Sid> {
978    use std::str::FromStr;
979    let mut rd = data;
980    let len = rmp::decode::read_array_len(&mut rd).ok()?;
981    if len < 1 {
982        return None;
983    }
984
985    let mut buff = [0u8; Sid::ZERO.as_str().len()];
986    let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?;
987    Sid::from_str(str).ok()
988}
989
990#[cfg(test)]
991mod tests {
992    use super::*;
993    use futures_util::stream::{self, FusedStream, StreamExt};
994    use socketioxide_core::{Str, Value, adapter::AckStreamItem};
995    use std::convert::Infallible;
996
997    #[derive(Clone)]
998    struct StubDriver;
999    impl Driver for StubDriver {
1000        type Error = Infallible;
1001
1002        async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
1003            Ok(())
1004        }
1005
1006        async fn subscribe(
1007            &self,
1008            _: String,
1009            _: usize,
1010        ) -> Result<MessageStream<ChanItem>, Self::Error> {
1011            Ok(MessageStream::new_empty())
1012        }
1013
1014        async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
1015            Ok(())
1016        }
1017
1018        async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
1019            Ok(0)
1020        }
1021    }
1022    fn new_stub_ack_stream(
1023        remote: MessageStream<Vec<u8>>,
1024        timeout: Duration,
1025    ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1026        AckStream::new(
1027            stream::empty::<AckStreamItem<()>>(),
1028            remote,
1029            timeout,
1030            2,
1031            Sid::new(),
1032            Arc::new(Mutex::new(HashMap::new())),
1033        )
1034    }
1035
1036    #[tokio::test]
1038    async fn ack_stream() {
1039        let (tx, rx) = tokio::sync::mpsc::channel(255);
1040        let remote = MessageStream::new(rx);
1041        let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1042        let node_id = Uid::new();
1043        let req_id = Sid::new();
1044
1045        let ack_cnt_res = Response::<()> {
1047            node_id,
1048            r#type: ResponseType::BroadcastAckCount(2),
1049        };
1050        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1051            .unwrap();
1052        tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1053            .unwrap();
1054
1055        let ack_res = Response::<String> {
1056            node_id,
1057            r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1058        };
1059        for _ in 0..4 {
1060            tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1061                .unwrap();
1062        }
1063        futures_util::pin_mut!(stream);
1064        for _ in 0..4 {
1065            assert!(stream.next().await.is_some());
1066        }
1067        assert!(stream.is_terminated());
1068    }
1069
1070    #[tokio::test]
1071    async fn ack_stream_timeout() {
1072        let (tx, rx) = tokio::sync::mpsc::channel(255);
1073        let remote = MessageStream::new(rx);
1074        let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1075        let node_id = Uid::new();
1076        let req_id = Sid::new();
1077        let ack_cnt_res = Response::<()> {
1079            node_id,
1080            r#type: ResponseType::BroadcastAckCount(2),
1081        };
1082        tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1083            .unwrap();
1084
1085        futures_util::pin_mut!(stream);
1086        tokio::time::sleep(Duration::from_millis(50)).await;
1087        assert!(stream.next().await.is_none());
1088        assert!(stream.is_terminated());
1089    }
1090
1091    #[tokio::test]
1092    async fn ack_stream_drop() {
1093        let (tx, rx) = tokio::sync::mpsc::channel(255);
1094        let remote = MessageStream::new(rx);
1095        let handlers = Arc::new(Mutex::new(HashMap::new()));
1096        let id = Sid::new();
1097        handlers.lock().unwrap().insert(id, tx);
1098        let stream = AckStream::new(
1099            stream::empty::<AckStreamItem<()>>(),
1100            remote,
1101            Duration::from_secs(10),
1102            2,
1103            id,
1104            handlers.clone(),
1105        );
1106        drop(stream);
1107        assert!(handlers.lock().unwrap().is_empty(),);
1108    }
1109
1110    #[test]
1111    fn check_ns_error() {
1112        assert!(matches!(
1113            check_ns::<StubDriver>("#"),
1114            Err(InitError::MalformedNamespace)
1115        ));
1116        assert!(matches!(
1117            check_ns::<StubDriver>(""),
1118            Err(InitError::MalformedNamespace)
1119        ));
1120    }
1121}