tower_batch/
service.rs

1use std::{
2    fmt::Debug,
3    sync::Arc,
4    task::{Context, Poll},
5};
6
7use futures_core::ready;
8use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
9use tokio_util::sync::PollSemaphore;
10use tower::Service;
11
12use super::{
13    future::ResponseFuture,
14    message::Message,
15    worker::{Handle, Worker},
16    BatchControl,
17};
18
19/// Allows batch processing of requests.
20///
21/// See the module documentation for more details.
22#[derive(Debug)]
23pub struct Batch<T, Request>
24where
25    T: Service<BatchControl<Request>>,
26{
27    // Note: this actually _is_ bounded, but rather than using Tokio's bounded
28    // channel, we use Tokio's semaphore separately to implement the bound.
29    tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
30
31    // When the buffer's channel is full, we want to exert backpressure in
32    // `poll_ready`, so that callers such as load balancers could choose to call
33    // another service rather than waiting for buffer capacity.
34    //
35    // Unfortunately, this can't be done easily using Tokio's bounded MPSC
36    // channel, because it doesn't expose a polling-based interface, only an
37    // `async fn ready`, which borrows the sender. Therefore, we implement our
38    // own bounded MPSC on top of the unbounded channel, using a semaphore to
39    // limit how many items are in the channel.
40    semaphore: PollSemaphore,
41
42    // The current semaphore permit, if one has been acquired.
43    //
44    // This is acquired in `poll_ready` and taken in `call`.
45    permit: Option<OwnedSemaphorePermit>,
46    handle: Handle,
47}
48
49impl<T, Request> Batch<T, Request>
50where
51    T: Service<BatchControl<Request>>,
52    T::Error: Into<crate::BoxError>,
53{
54    /// Creates a new `Batch` wrapping `service`.
55    ///
56    /// The wrapper is responsible for telling the inner service when to flush a
57    /// batch of requests.
58    ///
59    /// The default Tokio executor is used to run the given service, which means
60    /// that this method must be called while on the Tokio runtime.
61    pub fn new(service: T, size: usize, time: std::time::Duration) -> Self
62    where
63        T: Send + 'static,
64        T::Future: Send,
65        T::Error: Send + Sync,
66        Request: Send + 'static,
67    {
68        let (service, worker) = Self::pair(service, size, time);
69        tokio::spawn(worker);
70        service
71    }
72
73    /// Creates a new `Batch` wrapping `service`, but returns the background worker.
74    ///
75    /// This is useful if you do not want to spawn directly onto the `tokio`
76    /// runtime but instead want to use your own executor. This will return the
77    /// `Batch` and the background `Worker` that you can then spawn.
78    pub fn pair(service: T, size: usize, time: std::time::Duration) -> (Self, Worker<T, Request>)
79    where
80        T: Send + 'static,
81        T::Future: Send,
82        T::Error: Send + Sync,
83        Request: Send + 'static,
84    {
85        // The semaphore bound limits the maximum number of concurrent requests
86        // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
87        // used their semaphore reservation in a `call` yet).
88        // We choose a bound that allows callers to check readiness for every item in
89        // a batch, then actually submit those items.
90        let (tx, rx) = mpsc::unbounded_channel();
91        let bound = size;
92        let semaphore = Arc::new(Semaphore::new(bound));
93
94        let (handle, worker) = Worker::new(rx, service, size, time, &semaphore);
95
96        let batch = Self {
97            tx,
98            semaphore: PollSemaphore::new(semaphore),
99            permit: None,
100            handle,
101        };
102        (batch, worker)
103    }
104
105    fn get_worker_error(&self) -> crate::BoxError {
106        self.handle.get_error_on_closed()
107    }
108}
109
110impl<T, Request> Service<Request> for Batch<T, Request>
111where
112    T: Service<BatchControl<Request>>,
113    T::Error: Into<crate::BoxError>,
114{
115    // Our response is effectively the response of the service used by the Worker
116    type Response = T::Response;
117    type Error = crate::BoxError;
118    type Future = ResponseFuture<T::Future>;
119
120    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121        tracing::debug!("checking if service is ready");
122
123        // First, check if the worker is still alive.
124        if self.tx.is_closed() {
125            // If the inner service has errored, then we error here.
126            return Poll::Ready(Err(self.get_worker_error()));
127        }
128
129        // Then, check if we've already acquired a permit.
130        if self.permit.is_some() {
131            // We've already reserved capacity to send a request. We're ready!
132            return Poll::Ready(Ok(()));
133        }
134
135        // Finally, if we haven't already acquired a permit, poll the semaphore to acquire one. If
136        // we acquire a permit, then there's enough buffer capacity to send a new request.
137        // Otherwise, we need to wait for capacity.
138        //
139        // The current task must be scheduled for wakeup every time we return `Poll::Pending`. If
140        // it returns Pending, the semaphore also schedules the task for wakeup when the next permit
141        // is available.
142        let permit =
143            ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
144        self.permit = Some(permit);
145
146        Poll::Ready(Ok(()))
147    }
148
149    fn call(&mut self, request: Request) -> Self::Future {
150        tracing::debug!("sending request to batch worker");
151
152        let _permit = self
153            .permit
154            .take()
155            .expect("batch full; poll_ready must be called first");
156
157        // Get the current Span so that we can explicitly propagate it to the worker
158        // if we didn't do this, events on the worker related to this span wouldn't be counted
159        // towards that span since the worker would have no way of entering it.
160        let span = tracing::Span::current();
161
162        // If we've made it here, then a semaphore permit has already been acquired, so we can
163        // freely allocate a oneshot.
164        let (tx, rx) = oneshot::channel();
165
166        // The worker is in control of completing the request now.
167        match self.tx.send(Message {
168            request,
169            span,
170            tx,
171            _permit,
172        }) {
173            Err(_) => ResponseFuture::failed(self.get_worker_error()),
174            Ok(_) => ResponseFuture::new(rx),
175        }
176    }
177}
178
179impl<T, Request> Clone for Batch<T, Request>
180where
181    T: Service<BatchControl<Request>>,
182{
183    fn clone(&self) -> Self {
184        Self {
185            tx: self.tx.clone(),
186            semaphore: self.semaphore.clone(),
187            handle: self.handle.clone(),
188
189            // The new clone hasn't acquired a permit yet. It will when it's next polled ready.
190            permit: None,
191        }
192    }
193}