tower_batch/service.rs
1use std::{
2 fmt::Debug,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use futures_core::ready;
8use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
9use tokio_util::sync::PollSemaphore;
10use tower::Service;
11
12use super::{
13 future::ResponseFuture,
14 message::Message,
15 worker::{Handle, Worker},
16 BatchControl,
17};
18
19/// Handle for submitting requests to a batch worker.
20///
21/// Each `Batch` handle communicates with a single background worker over a
22/// shared channel. Handles are cheap to [`Clone`] – every clone sends to the
23/// same worker, so you can hand them to multiple tasks.
24///
25/// See the [module documentation](crate) for the full lifecycle and error
26/// semantics.
27#[derive(Debug)]
28pub struct Batch<T, Request>
29where
30 T: Service<BatchControl<Request>>,
31{
32 // Note: this actually _is_ bounded, but rather than using Tokio's bounded
33 // channel, we use Tokio's semaphore separately to implement the bound.
34 tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
35
36 // When the buffer's channel is full, we want to exert backpressure in
37 // `poll_ready`, so that callers such as load balancers could choose to call
38 // another service rather than waiting for buffer capacity.
39 //
40 // Unfortunately, this can't be done easily using Tokio's bounded MPSC
41 // channel, because it doesn't expose a polling-based interface, only an
42 // `async fn ready`, which borrows the sender. Therefore, we implement our
43 // own bounded MPSC on top of the unbounded channel, using a semaphore to
44 // limit how many items are in the channel.
45 semaphore: PollSemaphore,
46
47 // The current semaphore permit, if one has been acquired.
48 //
49 // This is acquired in `poll_ready` and taken in `call`.
50 permit: Option<OwnedSemaphorePermit>,
51 handle: Handle,
52}
53
54impl<T, Request> Batch<T, Request>
55where
56 T: Service<BatchControl<Request>>,
57 T::Error: Into<crate::BoxError>,
58{
59 /// Creates a new `Batch` wrapping `service`.
60 ///
61 /// `size` is the maximum number of items per batch and `time` is the
62 /// maximum duration before a batch is flushed. The worker flushes
63 /// whichever limit is hit first.
64 ///
65 /// The background worker is spawned on the default Tokio executor, so
66 /// this method must be called while on the Tokio runtime.
67 pub fn new(service: T, size: usize, time: std::time::Duration) -> Self
68 where
69 T: Send + 'static,
70 T::Future: Send,
71 T::Error: Send + Sync,
72 Request: Send + 'static,
73 {
74 let (service, worker) = Self::pair(service, size, time);
75 tokio::spawn(worker);
76 service
77 }
78
79 /// Creates a new `Batch` wrapping `service`, but returns the background worker.
80 ///
81 /// This is useful if you do not want to spawn directly onto the `tokio`
82 /// runtime but instead want to use your own executor. This will return the
83 /// `Batch` and the background `Worker` that you can then spawn.
84 pub fn pair(service: T, size: usize, time: std::time::Duration) -> (Self, Worker<T, Request>)
85 where
86 T: Send + 'static,
87 T::Future: Send,
88 T::Error: Send + Sync,
89 Request: Send + 'static,
90 {
91 // The semaphore bound limits the maximum number of concurrent requests
92 // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
93 // used their semaphore reservation in a `call` yet).
94 // We choose a bound that allows callers to check readiness for every item in
95 // a batch, then actually submit those items.
96 let (tx, rx) = mpsc::unbounded_channel();
97 let bound = size;
98 let semaphore = Arc::new(Semaphore::new(bound));
99
100 let (handle, worker) = Worker::new(rx, service, size, time, &semaphore);
101
102 let batch = Self {
103 tx,
104 semaphore: PollSemaphore::new(semaphore),
105 permit: None,
106 handle,
107 };
108 (batch, worker)
109 }
110
111 fn get_worker_error(&self) -> crate::BoxError {
112 self.handle.get_error_on_closed()
113 }
114}
115
116impl<T, Request> Service<Request> for Batch<T, Request>
117where
118 T: Service<BatchControl<Request>>,
119 T::Error: Into<crate::BoxError>,
120{
121 // Our response is effectively the response of the service used by the Worker
122 type Response = T::Response;
123 type Error = crate::BoxError;
124 type Future = ResponseFuture<T::Future>;
125
126 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
127 tracing::debug!("checking if service is ready");
128
129 // First, check if the worker is still alive.
130 if self.tx.is_closed() {
131 // If the inner service has errored, then we error here.
132 return Poll::Ready(Err(self.get_worker_error()));
133 }
134
135 // Then, check if we've already acquired a permit.
136 if self.permit.is_some() {
137 // We've already reserved capacity to send a request. We're ready!
138 return Poll::Ready(Ok(()));
139 }
140
141 // Finally, if we haven't already acquired a permit, poll the semaphore to acquire one. If
142 // we acquire a permit, then there's enough buffer capacity to send a new request.
143 // Otherwise, we need to wait for capacity.
144 //
145 // The current task must be scheduled for wakeup every time we return `Poll::Pending`. If
146 // it returns Pending, the semaphore also schedules the task for wakeup when the next permit
147 // is available.
148 let permit =
149 ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
150 self.permit = Some(permit);
151
152 Poll::Ready(Ok(()))
153 }
154
155 fn call(&mut self, request: Request) -> Self::Future {
156 tracing::debug!("sending request to batch worker");
157
158 let permit = self
159 .permit
160 .take()
161 .expect("batch full; poll_ready must be called first");
162
163 // Get the current Span so that we can explicitly propagate it to the worker
164 // if we didn't do this, events on the worker related to this span wouldn't be counted
165 // towards that span since the worker would have no way of entering it.
166 let span = tracing::Span::current();
167
168 // If we've made it here, then a semaphore permit has already been acquired, so we can
169 // freely allocate a oneshot.
170 let (tx, rx) = oneshot::channel();
171
172 // The worker is in control of completing the request now.
173 match self.tx.send(Message {
174 request,
175 tx,
176 span,
177 _permit: permit,
178 }) {
179 Err(_) => ResponseFuture::failed(self.get_worker_error()),
180 Ok(()) => ResponseFuture::new(rx),
181 }
182 }
183}
184
185impl<T, Request> Clone for Batch<T, Request>
186where
187 T: Service<BatchControl<Request>>,
188{
189 fn clone(&self) -> Self {
190 Self {
191 tx: self.tx.clone(),
192 semaphore: self.semaphore.clone(),
193 handle: self.handle.clone(),
194
195 // The new clone hasn't acquired a permit yet. It will when it's next polled ready.
196 permit: None,
197 }
198 }
199}