Skip to main content

tako_rs_plugins/middleware/
upload_progress.rs

1//! Upload progress tracking middleware.
2//!
3//! Wraps the request body to track upload progress and report it via a callback
4//! or through request extensions. Handlers can access the progress tracker to
5//! monitor bytes received.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use tako::middleware::upload_progress::UploadProgress;
11//! use tako::middleware::IntoMiddleware;
12//!
13//! // With callback
14//! let progress = UploadProgress::new()
15//!     .on_progress(|state| {
16//!         println!("{}% ({}/{})",
17//!             state.percent().unwrap_or(0),
18//!             state.bytes_read,
19//!             state.total_bytes.unwrap_or(0),
20//!         );
21//!     });
22//! let mw = progress.into_middleware();
23//! ```
24
25use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::atomic::AtomicU64;
29use std::sync::atomic::Ordering;
30use std::task::Context;
31use std::task::Poll;
32
33use bytes::Bytes;
34use http_body::Body;
35use http_body::Frame;
36use http_body::SizeHint;
37use parking_lot::Mutex;
38use pin_project_lite::pin_project;
39use tako_rs_core::body::TakoBody;
40use tako_rs_core::middleware::IntoMiddleware;
41use tako_rs_core::middleware::Next;
42use tako_rs_core::types::BoxError;
43use tako_rs_core::types::Request;
44use tako_rs_core::types::Response;
45
46/// Upload progress state accessible during and after upload.
47#[derive(Debug, Clone)]
48pub struct ProgressState {
49  /// Number of bytes read so far.
50  pub bytes_read: u64,
51  /// Total expected bytes (from Content-Length), if known.
52  pub total_bytes: Option<u64>,
53}
54
55impl ProgressState {
56  /// Returns the upload percentage (0-100), if total is known.
57  pub fn percent(&self) -> Option<u8> {
58    self.total_bytes.map(|total| {
59      if total == 0 {
60        100
61      } else {
62        ((self.bytes_read as f64 / total as f64) * 100.0).min(100.0) as u8
63      }
64    })
65  }
66}
67
68/// Shared progress tracker inserted into request extensions.
69///
70/// Handlers can access this to check current upload progress.
71#[derive(Clone)]
72pub struct ProgressTracker {
73  bytes_read: Arc<AtomicU64>,
74  total_bytes: Option<u64>,
75}
76
77impl ProgressTracker {
78  /// Returns the current progress state.
79  pub fn state(&self) -> ProgressState {
80    ProgressState {
81      bytes_read: self.bytes_read.load(Ordering::Relaxed),
82      total_bytes: self.total_bytes,
83    }
84  }
85
86  /// Returns the number of bytes read so far.
87  pub fn bytes_read(&self) -> u64 {
88    self.bytes_read.load(Ordering::Relaxed)
89  }
90
91  /// Returns the total expected bytes, if known.
92  pub fn total_bytes(&self) -> Option<u64> {
93    self.total_bytes
94  }
95
96  /// Returns the upload percentage (0-100), if total is known.
97  pub fn percent(&self) -> Option<u8> {
98    self.state().percent()
99  }
100}
101
102/// Upload progress middleware configuration.
103///
104/// # Examples
105///
106/// ```rust
107/// use tako::middleware::upload_progress::UploadProgress;
108/// use tako::middleware::IntoMiddleware;
109///
110/// // Simple progress tracking (access via ProgressTracker in extensions)
111/// let progress = UploadProgress::new();
112///
113/// // With progress callback
114/// let progress = UploadProgress::new()
115///     .on_progress(|state| {
116///         if let Some(pct) = state.percent() {
117///             println!("Upload: {pct}%");
118///         }
119///     })
120///     .min_notify_interval_bytes(8192); // notify at most every 8KB
121/// ```
122pub struct UploadProgress {
123  callback: Option<Arc<dyn Fn(ProgressState) + Send + Sync + 'static>>,
124  min_notify_interval: u64,
125}
126
127impl Default for UploadProgress {
128  fn default() -> Self {
129    Self::new()
130  }
131}
132
133impl UploadProgress {
134  /// Creates a new upload progress middleware.
135  pub fn new() -> Self {
136    Self {
137      callback: None,
138      min_notify_interval: 0,
139    }
140  }
141
142  /// Sets a callback that is called as bytes are received.
143  pub fn on_progress<F>(mut self, f: F) -> Self
144  where
145    F: Fn(ProgressState) + Send + Sync + 'static,
146  {
147    self.callback = Some(Arc::new(f));
148    self
149  }
150
151  /// Sets the minimum byte interval between progress notifications.
152  ///
153  /// This prevents the callback from being called too frequently for
154  /// large uploads. Default is 0 (notify on every chunk).
155  pub fn min_notify_interval_bytes(mut self, bytes: u64) -> Self {
156    self.min_notify_interval = bytes;
157    self
158  }
159}
160
161pin_project! {
162  /// Body wrapper that tracks bytes read frame-by-frame without buffering.
163  ///
164  /// Increments the shared counter as each data frame flows through and fires
165  /// the optional callback when the configured byte interval is exceeded. Errors
166  /// and end-of-stream are forwarded transparently.
167  struct ProgressBody<B> {
168    #[pin]
169    inner: B,
170    bytes_read: Arc<AtomicU64>,
171    total_bytes: Option<u64>,
172    last_notified_at: u64,
173    min_interval: u64,
174    callback: Option<Arc<dyn Fn(ProgressState) + Send + Sync + 'static>>,
175    final_notified: Arc<Mutex<bool>>,
176  }
177}
178
179impl<B> Body for ProgressBody<B>
180where
181  B: Body<Data = Bytes>,
182  B::Error: Into<BoxError>,
183{
184  type Data = Bytes;
185  type Error = BoxError;
186
187  fn poll_frame(
188    self: Pin<&mut Self>,
189    cx: &mut Context<'_>,
190  ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
191    let this = self.project();
192    match this.inner.poll_frame(cx) {
193      Poll::Ready(Some(Ok(frame))) => {
194        if let Some(data) = frame.data_ref() {
195          let added = data.len() as u64;
196          let total = this.bytes_read.fetch_add(added, Ordering::Relaxed) + added;
197          if let Some(cb) = this.callback.as_ref()
198            && (*this.min_interval == 0 || total - *this.last_notified_at >= *this.min_interval)
199          {
200            *this.last_notified_at = total;
201            cb(ProgressState {
202              bytes_read: total,
203              total_bytes: *this.total_bytes,
204            });
205          }
206        }
207        Poll::Ready(Some(Ok(frame)))
208      }
209      Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
210      Poll::Ready(None) => {
211        // Fire a final callback exactly once when the body ends, so callers see
212        // the closing total even if the last interval did not trigger a notify.
213        // Empty uploads (CL=0) used to slip through here because
214        // `bytes_read == last_notified_at == 0` skipped the call. Fire it
215        // unconditionally on EOF so callers always observe a terminal event.
216        if let Some(cb) = this.callback.as_ref() {
217          let mut guard = this.final_notified.lock();
218          if !*guard {
219            *guard = true;
220            let final_read = this.bytes_read.load(Ordering::Relaxed);
221            cb(ProgressState {
222              bytes_read: final_read,
223              total_bytes: *this.total_bytes,
224            });
225            *this.last_notified_at = final_read;
226          }
227        }
228        Poll::Ready(None)
229      }
230      Poll::Pending => Poll::Pending,
231    }
232  }
233
234  fn is_end_stream(&self) -> bool {
235    self.inner.is_end_stream()
236  }
237
238  fn size_hint(&self) -> SizeHint {
239    self.inner.size_hint()
240  }
241}
242
243impl IntoMiddleware for UploadProgress {
244  fn into_middleware(
245    self,
246  ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
247  + Clone
248  + Send
249  + Sync
250  + 'static {
251    let callback = self.callback;
252    let min_interval = self.min_notify_interval;
253
254    move |mut req: Request, next: Next| {
255      let callback = callback.clone();
256
257      Box::pin(async move {
258        // Extract total from Content-Length header
259        let total_bytes = req
260          .headers()
261          .get(http::header::CONTENT_LENGTH)
262          .and_then(|v| v.to_str().ok())
263          .and_then(|s| s.parse::<u64>().ok());
264
265        let bytes_read = Arc::new(AtomicU64::new(0));
266
267        // Insert tracker into extensions for handler access
268        let tracker = ProgressTracker {
269          bytes_read: Arc::clone(&bytes_read),
270          total_bytes,
271        };
272        req.extensions_mut().insert(tracker);
273
274        // Wrap the body in a streaming progress tracker — no buffering.
275        let (parts, body) = req.into_parts();
276        let progress_body = ProgressBody {
277          inner: body,
278          bytes_read,
279          total_bytes,
280          last_notified_at: 0,
281          min_interval,
282          callback,
283          final_notified: Arc::new(Mutex::new(false)),
284        };
285        let req = http::Request::from_parts(parts, TakoBody::new(progress_body));
286
287        next.run(req).await
288      })
289    }
290  }
291}