tower_batch_control/service.rs
1//! Wrapper service for batching items to an underlying service.
2
3use std::{
4 cmp::max,
5 fmt,
6 future::Future,
7 pin::Pin,
8 sync::{Arc, Mutex},
9 task::{Context, Poll},
10};
11
12use futures_core::ready;
13use tokio::{
14 pin,
15 sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore},
16 task::JoinHandle,
17};
18use tokio_util::sync::PollSemaphore;
19use tower::Service;
20use tracing::{info_span, Instrument};
21
22use crate::RequestWeight;
23
24use super::{
25 future::ResponseFuture,
26 message::Message,
27 worker::{ErrorHandle, Worker},
28 BatchControl,
29};
30
31/// The maximum number of batches in the queue.
32///
33/// This avoids having very large queues on machines with hundreds or thousands of cores.
34pub const QUEUE_BATCH_LIMIT: usize = 64;
35
36/// Allows batch processing of requests.
37///
38/// See the crate documentation for more details.
39pub struct Batch<T, Request: RequestWeight>
40where
41 T: Service<BatchControl<Request>>,
42{
43 // Batch management
44 //
45 /// A custom-bounded channel for sending requests to the batch worker.
46 ///
47 /// Note: this actually _is_ bounded, but rather than using Tokio's unbounded
48 /// channel, we use tokio's semaphore separately to implement the bound.
49 tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
50
51 /// A semaphore used to bound the channel.
52 ///
53 /// When the buffer's channel is full, we want to exert backpressure in
54 /// `poll_ready`, so that callers such as load balancers could choose to call
55 /// another service rather than waiting for buffer capacity.
56 ///
57 /// Unfortunately, this can't be done easily using Tokio's bounded MPSC
58 /// channel, because it doesn't wake pending tasks on close. Therefore, we implement our
59 /// own bounded MPSC on top of the unbounded channel, using a semaphore to
60 /// limit how many items are in the channel.
61 semaphore: PollSemaphore,
62
63 /// A semaphore permit that allows this service to send one message on `tx`.
64 permit: Option<OwnedSemaphorePermit>,
65
66 // Errors
67 //
68 /// An error handle shared between all service clones for the same worker.
69 error_handle: ErrorHandle,
70
71 /// A worker task handle shared between all service clones for the same worker.
72 ///
73 /// Only used when the worker is spawned on the tokio runtime.
74 worker_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
75}
76
77impl<T, Request: RequestWeight> fmt::Debug for Batch<T, Request>
78where
79 T: Service<BatchControl<Request>>,
80{
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 let name = std::any::type_name::<Self>();
83 f.debug_struct(name)
84 .field("tx", &self.tx)
85 .field("semaphore", &self.semaphore)
86 .field("permit", &self.permit)
87 .field("error_handle", &self.error_handle)
88 .field("worker_handle", &self.worker_handle)
89 .finish()
90 }
91}
92
93impl<T, Request: RequestWeight> Batch<T, Request>
94where
95 T: Service<BatchControl<Request>>,
96 T::Future: Send + 'static,
97 T::Error: Into<crate::BoxError>,
98{
99 /// Creates a new `Batch` wrapping `service`.
100 ///
101 /// The wrapper is responsible for telling the inner service when to flush a
102 /// batch of requests. These parameters control this policy:
103 ///
104 /// * `max_items_weight_in_batch` gives the maximum item weight per batch.
105 /// * `max_batches` is an upper bound on the number of batches in the queue,
106 /// and the number of concurrently executing batches.
107 /// If this is `None`, we use the current number of [`rayon`] threads.
108 /// The number of batches in the queue is also limited by [`QUEUE_BATCH_LIMIT`].
109 /// * `max_latency` gives the maximum latency for a batch item to start verifying.
110 ///
111 /// The default Tokio executor is used to run the given service, which means
112 /// that this method must be called while on the Tokio runtime.
113 pub fn new(
114 service: T,
115 max_items_weight_in_batch: usize,
116 max_batches: impl Into<Option<usize>>,
117 max_latency: std::time::Duration,
118 ) -> Self
119 where
120 T: Send + 'static,
121 T::Future: Send,
122 T::Response: Send,
123 T::Error: Send + Sync,
124 Request: Send + 'static,
125 {
126 let (mut batch, worker) =
127 Self::pair(service, max_items_weight_in_batch, max_batches, max_latency);
128
129 let span = info_span!("batch worker", kind = std::any::type_name::<T>());
130
131 #[cfg(tokio_unstable)]
132 let worker_handle = {
133 let batch_kind = std::any::type_name::<T>();
134
135 // TODO: identify the unique part of the type name generically,
136 // or make it an argument to this method
137 let batch_kind = batch_kind.trim_start_matches("zebra_consensus::primitives::");
138 let batch_kind = batch_kind.trim_end_matches("::Verifier");
139
140 tokio::task::Builder::new()
141 .name(&format!("{} batch", batch_kind))
142 .spawn(worker.run().instrument(span))
143 .expect("panic on error to match tokio::spawn")
144 };
145 #[cfg(not(tokio_unstable))]
146 let worker_handle = tokio::spawn(worker.run().instrument(span));
147
148 batch.register_worker(worker_handle);
149
150 batch
151 }
152
153 /// Creates a new `Batch` wrapping `service`, but returns the background worker.
154 ///
155 /// This is useful if you do not want to spawn directly onto the `tokio`
156 /// runtime but instead want to use your own executor. This will return the
157 /// `Batch` and the background `Worker` that you can then spawn.
158 pub fn pair(
159 service: T,
160 max_items_weight_in_batch: usize,
161 max_batches: impl Into<Option<usize>>,
162 max_latency: std::time::Duration,
163 ) -> (Self, Worker<T, Request>)
164 where
165 T: Send + 'static,
166 T::Error: Send + Sync,
167 Request: Send + 'static,
168 {
169 let (tx, rx) = mpsc::unbounded_channel();
170
171 // Clamp config to sensible values.
172 let max_items_weight_in_batch = max(max_items_weight_in_batch, 1);
173 let max_batches = max_batches
174 .into()
175 .unwrap_or_else(rayon::current_num_threads);
176 let max_batches_in_queue = max_batches.clamp(1, QUEUE_BATCH_LIMIT);
177
178 // The semaphore bound limits the maximum number of concurrent requests
179 // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
180 // used their semaphore reservation in a `call` yet).
181 //
182 // We choose a bound that allows callers to check readiness for one batch per rayon CPU thread.
183 // This helps keep all CPUs filled with work: there is one batch executing, and another ready to go.
184 // Often there is only one verifier running, when that happens we want it to take all the cores.
185 //
186 // Requests with a request weight greater than 1 won't typically exhaust the number of available
187 // permits, but will still be bounded to the maximum possible number of concurrent requests.
188 let semaphore = Semaphore::new(max_items_weight_in_batch * max_batches_in_queue);
189 let semaphore = PollSemaphore::new(Arc::new(semaphore));
190
191 let (error_handle, worker) = Worker::new(
192 service,
193 rx,
194 max_items_weight_in_batch,
195 max_batches,
196 max_latency,
197 semaphore.clone(),
198 );
199
200 let batch = Batch {
201 tx,
202 semaphore,
203 permit: None,
204 error_handle,
205 worker_handle: Arc::new(Mutex::new(None)),
206 };
207
208 (batch, worker)
209 }
210
211 /// Ask the `Batch` to monitor the spawned worker task's [`JoinHandle`].
212 ///
213 /// Only used when the task is spawned on the tokio runtime.
214 pub fn register_worker(&mut self, worker_handle: JoinHandle<()>) {
215 *self
216 .worker_handle
217 .lock()
218 .expect("previous task panicked while holding the worker handle mutex") =
219 Some(worker_handle);
220 }
221
222 /// Returns the error from the batch worker's `error_handle`.
223 fn get_worker_error(&self) -> crate::BoxError {
224 self.error_handle.get_error_on_closed()
225 }
226}
227
228impl<T, Request: RequestWeight> Service<Request> for Batch<T, Request>
229where
230 T: Service<BatchControl<Request>>,
231 T::Future: Send + 'static,
232 T::Error: Into<crate::BoxError>,
233{
234 type Response = T::Response;
235 type Error = crate::BoxError;
236 type Future = ResponseFuture<T::Future>;
237
238 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
239 // Check to see if the worker has returned or panicked.
240 //
241 // Correctness: Registers this task for wakeup when the worker finishes.
242 if let Some(worker_handle) = self
243 .worker_handle
244 .lock()
245 .expect("previous task panicked while holding the worker handle mutex")
246 .as_mut()
247 {
248 // # Correctness
249 //
250 // The inner service used with `Batch` MUST NOT return recoverable errors from its:
251 // - `poll_ready` method, or
252 // - `call` method when called with a `BatchControl::Flush` request.
253 //
254 // If the inner service returns an error in those cases, this `poll_ready` method will
255 // return an error the first time its called, and will panic the second time its called
256 // as it attempts to call `poll` on a `JoinHandle` that has already completed.
257 match Pin::new(worker_handle).poll(cx) {
258 Poll::Ready(Ok(())) => {
259 let worker_error = self.get_worker_error();
260 tracing::warn!(?worker_error, "batch worker finished unexpectedly");
261 return Poll::Ready(Err(worker_error));
262 }
263 Poll::Ready(Err(task_cancelled)) if task_cancelled.is_cancelled() => {
264 tracing::warn!(
265 "batch task cancelled: {task_cancelled}\n\
266 Is Zebra shutting down?"
267 );
268
269 return Poll::Ready(Err(task_cancelled.into()));
270 }
271 Poll::Ready(Err(task_panic)) => {
272 std::panic::resume_unwind(task_panic.into_panic());
273 }
274 Poll::Pending => {}
275 }
276 }
277
278 // Check if the worker has set an error and closed its channels.
279 //
280 // Correctness: Registers this task for wakeup when the channel is closed.
281 let tx = self.tx.clone();
282 let closed = tx.closed();
283 pin!(closed);
284 if closed.poll(cx).is_ready() {
285 return Poll::Ready(Err(self.get_worker_error()));
286 }
287
288 // Poll to acquire a semaphore permit.
289 //
290 // CORRECTNESS
291 //
292 // If we acquire a permit, then there's enough buffer capacity to send a new request.
293 // Otherwise, we need to wait for capacity. When that happens, `poll_acquire()` registers
294 // this task for wakeup when the next permit is available, or when the semaphore is closed.
295 //
296 // When `poll_ready()` is called multiple times, and channel capacity is 1,
297 // avoid deadlocks by dropping any previous permit before acquiring another one.
298 // This also stops tasks holding a permit after an error.
299 //
300 // Calling `poll_ready()` multiple times can make tasks lose their previous permit
301 // to another concurrent task.
302 self.permit = None;
303
304 let permit = ready!(self.semaphore.poll_acquire(cx));
305 if let Some(permit) = permit {
306 // Calling poll_ready() more than once will drop any previous permit,
307 // releasing its capacity back to the semaphore.
308 self.permit = Some(permit);
309 } else {
310 // The semaphore has been closed.
311 return Poll::Ready(Err(self.get_worker_error()));
312 }
313
314 Poll::Ready(Ok(()))
315 }
316
317 fn call(&mut self, request: Request) -> Self::Future {
318 tracing::trace!("sending request to batch worker");
319 let _permit = self
320 .permit
321 .take()
322 .expect("poll_ready must be called before a batch request");
323
324 // get the current Span so that we can explicitly propagate it to the worker
325 // if we didn't do this, events on the worker related to this span wouldn't be counted
326 // towards that span since the worker would have no way of entering it.
327 let span = tracing::Span::current();
328
329 // If we've made it here, then a semaphore permit has already been
330 // acquired, so we can freely allocate a oneshot.
331 let (tx, rx) = oneshot::channel();
332
333 match self.tx.send(Message {
334 request,
335 tx,
336 span,
337 _permit,
338 }) {
339 Err(_) => ResponseFuture::failed(self.get_worker_error()),
340 Ok(_) => ResponseFuture::new(rx),
341 }
342 }
343}
344
345impl<T, Request: RequestWeight> Clone for Batch<T, Request>
346where
347 T: Service<BatchControl<Request>>,
348{
349 fn clone(&self) -> Self {
350 Self {
351 tx: self.tx.clone(),
352 semaphore: self.semaphore.clone(),
353 permit: None,
354 error_handle: self.error_handle.clone(),
355 worker_handle: self.worker_handle.clone(),
356 }
357 }
358}