1#![warn(missing_docs)]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3
4use std::{
142 borrow::Cow,
143 collections::HashMap,
144 fmt,
145 future::{self, Future},
146 pin::Pin,
147 sync::{Arc, Mutex},
148 task::{Context, Poll},
149 time::Duration,
150};
151
152use drivers::{ChanItem, Driver, MessageStream};
153use futures_core::Stream;
154use futures_util::StreamExt;
155use serde::{Serialize, de::DeserializeOwned};
156use socketioxide_core::adapter::remote_packet::{
157 RequestIn, RequestOut, RequestTypeIn, RequestTypeOut, Response, ResponseType, ResponseTypeId,
158};
159use socketioxide_core::{
160 Sid, Uid,
161 adapter::errors::{AdapterError, BroadcastError},
162 adapter::{
163 BroadcastOptions, CoreAdapter, CoreLocalAdapter, DefinedAdapter, RemoteSocketData, Room,
164 RoomParam, SocketEmitter, Spawnable,
165 },
166 packet::Packet,
167};
168use stream::{AckStream, DropStream};
169use tokio::{sync::mpsc, time};
170
171pub mod drivers;
174
175mod stream;
176
177#[derive(thiserror::Error)]
179pub enum Error<R: Driver> {
180 #[error("driver error: {0}")]
182 Driver(R::Error),
183 #[error("packet encoding error: {0}")]
185 Decode(#[from] rmp_serde::decode::Error),
186 #[error("packet decoding error: {0}")]
188 Encode(#[from] rmp_serde::encode::Error),
189}
190
191impl<R: Driver> Error<R> {
192 fn from_driver(err: R::Error) -> Self {
193 Self::Driver(err)
194 }
195}
196impl<R: Driver> fmt::Debug for Error<R> {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 match self {
199 Self::Driver(err) => write!(f, "Driver error: {err:?}"),
200 Self::Decode(err) => write!(f, "Decode error: {err:?}"),
201 Self::Encode(err) => write!(f, "Encode error: {err:?}"),
202 }
203 }
204}
205
206impl<R: Driver> From<Error<R>> for AdapterError {
207 fn from(err: Error<R>) -> Self {
208 AdapterError::from(Box::new(err) as Box<dyn std::error::Error + Send>)
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct RedisAdapterConfig {
215 pub request_timeout: Duration,
218
219 pub prefix: Cow<'static, str>,
221
222 pub ack_response_buffer: usize,
227
228 pub stream_buffer: usize,
232}
233impl RedisAdapterConfig {
234 pub fn new() -> Self {
236 Self::default()
237 }
238 pub fn with_request_timeout(mut self, timeout: Duration) -> Self {
240 self.request_timeout = timeout;
241 self
242 }
243
244 pub fn with_prefix(mut self, prefix: impl Into<Cow<'static, str>>) -> Self {
246 self.prefix = prefix.into();
247 self
248 }
249
250 pub fn with_ack_response_buffer(mut self, buffer: usize) -> Self {
255 assert!(buffer > 0, "buffer size must be greater than 0");
256 self.ack_response_buffer = buffer;
257 self
258 }
259
260 pub fn with_stream_buffer(mut self, buffer: usize) -> Self {
264 assert!(buffer > 0, "buffer size must be greater than 0");
265 self.stream_buffer = buffer;
266 self
267 }
268}
269
270impl Default for RedisAdapterConfig {
271 fn default() -> Self {
272 Self {
273 request_timeout: Duration::from_secs(5),
274 prefix: Cow::Borrowed("socket.io"),
275 ack_response_buffer: 255,
276 stream_buffer: 1024,
277 }
278 }
279}
280
281#[derive(Debug)]
284pub struct RedisAdapterCtr<R> {
285 driver: R,
286 config: RedisAdapterConfig,
287}
288
289#[cfg(feature = "redis")]
290impl RedisAdapterCtr<drivers::redis::RedisDriver> {
291 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
293 pub async fn new_with_redis(client: &redis::Client) -> redis::RedisResult<Self> {
294 Self::new_with_redis_config(client, RedisAdapterConfig::default()).await
295 }
296 #[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
298 pub async fn new_with_redis_config(
299 client: &redis::Client,
300 config: RedisAdapterConfig,
301 ) -> redis::RedisResult<Self> {
302 let driver = drivers::redis::RedisDriver::new(client).await?;
303 Ok(Self::new_with_driver(driver, config))
304 }
305}
306#[cfg(feature = "redis-cluster")]
307impl RedisAdapterCtr<drivers::redis::ClusterDriver> {
308 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
310 pub async fn new_with_cluster(
311 client: &redis::cluster::ClusterClient,
312 ) -> redis::RedisResult<Self> {
313 Self::new_with_cluster_config(client, RedisAdapterConfig::default()).await
314 }
315
316 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
318 pub async fn new_with_cluster_config(
319 client: &redis::cluster::ClusterClient,
320 config: RedisAdapterConfig,
321 ) -> redis::RedisResult<Self> {
322 let driver = drivers::redis::ClusterDriver::new(client).await?;
323 Ok(Self::new_with_driver(driver, config))
324 }
325}
326#[cfg(feature = "fred")]
327impl RedisAdapterCtr<drivers::fred::FredDriver> {
328 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
330 pub async fn new_with_fred(
331 client: fred::clients::SubscriberClient,
332 ) -> fred::prelude::FredResult<Self> {
333 Self::new_with_fred_config(client, RedisAdapterConfig::default()).await
334 }
335 #[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
337 pub async fn new_with_fred_config(
338 client: fred::clients::SubscriberClient,
339 config: RedisAdapterConfig,
340 ) -> fred::prelude::FredResult<Self> {
341 let driver = drivers::fred::FredDriver::new(client).await?;
342 Ok(Self::new_with_driver(driver, config))
343 }
344}
345impl<R: Driver> RedisAdapterCtr<R> {
346 pub fn new_with_driver(driver: R, config: RedisAdapterConfig) -> RedisAdapterCtr<R> {
351 RedisAdapterCtr { driver, config }
352 }
353}
354
355pub(crate) type ResponseHandlers = HashMap<Sid, mpsc::Sender<Vec<u8>>>;
356
357#[cfg_attr(docsrs, doc(cfg(feature = "fred")))]
359#[cfg(feature = "fred")]
360pub type FredAdapter<E> = CustomRedisAdapter<E, drivers::fred::FredDriver>;
361
362#[cfg_attr(docsrs, doc(cfg(feature = "redis")))]
364#[cfg(feature = "redis")]
365pub type RedisAdapter<E> = CustomRedisAdapter<E, drivers::redis::RedisDriver>;
366
367#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
369#[cfg(feature = "redis-cluster")]
370pub type ClusterAdapter<E> = CustomRedisAdapter<E, drivers::redis::ClusterDriver>;
371
372pub struct CustomRedisAdapter<E, R> {
377 driver: R,
380 config: RedisAdapterConfig,
382 uid: Uid,
384 local: CoreLocalAdapter<E>,
386 req_chan: String,
389 responses: Arc<Mutex<ResponseHandlers>>,
391}
392
393impl<E, R> DefinedAdapter for CustomRedisAdapter<E, R> {}
394impl<E: SocketEmitter, R: Driver> CoreAdapter<E> for CustomRedisAdapter<E, R> {
395 type Error = Error<R>;
396 type State = RedisAdapterCtr<R>;
397 type AckStream = AckStream<E::AckStream>;
398 type InitRes = InitRes<R>;
399
400 fn new(state: &Self::State, local: CoreLocalAdapter<E>) -> Self {
401 let req_chan = format!("{}-request#{}#", state.config.prefix, local.path());
402 let uid = local.server_id();
403 Self {
404 local,
405 req_chan,
406 uid,
407 driver: state.driver.clone(),
408 config: state.config.clone(),
409 responses: Arc::new(Mutex::new(HashMap::new())),
410 }
411 }
412
413 fn init(self: Arc<Self>, on_success: impl FnOnce() + Send + 'static) -> Self::InitRes {
414 let fut = async move {
415 check_ns(self.local.path())?;
416 let global_stream = self.subscribe(self.req_chan.clone()).await?;
417 let specific_stream = self.subscribe(self.get_req_chan(Some(self.uid))).await?;
418 let response_chan = format!(
419 "{}-response#{}#{}#",
420 &self.config.prefix,
421 self.local.path(),
422 self.uid
423 );
424
425 let response_stream = self.subscribe(response_chan.clone()).await?;
426 let stream = futures_util::stream::select(global_stream, specific_stream);
427 let stream = futures_util::stream::select(stream, response_stream);
428 tokio::spawn(self.pipe_stream(stream, response_chan));
429 on_success();
430 Ok(())
431 };
432 InitRes(Box::pin(fut))
433 }
434
435 async fn close(&self) -> Result<(), Self::Error> {
436 let response_chan = format!(
437 "{}-response#{}#{}#",
438 &self.config.prefix,
439 self.local.path(),
440 self.uid
441 );
442 tokio::try_join!(
443 self.driver.unsubscribe(self.req_chan.clone()),
444 self.driver.unsubscribe(self.get_req_chan(Some(self.uid))),
445 self.driver.unsubscribe(response_chan)
446 )
447 .map_err(Error::from_driver)?;
448
449 Ok(())
450 }
451
452 async fn server_count(&self) -> Result<u16, Self::Error> {
454 let count = self
455 .driver
456 .num_serv(&self.req_chan)
457 .await
458 .map_err(Error::from_driver)?;
459
460 Ok(count)
461 }
462
463 async fn broadcast(
465 &self,
466 packet: Packet,
467 opts: BroadcastOptions,
468 ) -> Result<(), BroadcastError> {
469 if !opts.is_local(self.uid) {
470 let req = RequestOut::new(self.uid, RequestTypeOut::Broadcast(&packet), &opts);
471 self.send_req(req, opts.server_id)
472 .await
473 .map_err(AdapterError::from)?;
474 }
475
476 self.local.broadcast(packet, opts)?;
477 Ok(())
478 }
479
480 async fn broadcast_with_ack(
508 &self,
509 packet: Packet,
510 opts: BroadcastOptions,
511 timeout: Option<Duration>,
512 ) -> Result<Self::AckStream, Self::Error> {
513 if opts.is_local(self.uid) {
514 tracing::debug!(?opts, "broadcast with ack is local");
515 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
516 let stream = AckStream::new_local(local);
517 return Ok(stream);
518 }
519 let req = RequestOut::new(self.uid, RequestTypeOut::BroadcastWithAck(&packet), &opts);
520 let req_id = req.id;
521
522 let remote_serv_cnt = self.server_count().await?.saturating_sub(1);
523
524 let (tx, rx) = mpsc::channel(self.config.ack_response_buffer + remote_serv_cnt as usize);
525 self.responses.lock().unwrap().insert(req_id, tx);
526 let remote = MessageStream::new(rx);
527
528 self.send_req(req, opts.server_id).await?;
529 let (local, _) = self.local.broadcast_with_ack(packet, opts, timeout);
530
531 let timeout = self
533 .config
534 .request_timeout
535 .saturating_add(timeout.unwrap_or(self.local.ack_timeout()));
536
537 Ok(AckStream::new(
538 local,
539 remote,
540 timeout,
541 remote_serv_cnt,
542 req_id,
543 self.responses.clone(),
544 ))
545 }
546
547 async fn disconnect_socket(&self, opts: BroadcastOptions) -> Result<(), BroadcastError> {
548 if !opts.is_local(self.uid) {
549 let req = RequestOut::new(self.uid, RequestTypeOut::DisconnectSockets, &opts);
550 self.send_req(req, opts.server_id)
551 .await
552 .map_err(AdapterError::from)?;
553 }
554 self.local
555 .disconnect_socket(opts)
556 .map_err(BroadcastError::Socket)?;
557
558 Ok(())
559 }
560
561 async fn rooms(&self, opts: BroadcastOptions) -> Result<Vec<Room>, Self::Error> {
562 if opts.is_local(self.uid) {
563 return Ok(self.local.rooms(opts).into_iter().collect());
564 }
565 let req = RequestOut::new(self.uid, RequestTypeOut::AllRooms, &opts);
566 let req_id = req.id;
567
568 let stream = self
571 .get_res::<()>(req_id, ResponseTypeId::AllRooms, opts.server_id)
572 .await?;
573 self.send_req(req, opts.server_id).await?;
574 let local = self.local.rooms(opts);
575 let rooms = stream
576 .filter_map(|item| future::ready(item.into_rooms()))
577 .fold(local, async |mut acc, item| {
578 acc.extend(item);
579 acc
580 })
581 .await;
582 Ok(Vec::from_iter(rooms))
583 }
584
585 async fn add_sockets(
586 &self,
587 opts: BroadcastOptions,
588 rooms: impl RoomParam,
589 ) -> Result<(), Self::Error> {
590 let rooms: Vec<Room> = rooms.into_room_iter().collect();
591 if !opts.is_local(self.uid) {
592 let req = RequestOut::new(self.uid, RequestTypeOut::AddSockets(&rooms), &opts);
593 self.send_req(req, opts.server_id).await?;
594 }
595 self.local.add_sockets(opts, rooms);
596 Ok(())
597 }
598
599 async fn del_sockets(
600 &self,
601 opts: BroadcastOptions,
602 rooms: impl RoomParam,
603 ) -> Result<(), Self::Error> {
604 let rooms: Vec<Room> = rooms.into_room_iter().collect();
605 if !opts.is_local(self.uid) {
606 let req = RequestOut::new(self.uid, RequestTypeOut::DelSockets(&rooms), &opts);
607 self.send_req(req, opts.server_id).await?;
608 }
609 self.local.del_sockets(opts, rooms);
610 Ok(())
611 }
612
613 async fn fetch_sockets(
614 &self,
615 opts: BroadcastOptions,
616 ) -> Result<Vec<RemoteSocketData>, Self::Error> {
617 if opts.is_local(self.uid) {
618 return Ok(self.local.fetch_sockets(opts));
619 }
620 let req = RequestOut::new(self.uid, RequestTypeOut::FetchSockets, &opts);
621 let req_id = req.id;
622 let remote = self
625 .get_res::<RemoteSocketData>(req_id, ResponseTypeId::FetchSockets, opts.server_id)
626 .await?;
627
628 self.send_req(req, opts.server_id).await?;
629 let local = self.local.fetch_sockets(opts);
630 let sockets = remote
631 .filter_map(|item| future::ready(item.into_fetch_sockets()))
632 .fold(local, async |mut acc, item| {
633 acc.extend(item);
634 acc
635 })
636 .await;
637 Ok(sockets)
638 }
639
640 fn get_local(&self) -> &CoreLocalAdapter<E> {
641 &self.local
642 }
643}
644
645#[derive(thiserror::Error)]
647pub enum InitError<D: Driver> {
648 #[error("driver error: {0}")]
650 Driver(D::Error),
651 #[error("malformed namespace path, it must not contain '#'")]
653 MalformedNamespace,
654}
655impl<D: Driver> fmt::Debug for InitError<D> {
656 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
657 match self {
658 Self::Driver(err) => fmt::Debug::fmt(err, f),
659 Self::MalformedNamespace => write!(f, "Malformed namespace path"),
660 }
661 }
662}
663#[must_use = "futures do nothing unless you `.await` or poll them"]
665pub struct InitRes<D: Driver>(futures_core::future::BoxFuture<'static, Result<(), InitError<D>>>);
666
667impl<D: Driver> Future for InitRes<D> {
668 type Output = Result<(), InitError<D>>;
669
670 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
671 self.0.as_mut().poll(cx)
672 }
673}
674impl<D: Driver> Spawnable for InitRes<D> {
675 fn spawn(self) {
676 tokio::spawn(async move {
677 if let Err(e) = self.0.await {
678 tracing::error!("error initializing adapter: {e}");
679 }
680 });
681 }
682}
683
684impl<E: SocketEmitter, R: Driver> CustomRedisAdapter<E, R> {
685 fn get_res_chan(&self, uid: Uid) -> String {
690 let path = self.local.path();
691 let prefix = &self.config.prefix;
692 format!("{prefix}-response#{path}#{uid}#")
693 }
694 fn get_req_chan(&self, node_id: Option<Uid>) -> String {
699 match node_id {
700 Some(uid) => format!("{}{}#", self.req_chan, uid),
701 None => self.req_chan.clone(),
702 }
703 }
704
705 async fn pipe_stream(
706 self: Arc<Self>,
707 mut stream: impl Stream<Item = ChanItem> + Unpin,
708 response_chan: String,
709 ) {
710 while let Some((chan, item)) = stream.next().await {
711 if chan.starts_with(&self.req_chan) {
712 if let Err(e) = self.recv_req(item) {
713 let ns = self.local.path();
714 let uid = self.uid;
715 tracing::warn!(?uid, ?ns, "request handler error: {e}");
716 }
717 } else if chan == response_chan {
718 let req_id = read_req_id(&item);
719 tracing::trace!(?req_id, ?chan, ?response_chan, "extracted sid");
720 let handlers = self.responses.lock().unwrap();
721 if let Some(tx) = req_id.and_then(|id| handlers.get(&id)) {
722 if let Err(e) = tx.try_send(item) {
723 tracing::warn!("error sending response to handler: {e}");
724 }
725 } else {
726 tracing::warn!(?req_id, "could not find req handler");
727 }
728 } else {
729 tracing::warn!("unexpected message/channel: {chan}");
730 }
731 }
732 }
733
734 fn recv_req(self: &Arc<Self>, item: Vec<u8>) -> Result<(), Error<R>> {
736 let req: RequestIn = rmp_serde::from_slice(&item)?;
737 if req.node_id == self.uid {
738 return Ok(());
739 }
740
741 tracing::trace!(?req, "handling request");
742 let Some(opts) = req.opts else {
743 tracing::warn!(?req, "request is missing options");
744 return Ok(());
745 };
746
747 match req.r#type {
748 RequestTypeIn::Broadcast(p) => self.recv_broadcast(opts, p),
749 RequestTypeIn::BroadcastWithAck(p) => {
750 self.clone()
751 .recv_broadcast_with_ack(req.node_id, req.id, p, opts)
752 }
753 RequestTypeIn::DisconnectSockets => self.recv_disconnect_sockets(opts),
754 RequestTypeIn::AllRooms => self.recv_rooms(req.node_id, req.id, opts),
755 RequestTypeIn::AddSockets(rooms) => self.recv_add_sockets(opts, rooms),
756 RequestTypeIn::DelSockets(rooms) => self.recv_del_sockets(opts, rooms),
757 RequestTypeIn::FetchSockets => self.recv_fetch_sockets(req.node_id, req.id, opts),
758 _ => (),
759 };
760 Ok(())
761 }
762
763 fn recv_broadcast(&self, opts: BroadcastOptions, packet: Packet) {
764 if let Err(e) = self.local.broadcast(packet, opts) {
765 let ns = self.local.path();
766 tracing::warn!(?self.uid, ?ns, "remote request broadcast handler: {:?}", e);
767 }
768 }
769
770 fn recv_disconnect_sockets(&self, opts: BroadcastOptions) {
771 if let Err(e) = self.local.disconnect_socket(opts) {
772 let ns = self.local.path();
773 tracing::warn!(
774 ?self.uid,
775 ?ns,
776 "remote request disconnect sockets handler: {:?}",
777 e
778 );
779 }
780 }
781
782 fn recv_broadcast_with_ack(
783 self: Arc<Self>,
784 origin: Uid,
785 req_id: Sid,
786 packet: Packet,
787 opts: BroadcastOptions,
788 ) {
789 let (stream, count) = self.local.broadcast_with_ack(packet, opts, None);
790 tokio::spawn(async move {
791 let on_err = |err| {
792 let ns = self.local.path();
793 tracing::warn!(
794 ?origin,
795 ?ns,
796 "remote request broadcast with ack handler errors: {:?}",
797 err
798 );
799 };
800 let res = Response {
803 r#type: ResponseType::<()>::BroadcastAckCount(count),
804 node_id: self.uid,
805 };
806 if let Err(err) = self.send_res(origin, req_id, res).await {
807 on_err(err);
808 return;
809 }
810
811 futures_util::pin_mut!(stream);
813 while let Some(ack) = stream.next().await {
814 let res = Response {
815 r#type: ResponseType::BroadcastAck(ack),
816 node_id: self.uid,
817 };
818 if let Err(err) = self.send_res(origin, req_id, res).await {
819 on_err(err);
820 return;
821 }
822 }
823 });
824 }
825
826 fn recv_rooms(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
827 let rooms = self.local.rooms(opts);
828 let res = Response {
829 r#type: ResponseType::<()>::AllRooms(rooms),
830 node_id: self.uid,
831 };
832 let fut = self.send_res(origin, req_id, res);
833 let ns = self.local.path().clone();
834 let uid = self.uid;
835 tokio::spawn(async move {
836 if let Err(err) = fut.await {
837 tracing::warn!(?uid, ?ns, "remote request rooms handler: {:?}", err);
838 }
839 });
840 }
841
842 fn recv_add_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
843 self.local.add_sockets(opts, rooms);
844 }
845
846 fn recv_del_sockets(&self, opts: BroadcastOptions, rooms: Vec<Room>) {
847 self.local.del_sockets(opts, rooms);
848 }
849 fn recv_fetch_sockets(&self, origin: Uid, req_id: Sid, opts: BroadcastOptions) {
850 let sockets = self.local.fetch_sockets(opts);
851 let res = Response {
852 node_id: self.uid,
853 r#type: ResponseType::FetchSockets(sockets),
854 };
855 let fut = self.send_res(origin, req_id, res);
856 let ns = self.local.path().clone();
857 let uid = self.uid;
858 tokio::spawn(async move {
859 if let Err(err) = fut.await {
860 tracing::warn!(?uid, ?ns, "remote request fetch sockets handler: {:?}", err);
861 }
862 });
863 }
864
865 async fn send_req(&self, req: RequestOut<'_>, target_uid: Option<Uid>) -> Result<(), Error<R>> {
866 tracing::trace!(?req, "sending request");
867 let req = rmp_serde::to_vec(&req)?;
868 let chan = self.get_req_chan(target_uid);
869 self.driver
870 .publish(chan, req)
871 .await
872 .map_err(Error::from_driver)?;
873
874 Ok(())
875 }
876
877 fn send_res<D: Serialize + fmt::Debug>(
878 &self,
879 req_node_id: Uid,
880 req_id: Sid,
881 res: Response<D>,
882 ) -> impl Future<Output = Result<(), Error<R>>> + Send + 'static {
883 let chan = self.get_res_chan(req_node_id);
884 tracing::trace!(?res, "sending response to {}", &chan);
885 let res = rmp_serde::to_vec(&(req_id, res));
889 let driver = self.driver.clone();
890 async move {
891 driver
892 .publish(chan, res?)
893 .await
894 .map_err(Error::from_driver)?;
895 Ok(())
896 }
897 }
898
899 async fn get_res<D: DeserializeOwned + fmt::Debug>(
901 &self,
902 req_id: Sid,
903 response_type: ResponseTypeId,
904 target_uid: Option<Uid>,
905 ) -> Result<impl Stream<Item = Response<D>>, Error<R>> {
906 let remote_serv_cnt = if target_uid.is_none() {
908 self.server_count().await?.saturating_sub(1) as usize
909 } else {
910 1
911 };
912 let (tx, rx) = mpsc::channel(std::cmp::max(remote_serv_cnt, 1));
913 self.responses.lock().unwrap().insert(req_id, tx);
914 let stream = MessageStream::new(rx)
915 .filter_map(|item| {
916 let data = match rmp_serde::from_slice::<(Sid, Response<D>)>(&item) {
917 Ok((_, data)) => Some(data),
918 Err(e) => {
919 tracing::warn!("error decoding response: {e}");
920 None
921 }
922 };
923 future::ready(data)
924 })
925 .filter(move |item| future::ready(ResponseTypeId::from(&item.r#type) == response_type))
926 .take(remote_serv_cnt)
927 .take_until(time::sleep(self.config.request_timeout));
928 let stream = DropStream::new(stream, self.responses.clone(), req_id);
929 Ok(stream)
930 }
931
932 #[inline]
934 async fn subscribe(&self, pat: String) -> Result<MessageStream<ChanItem>, InitError<R>> {
935 tracing::trace!(?pat, "subscribing to");
936 self.driver
937 .subscribe(pat, self.config.stream_buffer)
938 .await
939 .map_err(InitError::Driver)
940 }
941}
942
943fn check_ns<D: Driver>(path: &str) -> Result<(), InitError<D>> {
946 if path.is_empty() || path.contains('#') {
947 Err(InitError::MalformedNamespace)
948 } else {
949 Ok(())
950 }
951}
952
953pub fn read_req_id(data: &[u8]) -> Option<Sid> {
955 use std::str::FromStr;
956 let mut rd = data;
957 let len = rmp::decode::read_array_len(&mut rd).ok()?;
958 if len < 1 {
959 return None;
960 }
961
962 let mut buff = [0u8; Sid::ZERO.as_str().len()];
963 let str = rmp::decode::read_str(&mut rd, &mut buff).ok()?;
964 Sid::from_str(str).ok()
965}
966
967#[cfg(test)]
968mod tests {
969 use super::*;
970 use futures_util::stream::{self, FusedStream, StreamExt};
971 use socketioxide_core::{Str, Value, adapter::AckStreamItem};
972 use std::convert::Infallible;
973
974 #[derive(Clone)]
975 struct StubDriver;
976 impl Driver for StubDriver {
977 type Error = Infallible;
978
979 async fn publish(&self, _: String, _: Vec<u8>) -> Result<(), Self::Error> {
980 Ok(())
981 }
982
983 async fn subscribe(
984 &self,
985 _: String,
986 _: usize,
987 ) -> Result<MessageStream<ChanItem>, Self::Error> {
988 Ok(MessageStream::new_empty())
989 }
990
991 async fn unsubscribe(&self, _: String) -> Result<(), Self::Error> {
992 Ok(())
993 }
994
995 async fn num_serv(&self, _: &str) -> Result<u16, Self::Error> {
996 Ok(0)
997 }
998 }
999 fn new_stub_ack_stream(
1000 remote: MessageStream<Vec<u8>>,
1001 timeout: Duration,
1002 ) -> AckStream<stream::Empty<AckStreamItem<()>>> {
1003 AckStream::new(
1004 stream::empty::<AckStreamItem<()>>(),
1005 remote,
1006 timeout,
1007 2,
1008 Sid::new(),
1009 Arc::new(Mutex::new(HashMap::new())),
1010 )
1011 }
1012
1013 #[tokio::test]
1015 async fn ack_stream() {
1016 let (tx, rx) = tokio::sync::mpsc::channel(255);
1017 let remote = MessageStream::new(rx);
1018 let stream = new_stub_ack_stream(remote, Duration::from_secs(10));
1019 let node_id = Uid::new();
1020 let req_id = Sid::new();
1021
1022 let ack_cnt_res = Response::<()> {
1024 node_id,
1025 r#type: ResponseType::BroadcastAckCount(2),
1026 };
1027 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1028 .unwrap();
1029 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_cnt_res)).unwrap())
1030 .unwrap();
1031
1032 let ack_res = Response::<String> {
1033 node_id,
1034 r#type: ResponseType::BroadcastAck((Sid::new(), Ok(Value::Str(Str::from(""), None)))),
1035 };
1036 for _ in 0..4 {
1037 tx.try_send(rmp_serde::to_vec(&(req_id, &ack_res)).unwrap())
1038 .unwrap();
1039 }
1040 futures_util::pin_mut!(stream);
1041 for _ in 0..4 {
1042 assert!(stream.next().await.is_some());
1043 }
1044 assert!(stream.is_terminated());
1045 }
1046
1047 #[tokio::test]
1048 async fn ack_stream_timeout() {
1049 let (tx, rx) = tokio::sync::mpsc::channel(255);
1050 let remote = MessageStream::new(rx);
1051 let stream = new_stub_ack_stream(remote, Duration::from_millis(50));
1052 let node_id = Uid::new();
1053 let req_id = Sid::new();
1054 let ack_cnt_res = Response::<()> {
1056 node_id,
1057 r#type: ResponseType::BroadcastAckCount(2),
1058 };
1059 tx.try_send(rmp_serde::to_vec(&(req_id, ack_cnt_res)).unwrap())
1060 .unwrap();
1061
1062 futures_util::pin_mut!(stream);
1063 tokio::time::sleep(Duration::from_millis(50)).await;
1064 assert!(stream.next().await.is_none());
1065 assert!(stream.is_terminated());
1066 }
1067
1068 #[tokio::test]
1069 async fn ack_stream_drop() {
1070 let (tx, rx) = tokio::sync::mpsc::channel(255);
1071 let remote = MessageStream::new(rx);
1072 let handlers = Arc::new(Mutex::new(HashMap::new()));
1073 let id = Sid::new();
1074 handlers.lock().unwrap().insert(id, tx);
1075 let stream = AckStream::new(
1076 stream::empty::<AckStreamItem<()>>(),
1077 remote,
1078 Duration::from_secs(10),
1079 2,
1080 id,
1081 handlers.clone(),
1082 );
1083 drop(stream);
1084 assert!(handlers.lock().unwrap().is_empty(),);
1085 }
1086
1087 #[test]
1088 fn check_ns_error() {
1089 assert!(matches!(
1090 check_ns::<StubDriver>("#"),
1091 Err(InitError::MalformedNamespace)
1092 ));
1093 assert!(matches!(
1094 check_ns::<StubDriver>(""),
1095 Err(InitError::MalformedNamespace)
1096 ));
1097 }
1098}