Skip to main content

stream_transfer_limit/
transfer_limit.rs

1use crate::{ChunkLength, TransferCounter, TransferLimitError};
2use futures::{Stream, StreamExt, TryStream, TryStreamExt, stream};
3
4/// Default progress callback type used when no callback is configured.
5pub type NoopProgress<C = usize> = fn(C);
6
7fn noop_progress<C>(_: C) {}
8
9/// Builder for applying byte-count transfer limits to fallible streams.
10#[derive(Debug, Clone)]
11pub struct TransferLimit<C = usize, P = NoopProgress<C>> {
12    limit: Option<C>,
13    bytes_seen: C,
14    failed: bool,
15    on_progress: P,
16}
17
18impl<C> Default for TransferLimit<C, NoopProgress<C>>
19where
20    C: TransferCounter,
21{
22    fn default() -> Self {
23        Self::from_optional_limit(None)
24    }
25}
26
27impl TransferLimit<usize, NoopProgress<usize>> {
28    /// Create a transfer limit that allows at most `limit` bytes.
29    pub fn new(limit: usize) -> Self {
30        Self::from_limit(limit)
31    }
32
33    /// Create a transfer limit from an optional byte limit.
34    pub fn optional(limit: Option<usize>) -> Self {
35        Self::from_optional_limit(limit)
36    }
37
38    /// Create a transfer tracker without a byte limit.
39    pub fn unlimited() -> Self {
40        Self::from_optional_limit(None)
41    }
42}
43
44impl<C> TransferLimit<C, NoopProgress<C>>
45where
46    C: TransferCounter,
47{
48    /// Create a transfer limit using an explicit counter type.
49    pub fn from_limit(limit: C) -> Self {
50        Self::from_optional_limit(Some(limit))
51    }
52
53    /// Create a transfer limit from an optional byte limit using an explicit
54    /// counter type.
55    pub fn from_optional_limit(limit: Option<C>) -> Self {
56        Self {
57            limit,
58            bytes_seen: C::ZERO,
59            failed: false,
60            on_progress: noop_progress,
61        }
62    }
63}
64
65impl<C, P> TransferLimit<C, P>
66where
67    C: TransferCounter,
68{
69    /// Set the maximum allowed number of bytes.
70    ///
71    /// A stream is allowed to produce exactly `limit` bytes. It fails on the
72    /// first chunk that makes the cumulative total greater than `limit`.
73    pub fn with_limit(mut self, limit: C) -> Self {
74        self.limit = Some(limit);
75        self
76    }
77
78    /// Remove the maximum byte limit while keeping progress tracking.
79    pub fn without_limit(mut self) -> Self {
80        self.limit = None;
81        self
82    }
83
84    /// Replace the progress callback.
85    ///
86    /// The callback receives cumulative bytes after every successful chunk read
87    /// from the inner stream, including the chunk that crosses the limit.
88    pub fn on_progress<F>(self, on_progress: F) -> TransferLimit<C, F>
89    where
90        F: FnMut(C),
91    {
92        TransferLimit {
93            limit: self.limit,
94            bytes_seen: self.bytes_seen,
95            failed: self.failed,
96            on_progress,
97        }
98    }
99
100    /// Return the configured maximum byte count, if any.
101    pub fn limit(&self) -> Option<C> {
102        self.limit
103    }
104
105    /// Wrap a fallible stream and apply this transfer limit.
106    pub fn wrap<S>(
107        mut self,
108        stream: S,
109    ) -> impl Stream<Item = Result<S::Ok, TransferLimitError<S::Error, C>>>
110    where
111        S: TryStream,
112        S::Ok: ChunkLength,
113        P: FnMut(C) + Unpin,
114    {
115        self.bytes_seen = C::ZERO;
116        self.failed = false;
117
118        let stream = Box::pin(stream.into_stream());
119        Box::pin(stream::unfold(
120            (stream, self),
121            |(mut stream, mut limit)| async move {
122                if limit.failed {
123                    return None;
124                }
125
126                let item = stream
127                    .next()
128                    .await?
129                    .map_err(TransferLimitError::inner)
130                    .and_then(|chunk| {
131                        limit
132                            .record_chunk(chunk.chunk_len())
133                            .inspect_err(|_| limit.failed = true)
134                            .map(|_| chunk)
135                    });
136
137                Some((item, (stream, limit)))
138            },
139        ))
140    }
141}
142
143impl<C, P> TransferLimit<C, P>
144where
145    C: TransferCounter,
146    P: FnMut(C),
147{
148    fn record_chunk<E>(&mut self, chunk_len: usize) -> Result<(), TransferLimitError<E, C>> {
149        self.bytes_seen = self
150            .bytes_seen
151            .checked_add_chunk(chunk_len)
152            .ok_or_else(|| TransferLimitError::CounterOverflow {
153                bytes_seen: self.bytes_seen,
154                chunk_len,
155            })?;
156        (self.on_progress)(self.bytes_seen);
157
158        self.limit
159            .filter(|&limit| self.bytes_seen > limit)
160            .map_or(Ok(()), |limit| {
161                Err(TransferLimitError::LimitExceeded {
162                    limit,
163                    actual: self.bytes_seen,
164                })
165            })
166    }
167}