1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(
3 clippy::all,
4 clippy::todo,
5 clippy::empty_enum,
6 clippy::mem_forget,
7 clippy::unused_self,
8 clippy::filter_map_next,
9 clippy::needless_continue,
10 clippy::needless_borrow,
11 clippy::match_wildcard_for_single_variants,
12 clippy::if_let_mutex,
13 clippy::await_holding_lock,
14 clippy::match_on_vec_items,
15 clippy::imprecise_flops,
16 clippy::suboptimal_flops,
17 clippy::lossy_float_literal,
18 clippy::rest_pat_in_fully_bound_structs,
19 clippy::fn_params_excessive_bools,
20 clippy::exit,
21 clippy::inefficient_to_string,
22 clippy::linkedlist,
23 clippy::macro_use_imports,
24 clippy::option_option,
25 clippy::verbose_file_reads,
26 clippy::unnested_or_patterns,
27 rust_2018_idioms,
28 rust_2024_compatibility,
29 future_incompatible,
30 nonstandard_style,
31 missing_docs
32)]
33
34use std::{
172 borrow::Cow,
173 collections::HashMap,
174 fmt,
175 future::{self, Future},
176 pin::Pin,
177 sync::{Arc, Mutex},
178 task::{Context, Poll},
179 time::Duration,
180};
181
182use drivers::{ChanItem, Driver, MessageStream};
183use futures_core::Stream;
184use futures_util::StreamExt;
185use serde::{Serialize, de::DeserializeOwned};
186use socketioxide_core::adapter::remote_packet::{
187 RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
188};
189use socketioxide_core::{
190 Sid, Uid,
191 adapter::errors::{AdapterError, BroadcastError},
192 adapter::{
193 BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
194 RoomParam, SocketEmitter, Spawnable,
195 },
196 packet::Packet,
197};
198use stream::{AckStream, DropStream};
199use tokio::{sync::mpsc, time};
200
201pub mod drivers;
204
205mod stream;
206
207#[derive(thiserror::Error)]
209pub enum Error<R: Driver> {
210 #[error("driver error: {0}")]
212 Driver(R::Error),
213 #[error("packet encoding error: {0}")]
215 Decode(#[from] rmp_serde::decode::Error),
216 #[error("packet decoding error: {0}")]
218 Encode(#[from] rmp_serde::encode::Error),
219}
220
221impl<R: Driver> Error<R> {
222 fn from_driver(err: R::Error) -> Self {
223 Self::Driver(err)
224 }
225}
226impl<R: Driver> fmt::Debug for Error<R> {
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 match self {
229 Self::Driver(err) => write!(f, "Driver error: {:?}", err),
230 Self::Decode(err) => write!(f, "Decode error: {:?}", err),
231 Self::Encode(err) => write!(f, "Encode error: {:?}", err),
232 }
233 }
234}
235
236impl<R: Driver> From<Error<R>> for AdapterError {
237 fn from(err: Error<R>) -> Self {
238 AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct RedisAdapterConfig {
245 pub request_timeout: Duration,
248
249 pub prefix: Cow<'static, str>,
251
252 pub ack_response_buffer: usize,
257
258 pub stream_buffer: usize,
262}
263impl RedisAdapterConfig {
264 pub fn new() -> Self {
266 Self::default()
267 }
268 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
270 self.request_timeout = timeout;
271 self
272 }
273
274 pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
276 self.prefix = prefix.into();
277 self
278 }
279
280 pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
285 assert!(buffer > 0, "buffer size must be greater than 0");
286 self.ack_response_buffer = buffer;
287 self
288 }
289
290 pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
294 assert!(buffer > 0, "buffer size must be greater than 0");
295 self.stream_buffer = buffer;
296 self
297 }
298}
299
300impl Default for RedisAdapterConfig {
301 fn default() -> Self {
302 Self {
303 request_timeout: Duration::from_secs(5),
304 prefix: Cow::Borrowed("socket.io"),
305 ack_response_buffer: 255,
306 stream_buffer: 1024,
307 }
308 }
309}
310
311#[derive(Debug)]
314pub struct RedisAdapterCtr<R> {
315 driver: R,
316 config: RedisAdapterConfig,
317}
318
319#[cfg(feature = "redis")]
320impl RedisAdapterCtr<drivers::redis::RedisDriver> {
321 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
323 pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
324 Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
325 }
326 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
328 pub async fn new_with_redis_config(
329 client: &redis::Client,
330 config: RedisAdapterConfig,
331 ) -> redis::RedisResult<Self> {
332 let driver = drivers::redis::RedisDriver::new(client).await?;
333 Ok(Self::new_with_driver(driver, config))
334 }
335}
336#[cfg(feature = "redis-cluster")]
337impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
338 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
340 pub async fn new_with_cluster(
341 client: &redis::cluster::ClusterClient,
342 ) -> redis::RedisResult<Self> {
343 Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
344 }
345
346 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
348 pub async fn new_with_cluster_config(
349 client: &redis::cluster::ClusterClient,
350 config: RedisAdapterConfig,
351 ) -> redis::RedisResult<Self> {
352 let driver = drivers::redis::ClusterDriver::new(client).await?;
353 Ok(Self::new_with_driver(driver, config))
354 }
355}
356#[cfg(feature = "fred")]
357impl RedisAdapterCtr<drivers::fred::FredDriver> {
358 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
360 pub async fn new_with_fred(
361 client: fred::clients::SubscriberClient,
362 ) -> fred::prelude::FredResult<Self> {
363 Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
364 }
365 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
367 pub async fn new_with_fred_config(
368 client: fred::clients::SubscriberClient,
369 config: RedisAdapterConfig,
370 ) -> fred::prelude::FredResult<Self> {
371 let driver = drivers::fred::FredDriver::new(client).await?;
372 Ok(Self::new_with_driver(driver, config))
373 }
374}
375impl<R: Driver> RedisAdapterCtr<R> {
376 pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
381 RedisAdapterCtr { driver, config }
382 }
383}
384
385pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
386
387#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
389#[cfg(feature = "fred")]
390pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
391
392#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
394#[cfg(feature = "redis")]
395pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
396
397#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
399#[cfg(feature = "redis-cluster")]
400pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
401
402pub struct CustomRedisAdapter<E, R> {
407 driver: R,
410 config: RedisAdapterConfig,
412 uid: Uid,
414 local: CoreLocalAdapter<E>,
416 req_chan: String,
419 responses: Arc<Mutex<ResponseHandlers>>,
421}
422
423impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
424impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
425 type Error = Error<R>;
426 type State = RedisAdapterCtr<R>;
427 type AckStream = AckStream<E::AckStream>;
428 type InitRes = InitRes<R>;
429
430 fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
431 let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
432 let uid = local.server_id();
433 Self {
434 local,
435 req_chan,
436 uid,
437 driver: state.driver.clone(),
438 config: state.config.clone(),
439 responses: Arc::new(Mutex::new(HashMap::new())),
440 }
441 }
442
443 fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
444 let fut = async move {
445 check_ns(self.local.path())?;
446 let global_stream = self.subscribe(self.req_chan.clone()).await?;
447 let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
448 let response_chan = format!(
449 "{}-response#{}#{}#",
450 &self.config.prefix,
451 self.local.path(),
452 self.uid
453 );
454
455 let response_stream = self.subscribe(response_chan.clone()).await?;
456 let stream = futures_util::stream::select(global_stream, specific_stream);
457 let stream = futures_util::stream::select(stream, response_stream);
458 tokio::spawn(self.pipe_stream(stream, response_chan));
459 on_success();
460 Ok(())
461 };
462 InitRes(Box::pin(fut))
463 }
464
465 async fn close(&self) -> Result<(), Self::Error> {
466 let response_chan = format!(
467 "{}-response#{}#{}#",
468 &self.config.prefix,
469 self.local.path(),
470 self.uid
471 );
472 tokio::try_join!(
473 self.driver.unsubscribe(self.req_chan.clone()),
474 self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
475 self.driver.unsubscribe(response_chan)
476 )
477 .map_err(Error::from_driver)?;
478
479 Ok(())
480 }
481
482 async fn server_count(&self) -> Result<u16, Self::Error> {
484 let count = self
485 .driver
486 .num_serv(&self.req_chan)
487 .await
488 .map_err(Error::from_driver)?;
489
490 Ok(count)
491 }
492
493 async fn broadcast(
495 &self,
496 packet: Packet,
497 opts: BroadcastOptions,
498 ) -> Result<(), BroadcastError> {
499 if !opts.is_local(self.uid) {
500 let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
501 self.send_req(req, opts.server_id)
502 .await
503 .map_err(AdapterError::from)?;
504 }
505
506 self.local.broadcast(packet, opts)?;
507 Ok(())
508 }
509
510 async fn broadcast_with_ack(
538 &self,
539 packet: Packet,
540 opts: BroadcastOptions,
541 timeout: Option<Duration>,
542 ) -> Result<Self::AckStream, Self::Error> {
543 if opts.is_local(self.uid) {
544 tracing::debug!(?opts, "broadcast with ack is local");
545 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
546 let stream = AckStream::new_local(local);
547 return Ok(stream);
548 }
549 let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
550 let req_id = req.id;
551
552 let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
553
554 let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
555 self.responses.lock().unwrap().insert(req_id, tx);
556 let remote = MessageStream::new(rx);
557
558 self.send_req(req, opts.server_id).await?;
559 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
560
561 Ok(AckStream::new(
562 local,
563 remote,
564 self.config.request_timeout,
565 remote_serv_cnt,
566 req_id,
567 self.responses.clone(),
568 ))
569 }
570
571 async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
572 if !opts.is_local(self.uid) {
573 let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
574 self.send_req(req, opts.server_id)
575 .await
576 .map_err(AdapterError::from)?;
577 }
578 self.local
579 .disconnect_socket(opts)
580 .map_err(BroadcastError::Socket)?;
581
582 Ok(())
583 }
584
585 async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
586 if opts.is_local(self.uid) {
587 return Ok(self.local.rooms(opts).into_iter().collect());
588 }
589 let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
590 let req_id = req.id;
591
592 let stream = self
595 .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
596 .await?;
597 self.send_req(req, opts.server_id).await?;
598 let local = self.local.rooms(opts);
599 let rooms = stream
600 .filter_map(|item| future::ready(item.into_rooms()))
601 .fold(local, async |mut acc, item| {
602 acc.extend(item);
603 acc
604 })
605 .await;
606 Ok(Vec::from_iter(rooms))
607 }
608
609 async fn add_sockets(
610 &self,
611 opts: BroadcastOptions,
612 rooms: impl RoomParam,
613 ) -> Result<(), Self::Error> {
614 let rooms: Vec<Room> = rooms.into_room_iter().collect();
615 if !opts.is_local(self.uid) {
616 let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
617 self.send_req(req, opts.server_id).await?;
618 }
619 self.local.add_sockets(opts, rooms);
620 Ok(())
621 }
622
623 async fn del_sockets(
624 &self,
625 opts: BroadcastOptions,
626 rooms: impl RoomParam,
627 ) -> Result<(), Self::Error> {
628 let rooms: Vec<Room> = rooms.into_room_iter().collect();
629 if !opts.is_local(self.uid) {
630 let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
631 self.send_req(req, opts.server_id).await?;
632 }
633 self.local.del_sockets(opts, rooms);
634 Ok(())
635 }
636
637 async fn fetch_sockets(
638 &self,
639 opts: BroadcastOptions,
640 ) -> Result<Vec<RemoteSocketData>, Self::Error> {
641 if opts.is_local(self.uid) {
642 return Ok(self.local.fetch_sockets(opts));
643 }
644 let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
645 let req_id = req.id;
646 let remote = self
649 .get_res::<RemoteSocketData>(req_id, ResponseTypeId::FetchSockets, opts.server_id)
650 .await?;
651
652 self.send_req(req, opts.server_id).await?;
653 let local = self.local.fetch_sockets(opts);
654 let sockets = remote
655 .filter_map(|item| future::ready(item.into_fetch_sockets()))
656 .fold(local, async |mut acc, item| {
657 acc.extend(item);
658 acc
659 })
660 .await;
661 Ok(sockets)
662 }
663
664 fn get_local(&self) -> &CoreLocalAdapter<E> {
665 &self.local
666 }
667}
668
669#[derive(thiserror::Error)]
671pub enum InitError<D: Driver> {
672 #[error("driver error: {0}")]
674 Driver(D::Error),
675 #[error("malformed namespace path, it must not contain '#'")]
677 MalformedNamespace,
678}
679impl<D: Driver> fmt::Debug for InitError<D> {
680 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
681 match self {
682 Self::Driver(err) => fmt::Debug::fmt(err, f),
683 Self::MalformedNamespace => write!(f, "Malformed namespace path"),
684 }
685 }
686}
687#[must_use = "futures do nothing unless you `.await` or poll them"]
689pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
690
691impl<D: Driver> Future for InitRes<D> {
692 type Output = Result<(), InitError<D>>;
693
694 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
695 self.0.as_mut().poll(cx)
696 }
697}
698impl<D: Driver> Spawnable for InitRes<D> {
699 fn spawn(self) {
700 tokio::spawn(async move {
701 if let Err(e) = self.0.await {
702 tracing::error!("error initializing adapter: {e}");
703 }
704 });
705 }
706}
707
708impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
709 fn get_res_chan(&self, uid: Uid) -> String {
714 let path = self.local.path();
715 let prefix = &self.config.prefix;
716 format!("{}-response#{}#{}#", prefix, path, uid)
717 }
718 fn get_req_chan(&self, node_id: Option<Uid>) -> String {
723 match node_id {
724 Some(uid) => format!("{}{}#", self.req_chan, uid),
725 None => self.req_chan.clone(),
726 }
727 }
728
729 async fn pipe_stream(
730 self: Arc<Self>,
731 mut stream: impl Stream<Item = ChanItem> + Unpin,
732 response_chan: String,
733 ) {
734 while let Some((chan, item)) = stream.next().await {
735 if chan.starts_with(&self.req_chan) {
736 if let Err(e) = self.recv_req(item) {
737 let ns = self.local.path();
738 let uid = self.uid;
739 tracing::warn!(?uid, ?ns, "request handler error: {e}");
740 }
741 } else if chan == response_chan {
742 let req_id = read_req_id(&item);
743 tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
744 let handlers = self.responses.lock().unwrap();
745 if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
746 if let Err(e) = tx.try_send(item) {
747 tracing::warn!("error sending response to handler: {e}");
748 }
749 } else {
750 tracing::warn!(?req_id, "could not find req handler");
751 }
752 } else {
753 tracing::warn!("unexpected message/channel: {chan}");
754 }
755 }
756 }
757
758 fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
760 let req: RequestIn = rmp_serde::from_slice(&item)?;
761 if req.node_id == self.uid {
762 return Ok(());
763 }
764
765 tracing::trace!(?req, "handling request");
766 let Some(opts) = req.opts else {
767 tracing::warn!(?req, "request is missing options");
768 return Ok(());
769 };
770
771 match req.r#type {
772 RequestTypeIn::Broadcast(p) => self.recv_broadcast(opts, p),
773 RequestTypeIn::BroadcastWithAck(p) => {
774 self.clone()
775 .recv_broadcast_with_ack(req.node_id, req.id, p, opts)
776 }
777 RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(opts),
778 RequestTypeIn::AllRooms => self.recv_rooms(req.node_id, req.id, opts),
779 RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(opts, rooms),
780 RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(opts, rooms),
781 RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req.node_id, req.id, opts),
782 _ => (),
783 };
784 Ok(())
785 }
786
787 fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
788 if let Err(e) = self.local.broadcast(packet, opts) {
789 let ns = self.local.path();
790 tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
791 }
792 }
793
794 fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
795 if let Err(e) = self.local.disconnect_socket(opts) {
796 let ns = self.local.path();
797 tracing::warn!(
798 ?self.uid,
799 ?ns,
800 "remote request disconnect sockets handler: {:?}",
801 e
802 );
803 }
804 }
805
806 fn recv_broadcast_with_ack(
807 self: Arc<Self>,
808 origin: Uid,
809 req_id: Sid,
810 packet: Packet,
811 opts: BroadcastOptions,
812 ) {
813 let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
814 tokio::spawn(async move {
815 let on_err = |err| {
816 let ns = self.local.path();
817 tracing::warn!(
818 ?origin,
819 ?ns,
820 "remote request broadcast with ack handler errors: {:?}",
821 err
822 );
823 };
824 let res = Response {
827 r#type: ResponseType::<()>::BroadcastAckCount(count),
828 node_id: self.uid,
829 };
830 if let Err(err) = self.send_res(origin, req_id, res).await {
831 on_err(err);
832 return;
833 }
834
835 futures_util::pin_mut!(stream);
837 while let Some(ack) = stream.next().await {
838 let res = Response {
839 r#type: ResponseType::BroadcastAck(ack),
840 node_id: self.uid,
841 };
842 if let Err(err) = self.send_res(origin, req_id, res).await {
843 on_err(err);
844 return;
845 }
846 }
847 });
848 }
849
850 fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
851 let rooms = self.local.rooms(opts);
852 let res = Response {
853 r#type: ResponseType::<()>::AllRooms(rooms),
854 node_id: self.uid,
855 };
856 let fut = self.send_res(origin, req_id, res);
857 let ns = self.local.path().clone();
858 let uid = self.uid;
859 tokio::spawn(async move {
860 if let Err(err) = fut.await {
861 tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
862 }
863 });
864 }
865
866 fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
867 self.local.add_sockets(opts, rooms);
868 }
869
870 fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
871 self.local.del_sockets(opts, rooms);
872 }
873 fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
874 let sockets = self.local.fetch_sockets(opts);
875 let res = Response {
876 node_id: self.uid,
877 r#type: ResponseType::FetchSockets(sockets),
878 };
879 let fut = self.send_res(origin, req_id, res);
880 let ns = self.local.path().clone();
881 let uid = self.uid;
882 tokio::spawn(async move {
883 if let Err(err) = fut.await {
884 tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
885 }
886 });
887 }
888
889 async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
890 tracing::trace!(?req, "sending request");
891 let req = rmp_serde::to_vec(&req)?;
892 let chan = self.get_req_chan(target_uid);
893 self.driver
894 .publish(chan, req)
895 .await
896 .map_err(Error::from_driver)?;
897
898 Ok(())
899 }
900
901 fn send_res<D: Serialize + fmt::Debug>(
902 &self,
903 req_node_id: Uid,
904 req_id: Sid,
905 res: Response<D>,
906 ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
907 let chan = self.get_res_chan(req_node_id);
908 tracing::trace!(?res, "sending response to {}", &chan);
909 let res = rmp_serde::to_vec(&(req_id, res));
913 let driver = self.driver.clone();
914 async move {
915 driver
916 .publish(chan, res?)
917 .await
918 .map_err(Error::from_driver)?;
919 Ok(())
920 }
921 }
922
923 async fn get_res<D: DeserializeOwned + fmt::Debug>(
925 &self,
926 req_id: Sid,
927 response_type: ResponseTypeId,
928 target_uid: Option<Uid>,
929 ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
930 let remote_serv_cnt = if target_uid.is_none() {
932 self.server_count().await?.saturating_sub(1) as usize
933 } else {
934 1
935 };
936 let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
937 self.responses.lock().unwrap().insert(req_id, tx);
938 let stream = MessageStream::new(rx)
939 .filter_map(|item| {
940 let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
941 Ok((_, data)) => Some(data),
942 Err(e) => {
943 tracing::warn!("error decoding response: {e}");
944 None
945 }
946 };
947 future::ready(data)
948 })
949 .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
950 .take(remote_serv_cnt)
951 .take_until(time::sleep(self.config.request_timeout));
952 let stream = DropStream::new(stream, self.responses.clone(), req_id);
953 Ok(stream)
954 }
955
956 #[inline]
958 async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
959 tracing::trace!(?pat, "subscribing to");
960 self.driver
961 .subscribe(pat, self.config.stream_buffer)
962 .await
963 .map_err(InitError::Driver)
964 }
965}
966
967fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
970 if path.is_empty() || path.contains('#') {
971 Err(InitError::MalformedNamespace)
972 } else {
973 Ok(())
974 }
975}
976
977pub fn read_req_id(data: &[u8]) -> Option<Sid> {
979 use std::str::FromStr;
980 let mut rd = data;
981 let len = rmp::decode::read_array_len(&mut rd).ok()?;
982 if len < 1 {
983 return None;
984 }
985
986 let mut buff = [0u8; Sid::ZERO.as_str().len()];
987 let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?;
988 Sid::from_str(str).ok()
989}
990
991#[cfg(test)]
992mod tests {
993 use super::*;
994 use futures_util::stream::{self, FusedStream, StreamExt};
995 use socketioxide_core::{Str, Value, adapter::AckStreamItem};
996 use std::convert::Infallible;
997
998 #[derive(Clone)]
999 struct StubDriver;
1000 impl Driver for StubDriver {
1001 type Error = Infallible;
1002
1003 async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
1004 Ok(())
1005 }
1006
1007 async fn subscribe(
1008 &self,
1009 _: String,
1010 _: usize,
1011 ) -> Result<MessageStream<ChanItem>, Self::Error> {
1012 Ok(MessageStream::new_empty())
1013 }
1014
1015 async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
1016 Ok(())
1017 }
1018
1019 async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
1020 Ok(0)
1021 }
1022 }
1023 fn new_stub_ack_stream(
1024 remote: MessageStream<Vec<u8>>,
1025 timeout: Duration,
1026 ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1027 AckStream::new(
1028 stream::empty::<AckStreamItem<()>>(),
1029 remote,
1030 timeout,
1031 2,
1032 Sid::new(),
1033 Arc::new(Mutex::new(HashMap::new())),
1034 )
1035 }
1036
1037 #[tokio::test]
1039 async fn ack_stream() {
1040 let (tx, rx) = tokio::sync::mpsc::channel(255);
1041 let remote = MessageStream::new(rx);
1042 let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1043 let node_id = Uid::new();
1044 let req_id = Sid::new();
1045
1046 let ack_cnt_res = Response::<()> {
1048 node_id,
1049 r#type: ResponseType::BroadcastAckCount(2),
1050 };
1051 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1052 .unwrap();
1053 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1054 .unwrap();
1055
1056 let ack_res = Response::<String> {
1057 node_id,
1058 r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1059 };
1060 for _ in 0..4 {
1061 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1062 .unwrap();
1063 }
1064 futures_util::pin_mut!(stream);
1065 for _ in 0..4 {
1066 assert!(stream.next().await.is_some());
1067 }
1068 assert!(stream.is_terminated());
1069 }
1070
1071 #[tokio::test]
1072 async fn ack_stream_timeout() {
1073 let (tx, rx) = tokio::sync::mpsc::channel(255);
1074 let remote = MessageStream::new(rx);
1075 let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1076 let node_id = Uid::new();
1077 let req_id = Sid::new();
1078 let ack_cnt_res = Response::<()> {
1080 node_id,
1081 r#type: ResponseType::BroadcastAckCount(2),
1082 };
1083 tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1084 .unwrap();
1085
1086 futures_util::pin_mut!(stream);
1087 tokio::time::sleep(Duration::from_millis(50)).await;
1088 assert!(stream.next().await.is_none());
1089 assert!(stream.is_terminated());
1090 }
1091
1092 #[tokio::test]
1093 async fn ack_stream_drop() {
1094 let (tx, rx) = tokio::sync::mpsc::channel(255);
1095 let remote = MessageStream::new(rx);
1096 let handlers = Arc::new(Mutex::new(HashMap::new()));
1097 let id = Sid::new();
1098 handlers.lock().unwrap().insert(id, tx);
1099 let stream = AckStream::new(
1100 stream::empty::<AckStreamItem<()>>(),
1101 remote,
1102 Duration::from_secs(10),
1103 2,
1104 id,
1105 handlers.clone(),
1106 );
1107 drop(stream);
1108 assert!(handlers.lock().unwrap().is_empty(),);
1109 }
1110
1111 #[test]
1112 fn check_ns_error() {
1113 assert!(matches!(
1114 check_ns::<StubDriver>("#"),
1115 Err(InitError::MalformedNamespace)
1116 ));
1117 assert!(matches!(
1118 check_ns::<StubDriver>(""),
1119 Err(InitError::MalformedNamespace)
1120 ));
1121 }
1122}