1#![cfg_attr(docsrs, feature(doc_cfg))]
2#![warn(
3 clippy::all,
4 clippy::todo,
5 clippy::empty_enum,
6 clippy::mem_forget,
7 clippy::unused_self,
8 clippy::filter_map_next,
9 clippy::needless_continue,
10 clippy::needless_borrow,
11 clippy::match_wildcard_for_single_variants,
12 clippy::if_let_mutex,
13 clippy::await_holding_lock,
14 clippy::match_on_vec_items,
15 clippy::imprecise_flops,
16 clippy::suboptimal_flops,
17 clippy::lossy_float_literal,
18 clippy::rest_pat_in_fully_bound_structs,
19 clippy::fn_params_excessive_bools,
20 clippy::exit,
21 clippy::inefficient_to_string,
22 clippy::linkedlist,
23 clippy::macro_use_imports,
24 clippy::option_option,
25 clippy::verbose_file_reads,
26 clippy::unnested_or_patterns,
27 rust_2018_idioms,
28 future_incompatible,
29 nonstandard_style,
30 missing_docs
31)]
32
33use std::{
171 borrow::Cow,
172 collections::HashMap,
173 fmt,
174 future::{self, Future},
175 pin::Pin,
176 sync::{Arc, Mutex},
177 task::{Context, Poll},
178 time::Duration,
179};
180
181use drivers::{ChanItem, Driver, MessageStream};
182use futures_core::Stream;
183use futures_util::StreamExt;
184use request::{
185 read_req_id, RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType,
186};
187use serde::{de::DeserializeOwned, Serialize};
188use socketioxide_core::{
189 adapter::{
190 BroadcastFlags, BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter,
191 RemoteSocketData, Room, RoomParam, SocketEmitter, Spawnable,
192 },
193 errors::{AdapterError, BroadcastError},
194 packet::Packet,
195 Sid, Uid,
196};
197use stream::{AckStream, DropStream};
198use tokio::{sync::mpsc, time};
199
200pub mod drivers;
203
204mod request;
205mod stream;
206
207#[derive(thiserror::Error)]
209pub enum Error<R: Driver> {
210 #[error("driver error: {0}")]
212 Driver(R::Error),
213 #[error("packet encoding error: {0}")]
215 Decode(#[from] rmp_serde::decode::Error),
216 #[error("packet decoding error: {0}")]
218 Encode(#[from] rmp_serde::encode::Error),
219}
220
221impl<R: Driver> Error<R> {
222 fn from_driver(err: R::Error) -> Self {
223 Self::Driver(err)
224 }
225}
226impl<R: Driver> fmt::Debug for Error<R> {
227 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
228 match self {
229 Self::Driver(err) => write!(f, "Driver error: {:?}", err),
230 Self::Decode(err) => write!(f, "Decode error: {:?}", err),
231 Self::Encode(err) => write!(f, "Encode error: {:?}", err),
232 }
233 }
234}
235
236impl<R: Driver> From<Error<R>> for AdapterError {
237 fn from(err: Error<R>) -> Self {
238 AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct RedisAdapterConfig {
245 pub request_timeout: Duration,
248
249 pub prefix: Cow<'static, str>,
251
252 pub ack_response_buffer: usize,
257
258 pub stream_buffer: usize,
262}
263impl RedisAdapterConfig {
264 pub fn new() -> Self {
266 Self::default()
267 }
268 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
270 self.request_timeout = timeout;
271 self
272 }
273
274 pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
276 self.prefix = prefix.into();
277 self
278 }
279
280 pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
285 assert!(buffer > 0, "buffer size must be greater than 0");
286 self.ack_response_buffer = buffer;
287 self
288 }
289
290 pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
294 assert!(buffer > 0, "buffer size must be greater than 0");
295 self.stream_buffer = buffer;
296 self
297 }
298}
299
300impl Default for RedisAdapterConfig {
301 fn default() -> Self {
302 Self {
303 request_timeout: Duration::from_secs(5),
304 prefix: Cow::Borrowed("socket.io"),
305 ack_response_buffer: 255,
306 stream_buffer: 1024,
307 }
308 }
309}
310
311#[derive(Debug)]
314pub struct RedisAdapterCtr<R> {
315 driver: R,
316 config: RedisAdapterConfig,
317}
318
319#[cfg(feature = "redis")]
320impl RedisAdapterCtr<drivers::redis::RedisDriver> {
321 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
323 pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
324 Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
325 }
326 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
328 pub async fn new_with_redis_config(
329 client: &redis::Client,
330 config: RedisAdapterConfig,
331 ) -> redis::RedisResult<Self> {
332 let driver = drivers::redis::RedisDriver::new(client).await?;
333 Ok(Self::new_with_driver(driver, config))
334 }
335}
336#[cfg(feature = "redis-cluster")]
337impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
338 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
340 pub async fn new_with_cluster(
341 client: &redis::cluster::ClusterClient,
342 ) -> redis::RedisResult<Self> {
343 Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
344 }
345
346 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
348 pub async fn new_with_cluster_config(
349 client: &redis::cluster::ClusterClient,
350 config: RedisAdapterConfig,
351 ) -> redis::RedisResult<Self> {
352 let driver = drivers::redis::ClusterDriver::new(client).await?;
353 Ok(Self::new_with_driver(driver, config))
354 }
355}
356#[cfg(feature = "fred")]
357impl RedisAdapterCtr<drivers::fred::FredDriver> {
358 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
360 pub async fn new_with_fred(
361 client: fred::clients::SubscriberClient,
362 ) -> fred::prelude::FredResult<Self> {
363 Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
364 }
365 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
367 pub async fn new_with_fred_config(
368 client: fred::clients::SubscriberClient,
369 config: RedisAdapterConfig,
370 ) -> fred::prelude::FredResult<Self> {
371 let driver = drivers::fred::FredDriver::new(client).await?;
372 Ok(Self::new_with_driver(driver, config))
373 }
374}
375impl<R: Driver> RedisAdapterCtr<R> {
376 pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
381 RedisAdapterCtr { driver, config }
382 }
383}
384
385pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
386
387#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
389#[cfg(feature = "fred")]
390pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
391
392#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
394#[cfg(feature = "redis")]
395pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
396
397#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
399#[cfg(feature = "redis-cluster")]
400pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
401
402pub struct CustomRedisAdapter<E, R> {
407 driver: R,
410 config: RedisAdapterConfig,
412 uid: Uid,
414 local: CoreLocalAdapter<E>,
416 req_chan: String,
419 responses: Arc<Mutex<ResponseHandlers>>,
421}
422
423impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
424impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
425 type Error = Error<R>;
426 type State = RedisAdapterCtr<R>;
427 type AckStream = AckStream<E::AckStream>;
428 type InitRes = InitRes<R>;
429
430 fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
431 let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
432 let uid = local.server_id();
433 Self {
434 local,
435 req_chan,
436 uid,
437 driver: state.driver.clone(),
438 config: state.config.clone(),
439 responses: Arc::new(Mutex::new(HashMap::new())),
440 }
441 }
442
443 fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
444 let fut = async move {
445 check_ns(self.local.path())?;
446 let global_stream = self.subscribe(self.req_chan.clone()).await?;
447 let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
448 let response_chan = format!(
449 "{}-response#{}#{}#",
450 &self.config.prefix,
451 self.local.path(),
452 self.uid
453 );
454
455 let response_stream = self.subscribe(response_chan.clone()).await?;
456 let stream = futures_util::stream::select(global_stream, specific_stream);
457 let stream = futures_util::stream::select(stream, response_stream);
458 tokio::spawn(self.pipe_stream(stream, response_chan));
459 on_success();
460 Ok(())
461 };
462 InitRes(Box::pin(fut))
463 }
464
465 async fn close(&self) -> Result<(), Self::Error> {
466 let response_chan = format!(
467 "{}-response#{}#{}#",
468 &self.config.prefix,
469 self.local.path(),
470 self.uid
471 );
472 tokio::try_join!(
473 self.driver.unsubscribe(self.req_chan.clone()),
474 self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
475 self.driver.unsubscribe(response_chan)
476 )
477 .map_err(Error::from_driver)?;
478
479 Ok(())
480 }
481
482 async fn server_count(&self) -> Result<u16, Self::Error> {
484 let count = self
485 .driver
486 .num_serv(&self.req_chan)
487 .await
488 .map_err(Error::from_driver)?;
489
490 Ok(count)
491 }
492
493 async fn broadcast(
495 &self,
496 packet: Packet,
497 opts: BroadcastOptions,
498 ) -> Result<(), BroadcastError> {
499 if !is_local_op(self.uid, &opts) {
500 let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
501 self.send_req(req, opts.server_id)
502 .await
503 .map_err(AdapterError::from)?;
504 }
505
506 self.local.broadcast(packet, opts)?;
507 Ok(())
508 }
509
510 async fn broadcast_with_ack(
538 &self,
539 packet: Packet,
540 opts: BroadcastOptions,
541 timeout: Option<Duration>,
542 ) -> Result<Self::AckStream, Self::Error> {
543 if is_local_op(self.uid, &opts) {
544 tracing::debug!(?opts, "broadcast with ack is local");
545 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
546 let stream = AckStream::new_local(local);
547 return Ok(stream);
548 }
549 let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
550 let req_id = req.id;
551
552 let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
553
554 let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
555 self.responses.lock().unwrap().insert(req_id, tx);
556 let remote = MessageStream::new(rx);
557
558 self.send_req(req, opts.server_id).await?;
559 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
560
561 Ok(AckStream::new(
562 local,
563 remote,
564 self.config.request_timeout,
565 remote_serv_cnt,
566 req_id,
567 self.responses.clone(),
568 ))
569 }
570
571 async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
572 if !is_local_op(self.uid, &opts) {
573 let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
574 self.send_req(req, opts.server_id)
575 .await
576 .map_err(AdapterError::from)?;
577 }
578 self.local
579 .disconnect_socket(opts)
580 .map_err(BroadcastError::Socket)?;
581
582 Ok(())
583 }
584
585 async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
586 const PACKET_IDX: u8 = 2;
587
588 if is_local_op(self.uid, &opts) {
589 return Ok(self.local.rooms(opts).into_iter().collect());
590 }
591 let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
592 let req_id = req.id;
593
594 let stream = self
597 .get_res::<()>(req_id, PACKET_IDX, opts.server_id)
598 .await?;
599 self.send_req(req, opts.server_id).await?;
600 let local = self.local.rooms(opts);
601 let rooms = stream
602 .filter_map(|item| future::ready(item.into_rooms()))
603 .fold(local, |mut acc, item| async move {
604 acc.extend(item);
605 acc
606 })
607 .await;
608 Ok(Vec::from_iter(rooms))
609 }
610
611 async fn add_sockets(
612 &self,
613 opts: BroadcastOptions,
614 rooms: impl RoomParam,
615 ) -> Result<(), Self::Error> {
616 let rooms: Vec<Room> = rooms.into_room_iter().collect();
617 if !is_local_op(self.uid, &opts) {
618 let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
619 self.send_req(req, opts.server_id).await?;
620 }
621 self.local.add_sockets(opts, rooms);
622 Ok(())
623 }
624
625 async fn del_sockets(
626 &self,
627 opts: BroadcastOptions,
628 rooms: impl RoomParam,
629 ) -> Result<(), Self::Error> {
630 let rooms: Vec<Room> = rooms.into_room_iter().collect();
631 if !is_local_op(self.uid, &opts) {
632 let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
633 self.send_req(req, opts.server_id).await?;
634 }
635 self.local.del_sockets(opts, rooms);
636 Ok(())
637 }
638
639 async fn fetch_sockets(
640 &self,
641 opts: BroadcastOptions,
642 ) -> Result<Vec<RemoteSocketData>, Self::Error> {
643 if is_local_op(self.uid, &opts) {
644 return Ok(self.local.fetch_sockets(opts));
645 }
646 const PACKET_IDX: u8 = 3;
647 let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
648 let req_id = req.id;
649 let remote = self
652 .get_res::<RemoteSocketData>(req_id, PACKET_IDX, opts.server_id)
653 .await?;
654
655 self.send_req(req, opts.server_id).await?;
656 let local = self.local.fetch_sockets(opts);
657 let sockets = remote
658 .filter_map(|item| future::ready(item.into_fetch_sockets()))
659 .fold(local, |mut acc, item| async move {
660 acc.extend(item);
661 acc
662 })
663 .await;
664 Ok(sockets)
665 }
666
667 fn get_local(&self) -> &CoreLocalAdapter<E> {
668 &self.local
669 }
670}
671
672#[derive(thiserror::Error)]
674pub enum InitError<D: Driver> {
675 #[error("driver error: {0}")]
677 Driver(D::Error),
678 #[error("malformed namespace path, it must not contain '#'")]
680 MalformedNamespace,
681}
682impl<D: Driver> fmt::Debug for InitError<D> {
683 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
684 match self {
685 Self::Driver(err) => fmt::Debug::fmt(err, f),
686 Self::MalformedNamespace => write!(f, "Malformed namespace path"),
687 }
688 }
689}
690#[must_use = "futures do nothing unless you `.await` or poll them"]
692pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
693
694impl<D: Driver> Future for InitRes<D> {
695 type Output = Result<(), InitError<D>>;
696
697 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
698 self.0.as_mut().poll(cx)
699 }
700}
701impl<D: Driver> Spawnable for InitRes<D> {
702 fn spawn(self) {
703 tokio::spawn(async move {
704 if let Err(e) = self.0.await {
705 tracing::error!("error initializing adapter: {e}");
706 }
707 });
708 }
709}
710
711impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
712 fn get_res_chan(&self, uid: Uid) -> String {
717 let path = self.local.path();
718 let prefix = &self.config.prefix;
719 format!("{}-response#{}#{}#", prefix, path, uid)
720 }
721 fn get_req_chan(&self, node_id: Option<Uid>) -> String {
726 match node_id {
727 Some(uid) => format!("{}{}#", self.req_chan, uid),
728 None => self.req_chan.clone(),
729 }
730 }
731
732 async fn pipe_stream(
733 self: Arc<Self>,
734 mut stream: impl Stream<Item = ChanItem> + Unpin,
735 response_chan: String,
736 ) {
737 while let Some((chan, item)) = stream.next().await {
738 if chan.starts_with(&self.req_chan) {
739 if let Err(e) = self.recv_req(item) {
740 let ns = self.local.path();
741 let uid = self.uid;
742 tracing::warn!(?uid, ?ns, "request handler error: {e}");
743 }
744 } else if chan == response_chan {
745 let req_id = read_req_id(&item);
746 tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
747 let handlers = self.responses.lock().unwrap();
748 if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
749 if let Err(e) = tx.try_send(item) {
750 tracing::warn!("error sending response to handler: {e}");
751 }
752 } else {
753 tracing::warn!(?req_id, "could not find req handler");
754 }
755 } else {
756 tracing::warn!("unexpected message/channel: {chan}");
757 }
758 }
759 }
760
761 fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
763 let req: RequestIn = rmp_serde::from_slice(&item)?;
764 if req.node_id == self.uid {
765 return Ok(());
766 }
767
768 tracing::trace!(?req, "handling request");
769
770 match req.r#type {
771 RequestTypeIn::Broadcast(p) => self.recv_broadcast(req.opts, p),
772 RequestTypeIn::BroadcastWithAck(_) => self.clone().recv_broadcast_with_ack(req),
773 RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(req),
774 RequestTypeIn::AllRooms => self.recv_rooms(req),
775 RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(req.opts, rooms),
776 RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(req.opts, rooms),
777 RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req),
778 };
779 Ok(())
780 }
781
782 fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
783 if let Err(e) = self.local.broadcast(packet, opts) {
784 let ns = self.local.path();
785 tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
786 }
787 }
788
789 fn recv_disconnect_sockets(&self, req: RequestIn) {
790 if let Err(e) = self.local.disconnect_socket(req.opts) {
791 let ns = self.local.path();
792 tracing::warn!(
793 ?self.uid,
794 ?ns,
795 "remote request disconnect sockets handler: {:?}",
796 e
797 );
798 }
799 }
800
801 fn recv_broadcast_with_ack(self: Arc<Self>, req: RequestIn) {
802 let packet = match req.r#type {
803 RequestTypeIn::BroadcastWithAck(p) => p,
804 _ => unreachable!(),
805 };
806 let (stream, count) = self.local.broadcast_with_ack(packet, req.opts, None);
807 tokio::spawn(async move {
808 let on_err = |err| {
809 let ns = self.local.path();
810 tracing::warn!(
811 ?self.uid,
812 ?ns,
813 "remote request broadcast with ack handler errors: {:?}",
814 err
815 );
816 };
817 let res = Response {
820 r#type: ResponseType::<()>::BroadcastAckCount(count),
821 node_id: self.uid,
822 };
823 if let Err(err) = self.send_res(req.node_id, req.id, res).await {
824 on_err(err);
825 return;
826 }
827
828 futures_util::pin_mut!(stream);
830 while let Some(ack) = stream.next().await {
831 let res = Response {
832 r#type: ResponseType::BroadcastAck(ack),
833 node_id: self.uid,
834 };
835 if let Err(err) = self.send_res(req.node_id, req.id, res).await {
836 on_err(err);
837 return;
838 }
839 }
840 });
841 }
842
843 fn recv_rooms(&self, req: RequestIn) {
844 let rooms = self.local.rooms(req.opts);
845 let res = Response {
846 r#type: ResponseType::<()>::AllRooms(rooms),
847 node_id: self.uid,
848 };
849 let fut = self.send_res(req.node_id, req.id, res);
850 let ns = self.local.path().clone();
851 let uid = self.uid;
852 tokio::spawn(async move {
853 if let Err(err) = fut.await {
854 tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
855 }
856 });
857 }
858
859 fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
860 self.local.add_sockets(opts, rooms);
861 }
862
863 fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
864 self.local.del_sockets(opts, rooms);
865 }
866 fn recv_fetch_sockets(&self, req: RequestIn) {
867 let sockets = self.local.fetch_sockets(req.opts);
868 let res = Response {
869 node_id: self.uid,
870 r#type: ResponseType::FetchSockets(sockets),
871 };
872 let fut = self.send_res(req.node_id, req.id, res);
873 let ns = self.local.path().clone();
874 let uid = self.uid;
875 tokio::spawn(async move {
876 if let Err(err) = fut.await {
877 tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
878 }
879 });
880 }
881
882 async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
883 tracing::trace!(?req, "sending request");
884 let req = rmp_serde::to_vec(&req)?;
885 let chan = self.get_req_chan(target_uid);
886 self.driver
887 .publish(chan, req)
888 .await
889 .map_err(Error::from_driver)?;
890
891 Ok(())
892 }
893
894 fn send_res<D: Serialize + fmt::Debug>(
895 &self,
896 req_node_id: Uid,
897 req_id: Sid,
898 res: Response<D>,
899 ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
900 let chan = self.get_res_chan(req_node_id);
901 tracing::trace!(?res, "sending response to {}", &chan);
902 let res = rmp_serde::to_vec(&(req_id, res));
906 let driver = self.driver.clone();
907 async move {
908 driver
909 .publish(chan, res?)
910 .await
911 .map_err(Error::from_driver)?;
912 Ok(())
913 }
914 }
915
916 async fn get_res<D: DeserializeOwned + fmt::Debug>(
918 &self,
919 req_id: Sid,
920 response_idx: u8,
921 target_uid: Option<Uid>,
922 ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
923 let remote_serv_cnt = if target_uid.is_none() {
925 self.server_count().await?.saturating_sub(1) as usize
926 } else {
927 1
928 };
929 let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
930 self.responses.lock().unwrap().insert(req_id, tx);
931 let stream = MessageStream::new(rx)
932 .filter_map(|item| {
933 let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
934 Ok((_, data)) => Some(data),
935 Err(e) => {
936 tracing::warn!("error decoding response: {e}");
937 None
938 }
939 };
940 future::ready(data)
941 })
942 .filter(move |item| future::ready(item.r#type.to_u8() == response_idx))
943 .take(remote_serv_cnt)
944 .take_until(time::sleep(self.config.request_timeout));
945 let stream = DropStream::new(stream, self.responses.clone(), req_id);
946 Ok(stream)
947 }
948
949 #[inline]
951 async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
952 tracing::trace!(?pat, "subscribing to");
953 self.driver
954 .subscribe(pat, self.config.stream_buffer)
955 .await
956 .map_err(InitError::Driver)
957 }
958}
959
960#[inline]
963fn is_local_op(uid: Uid, opts: &BroadcastOptions) -> bool {
964 if opts.has_flag(BroadcastFlags::Local)
965 || (!opts.has_flag(BroadcastFlags::Broadcast)
966 && opts.server_id == Some(uid)
967 && opts.rooms.is_empty()
968 && opts.sid.is_some())
969 {
970 tracing::debug!(?opts, "operation is local");
971 true
972 } else {
973 false
974 }
975}
976
977fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
980 if path.is_empty() || path.contains('#') {
981 Err(InitError::MalformedNamespace)
982 } else {
983 Ok(())
984 }
985}
986
987#[cfg(test)]
988mod tests {
989 use super::*;
990 use futures_util::stream::{self, FusedStream, StreamExt};
991 use socketioxide_core::{adapter::AckStreamItem, Str, Value};
992 use std::convert::Infallible;
993
994 #[derive(Clone)]
995 struct StubDriver;
996 impl Driver for StubDriver {
997 type Error = Infallible;
998
999 async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
1000 Ok(())
1001 }
1002
1003 async fn subscribe(
1004 &self,
1005 _: String,
1006 _: usize,
1007 ) -> Result<MessageStream<ChanItem>, Self::Error> {
1008 Ok(MessageStream::new_empty())
1009 }
1010
1011 async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
1012 Ok(())
1013 }
1014
1015 async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
1016 Ok(0)
1017 }
1018 }
1019 fn new_stub_ack_stream(
1020 remote: MessageStream<Vec<u8>>,
1021 timeout: Duration,
1022 ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1023 AckStream::new(
1024 stream::empty::<AckStreamItem<()>>(),
1025 remote,
1026 timeout,
1027 2,
1028 Sid::new(),
1029 Arc::new(Mutex::new(HashMap::new())),
1030 )
1031 }
1032
1033 #[tokio::test]
1035 async fn ack_stream() {
1036 let (tx, rx) = tokio::sync::mpsc::channel(255);
1037 let remote = MessageStream::new(rx);
1038 let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1039 let node_id = Uid::new();
1040 let req_id = Sid::new();
1041
1042 let ack_cnt_res = Response::<()> {
1044 node_id,
1045 r#type: ResponseType::BroadcastAckCount(2),
1046 };
1047 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1048 .unwrap();
1049 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1050 .unwrap();
1051
1052 let ack_res = Response::<String> {
1053 node_id,
1054 r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1055 };
1056 for _ in 0..4 {
1057 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1058 .unwrap();
1059 }
1060 futures_util::pin_mut!(stream);
1061 for _ in 0..4 {
1062 assert!(stream.next().await.is_some());
1063 }
1064 assert!(stream.is_terminated());
1065 }
1066
1067 #[tokio::test]
1068 async fn ack_stream_timeout() {
1069 let (tx, rx) = tokio::sync::mpsc::channel(255);
1070 let remote = MessageStream::new(rx);
1071 let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1072 let node_id = Uid::new();
1073 let req_id = Sid::new();
1074 let ack_cnt_res = Response::<()> {
1076 node_id,
1077 r#type: ResponseType::BroadcastAckCount(2),
1078 };
1079 tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1080 .unwrap();
1081
1082 futures_util::pin_mut!(stream);
1083 tokio::time::sleep(Duration::from_millis(50)).await;
1084 assert!(stream.next().await.is_none());
1085 assert!(stream.is_terminated());
1086 }
1087
1088 #[tokio::test]
1089 async fn ack_stream_drop() {
1090 let (tx, rx) = tokio::sync::mpsc::channel(255);
1091 let remote = MessageStream::new(rx);
1092 let handlers = Arc::new(Mutex::new(HashMap::new()));
1093 let id = Sid::new();
1094 handlers.lock().unwrap().insert(id, tx);
1095 let stream = AckStream::new(
1096 stream::empty::<AckStreamItem<()>>(),
1097 remote,
1098 Duration::from_secs(10),
1099 2,
1100 id,
1101 handlers.clone(),
1102 );
1103 drop(stream);
1104 assert!(handlers.lock().unwrap().is_empty(),);
1105 }
1106
1107 #[test]
1108 fn test_is_local_op() {
1109 let server_id = Uid::new();
1110 let remote = RemoteSocketData {
1111 id: Sid::new(),
1112 server_id,
1113 ns: "/".into(),
1114 };
1115 let opts = BroadcastOptions::new_remote(&remote);
1116 assert!(is_local_op(server_id, &opts));
1117 assert!(!is_local_op(Uid::new(), &opts));
1118 let opts = BroadcastOptions::new(Sid::new());
1119 assert!(!is_local_op(Uid::new(), &opts));
1120 }
1121
1122 #[test]
1123 fn check_ns_error() {
1124 assert!(matches!(
1125 check_ns::<StubDriver>("#"),
1126 Err(InitError::MalformedNamespace)
1127 ));
1128 assert!(matches!(
1129 check_ns::<StubDriver>(""),
1130 Err(InitError::MalformedNamespace)
1131 ));
1132 }
1133}