1use std::{
2 future::Future,
3 mem,
4 ops::Add,
5 pin::Pin,
6 sync::{Arc, Mutex, Weak},
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use futures_core::ready;
12use tokio::{
13 sync::{mpsc, Semaphore},
14 time::{sleep_until, Sleep},
15};
16use tower::Service;
17use tracing::{debug, trace};
18
19use super::{
20 error::{Closed, ServiceError},
21 message::{Message, Tx},
22 BatchControl,
23};
24
25#[derive(Debug)]
27pub(crate) struct Handle {
28 inner: Arc<Mutex<Option<ServiceError>>>,
29}
30
31#[derive(Debug)]
33struct Bridge<Fut, Request> {
34 rx: mpsc::UnboundedReceiver<Message<Request, Fut>>,
35 handle: Handle,
36 current_message: Option<Message<Request, Fut>>,
37 close: Option<Weak<Semaphore>>,
38 failed: Option<ServiceError>,
39}
40
41#[derive(Debug)]
42struct Lot<Fut> {
43 max_size: usize,
44 max_time: Duration,
45 responses: Vec<(Tx<Fut>, Result<Fut, ServiceError>)>,
46 time_elapses: Option<Pin<Box<Sleep>>>,
47 time_elapsed: bool,
48}
49
50pin_project_lite::pin_project! {
51 #[project = StateProj]
52 #[derive(Debug)]
53 enum State<Fut> {
54 Collecting,
55 Flushing {
56 reason: Option<String>,
57 #[pin]
58 flush_fut: Option<Fut>,
59 },
60 Finished
61 }
62}
63
64pin_project_lite::pin_project! {
65 #[derive(Debug)]
73 pub struct Worker<T, Request>
74 where
75 T: Service<BatchControl<Request>>,
76 T::Error: Into<crate::BoxError>,
77 {
78 service: T,
79 bridge: Bridge<T::Future, Request>,
80 lot: Lot<T::Future>,
81 #[pin]
82 state: State<T::Future>,
83 }
84}
85
86impl<T, Request> Worker<T, Request>
89where
90 T: Service<BatchControl<Request>>,
91 T::Error: Into<crate::BoxError>,
92{
93 pub(crate) fn new(
94 rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
95 service: T,
96 max_size: usize,
97 max_time: Duration,
98 semaphore: &Arc<Semaphore>,
99 ) -> (Handle, Worker<T, Request>) {
100 trace!("creating Batch worker");
101
102 let handle = Handle {
103 inner: Arc::new(Mutex::new(None)),
104 };
105
106 let semaphore = Arc::downgrade(semaphore);
110 let worker = Self {
111 service,
112 bridge: Bridge {
113 rx,
114 current_message: None,
115 handle: handle.clone(),
116 close: Some(semaphore),
117 failed: None,
118 },
119 lot: Lot::new(max_size, max_time),
120 state: State::Collecting,
121 };
122
123 (handle, worker)
124 }
125}
126
127impl<T, Request> Future for Worker<T, Request>
128where
129 T: Service<BatchControl<Request>>,
130 T::Error: Into<crate::BoxError>,
131{
132 type Output = ();
133
134 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135 trace!("polling worker");
136
137 let mut this = self.project();
138
139 if let Poll::Ready(Some(())) = this.lot.poll_max_time(cx) {
141 this.state.set(State::flushing("time".to_owned(), None))
142 }
143
144 loop {
145 match this.state.as_mut().project() {
146 StateProj::Collecting => {
147 match ready!(this.bridge.poll_next_msg(cx)) {
148 Some((msg, first)) => {
149 let _guard = msg.span.enter();
150
151 trace!(resumed = !first, message = "worker received request");
152
153 trace!(message = "waiting for service readiness");
155 match this.service.poll_ready(cx) {
156 Poll::Ready(Ok(())) => {
157 debug!(service.ready = true, message = "adding item");
158
159 let response = this.service.call(msg.request.into());
160 this.lot.add((msg.tx, Ok(response)));
161
162 if this.lot.is_full() {
164 this.state.set(State::flushing("size".to_owned(), None));
165 }
166
167 if this.lot.poll_max_time(cx).is_ready() {
169 this.state.set(State::flushing("time".to_owned(), None));
170 }
171 }
172 Poll::Pending => {
173 drop(_guard);
174 debug!(service.ready = false, message = "delay item addition");
175 this.bridge.return_msg(msg);
176 return Poll::Pending;
177 }
178 Poll::Ready(Err(e)) => {
179 drop(_guard);
180 this.bridge.failed("item addition", e.into());
181 if let Some(ref e) = this.bridge.failed {
182 this.lot.add((msg.tx, Err(e.clone())));
184 this.lot.notify(Some(e.clone()));
185 }
186 }
187 }
188 }
189 None => {
190 trace!("shutting down, no more requests _ever_");
191 this.state.set(State::Finished);
192 return Poll::Ready(());
193 }
194 }
195 }
196 StateProj::Flushing { reason, flush_fut } => match flush_fut.as_pin_mut() {
197 None => {
198 trace!(
199 reason = reason.as_mut().unwrap().as_str(),
200 message = "waiting for service readiness"
201 );
202 match this.service.poll_ready(cx) {
203 Poll::Ready(Ok(())) => {
204 debug!(
205 service.ready = true,
206 reason = reason.as_mut().unwrap().as_str(),
207 message = "flushing batch"
208 );
209 let response = this.service.call(BatchControl::Flush);
210 let reason = reason.take().expect("missing reason");
211 this.state.set(State::flushing(reason, Some(response)));
212 }
213 Poll::Pending => {
214 debug!(
215 service.ready = false,
216 reason = reason.as_mut().unwrap().as_str(),
217 message = "delay flush"
218 );
219 return Poll::Pending;
220 }
221 Poll::Ready(Err(e)) => {
222 this.bridge.failed("flush", e.into());
223 if let Some(ref e) = this.bridge.failed {
224 this.lot.notify(Some(e.clone()));
225 }
226 }
227 }
228 }
229 Some(future) => {
230 match ready!(future.poll(cx)) {
231 Ok(_) => {
232 debug!(reason = reason.as_mut().unwrap().as_str(), "batch flushed");
233 this.lot.notify(None);
234 this.state.set(State::Collecting)
235 },
236 Err(e) => {
237 this.bridge.failed("flush", e.into());
238 if let Some(ref e) = this.bridge.failed {
239 this.lot.notify(Some(e.clone()));
240 }
241 this.state.set(State::Finished);
242 return Poll::Ready(());
243 }
244 }
245 }
246 },
247 StateProj::Finished => {
248 return Poll::Ready(());
250 }
251 }
252 }
253 }
254}
255
256impl<Fut> State<Fut> {
259 fn flushing(reason: String, f: Option<Fut>) -> Self {
260 Self::Flushing {
261 reason: Some(reason),
262 flush_fut: f,
263 }
264 }
265}
266
267impl<Fut, Request> Drop for Bridge<Fut, Request> {
270 fn drop(&mut self) {
271 self.close_semaphore()
272 }
273}
274
275impl<Fut, Request> Bridge<Fut, Request> {
276 fn close_semaphore(&mut self) {
278 if let Some(close) = self
279 .close
280 .take()
281 .as_ref()
282 .and_then(Weak::<Semaphore>::upgrade)
283 {
284 debug!("buffer closing; waking pending tasks");
285 close.close();
286 } else {
287 trace!("buffer already closed");
288 }
289 }
290
291 fn failed(&mut self, action: &str, error: crate::BoxError) {
292 debug!(action, %error , "service failed");
293
294 let error = ServiceError::new(error);
307
308 let mut inner = self.handle.inner.lock().unwrap();
309
310 if inner.is_some() {
311 return;
313 }
314
315 *inner = Some(error.clone());
316 drop(inner);
317
318 self.rx.close();
319
320 self.close_semaphore();
322
323 self.failed = Some(error);
327 }
328
329 fn poll_next_msg(
334 &mut self,
335 cx: &mut Context<'_>,
336 ) -> Poll<Option<(Message<Request, Fut>, bool)>> {
337 trace!("worker polling for next message");
338
339 if let Some(msg) = self.current_message.take() {
341 if !msg.tx.is_closed() {
344 trace!("resuming buffered request");
345 return Poll::Ready(Some((msg, false)));
346 }
347
348 trace!("dropping cancelled buffered request");
349 }
350
351 while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
353 if !msg.tx.is_closed() {
354 trace!("processing new request");
355 return Poll::Ready(Some((msg, true)));
356 }
357
358 trace!("dropping cancelled request");
360 }
361
362 Poll::Ready(None)
363 }
364
365 fn return_msg(&mut self, msg: Message<Request, Fut>) {
366 self.current_message = Some(msg)
367 }
368}
369
370impl<Fut> Lot<Fut> {
373 fn new(max_size: usize, max_time: Duration) -> Self {
374 Self {
375 max_size,
376 max_time,
377 responses: Vec::with_capacity(max_size),
378 time_elapses: None,
379 time_elapsed: false,
380 }
381 }
382
383 fn poll_max_time(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
384 if self.time_elapsed {
389 return Poll::Ready(None);
390 }
391
392 if let Some(ref mut sleep) = self.time_elapses {
393 if Pin::new(sleep).poll(cx).is_ready() {
394 self.time_elapsed = true;
395 return Poll::Ready(Some(()));
396 }
397 }
398
399 Poll::Pending
400 }
401
402 fn is_full(&self) -> bool {
403 self.responses.len() == self.max_size
404 }
405
406 fn add(&mut self, item: (Tx<Fut>, Result<Fut, ServiceError>)) {
407 if self.responses.is_empty() {
408 self.time_elapses = Some(Box::pin(sleep_until(
409 Instant::now().add(self.max_time).into(),
410 )));
411 }
412 self.responses.push(item);
413 }
414
415 fn notify(&mut self, err: Option<ServiceError>) {
416 for (tx, response) in mem::replace(&mut self.responses, Vec::with_capacity(self.max_size)) {
417 if let Some(ref response) = err {
418 let _ = tx.send(Err(response.clone()));
419 } else {
420 let _ = tx.send(response);
421 }
422 }
423 self.time_elapses = None;
424 self.time_elapsed = false;
425 }
426}
427
428impl Handle {
431 pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
432 self.inner
433 .lock()
434 .unwrap()
435 .as_ref()
436 .map(|svc_err| svc_err.clone().into())
437 .unwrap_or_else(|| Closed::new().into())
438 }
439}
440
441impl Clone for Handle {
442 fn clone(&self) -> Self {
443 Handle {
444 inner: self.inner.clone(),
445 }
446 }
447}