Skip to main content

tower_batch/
service.rs

1use std::{
2    fmt::Debug,
3    sync::Arc,
4    task::{Context, Poll},
5};
6
7use futures_core::ready;
8use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
9use tokio_util::sync::PollSemaphore;
10use tower::Service;
11
12use super::{
13    future::ResponseFuture,
14    message::Message,
15    worker::{Handle, Worker},
16    BatchControl,
17};
18
19/// Handle for submitting requests to a batch worker.
20///
21/// Each `Batch` handle communicates with a single background worker over a
22/// shared channel. Handles are cheap to [`Clone`] – every clone sends to the
23/// same worker, so you can hand them to multiple tasks.
24///
25/// See the [module documentation](crate) for the full lifecycle and error
26/// semantics.
27#[derive(Debug)]
28pub struct Batch<T, Request>
29where
30    T: Service<BatchControl<Request>>,
31{
32    // Note: this actually _is_ bounded, but rather than using Tokio's bounded
33    // channel, we use Tokio's semaphore separately to implement the bound.
34    tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
35
36    // When the buffer's channel is full, we want to exert backpressure in
37    // `poll_ready`, so that callers such as load balancers could choose to call
38    // another service rather than waiting for buffer capacity.
39    //
40    // Unfortunately, this can't be done easily using Tokio's bounded MPSC
41    // channel, because it doesn't expose a polling-based interface, only an
42    // `async fn ready`, which borrows the sender. Therefore, we implement our
43    // own bounded MPSC on top of the unbounded channel, using a semaphore to
44    // limit how many items are in the channel.
45    semaphore: PollSemaphore,
46
47    // The current semaphore permit, if one has been acquired.
48    //
49    // This is acquired in `poll_ready` and taken in `call`.
50    permit: Option<OwnedSemaphorePermit>,
51    handle: Handle,
52}
53
54impl<T, Request> Batch<T, Request>
55where
56    T: Service<BatchControl<Request>>,
57    T::Error: Into<crate::BoxError>,
58{
59    /// Creates a new `Batch` wrapping `service`.
60    ///
61    /// `size` is the maximum number of items per batch and `time` is the
62    /// maximum duration before a batch is flushed. The worker flushes
63    /// whichever limit is hit first.
64    ///
65    /// The background worker is spawned on the default Tokio executor, so
66    /// this method must be called while on the Tokio runtime.
67    pub fn new(service: T, size: usize, time: std::time::Duration) -> Self
68    where
69        T: Send + 'static,
70        T::Future: Send,
71        T::Error: Send + Sync,
72        Request: Send + 'static,
73    {
74        let (service, worker) = Self::pair(service, size, time);
75        tokio::spawn(worker);
76        service
77    }
78
79    /// Creates a new `Batch` wrapping `service`, but returns the background worker.
80    ///
81    /// This is useful if you do not want to spawn directly onto the `tokio`
82    /// runtime but instead want to use your own executor. This will return the
83    /// `Batch` and the background `Worker` that you can then spawn.
84    pub fn pair(service: T, size: usize, time: std::time::Duration) -> (Self, Worker<T, Request>)
85    where
86        T: Send + 'static,
87        T::Future: Send,
88        T::Error: Send + Sync,
89        Request: Send + 'static,
90    {
91        // The semaphore bound limits the maximum number of concurrent requests
92        // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
93        // used their semaphore reservation in a `call` yet).
94        // We choose a bound that allows callers to check readiness for every item in
95        // a batch, then actually submit those items.
96        let (tx, rx) = mpsc::unbounded_channel();
97        let bound = size;
98        let semaphore = Arc::new(Semaphore::new(bound));
99
100        let (handle, worker) = Worker::new(rx, service, size, time, &semaphore);
101
102        let batch = Self {
103            tx,
104            semaphore: PollSemaphore::new(semaphore),
105            permit: None,
106            handle,
107        };
108        (batch, worker)
109    }
110
111    fn get_worker_error(&self) -> crate::BoxError {
112        self.handle.get_error_on_closed()
113    }
114}
115
116impl<T, Request> Service<Request> for Batch<T, Request>
117where
118    T: Service<BatchControl<Request>>,
119    T::Error: Into<crate::BoxError>,
120{
121    // Our response is effectively the response of the service used by the Worker
122    type Response = T::Response;
123    type Error = crate::BoxError;
124    type Future = ResponseFuture<T::Future>;
125
126    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127        tracing::debug!("checking if service is ready");
128
129        // First, check if the worker is still alive.
130        if self.tx.is_closed() {
131            // If the inner service has errored, then we error here.
132            return Poll::Ready(Err(self.get_worker_error()));
133        }
134
135        // Then, check if we've already acquired a permit.
136        if self.permit.is_some() {
137            // We've already reserved capacity to send a request. We're ready!
138            return Poll::Ready(Ok(()));
139        }
140
141        // Finally, if we haven't already acquired a permit, poll the semaphore to acquire one. If
142        // we acquire a permit, then there's enough buffer capacity to send a new request.
143        // Otherwise, we need to wait for capacity.
144        //
145        // The current task must be scheduled for wakeup every time we return `Poll::Pending`. If
146        // it returns Pending, the semaphore also schedules the task for wakeup when the next permit
147        // is available.
148        let permit =
149            ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
150        self.permit = Some(permit);
151
152        Poll::Ready(Ok(()))
153    }
154
155    fn call(&mut self, request: Request) -> Self::Future {
156        tracing::debug!("sending request to batch worker");
157
158        let permit = self
159            .permit
160            .take()
161            .expect("batch full; poll_ready must be called first");
162
163        // Get the current Span so that we can explicitly propagate it to the worker
164        // if we didn't do this, events on the worker related to this span wouldn't be counted
165        // towards that span since the worker would have no way of entering it.
166        let span = tracing::Span::current();
167
168        // If we've made it here, then a semaphore permit has already been acquired, so we can
169        // freely allocate a oneshot.
170        let (tx, rx) = oneshot::channel();
171
172        // The worker is in control of completing the request now.
173        match self.tx.send(Message {
174            request,
175            tx,
176            span,
177            _permit: permit,
178        }) {
179            Err(_) => ResponseFuture::failed(self.get_worker_error()),
180            Ok(()) => ResponseFuture::new(rx),
181        }
182    }
183}
184
185impl<T, Request> Clone for Batch<T, Request>
186where
187    T: Service<BatchControl<Request>>,
188{
189    fn clone(&self) -> Self {
190        Self {
191            tx: self.tx.clone(),
192            semaphore: self.semaphore.clone(),
193            handle: self.handle.clone(),
194
195            // The new clone hasn't acquired a permit yet. It will when it's next polled ready.
196            permit: None,
197        }
198    }
199}