tako_rs_plugins/middleware/
upload_progress.rs1use std::future::Future;
26use std::pin::Pin;
27use std::sync::Arc;
28use std::sync::atomic::AtomicU64;
29use std::sync::atomic::Ordering;
30use std::task::Context;
31use std::task::Poll;
32
33use bytes::Bytes;
34use http_body::Body;
35use http_body::Frame;
36use http_body::SizeHint;
37use parking_lot::Mutex;
38use pin_project_lite::pin_project;
39use tako_rs_core::body::TakoBody;
40use tako_rs_core::middleware::IntoMiddleware;
41use tako_rs_core::middleware::Next;
42use tako_rs_core::types::BoxError;
43use tako_rs_core::types::Request;
44use tako_rs_core::types::Response;
45
46#[derive(Debug, Clone)]
48pub struct ProgressState {
49 pub bytes_read: u64,
51 pub total_bytes: Option<u64>,
53}
54
55impl ProgressState {
56 pub fn percent(&self) -> Option<u8> {
58 self.total_bytes.map(|total| {
59 if total == 0 {
60 100
61 } else {
62 ((self.bytes_read as f64 / total as f64) * 100.0).min(100.0) as u8
63 }
64 })
65 }
66}
67
68#[derive(Clone)]
72pub struct ProgressTracker {
73 bytes_read: Arc<AtomicU64>,
74 total_bytes: Option<u64>,
75}
76
77impl ProgressTracker {
78 pub fn state(&self) -> ProgressState {
80 ProgressState {
81 bytes_read: self.bytes_read.load(Ordering::Relaxed),
82 total_bytes: self.total_bytes,
83 }
84 }
85
86 pub fn bytes_read(&self) -> u64 {
88 self.bytes_read.load(Ordering::Relaxed)
89 }
90
91 pub fn total_bytes(&self) -> Option<u64> {
93 self.total_bytes
94 }
95
96 pub fn percent(&self) -> Option<u8> {
98 self.state().percent()
99 }
100}
101
102pub struct UploadProgress {
123 callback: Option<Arc<dyn Fn(ProgressState) + Send + Sync + 'static>>,
124 min_notify_interval: u64,
125}
126
127impl Default for UploadProgress {
128 fn default() -> Self {
129 Self::new()
130 }
131}
132
133impl UploadProgress {
134 pub fn new() -> Self {
136 Self {
137 callback: None,
138 min_notify_interval: 0,
139 }
140 }
141
142 pub fn on_progress<F>(mut self, f: F) -> Self
144 where
145 F: Fn(ProgressState) + Send + Sync + 'static,
146 {
147 self.callback = Some(Arc::new(f));
148 self
149 }
150
151 pub fn min_notify_interval_bytes(mut self, bytes: u64) -> Self {
156 self.min_notify_interval = bytes;
157 self
158 }
159}
160
161pin_project! {
162 struct ProgressBody<B> {
168 #[pin]
169 inner: B,
170 bytes_read: Arc<AtomicU64>,
171 total_bytes: Option<u64>,
172 last_notified_at: u64,
173 min_interval: u64,
174 callback: Option<Arc<dyn Fn(ProgressState) + Send + Sync + 'static>>,
175 final_notified: Arc<Mutex<bool>>,
176 }
177}
178
179impl<B> Body for ProgressBody<B>
180where
181 B: Body<Data = Bytes>,
182 B::Error: Into<BoxError>,
183{
184 type Data = Bytes;
185 type Error = BoxError;
186
187 fn poll_frame(
188 self: Pin<&mut Self>,
189 cx: &mut Context<'_>,
190 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
191 let this = self.project();
192 match this.inner.poll_frame(cx) {
193 Poll::Ready(Some(Ok(frame))) => {
194 if let Some(data) = frame.data_ref() {
195 let added = data.len() as u64;
196 let total = this.bytes_read.fetch_add(added, Ordering::Relaxed) + added;
197 if let Some(cb) = this.callback.as_ref()
198 && (*this.min_interval == 0 || total - *this.last_notified_at >= *this.min_interval)
199 {
200 *this.last_notified_at = total;
201 cb(ProgressState {
202 bytes_read: total,
203 total_bytes: *this.total_bytes,
204 });
205 }
206 }
207 Poll::Ready(Some(Ok(frame)))
208 }
209 Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e.into()))),
210 Poll::Ready(None) => {
211 if let Some(cb) = this.callback.as_ref() {
217 let mut guard = this.final_notified.lock();
218 if !*guard {
219 *guard = true;
220 let final_read = this.bytes_read.load(Ordering::Relaxed);
221 cb(ProgressState {
222 bytes_read: final_read,
223 total_bytes: *this.total_bytes,
224 });
225 *this.last_notified_at = final_read;
226 }
227 }
228 Poll::Ready(None)
229 }
230 Poll::Pending => Poll::Pending,
231 }
232 }
233
234 fn is_end_stream(&self) -> bool {
235 self.inner.is_end_stream()
236 }
237
238 fn size_hint(&self) -> SizeHint {
239 self.inner.size_hint()
240 }
241}
242
243impl IntoMiddleware for UploadProgress {
244 fn into_middleware(
245 self,
246 ) -> impl Fn(Request, Next) -> Pin<Box<dyn Future<Output = Response> + Send + 'static>>
247 + Clone
248 + Send
249 + Sync
250 + 'static {
251 let callback = self.callback;
252 let min_interval = self.min_notify_interval;
253
254 move |mut req: Request, next: Next| {
255 let callback = callback.clone();
256
257 Box::pin(async move {
258 let total_bytes = req
260 .headers()
261 .get(http::header::CONTENT_LENGTH)
262 .and_then(|v| v.to_str().ok())
263 .and_then(|s| s.parse::<u64>().ok());
264
265 let bytes_read = Arc::new(AtomicU64::new(0));
266
267 let tracker = ProgressTracker {
269 bytes_read: Arc::clone(&bytes_read),
270 total_bytes,
271 };
272 req.extensions_mut().insert(tracker);
273
274 let (parts, body) = req.into_parts();
276 let progress_body = ProgressBody {
277 inner: body,
278 bytes_read,
279 total_bytes,
280 last_notified_at: 0,
281 min_interval,
282 callback,
283 final_notified: Arc::new(Mutex::new(false)),
284 };
285 let req = http::Request::from_parts(parts, TakoBody::new(progress_body));
286
287 next.run(req).await
288 })
289 }
290 }
291}