tower_batch/
worker.rs

1use std::{
2    future::Future,
3    mem,
4    ops::Add,
5    pin::Pin,
6    sync::{Arc, Mutex, Weak},
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use futures_core::ready;
12use tokio::{
13    sync::{mpsc, Semaphore},
14    time::{sleep_until, Sleep},
15};
16use tower::Service;
17use tracing::{debug, trace};
18
19use super::{
20    error::{Closed, ServiceError},
21    message::{Message, Tx},
22    BatchControl,
23};
24
25/// Get the error out
26#[derive(Debug)]
27pub(crate) struct Handle {
28    inner: Arc<Mutex<Option<ServiceError>>>,
29}
30
31/// Wrap `Service` channel for easier use through projections.
32#[derive(Debug)]
33struct Bridge<Fut, Request> {
34    rx: mpsc::UnboundedReceiver<Message<Request, Fut>>,
35    handle: Handle,
36    current_message: Option<Message<Request, Fut>>,
37    close: Option<Weak<Semaphore>>,
38    failed: Option<ServiceError>,
39}
40
41#[derive(Debug)]
42struct Lot<Fut> {
43    max_size: usize,
44    max_time: Duration,
45    responses: Vec<(Tx<Fut>, Result<Fut, ServiceError>)>,
46    time_elapses: Option<Pin<Box<Sleep>>>,
47    time_elapsed: bool,
48}
49
50pin_project_lite::pin_project! {
51    #[project = StateProj]
52    #[derive(Debug)]
53    enum State<Fut> {
54        Collecting,
55        Flushing {
56            reason: Option<String>,
57            #[pin]
58            flush_fut: Option<Fut>,
59        },
60        Finished
61    }
62}
63
64pin_project_lite::pin_project! {
65    /// Task that handles processing the buffer. This type should not be used
66    /// directly, instead `Batch` requires an `Executor` that can accept this task.
67    ///
68    /// The struct is `pub` in the private module and the type is *not* re-exported
69    /// as part of the public API. This is the "sealed" pattern to include "private"
70    /// types in public traits that are not meant for consumers of the library to
71    /// implement (only call).
72    #[derive(Debug)]
73    pub struct Worker<T, Request>
74    where
75        T: Service<BatchControl<Request>>,
76        T::Error: Into<crate::BoxError>,
77    {
78        service: T,
79        bridge: Bridge<T::Future, Request>,
80        lot: Lot<T::Future>,
81        #[pin]
82        state: State<T::Future>,
83    }
84}
85
86// ===== impl Worker =====
87
88impl<T, Request> Worker<T, Request>
89where
90    T: Service<BatchControl<Request>>,
91    T::Error: Into<crate::BoxError>,
92{
93    pub(crate) fn new(
94        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
95        service: T,
96        max_size: usize,
97        max_time: Duration,
98        semaphore: &Arc<Semaphore>,
99    ) -> (Handle, Worker<T, Request>) {
100        trace!("creating Batch worker");
101
102        let handle = Handle {
103            inner: Arc::new(Mutex::new(None)),
104        };
105
106        // The service and worker have a parent - child relationship, so we must
107        // downgrade the Arc to Weak, to ensure a cycle between Arc pointers will
108        // never be deallocated.
109        let semaphore = Arc::downgrade(semaphore);
110        let worker = Self {
111            service,
112            bridge: Bridge {
113                rx,
114                current_message: None,
115                handle: handle.clone(),
116                close: Some(semaphore),
117                failed: None,
118            },
119            lot: Lot::new(max_size, max_time),
120            state: State::Collecting,
121        };
122
123        (handle, worker)
124    }
125}
126
127impl<T, Request> Future for Worker<T, Request>
128where
129    T: Service<BatchControl<Request>>,
130    T::Error: Into<crate::BoxError>,
131{
132    type Output = ();
133
134    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
135        trace!("polling worker");
136
137        let mut this = self.project();
138
139        // Flush if the max wait time is reached.
140        if let Poll::Ready(Some(())) = this.lot.poll_max_time(cx) {
141            this.state.set(State::flushing("time".to_owned(), None))
142        }
143
144        loop {
145            match this.state.as_mut().project() {
146                StateProj::Collecting => {
147                    match ready!(this.bridge.poll_next_msg(cx)) {
148                        Some((msg, first)) => {
149                            let _guard = msg.span.enter();
150
151                            trace!(resumed = !first, message = "worker received request");
152
153                            // Wait for the service to be ready
154                            trace!(message = "waiting for service readiness");
155                            match this.service.poll_ready(cx) {
156                                Poll::Ready(Ok(())) => {
157                                    debug!(service.ready = true, message = "adding item");
158
159                                    let response = this.service.call(msg.request.into());
160                                    this.lot.add((msg.tx, Ok(response)));
161
162                                    // Flush if the batch is full.
163                                    if this.lot.is_full() {
164                                        this.state.set(State::flushing("size".to_owned(), None));
165                                    }
166
167                                    // Or flush if the max time has elapsed.
168                                    if this.lot.poll_max_time(cx).is_ready() {
169                                        this.state.set(State::flushing("time".to_owned(), None));
170                                    }
171                                }
172                                Poll::Pending => {
173                                    drop(_guard);
174                                    debug!(service.ready = false, message = "delay item addition");
175                                    this.bridge.return_msg(msg);
176                                    return Poll::Pending;
177                                }
178                                Poll::Ready(Err(e)) => {
179                                    drop(_guard);
180                                    this.bridge.failed("item addition", e.into());
181                                    if let Some(ref e) = this.bridge.failed {
182                                        // Ensure the current caller is notified too.
183                                        this.lot.add((msg.tx, Err(e.clone())));
184                                        this.lot.notify(Some(e.clone()));
185                                    }
186                                }
187                            }
188                        }
189                        None => {
190                            trace!("shutting down, no more requests _ever_");
191                            this.state.set(State::Finished);
192                            return Poll::Ready(());
193                        }
194                    }
195                }
196                StateProj::Flushing { reason, flush_fut } => match flush_fut.as_pin_mut() {
197                    None => {
198                        trace!(
199                            reason = reason.as_mut().unwrap().as_str(),
200                            message = "waiting for service readiness"
201                        );
202                        match this.service.poll_ready(cx) {
203                            Poll::Ready(Ok(())) => {
204                                debug!(
205                                    service.ready = true,
206                                    reason = reason.as_mut().unwrap().as_str(),
207                                    message = "flushing batch"
208                                );
209                                let response = this.service.call(BatchControl::Flush);
210                                let reason = reason.take().expect("missing reason");
211                                this.state.set(State::flushing(reason, Some(response)));
212                            }
213                            Poll::Pending => {
214                                debug!(
215                                    service.ready = false,
216                                    reason = reason.as_mut().unwrap().as_str(),
217                                    message = "delay flush"
218                                );
219                                return Poll::Pending;
220                            }
221                            Poll::Ready(Err(e)) => {
222                                this.bridge.failed("flush", e.into());
223                                if let Some(ref e) = this.bridge.failed {
224                                    this.lot.notify(Some(e.clone()));
225                                }
226                            }
227                        }
228                    }
229                    Some(future) => {
230                        match ready!(future.poll(cx)) {
231                            Ok(_) => {
232                                debug!(reason = reason.as_mut().unwrap().as_str(), "batch flushed");
233                                this.lot.notify(None);
234                                this.state.set(State::Collecting)
235                            },
236                            Err(e) => {
237                                this.bridge.failed("flush", e.into());
238                                if let Some(ref e) = this.bridge.failed {
239                                    this.lot.notify(Some(e.clone()));
240                                }
241                                this.state.set(State::Finished);
242                                return Poll::Ready(());
243                            }
244                        }
245                    }
246                },
247                StateProj::Finished => {
248                    // We've already received None and are shutting down
249                    return Poll::Ready(());
250                }
251            }
252        }
253    }
254}
255
256// ===== impl State =====
257
258impl<Fut> State<Fut> {
259    fn flushing(reason: String, f: Option<Fut>) -> Self {
260        Self::Flushing {
261            reason: Some(reason),
262            flush_fut: f,
263        }
264    }
265}
266
267// ===== impl Bridge =====
268
269impl<Fut, Request> Drop for Bridge<Fut, Request> {
270    fn drop(&mut self) {
271        self.close_semaphore()
272    }
273}
274
275impl<Fut, Request> Bridge<Fut, Request> {
276    /// Closes the buffer's semaphore if it is still open, waking any pending tasks.
277    fn close_semaphore(&mut self) {
278        if let Some(close) = self
279            .close
280            .take()
281            .as_ref()
282            .and_then(Weak::<Semaphore>::upgrade)
283        {
284            debug!("buffer closing; waking pending tasks");
285            close.close();
286        } else {
287            trace!("buffer already closed");
288        }
289    }
290
291    fn failed(&mut self, action: &str, error: crate::BoxError) {
292        debug!(action,  %error , "service failed");
293
294        // The underlying service failed when we called `poll_ready` on it with the given `error`.
295        // We need to communicate this to all the `Buffer` handles. To do so, we wrap up the error
296        // in an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
297        // requests will also fail with the same error.
298
299        // Note that we need to handle the case where some handle is concurrently trying to send us
300        // a request. We need to make sure that *either* the send of the request fails *or* it
301        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
302        // case where we send errors to all outstanding requests, and *then* the caller sends its
303        // request. We do this by *first* exposing the error, *then* closing the channel used to
304        // send more requests (so the client will see the error when the send fails), and *then*
305        // sending the error to all outstanding requests.
306        let error = ServiceError::new(error);
307
308        let mut inner = self.handle.inner.lock().unwrap();
309
310        if inner.is_some() {
311            // Future::poll was called after we've already errored out!
312            return;
313        }
314
315        *inner = Some(error.clone());
316        drop(inner);
317
318        self.rx.close();
319
320        // Wake any tasks waiting on channel capacity.
321        self.close_semaphore();
322
323        // By closing the mpsc::Receiver, we know that that the run() loop will drain all pending
324        // requests. We just need to make sure that any requests that we receive before we've
325        // exhausted the receiver receive the error:
326        self.failed = Some(error);
327    }
328
329    /// Return the next queued Message that hasn't been canceled.
330    ///
331    /// If a `Message` is returned, the `bool` is true if this is the first time we received this
332    /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
333    fn poll_next_msg(
334        &mut self,
335        cx: &mut Context<'_>,
336    ) -> Poll<Option<(Message<Request, Fut>, bool)>> {
337        trace!("worker polling for next message");
338
339        // Pick any delayed request first
340        if let Some(msg) = self.current_message.take() {
341            // If the oneshot sender is closed, then the receiver is dropped, and nobody cares about
342            // the response. If this is the case, we should continue to the next request.
343            if !msg.tx.is_closed() {
344                trace!("resuming buffered request");
345                return Poll::Ready(Some((msg, false)));
346            }
347
348            trace!("dropping cancelled buffered request");
349        }
350
351        // Get the next request
352        while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
353            if !msg.tx.is_closed() {
354                trace!("processing new request");
355                return Poll::Ready(Some((msg, true)));
356            }
357
358            // Otherwise, request is canceled, so pop the next one.
359            trace!("dropping cancelled request");
360        }
361
362        Poll::Ready(None)
363    }
364
365    fn return_msg(&mut self, msg: Message<Request, Fut>) {
366        self.current_message = Some(msg)
367    }
368}
369
370// ===== impl Lot =====
371
372impl<Fut> Lot<Fut> {
373    fn new(max_size: usize, max_time: Duration) -> Self {
374        Self {
375            max_size,
376            max_time,
377            responses: Vec::with_capacity(max_size),
378            time_elapses: None,
379            time_elapsed: false,
380        }
381    }
382
383    fn poll_max_time(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
384        // When the Worker is polled and the time has elapsed, we return `Some` to let the Worker
385        // know it's time to enter the Flushing state. Subsequent polls (e.g. by the Flush future)
386        // will return None to prevent the Worker from getting stuck in an endless loop of entering
387        // the Flushing state.
388        if self.time_elapsed {
389            return Poll::Ready(None);
390        }
391
392        if let Some(ref mut sleep) = self.time_elapses {
393            if Pin::new(sleep).poll(cx).is_ready() {
394                self.time_elapsed = true;
395                return Poll::Ready(Some(()));
396            }
397        }
398
399        Poll::Pending
400    }
401
402    fn is_full(&self) -> bool {
403        self.responses.len() == self.max_size
404    }
405
406    fn add(&mut self, item: (Tx<Fut>, Result<Fut, ServiceError>)) {
407        if self.responses.is_empty() {
408            self.time_elapses = Some(Box::pin(sleep_until(
409                Instant::now().add(self.max_time).into(),
410            )));
411        }
412        self.responses.push(item);
413    }
414
415    fn notify(&mut self, err: Option<ServiceError>) {
416        for (tx, response) in mem::replace(&mut self.responses, Vec::with_capacity(self.max_size)) {
417            if let Some(ref response) = err {
418                let _ = tx.send(Err(response.clone()));
419            } else {
420                let _ = tx.send(response);
421            }
422        }
423        self.time_elapses = None;
424        self.time_elapsed = false;
425    }
426}
427
428// ===== impl Handle =====
429
430impl Handle {
431    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
432        self.inner
433            .lock()
434            .unwrap()
435            .as_ref()
436            .map(|svc_err| svc_err.clone().into())
437            .unwrap_or_else(|| Closed::new().into())
438    }
439}
440
441impl Clone for Handle {
442    fn clone(&self) -> Self {
443        Handle {
444            inner: self.inner.clone(),
445        }
446    }
447}