1#![cfg_attr(
39 feature = "memquota",
40 doc = "let config = tor_memquota::Config::builder().max(1024*1024*1024).build().unwrap();",
41 doc = "let trk = MemoryQuotaTracker::new(&runtime, config).unwrap();"
42)]
43#![cfg_attr(
44 not(feature = "memquota"),
45 doc = "let trk = MemoryQuotaTracker::new_noop();"
46)]
47#![forbid(unsafe_code)] use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
73
74use crate::internal_prelude::*;
75
76use std::task::{Context, Poll, Poll::*};
77use tor_async_utils::{ErasedSinkTrySendError, SinkCloseChannel, SinkTrySend};
78
79#[derive(Educe)]
86#[educe(Debug, Clone(bound = "C::Sender<Entry<T>>: Clone"))]
87pub struct Sender<T: Debug + Send + 'static, C: ChannelSpec> {
88 tx: C::Sender<Entry<T>>,
90
91 mq: TypedParticipation<Entry<T>>,
93
94 #[educe(Debug(ignore))] runtime: DynTimeProvider,
97}
98
99#[derive(Educe)] #[educe(Debug)]
107pub struct Receiver<T: Debug + Send + 'static, C: ChannelSpec> {
108 inner: Arc<ReceiverInner<T, C>>,
125}
126
127#[derive(Educe)]
132#[educe(Debug)]
133struct ReceiverInner<T: Debug + Send + 'static, C: ChannelSpec> {
134 state: Mutex<Result<ReceiverState<T, C>, CollapsedDueToReclaim>>,
141}
142
143#[derive(Educe)]
151#[educe(Debug)]
152struct ReceiverState<T: Debug + Send + 'static, C: ChannelSpec> {
153 rx: StreamUnobtrusivePeeker<C::Receiver<Entry<T>>>,
155
156 mq: TypedParticipation<Entry<T>>,
168
169 #[educe(Debug(method = "receiver_state_debug_collapse_notify"))]
173 collapse_callbacks: Vec<CollapseCallback>,
174}
175
176#[derive(Debug)]
180struct Entry<T> {
181 t: T,
183 when: CoarseInstant,
185}
186
187#[derive(Error, Clone, Debug)]
189#[non_exhaustive]
190pub enum SendError<CE> {
191 #[error("channel send failed")]
194 Channel(#[source] CE),
195
196 #[error("memory quota exhausted, queue reclaimed")]
204 Memquota(#[from] Error),
205}
206
207pub type CollapseCallback = Box<dyn FnOnce(CollapseReason) + Send + Sync + 'static>;
209
210#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
212#[non_exhaustive]
213pub enum CollapseReason {
214 ReceiverDropped,
216
217 MemoryReclaimed,
219}
220
221#[derive(Debug, Clone, Copy)]
223struct CollapsedDueToReclaim;
224
225pub trait ChannelSpec: Sealed + Sized + 'static {
245 type Sender<T: Debug + Send + 'static>: Sink<T, Error = Self::SendError>
259 + Debug + Unpin + Sized;
260
261 type Receiver<T: Debug + Send + 'static>: Stream<Item = T> + Debug + Unpin + Send + Sized;
263
264 type SendError: std::error::Error;
268
269 #[allow(clippy::type_complexity)] fn new_mq<T>(self, runtime: DynTimeProvider, account: &Account) -> crate::Result<(
276 Sender<T, Self>,
277 Receiver<T, Self>,
278 )>
279 where
280 T: HasMemoryCost + Debug + Send + 'static,
281 {
282 let (rx, (tx, mq)) = account.register_participant_with(
283 runtime.now_coarse(),
284 move |mq| {
285 let mq = TypedParticipation::new(mq);
286 let collapse_callbacks = vec![];
287 let (tx, rx) = self.raw_channel::<Entry<T>>();
288 let rx = StreamUnobtrusivePeeker::new(rx);
289 let state = ReceiverState { rx, mq: mq.clone(), collapse_callbacks };
290 let state = Mutex::new(Ok(state));
291 let inner = ReceiverInner { state };
292 Ok::<_, crate::Error>((inner.into(), (tx, mq)))
293 },
294 )??;
295
296 let runtime = runtime.clone();
297
298 let tx = Sender { runtime, tx, mq };
299 let rx = Receiver { inner: rx };
300
301 Ok((tx, rx))
302 }
303
304 fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>);
308
309 fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>);
314}
315
316#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor)]
327#[allow(clippy::exhaustive_structs)] pub struct MpscSpec {
329 pub buffer: usize,
331}
332
333#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor, Default)]
341#[allow(clippy::exhaustive_structs)] pub struct MpscUnboundedSpec;
343
344impl Sealed for MpscSpec {}
345impl Sealed for MpscUnboundedSpec {}
346
347impl ChannelSpec for MpscSpec {
348 type Sender<T: Debug + Send + 'static> = mpsc::Sender<T>;
349 type Receiver<T: Debug + Send + 'static> = mpsc::Receiver<T>;
350 type SendError = mpsc::SendError;
351
352 fn raw_channel<T: Debug + Send + 'static>(self) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
353 mpsc_channel_no_memquota(self.buffer)
354 }
355
356 fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
357 rx.close();
358 }
359}
360
361impl ChannelSpec for MpscUnboundedSpec {
362 type Sender<T: Debug + Send + 'static> = mpsc::UnboundedSender<T>;
363 type Receiver<T: Debug + Send + 'static> = mpsc::UnboundedReceiver<T>;
364 type SendError = mpsc::SendError;
365
366 fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>) {
367 mpsc::unbounded()
368 }
369
370 fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
371 rx.close();
372 }
373}
374
375impl<T, C> Sink<T> for Sender<T, C>
380where
381 T: HasMemoryCost + Debug + Send + 'static,
382 C: ChannelSpec,
383{
384 type Error = SendError<C::SendError>;
385
386 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387 self.get_mut()
388 .tx
389 .poll_ready_unpin(cx)
390 .map_err(SendError::Channel)
391 }
392
393 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
394 let self_ = self.get_mut();
395 let item = Entry {
396 t: item,
397 when: self_.runtime.now_coarse(),
398 };
399 self_.mq.try_claim(item, |item| {
400 self_.tx.start_send_unpin(item).map_err(SendError::Channel)
401 })?
402 }
403
404 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405 self.tx
406 .poll_flush_unpin(cx)
407 .map(|r| r.map_err(SendError::Channel))
408 }
409
410 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411 self.tx
412 .poll_close_unpin(cx)
413 .map(|r| r.map_err(SendError::Channel))
414 }
415}
416
417impl<T, C> SinkTrySend<T> for Sender<T, C>
418where
419 T: HasMemoryCost + Debug + Send + 'static,
420 C: ChannelSpec,
421 C::Sender<Entry<T>>: SinkTrySend<Entry<T>>,
422 <C::Sender<Entry<T>> as SinkTrySend<Entry<T>>>::Error: Send + Sync,
423{
424 type Error = ErasedSinkTrySendError;
425 fn try_send_or_return(
426 self: Pin<&mut Self>,
427 item: T,
428 ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)> {
429 let self_ = self.get_mut();
430 let item = Entry {
431 t: item,
432 when: self_.runtime.now_coarse(),
433 };
434
435 use ErasedSinkTrySendError as ESTSE;
436
437 self_
438 .mq
439 .try_claim_or_return(item, |item| {
440 Pin::new(&mut self_.tx).try_send_or_return(item)
441 })
442 .map_err(|(mqe, unsent)| (ESTSE::Other(Arc::new(mqe)), unsent.t))?
443 .map_err(|(tse, unsent)| (ESTSE::from(tse), unsent.t))
444 }
445}
446
447impl<T, C> SinkCloseChannel<T> for Sender<T, C>
448where
449 T: HasMemoryCost + Debug + Send, C: ChannelSpec,
451 C::Sender<Entry<T>>: SinkCloseChannel<Entry<T>>,
452{
453 fn close_channel(self: Pin<&mut Self>) {
454 Pin::new(&mut self.get_mut().tx).close_channel();
455 }
456}
457
458impl<T, C> Sender<T, C>
459where
460 T: Debug + Send + 'static,
461 C: ChannelSpec,
462{
463 pub fn time_provider(&self) -> &DynTimeProvider {
468 &self.runtime
469 }
470}
471
472impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Stream for Receiver<T, C> {
475 type Item = T;
476
477 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478 let mut state = self.inner.lock();
479 let state = match &mut *state {
480 Ok(y) => y,
481 Err(CollapsedDueToReclaim) => return Ready(None),
482 };
483 let ret = state.rx.poll_next_unpin(cx);
484 if let Ready(Some(item)) = &ret {
485 if let Some(enabled) = EnabledToken::new_if_compiled_in() {
486 let cost = item.typed_memory_cost(enabled);
487 state.mq.release(&cost);
488 }
489 }
490 ret.map(|r| r.map(|e| e.t))
491 }
492}
493
494impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> FusedStream for Receiver<T, C>
495where
496 C::Receiver<Entry<T>>: FusedStream,
497{
498 fn is_terminated(&self) -> bool {
499 match &*self.inner.lock() {
500 Ok(y) => y.rx.is_terminated(),
501 Err(CollapsedDueToReclaim) => true,
502 }
503 }
504}
505
506impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Receiver<T, C> {
509 pub fn register_collapse_hook(&self, call: CollapseCallback) {
527 let mut state = self.inner.lock();
528 let state = match &mut *state {
529 Ok(y) => y,
530 Err(reason) => {
531 let reason = (*reason).into();
532 drop::<MutexGuard<_>>(state);
533 call(reason);
534 return;
535 }
536 };
537 state.collapse_callbacks.push(call);
538 }
539}
540
541impl<T: Debug + Send + 'static, C: ChannelSpec> ReceiverInner<T, C> {
542 fn lock(&self) -> MutexGuard<Result<ReceiverState<T, C>, CollapsedDueToReclaim>> {
544 self.state.lock().expect("mq_mpsc lock poisoned")
545 }
546}
547
548impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> IsParticipant
549 for ReceiverInner<T, C>
550{
551 fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
552 let mut state = self.lock();
553 let state = match &mut *state {
554 Ok(y) => y,
555 Err(CollapsedDueToReclaim) => return None,
556 };
557 Pin::new(&mut state.rx)
558 .unobtrusive_peek()
559 .map(|peeked| peeked.when)
560 }
561
562 fn reclaim(self: Arc<Self>, _: EnabledToken) -> mtracker::ReclaimFuture {
563 Box::pin(async move {
564 let reason = CollapsedDueToReclaim;
565 let mut state_guard = self.lock();
566 let state = mem::replace(&mut *state_guard, Err(reason));
567 drop::<MutexGuard<_>>(state_guard);
568 #[allow(clippy::single_match)] match state {
570 Ok(mut state) => {
571 for call in state.collapse_callbacks.drain(..) {
572 call(reason.into());
573 }
574 drop::<ReceiverState<_, _>>(state); }
576 Err(CollapsedDueToReclaim) => {}
577 };
578 mtracker::Reclaimed::Collapsing
579 })
580 }
581}
582
583impl<T: Debug + Send + 'static, C: ChannelSpec> Drop for ReceiverState<T, C> {
584 fn drop(&mut self) {
585 mem::replace(&mut self.mq, Participation::new_dangling().into())
589 .into_raw()
590 .destroy_participant();
591
592 for call in self.collapse_callbacks.drain(..) {
593 call(CollapseReason::ReceiverDropped);
594 }
595
596 let mut noop_cx = Context::from_waker(Waker::noop());
599
600 if let Some(mut rx_inner) =
602 StreamUnobtrusivePeeker::as_raw_inner_pin_mut(Pin::new(&mut self.rx))
603 {
604 C::close_receiver(&mut rx_inner);
605 }
606
607 while let Ready(Some(item)) = self.rx.poll_next_unpin(&mut noop_cx) {
608 drop::<Entry<T>>(item);
609 }
610 }
611}
612
613fn receiver_state_debug_collapse_notify(
615 v: &[CollapseCallback],
616 f: &mut fmt::Formatter,
617) -> fmt::Result {
618 Debug::fmt(&v.len(), f)
619}
620
621impl<T: HasMemoryCost> HasMemoryCost for Entry<T> {
624 fn memory_cost(&self, enabled: EnabledToken) -> usize {
625 let time_size = std::alloc::Layout::new::<CoarseInstant>().size();
626 self.t.memory_cost(enabled).saturating_add(time_size)
627 }
628}
629
630impl From<CollapsedDueToReclaim> for CollapseReason {
631 fn from(CollapsedDueToReclaim: CollapsedDueToReclaim) -> CollapseReason {
632 CollapseReason::MemoryReclaimed
633 }
634}
635
636#[cfg(all(test, feature = "memquota", not(miri) ))]
637mod test {
638 #![allow(clippy::bool_assert_comparison)]
640 #![allow(clippy::clone_on_copy)]
641 #![allow(clippy::dbg_macro)]
642 #![allow(clippy::mixed_attributes_style)]
643 #![allow(clippy::print_stderr)]
644 #![allow(clippy::print_stdout)]
645 #![allow(clippy::single_char_pattern)]
646 #![allow(clippy::unwrap_used)]
647 #![allow(clippy::unchecked_time_subtraction)]
648 #![allow(clippy::useless_vec)]
649 #![allow(clippy::needless_pass_by_value)]
650 #![allow(clippy::arithmetic_side_effects)] use super::*;
654 use crate::mtracker::test::*;
655 use tor_rtmock::MockRuntime;
656 use tracing::debug;
657 use tracing_test::traced_test;
658
659 #[derive(Default, Debug)]
660 struct ItemTracker {
661 state: Mutex<ItemTrackerState>,
662 }
663 #[derive(Default, Debug)]
664 struct ItemTrackerState {
665 existing: usize,
666 next_id: usize,
667 }
668
669 #[derive(Debug)]
670 struct Item {
671 id: usize,
672 tracker: Arc<ItemTracker>,
673 }
674
675 impl ItemTracker {
676 fn new_item(self: &Arc<Self>) -> Item {
677 let mut state = self.lock();
678 let id = state.next_id;
679 state.existing += 1;
680 state.next_id += 1;
681 debug!("new {id}");
682 Item {
683 tracker: self.clone(),
684 id,
685 }
686 }
687
688 fn new_tracker() -> Arc<Self> {
689 Arc::default()
690 }
691
692 fn lock(&self) -> MutexGuard<ItemTrackerState> {
693 self.state.lock().unwrap()
694 }
695 }
696
697 impl Drop for Item {
698 fn drop(&mut self) {
699 debug!("old {}", self.id);
700 self.tracker.state.lock().unwrap().existing -= 1;
701 }
702 }
703
704 impl HasMemoryCost for Item {
705 fn memory_cost(&self, _: EnabledToken) -> usize {
706 mbytes(1)
707 }
708 }
709
710 struct Setup {
711 dtp: DynTimeProvider,
712 trk: Arc<mtracker::MemoryQuotaTracker>,
713 acct: Account,
714 itrk: Arc<ItemTracker>,
715 }
716
717 fn setup(rt: &MockRuntime) -> Setup {
718 let dtp = DynTimeProvider::new(rt.clone());
719 let trk = mk_tracker(rt);
720 let acct = trk.new_account(None).unwrap();
721 let itrk = ItemTracker::new_tracker();
722 Setup {
723 dtp,
724 trk,
725 acct,
726 itrk,
727 }
728 }
729
730 #[derive(Debug)]
731 struct Gigantic;
732 impl HasMemoryCost for Gigantic {
733 fn memory_cost(&self, _et: EnabledToken) -> usize {
734 mbytes(100)
735 }
736 }
737
738 impl Setup {
739 fn check_zero_claimed(&self, n_queues: usize) {
745 let used = self.trk.used_current_approx();
746 debug!(
747 "checking zero balance (with slop {n_queues} * 2 * {}; used={used:?}",
748 *mtracker::MAX_CACHE,
749 );
750 assert!(used.unwrap() <= n_queues * 2 * *mtracker::MAX_CACHE);
751 }
752 }
753
754 #[traced_test]
755 #[test]
756 fn lifecycle() {
757 MockRuntime::test_with_various(|rt| async move {
758 let s = setup(&rt);
759 let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
760
761 tx.send(s.itrk.new_item()).await.unwrap();
762 let _: Item = rx.next().await.unwrap();
763
764 for _ in 0..20 {
765 tx.send(s.itrk.new_item()).await.unwrap();
766 }
767
768 debug!("still existing items {}", s.itrk.lock().existing);
770
771 rt.advance_until_stalled().await;
772
773 assert!(s.itrk.lock().existing == 0);
775
776 assert!(rx.next().await.is_none());
777
778 let _: SendError<_> = tx.send(s.itrk.new_item()).await.unwrap_err();
781 });
782 }
783
784 #[traced_test]
785 #[test]
786 fn fill_and_empty() {
787 MockRuntime::test_with_various(|rt| async move {
788 let s = setup(&rt);
789 let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
790
791 const COUNT: usize = 19;
792
793 for _ in 0..COUNT {
794 tx.send(s.itrk.new_item()).await.unwrap();
795 }
796
797 rt.advance_until_stalled().await;
798
799 for _ in 0..COUNT {
800 let _: Item = rx.next().await.unwrap();
801 }
802
803 rt.advance_until_stalled().await;
804
805 s.check_zero_claimed(1);
807 });
808 }
809
810 #[traced_test]
811 #[test]
812 fn sink_error() {
813 #[derive(Debug, Copy, Clone)]
814 struct BustedSink {
815 error: BustedError,
816 }
817
818 impl<T> Sink<T> for BustedSink {
819 type Error = BustedError;
820
821 fn poll_ready(
822 self: Pin<&mut Self>,
823 _: &mut Context<'_>,
824 ) -> Poll<Result<(), Self::Error>> {
825 Ready(Err(self.error))
826 }
827 fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
828 panic!("poll_ready always gives error, start_send should not be called");
829 }
830 fn poll_flush(
831 self: Pin<&mut Self>,
832 _: &mut Context<'_>,
833 ) -> Poll<Result<(), Self::Error>> {
834 Ready(Ok(()))
835 }
836 fn poll_close(
837 self: Pin<&mut Self>,
838 _: &mut Context<'_>,
839 ) -> Poll<Result<(), Self::Error>> {
840 Ready(Ok(()))
841 }
842 }
843
844 impl<T> SinkTrySend<T> for BustedSink {
845 type Error = BustedError;
846
847 fn try_send_or_return(self: Pin<&mut Self>, item: T) -> Result<(), (BustedError, T)> {
848 Err((self.error, item))
849 }
850 }
851
852 impl tor_async_utils::SinkTrySendError for BustedError {
853 fn is_disconnected(&self) -> bool {
854 self.is_disconnected
855 }
856 fn is_full(&self) -> bool {
857 false
858 }
859 }
860
861 #[derive(Error, Debug, Clone, Copy)]
862 #[error("busted, for testing, dc={is_disconnected:?}")]
863 struct BustedError {
864 is_disconnected: bool,
865 }
866
867 struct BustedQueueSpec {
868 error: BustedError,
869 }
870 impl Sealed for BustedQueueSpec {}
871 impl ChannelSpec for BustedQueueSpec {
872 type Sender<T: Debug + Send + 'static> = BustedSink;
873 type Receiver<T: Debug + Send + 'static> = futures::stream::Pending<T>;
874 type SendError = BustedError;
875 fn raw_channel<T: Debug + Send + 'static>(self) -> (BustedSink, Self::Receiver<T>) {
876 (BustedSink { error: self.error }, futures::stream::pending())
877 }
878 fn close_receiver<T: Debug + Send + 'static>(_rx: &mut Self::Receiver<T>) {}
879 }
880
881 use ErasedSinkTrySendError as ESTSE;
882
883 MockRuntime::test_with_various(|rt| async move {
884 let error = BustedError {
885 is_disconnected: true,
886 };
887
888 let s = setup(&rt);
889 let (mut tx, _rx) = BustedQueueSpec { error }
890 .new_mq(s.dtp.clone(), &s.acct)
891 .unwrap();
892
893 let e = tx.send(s.itrk.new_item()).await.unwrap_err();
894 assert!(matches!(e, SendError::Channel(BustedError { .. })));
895
896 assert_eq!(s.itrk.lock().existing, 0);
898
899 fn error_is_other_of<E>(e: ESTSE) -> Result<(), impl Debug>
902 where
903 E: std::error::Error + 'static,
904 {
905 match e {
906 ESTSE::Other(e) if e.is::<E>() => Ok(()),
907 other => Err(other),
908 }
909 }
910
911 let item = s.itrk.new_item();
912
913 let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
916 assert!(matches!(e, ESTSE::Disconnected), "{e:?}");
917
918 let error = BustedError {
921 is_disconnected: false,
922 };
923 let (mut tx, _rx) = BustedQueueSpec { error }
924 .new_mq(s.dtp.clone(), &s.acct)
925 .unwrap();
926 let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
927 error_is_other_of::<BustedError>(e).unwrap();
928
929 s.check_zero_claimed(1);
931
932 {
936 let (mut tx, _rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
937 tx.send(Gigantic).await.unwrap();
938 rt.advance_until_stalled().await;
939 }
940
941 let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
942 error_is_other_of::<crate::Error>(e).unwrap();
943
944 drop::<Item>(item);
945 });
946 }
947}