Skip to main content

tor_memquota/
mq_queue.rs

1//! Queues that participate in the memory quota system
2//!
3//! Wraps a communication channel, such as [`futures::channel::mpsc`],
4//! tracks the memory use of the queue,
5//! and participates in the memory quota system.
6//!
7//! Each item in the queue must know its memory cost,
8//! and provide it via [`HasMemoryCost`].
9//!
10//! New queues are created by calling the [`new_mq`](ChannelSpec::new_mq) method
11//! on a [`ChannelSpec`],
12//! for example [`MpscSpec`] or [`MpscUnboundedSpec`].
13//!
14//! The ends implement [`Stream`] and [`Sink`].
15//! If the underlying channel's sender is `Clone`,
16//! for example with an MPSC queue, the returned sender is also `Clone`.
17//!
18//! Note that the [`Sender`] and [`Receiver`] only hold weak references to the `Account`.
19//! Ie, the queue is not the accountholder.
20//! The caller should keep a separate copy of the account.
21//!
22//! # Example
23//!
24//! ```
25//! use tor_memquota::{MemoryQuotaTracker, HasMemoryCost, EnabledToken};
26//! use tor_rtcompat::{DynTimeProvider, PreferredRuntime};
27//! use tor_memquota::mq_queue::{MpscSpec, ChannelSpec as _};
28//! # fn m() -> tor_memquota::Result<()> {
29//!
30//! #[derive(Debug)]
31//! struct Message(String);
32//! impl HasMemoryCost for Message {
33//!     fn memory_cost(&self, _: EnabledToken) -> usize { self.0.len() }
34//! }
35//!
36//! let runtime = PreferredRuntime::create().unwrap();
37//! let time_prov = DynTimeProvider::new(runtime.clone());
38#![cfg_attr(
39    feature = "memquota",
40    doc = "let config  = tor_memquota::Config::builder().max(1024*1024*1024).build().unwrap();",
41    doc = "let trk = MemoryQuotaTracker::new(&runtime, config).unwrap();"
42)]
43#![cfg_attr(
44    not(feature = "memquota"),
45    doc = "let trk = MemoryQuotaTracker::new_noop();"
46)]
47//! let account = trk.new_account(None).unwrap();
48//!
49//! let (tx, rx) = MpscSpec { buffer: 10 }.new_mq::<Message>(time_prov, &account)?;
50//! #
51//! # Ok(())
52//! # }
53//! # m().unwrap();
54//! ```
55//!
56//! # Caveat
57//!
58//! The memory use tracking is based on external observations,
59//! i.e., items inserted and removed.
60//!
61//! How well this reflects the actual memory use of the channel
62//! depends on the channel's implementation.
63//!
64//! For example, if the channel uses a single contiguous buffer
65//! containing the unboxed items, and that buffer doesn't shrink,
66//! then the memory tracking can be based on an underestimate.
67//! (This is significantly mitigated if the bulk of the memory use
68//! for each item is separately boxed.)
69
70#![forbid(unsafe_code)] // if you remove this, enable (or write) miri tests (git grep miri)
71
72use tor_async_utils::peekable_stream::UnobtrusivePeekableStream;
73
74use crate::internal_prelude::*;
75
76use std::task::{Context, Poll, Poll::*};
77use tor_async_utils::{ErasedSinkTrySendError, SinkCloseChannel, SinkTrySend};
78
79//---------- Sender ----------
80
81/// Sender for a channel that participates in the memory quota system
82///
83/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
84/// See the [module-level docs](crate::mq_queue).
85#[derive(Educe)]
86#[educe(Debug, Clone(bound = "C::Sender<Entry<T>>: Clone"))]
87pub struct Sender<T: Debug + Send + 'static, C: ChannelSpec> {
88    /// The inner sink
89    tx: C::Sender<Entry<T>>,
90
91    /// Our clone of the `Participation`, for memory accounting
92    mq: TypedParticipation<Entry<T>>,
93
94    /// Time provider for getting the data age
95    #[educe(Debug(ignore))] // CoarseTimeProvider isn't Debug
96    runtime: DynTimeProvider,
97}
98
99//---------- Receiver ----------
100
101/// Receiver for a channel that participates in the memory quota system
102///
103/// Returned by [`ChannelSpec::new_mq`], a method on `C`.
104/// See the [module-level docs](crate::mq_queue).
105#[derive(Educe)] // not Clone, see below
106#[educe(Debug)]
107pub struct Receiver<T: Debug + Send + 'static, C: ChannelSpec> {
108    /// Payload
109    //
110    // We don't make this an "exposed" `Arc`,
111    // because that would allow the caller to clone it -
112    // but we don't promise we're a multi-consumer queue even if `C::Receiver` is.
113    //
114    // Despite the in-principle Clone-ability of our `Receiver`,
115    // we're not a working multi-consumer queue, even if the underlying channel is,
116    // because StreamUnobtrusivePeeker isn't multi-consumer.
117    //
118    // Providing the multi-consumer feature would perhaps involve StreamUnobtrusivePeeker
119    // handling multiple wakers, and then `impl Clone for Receiver where C::Receiver: Clone`.
120    // (and writing a bunch of tests).
121    //
122    // This would all be useless without also `impl ChannelSpec`
123    // for a multi-consumer queue.
124    inner: Arc<ReceiverInner<T, C>>,
125}
126
127/// Payload of `Receiver`, that's within the `Arc`, but contains the `Mutex`.
128///
129/// This is a separate type because
130/// it's what we need to implement [`IsParticipant`] for.
131#[derive(Educe)]
132#[educe(Debug)]
133struct ReceiverInner<T: Debug + Send + 'static, C: ChannelSpec> {
134    /// Mutable state
135    ///
136    /// If we have collapsed due to memory reclaim, state is replaced by an `Err`.
137    /// In that case the caller mostly can't send on the Sender either,
138    /// because we'll have torn down the Participant,
139    /// so claims (beyond the cache in the `Sender`'s `Participation`) will fail.
140    state: Mutex<Result<ReceiverState<T, C>, CollapsedDueToReclaim>>,
141}
142
143/// Mutable state of a `Receiver`
144///
145/// Normally the mutex is only locked by the receiving task.
146/// On memory pressure, mutex is acquired by the memory system,
147/// which has a clone of the `Arc<ReceiverInner>`.
148///
149/// Within `Arc<Mutex<Result<, >>>`.
150#[derive(Educe)]
151#[educe(Debug)]
152struct ReceiverState<T: Debug + Send + 'static, C: ChannelSpec> {
153    /// The inner stream, but with an unobtrusive peek for getting the oldest data age
154    rx: StreamUnobtrusivePeeker<C::Receiver<Entry<T>>>,
155
156    /// The `Participation`, which we use for memory accounting
157    ///
158    /// ### Performance and locality
159    ///
160    /// We have separate [`Participation`]s for rx and tx.
161    /// The tx is constantly claiming and the rx releasing;
162    /// at least each MAX_CACHE, they must balance out
163    /// via the (fairly globally shared) `MemoryQuotaTracker`.
164    ///
165    /// If this turns out to be a problem,
166    /// we could arrange to share a `Participation`.
167    mq: TypedParticipation<Entry<T>>,
168
169    /// Hooks passed to [`Receiver::register_collapse_hook`]
170    ///
171    /// When receiver dropped, or memory reclaimed, we call all of these.
172    #[educe(Debug(method = "receiver_state_debug_collapse_notify"))]
173    collapse_callbacks: Vec<CollapseCallback>,
174}
175
176//---------- other types ----------
177
178/// Entry in the inner queue
179#[derive(Debug)]
180struct Entry<T> {
181    /// The actual entry
182    t: T,
183    /// The data age - when it was inserted into the queue
184    when: CoarseInstant,
185}
186
187/// Error returned when trying to write to a [`Sender`]
188#[derive(Error, Clone, Debug)]
189#[non_exhaustive]
190pub enum SendError<CE> {
191    /// The underlying channel rejected the message
192    // Can't be `#[from]` because rustc can't see that C::SendError isn't SendError<C>
193    #[error("channel send failed")]
194    Channel(#[source] CE),
195
196    /// The memory quota system prevented the send
197    ///
198    /// NB: when the channel is torn down due to memory pressure,
199    /// the inner receiver is also torn down.
200    /// This means that this variant is not always reported:
201    /// sending on the sender in this situation
202    /// may give [`SendError::Channel`] instead.
203    #[error("memory quota exhausted, queue reclaimed")]
204    Memquota(#[from] Error),
205}
206
207/// Callback passed to `Receiver::register_collapse_hook`
208pub type CollapseCallback = Box<dyn FnOnce(CollapseReason) + Send + Sync + 'static>;
209
210/// Argument to `CollapseCallback`: why are we collapsing?
211#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
212#[non_exhaustive]
213pub enum CollapseReason {
214    /// The `Receiver` was dropped
215    ReceiverDropped,
216
217    /// The memory quota system asked us to reclaim memory
218    MemoryReclaimed,
219}
220
221/// Marker, appears in state as `Err` to mean "we have collapsed"
222#[derive(Debug, Clone, Copy)]
223struct CollapsedDueToReclaim;
224
225//==================== Channel ====================
226
227/// Specification for a communication channel
228///
229/// Implemented for [`MpscSpec`] and [`MpscUnboundedSpec`].
230//
231// # Correctness (uncomment this if this trait is made unsealed)
232//
233// It is a requirement that this object really is some kind of channel.
234// Specifically:
235//
236//  * Things that get put into the `Sender` must eventually emerge from the `Receiver`.
237//  * Nothing may emerge from the `Receiver` that wasn't put into the `Sender`.
238//  * If the `Sender` and `Receiver` are dropped, the items must also get dropped.
239//
240// If these requirements are violated, it could result in corruption of the memory accounts
241//
242// Ideally, if the `Receiver` is dropped, most of the items are dropped soon.
243//
244pub trait ChannelSpec: Sealed /* see Correctness, above */ + Sized + 'static {
245    /// The sending [`Sink`] for items of type `T`.
246    //
247    // Right now we insist that everything is Unpin.
248    // futures::channel::mpsc's types all are.
249    // If we wanted to support !Unpin channels, that would be possible,
250    // but we would have some work to do.
251    //
252    // We also insist that everything is Debug.  That means `T: Debug`,
253    // as well as the channels.  We could avoid that, but it would involve
254    // skipping debug of important fields, or pervasive complex trait bounds
255    // (Eg `#[educe(Debug(bound = "C::Receiver<Entry<T>>: Debug"))]` or worse.)
256    //
257    // This is a GAT because we need to instantiate it with T=Entry<_>.
258    type Sender<T: Debug + Send + 'static>: Sink<T, Error = Self::SendError>
259        + Debug + Unpin + Sized;
260
261    /// The receiving [`Stream`] for items of type `T`.
262    type Receiver<T: Debug + Send + 'static>: Stream<Item = T> + Debug + Unpin + Send + Sized;
263
264    /// The error type `<Receiver<_> as Stream>::Error`.
265    ///
266    /// (For this trait to be implemented, it is not allowed to depend on `T`.)
267    type SendError: std::error::Error;
268
269    /// Create a new channel, based on the spec `self`, that participates in the memory quota
270    ///
271    /// See the [module-level docs](crate::mq_queue) for an example.
272    //
273    // This method is supposed to be called by the user, not overridden.
274    #[allow(clippy::type_complexity)] // the Result; not sensibly reducible or aliasable
275    fn new_mq<T>(self, runtime: DynTimeProvider, account: &Account) -> crate::Result<(
276        Sender<T, Self>,
277        Receiver<T, Self>,
278    )>
279    where
280        T: HasMemoryCost + Debug + Send + 'static,
281    {
282        let (rx, (tx, mq)) = account.register_participant_with(
283            runtime.now_coarse(),
284            move |mq| {
285                let mq = TypedParticipation::new(mq);
286                let collapse_callbacks = vec![];
287                let (tx, rx) = self.raw_channel::<Entry<T>>();
288                let rx = StreamUnobtrusivePeeker::new(rx);
289                let state = ReceiverState { rx, mq: mq.clone(), collapse_callbacks };
290                let state = Mutex::new(Ok(state));
291                let inner = ReceiverInner { state };
292                Ok::<_, crate::Error>((inner.into(), (tx, mq)))
293            },
294        )??;
295
296        let runtime = runtime.clone();
297
298        let tx = Sender { runtime, tx, mq };
299        let rx = Receiver { inner: rx };
300
301        Ok((tx, rx))
302    }
303
304    /// Create a new raw channel as specified by `self`
305    //
306    // This is called by `mq_queue`.
307    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>);
308
309    /// Close the receiver, preventing further sends
310    ///
311    /// This should ensure that only a smallish bounded number of further items
312    /// can be sent, before errors start being returned.
313    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>);
314}
315
316//---------- impls of Channel ----------
317
318/// Specification for a (bounded) MPSC channel
319///
320/// Corresponds to the constructor [`futures::channel::mpsc::channel`].
321///
322/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this type.
323///
324/// (The [`new`](MpscUnboundedSpec::new) method is provided for convenience;
325/// you may also construct the value directly.)
326#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor)]
327#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::channel
328pub struct MpscSpec {
329    /// Buffer size; see [`futures::channel::mpsc::channel`].
330    pub buffer: usize,
331}
332
333/// Specification for an unbounded MPSC channel
334///
335/// Corresponds to the constructor [`futures::channel::mpsc::unbounded`].
336///
337/// Call [`new_mq`](ChannelSpec::new_mq) on a value of this unit type.
338///
339/// (The [`new`](MpscUnboundedSpec::new) method is provided for orthogonality.)
340#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Constructor, Default)]
341#[allow(clippy::exhaustive_structs)] // This is precisely the arguments to mpsc::unbounded
342pub struct MpscUnboundedSpec;
343
344impl Sealed for MpscSpec {}
345impl Sealed for MpscUnboundedSpec {}
346
347impl ChannelSpec for MpscSpec {
348    type Sender<T: Debug + Send + 'static> = mpsc::Sender<T>;
349    type Receiver<T: Debug + Send + 'static> = mpsc::Receiver<T>;
350    type SendError = mpsc::SendError;
351
352    fn raw_channel<T: Debug + Send + 'static>(self) -> (mpsc::Sender<T>, mpsc::Receiver<T>) {
353        mpsc_channel_no_memquota(self.buffer)
354    }
355
356    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
357        rx.close();
358    }
359}
360
361impl ChannelSpec for MpscUnboundedSpec {
362    type Sender<T: Debug + Send + 'static> = mpsc::UnboundedSender<T>;
363    type Receiver<T: Debug + Send + 'static> = mpsc::UnboundedReceiver<T>;
364    type SendError = mpsc::SendError;
365
366    fn raw_channel<T: Debug + Send + 'static>(self) -> (Self::Sender<T>, Self::Receiver<T>) {
367        mpsc::unbounded()
368    }
369
370    fn close_receiver<T: Debug + Send + 'static>(rx: &mut Self::Receiver<T>) {
371        rx.close();
372    }
373}
374
375//==================== implementations ====================
376
377//---------- Sender ----------
378
379impl<T, C> Sink<T> for Sender<T, C>
380where
381    T: HasMemoryCost + Debug + Send + 'static,
382    C: ChannelSpec,
383{
384    type Error = SendError<C::SendError>;
385
386    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
387        self.get_mut()
388            .tx
389            .poll_ready_unpin(cx)
390            .map_err(SendError::Channel)
391    }
392
393    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
394        let self_ = self.get_mut();
395        let item = Entry {
396            t: item,
397            when: self_.runtime.now_coarse(),
398        };
399        self_.mq.try_claim(item, |item| {
400            self_.tx.start_send_unpin(item).map_err(SendError::Channel)
401        })?
402    }
403
404    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
405        self.tx
406            .poll_flush_unpin(cx)
407            .map(|r| r.map_err(SendError::Channel))
408    }
409
410    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
411        self.tx
412            .poll_close_unpin(cx)
413            .map(|r| r.map_err(SendError::Channel))
414    }
415}
416
417impl<T, C> SinkTrySend<T> for Sender<T, C>
418where
419    T: HasMemoryCost + Debug + Send + 'static,
420    C: ChannelSpec,
421    C::Sender<Entry<T>>: SinkTrySend<Entry<T>>,
422    <C::Sender<Entry<T>> as SinkTrySend<Entry<T>>>::Error: Send + Sync,
423{
424    type Error = ErasedSinkTrySendError;
425    fn try_send_or_return(
426        self: Pin<&mut Self>,
427        item: T,
428    ) -> Result<(), (<Self as SinkTrySend<T>>::Error, T)> {
429        let self_ = self.get_mut();
430        let item = Entry {
431            t: item,
432            when: self_.runtime.now_coarse(),
433        };
434
435        use ErasedSinkTrySendError as ESTSE;
436
437        self_
438            .mq
439            .try_claim_or_return(item, |item| {
440                Pin::new(&mut self_.tx).try_send_or_return(item)
441            })
442            .map_err(|(mqe, unsent)| (ESTSE::Other(Arc::new(mqe)), unsent.t))?
443            .map_err(|(tse, unsent)| (ESTSE::from(tse), unsent.t))
444    }
445}
446
447impl<T, C> SinkCloseChannel<T> for Sender<T, C>
448where
449    T: HasMemoryCost + Debug + Send, //Debug + 'static,
450    C: ChannelSpec,
451    C::Sender<Entry<T>>: SinkCloseChannel<Entry<T>>,
452{
453    fn close_channel(self: Pin<&mut Self>) {
454        Pin::new(&mut self.get_mut().tx).close_channel();
455    }
456}
457
458impl<T, C> Sender<T, C>
459where
460    T: Debug + Send + 'static,
461    C: ChannelSpec,
462{
463    /// Obtain a reference to the `Sender`'s [`DynTimeProvider`]
464    ///
465    /// (This can sometimes be used to avoid having to keep
466    /// a separate clone of the time provider.)
467    pub fn time_provider(&self) -> &DynTimeProvider {
468        &self.runtime
469    }
470}
471
472//---------- Receiver ----------
473
474impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Stream for Receiver<T, C> {
475    type Item = T;
476
477    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
478        let mut state = self.inner.lock();
479        let state = match &mut *state {
480            Ok(y) => y,
481            Err(CollapsedDueToReclaim) => return Ready(None),
482        };
483        let ret = state.rx.poll_next_unpin(cx);
484        if let Ready(Some(item)) = &ret {
485            if let Some(enabled) = EnabledToken::new_if_compiled_in() {
486                let cost = item.typed_memory_cost(enabled);
487                state.mq.release(&cost);
488            }
489        }
490        ret.map(|r| r.map(|e| e.t))
491    }
492}
493
494impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> FusedStream for Receiver<T, C>
495where
496    C::Receiver<Entry<T>>: FusedStream,
497{
498    fn is_terminated(&self) -> bool {
499        match &*self.inner.lock() {
500            Ok(y) => y.rx.is_terminated(),
501            Err(CollapsedDueToReclaim) => true,
502        }
503    }
504}
505
506// TODO: When we have a trait for peekable streams, Receiver should implement it
507
508impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> Receiver<T, C> {
509    /// Register a callback, called when we tear the channel down
510    ///
511    /// This will be called when the `Receiver` is dropped,
512    /// or if we tear down because the memory system asks us to reclaim.
513    ///
514    /// `call` might be called at any time, from any thread, but
515    /// it won't be holding any locks relating to memory quota or the queue.
516    ///
517    /// If `self` is *already* in the process of being torn down,
518    /// `call` might be called immediately, reentrantly!
519    //
520    // This callback is nicer than us handing out an mpsc rx
521    // which user must read and convert items from.
522    //
523    // This method is on Receiver because that has the State,
524    // but could be called during setup to hook both sender's and
525    // receiver's shutdown mechanisms.
526    pub fn register_collapse_hook(&self, call: CollapseCallback) {
527        let mut state = self.inner.lock();
528        let state = match &mut *state {
529            Ok(y) => y,
530            Err(reason) => {
531                let reason = (*reason).into();
532                drop::<MutexGuard<_>>(state);
533                call(reason);
534                return;
535            }
536        };
537        state.collapse_callbacks.push(call);
538    }
539}
540
541impl<T: Debug + Send + 'static, C: ChannelSpec> ReceiverInner<T, C> {
542    /// Convenience function to take the lock
543    fn lock(&self) -> MutexGuard<Result<ReceiverState<T, C>, CollapsedDueToReclaim>> {
544        self.state.lock().expect("mq_mpsc lock poisoned")
545    }
546}
547
548impl<T: HasMemoryCost + Debug + Send + 'static, C: ChannelSpec> IsParticipant
549    for ReceiverInner<T, C>
550{
551    fn get_oldest(&self, _: EnabledToken) -> Option<CoarseInstant> {
552        let mut state = self.lock();
553        let state = match &mut *state {
554            Ok(y) => y,
555            Err(CollapsedDueToReclaim) => return None,
556        };
557        Pin::new(&mut state.rx)
558            .unobtrusive_peek()
559            .map(|peeked| peeked.when)
560    }
561
562    fn reclaim(self: Arc<Self>, _: EnabledToken) -> mtracker::ReclaimFuture {
563        Box::pin(async move {
564            let reason = CollapsedDueToReclaim;
565            let mut state_guard = self.lock();
566            let state = mem::replace(&mut *state_guard, Err(reason));
567            drop::<MutexGuard<_>>(state_guard);
568            #[allow(clippy::single_match)] // pattern is intentional.
569            match state {
570                Ok(mut state) => {
571                    for call in state.collapse_callbacks.drain(..) {
572                        call(reason.into());
573                    }
574                    drop::<ReceiverState<_, _>>(state); // will drain queue, too
575                }
576                Err(CollapsedDueToReclaim) => {}
577            };
578            mtracker::Reclaimed::Collapsing
579        })
580    }
581}
582
583impl<T: Debug + Send + 'static, C: ChannelSpec> Drop for ReceiverState<T, C> {
584    fn drop(&mut self) {
585        // If there's a mutex, we're in its drop
586
587        // `destroy_participant` prevents the sender from making further non-cached claims
588        mem::replace(&mut self.mq, Participation::new_dangling().into())
589            .into_raw()
590            .destroy_participant();
591
592        for call in self.collapse_callbacks.drain(..) {
593            call(CollapseReason::ReceiverDropped);
594        }
595
596        // try to free whatever is in the queue, in case the stream doesn't do that itself
597        // No-one can poll us any more, so we are no longer interested in wakeups
598        let mut noop_cx = Context::from_waker(Waker::noop());
599
600        // prevent further sends, so that our drain doesn't race indefinitely with the sender
601        if let Some(mut rx_inner) =
602            StreamUnobtrusivePeeker::as_raw_inner_pin_mut(Pin::new(&mut self.rx))
603        {
604            C::close_receiver(&mut rx_inner);
605        }
606
607        while let Ready(Some(item)) = self.rx.poll_next_unpin(&mut noop_cx) {
608            drop::<Entry<T>>(item);
609        }
610    }
611}
612
613/// Method for educe's Debug impl for `ReceiverState.collapse_callbacks`
614fn receiver_state_debug_collapse_notify(
615    v: &[CollapseCallback],
616    f: &mut fmt::Formatter,
617) -> fmt::Result {
618    Debug::fmt(&v.len(), f)
619}
620
621//---------- misc ----------
622
623impl<T: HasMemoryCost> HasMemoryCost for Entry<T> {
624    fn memory_cost(&self, enabled: EnabledToken) -> usize {
625        let time_size = std::alloc::Layout::new::<CoarseInstant>().size();
626        self.t.memory_cost(enabled).saturating_add(time_size)
627    }
628}
629
630impl From<CollapsedDueToReclaim> for CollapseReason {
631    fn from(CollapsedDueToReclaim: CollapsedDueToReclaim) -> CollapseReason {
632        CollapseReason::MemoryReclaimed
633    }
634}
635
636#[cfg(all(test, feature = "memquota", not(miri) /* coarsetime */))]
637mod test {
638    // @@ begin test lint list maintained by maint/add_warning @@
639    #![allow(clippy::bool_assert_comparison)]
640    #![allow(clippy::clone_on_copy)]
641    #![allow(clippy::dbg_macro)]
642    #![allow(clippy::mixed_attributes_style)]
643    #![allow(clippy::print_stderr)]
644    #![allow(clippy::print_stdout)]
645    #![allow(clippy::single_char_pattern)]
646    #![allow(clippy::unwrap_used)]
647    #![allow(clippy::unchecked_time_subtraction)]
648    #![allow(clippy::useless_vec)]
649    #![allow(clippy::needless_pass_by_value)]
650    //! <!-- @@ end test lint list maintained by maint/add_warning @@ -->
651    #![allow(clippy::arithmetic_side_effects)] // don't mind potential panicking ops in tests
652
653    use super::*;
654    use crate::mtracker::test::*;
655    use tor_rtmock::MockRuntime;
656    use tracing::debug;
657    use tracing_test::traced_test;
658
659    #[derive(Default, Debug)]
660    struct ItemTracker {
661        state: Mutex<ItemTrackerState>,
662    }
663    #[derive(Default, Debug)]
664    struct ItemTrackerState {
665        existing: usize,
666        next_id: usize,
667    }
668
669    #[derive(Debug)]
670    struct Item {
671        id: usize,
672        tracker: Arc<ItemTracker>,
673    }
674
675    impl ItemTracker {
676        fn new_item(self: &Arc<Self>) -> Item {
677            let mut state = self.lock();
678            let id = state.next_id;
679            state.existing += 1;
680            state.next_id += 1;
681            debug!("new {id}");
682            Item {
683                tracker: self.clone(),
684                id,
685            }
686        }
687
688        fn new_tracker() -> Arc<Self> {
689            Arc::default()
690        }
691
692        fn lock(&self) -> MutexGuard<ItemTrackerState> {
693            self.state.lock().unwrap()
694        }
695    }
696
697    impl Drop for Item {
698        fn drop(&mut self) {
699            debug!("old {}", self.id);
700            self.tracker.state.lock().unwrap().existing -= 1;
701        }
702    }
703
704    impl HasMemoryCost for Item {
705        fn memory_cost(&self, _: EnabledToken) -> usize {
706            mbytes(1)
707        }
708    }
709
710    struct Setup {
711        dtp: DynTimeProvider,
712        trk: Arc<mtracker::MemoryQuotaTracker>,
713        acct: Account,
714        itrk: Arc<ItemTracker>,
715    }
716
717    fn setup(rt: &MockRuntime) -> Setup {
718        let dtp = DynTimeProvider::new(rt.clone());
719        let trk = mk_tracker(rt);
720        let acct = trk.new_account(None).unwrap();
721        let itrk = ItemTracker::new_tracker();
722        Setup {
723            dtp,
724            trk,
725            acct,
726            itrk,
727        }
728    }
729
730    #[derive(Debug)]
731    struct Gigantic;
732    impl HasMemoryCost for Gigantic {
733        fn memory_cost(&self, _et: EnabledToken) -> usize {
734            mbytes(100)
735        }
736    }
737
738    impl Setup {
739        /// Check that claims and releases have balanced out
740        ///
741        /// `n_queues` is the number of queues that exist.
742        /// This is used to provide some slop, since each queue has two [`Participation`]s
743        /// each of which can have some cached claim.
744        fn check_zero_claimed(&self, n_queues: usize) {
745            let used = self.trk.used_current_approx();
746            debug!(
747                "checking zero balance (with slop {n_queues} * 2 * {}; used={used:?}",
748                *mtracker::MAX_CACHE,
749            );
750            assert!(used.unwrap() <= n_queues * 2 * *mtracker::MAX_CACHE);
751        }
752    }
753
754    #[traced_test]
755    #[test]
756    fn lifecycle() {
757        MockRuntime::test_with_various(|rt| async move {
758            let s = setup(&rt);
759            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
760
761            tx.send(s.itrk.new_item()).await.unwrap();
762            let _: Item = rx.next().await.unwrap();
763
764            for _ in 0..20 {
765                tx.send(s.itrk.new_item()).await.unwrap();
766            }
767
768            // reclaim task hasn't had a chance to run
769            debug!("still existing items {}", s.itrk.lock().existing);
770
771            rt.advance_until_stalled().await;
772
773            // reclaim task should have torn everything down
774            assert!(s.itrk.lock().existing == 0);
775
776            assert!(rx.next().await.is_none());
777
778            // Empirically, this is a "disconnected" error from the inner mpsc,
779            // but let's not assert that.
780            let _: SendError<_> = tx.send(s.itrk.new_item()).await.unwrap_err();
781        });
782    }
783
784    #[traced_test]
785    #[test]
786    fn fill_and_empty() {
787        MockRuntime::test_with_various(|rt| async move {
788            let s = setup(&rt);
789            let (mut tx, mut rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
790
791            const COUNT: usize = 19;
792
793            for _ in 0..COUNT {
794                tx.send(s.itrk.new_item()).await.unwrap();
795            }
796
797            rt.advance_until_stalled().await;
798
799            for _ in 0..COUNT {
800                let _: Item = rx.next().await.unwrap();
801            }
802
803            rt.advance_until_stalled().await;
804
805            // no memory should be claimed
806            s.check_zero_claimed(1);
807        });
808    }
809
810    #[traced_test]
811    #[test]
812    fn sink_error() {
813        #[derive(Debug, Copy, Clone)]
814        struct BustedSink {
815            error: BustedError,
816        }
817
818        impl<T> Sink<T> for BustedSink {
819            type Error = BustedError;
820
821            fn poll_ready(
822                self: Pin<&mut Self>,
823                _: &mut Context<'_>,
824            ) -> Poll<Result<(), Self::Error>> {
825                Ready(Err(self.error))
826            }
827            fn start_send(self: Pin<&mut Self>, _item: T) -> Result<(), Self::Error> {
828                panic!("poll_ready always gives error, start_send should not be called");
829            }
830            fn poll_flush(
831                self: Pin<&mut Self>,
832                _: &mut Context<'_>,
833            ) -> Poll<Result<(), Self::Error>> {
834                Ready(Ok(()))
835            }
836            fn poll_close(
837                self: Pin<&mut Self>,
838                _: &mut Context<'_>,
839            ) -> Poll<Result<(), Self::Error>> {
840                Ready(Ok(()))
841            }
842        }
843
844        impl<T> SinkTrySend<T> for BustedSink {
845            type Error = BustedError;
846
847            fn try_send_or_return(self: Pin<&mut Self>, item: T) -> Result<(), (BustedError, T)> {
848                Err((self.error, item))
849            }
850        }
851
852        impl tor_async_utils::SinkTrySendError for BustedError {
853            fn is_disconnected(&self) -> bool {
854                self.is_disconnected
855            }
856            fn is_full(&self) -> bool {
857                false
858            }
859        }
860
861        #[derive(Error, Debug, Clone, Copy)]
862        #[error("busted, for testing, dc={is_disconnected:?}")]
863        struct BustedError {
864            is_disconnected: bool,
865        }
866
867        struct BustedQueueSpec {
868            error: BustedError,
869        }
870        impl Sealed for BustedQueueSpec {}
871        impl ChannelSpec for BustedQueueSpec {
872            type Sender<T: Debug + Send + 'static> = BustedSink;
873            type Receiver<T: Debug + Send + 'static> = futures::stream::Pending<T>;
874            type SendError = BustedError;
875            fn raw_channel<T: Debug + Send + 'static>(self) -> (BustedSink, Self::Receiver<T>) {
876                (BustedSink { error: self.error }, futures::stream::pending())
877            }
878            fn close_receiver<T: Debug + Send + 'static>(_rx: &mut Self::Receiver<T>) {}
879        }
880
881        use ErasedSinkTrySendError as ESTSE;
882
883        MockRuntime::test_with_various(|rt| async move {
884            let error = BustedError {
885                is_disconnected: true,
886            };
887
888            let s = setup(&rt);
889            let (mut tx, _rx) = BustedQueueSpec { error }
890                .new_mq(s.dtp.clone(), &s.acct)
891                .unwrap();
892
893            let e = tx.send(s.itrk.new_item()).await.unwrap_err();
894            assert!(matches!(e, SendError::Channel(BustedError { .. })));
895
896            // item should have been destroyed
897            assert_eq!(s.itrk.lock().existing, 0);
898
899            // ---- Test try_send error handling ----
900
901            fn error_is_other_of<E>(e: ESTSE) -> Result<(), impl Debug>
902            where
903                E: std::error::Error + 'static,
904            {
905                match e {
906                    ESTSE::Other(e) if e.is::<E>() => Ok(()),
907                    other => Err(other),
908                }
909            }
910
911            let item = s.itrk.new_item();
912
913            // Test try_send failure due to BustedError, is_disconnected: true
914
915            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
916            assert!(matches!(e, ESTSE::Disconnected), "{e:?}");
917
918            // Test try_send failure due to BustedError, is_disconnected: false (ie, Other)
919
920            let error = BustedError {
921                is_disconnected: false,
922            };
923            let (mut tx, _rx) = BustedQueueSpec { error }
924                .new_mq(s.dtp.clone(), &s.acct)
925                .unwrap();
926            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
927            error_is_other_of::<BustedError>(e).unwrap();
928
929            // no memory should be claimed
930            s.check_zero_claimed(1);
931
932            // Test try_send failure due to memory quota collapse
933
934            // cause reclaim
935            {
936                let (mut tx, _rx) = MpscUnboundedSpec.new_mq(s.dtp.clone(), &s.acct).unwrap();
937                tx.send(Gigantic).await.unwrap();
938                rt.advance_until_stalled().await;
939            }
940
941            let (e, item) = Pin::new(&mut tx).try_send_or_return(item).unwrap_err();
942            error_is_other_of::<crate::Error>(e).unwrap();
943
944            drop::<Item>(item);
945        });
946    }
947}