1#![deny(missing_docs)]
16
17use super::p2c::Balance;
18use crate::error;
19use futures_core::ready;
20use pin_project::pin_project;
21use slab::Slab;
22use std::{
23 fmt,
24 future::Future,
25 pin::Pin,
26 task::{Context, Poll},
27};
28use tower_discover::{Change, Discover};
29use tower_load::Load;
30use tower_make::MakeService;
31use tower_service::Service;
32
33#[cfg(test)]
34mod test;
35
36#[derive(Debug, Clone, Copy, Eq, PartialEq)]
37enum Level {
38 Low,
40 Normal,
42 High,
44}
45
46#[pin_project]
49pub struct PoolDiscoverer<MS, Target, Request>
50where
51 MS: MakeService<Target, Request>,
52{
53 maker: MS,
54 #[pin]
55 making: Option<MS::Future>,
56 target: Target,
57 load: Level,
58 services: Slab<()>,
59 died_tx: tokio::sync::mpsc::UnboundedSender<usize>,
60 #[pin]
61 died_rx: tokio::sync::mpsc::UnboundedReceiver<usize>,
62 limit: Option<usize>,
63}
64
65impl<MS, Target, Request> fmt::Debug for PoolDiscoverer<MS, Target, Request>
66where
67 MS: MakeService<Target, Request> + fmt::Debug,
68 Target: fmt::Debug,
69{
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("PoolDiscoverer")
72 .field("maker", &self.maker)
73 .field("making", &self.making.is_some())
74 .field("target", &self.target)
75 .field("load", &self.load)
76 .field("services", &self.services)
77 .field("limit", &self.limit)
78 .finish()
79 }
80}
81
82impl<MS, Target, Request> Discover for PoolDiscoverer<MS, Target, Request>
83where
84 MS: MakeService<Target, Request>,
85 MS::MakeError: Into<error::Error>,
86 MS::Error: Into<error::Error>,
87 Target: Clone,
88{
89 type Key = usize;
90 type Service = DropNotifyService<MS::Service>;
91 type Error = MS::MakeError;
92
93 fn poll_discover(
94 self: Pin<&mut Self>,
95 cx: &mut Context<'_>,
96 ) -> Poll<Result<Change<Self::Key, Self::Service>, Self::Error>> {
97 let mut this = self.project();
98
99 while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_recv(cx) {
100 this.services.remove(sid);
101 tracing::trace!(
102 pool.services = this.services.len(),
103 message = "removing dropped service"
104 );
105 }
106
107 if this.services.len() == 0 && this.making.is_none() {
108 let _ = ready!(this.maker.poll_ready(cx))?;
109 tracing::trace!("construct initial pool connection");
110 this.making
111 .set(Some(this.maker.make_service(this.target.clone())));
112 }
113
114 if let Level::High = this.load {
115 if this.making.is_none() {
116 if this
117 .limit
118 .map(|limit| this.services.len() >= limit)
119 .unwrap_or(false)
120 {
121 return Poll::Pending;
122 }
123
124 tracing::trace!(
125 pool.services = this.services.len(),
126 message = "decided to add service to loaded pool"
127 );
128 ready!(this.maker.poll_ready(cx))?;
129 tracing::trace!("making new service");
130 this.making
132 .set(Some(this.maker.make_service(this.target.clone())));
133 }
134 }
135
136 if let Some(fut) = this.making.as_mut().as_pin_mut() {
137 let svc = ready!(fut.poll(cx))?;
138 this.making.set(None);
139
140 let id = this.services.insert(());
141 let svc = DropNotifyService {
142 svc,
143 id,
144 notify: this.died_tx.clone(),
145 };
146 tracing::trace!(
147 pool.services = this.services.len(),
148 message = "finished creating new service"
149 );
150 *this.load = Level::Normal;
151 return Poll::Ready(Ok(Change::Insert(id, svc)));
152 }
153
154 match this.load {
155 Level::High => {
156 unreachable!("found high load but no Service being made");
157 }
158 Level::Normal => Poll::Pending,
159 Level::Low if this.services.len() == 1 => Poll::Pending,
160 Level::Low => {
161 *this.load = Level::Normal;
162 let rm = this.services.iter().next().unwrap().0;
164 tracing::trace!(
167 pool.services = this.services.len(),
168 message = "removing service for over-provisioned pool"
169 );
170 Poll::Ready(Ok(Change::Remove(rm)))
171 }
172 }
173 }
174}
175
176#[derive(Copy, Clone, Debug)]
182pub struct Builder {
183 low: f64,
184 high: f64,
185 init: f64,
186 alpha: f64,
187 limit: Option<usize>,
188}
189
190impl Default for Builder {
191 fn default() -> Self {
192 Builder {
193 init: 0.1,
194 low: 0.00001,
195 high: 0.2,
196 alpha: 0.03,
197 limit: None,
198 }
199 }
200}
201
202impl Builder {
203 pub fn new() -> Self {
207 Self::default()
208 }
209
210 pub fn underutilized_below(&mut self, low: f64) -> &mut Self {
216 self.low = low;
217 self
218 }
219
220 pub fn loaded_above(&mut self, high: f64) -> &mut Self {
227 self.high = high;
228 self
229 }
230
231 pub fn initial(&mut self, init: f64) -> &mut Self {
238 self.init = init;
239 self
240 }
241
242 pub fn urgency(&mut self, alpha: f64) -> &mut Self {
256 self.alpha = alpha.max(0.0).min(1.0);
257 self
258 }
259
260 pub fn max_services(&mut self, limit: Option<usize>) -> &mut Self {
267 self.limit = limit;
268 self
269 }
270
271 pub fn build<MS, Target, Request>(
273 &self,
274 make_service: MS,
275 target: Target,
276 ) -> Pool<MS, Target, Request>
277 where
278 MS: MakeService<Target, Request>,
279 MS::Service: Load,
280 <MS::Service as Load>::Metric: std::fmt::Debug,
281 MS::MakeError: Into<error::Error>,
282 MS::Error: Into<error::Error>,
283 Target: Clone,
284 {
285 let (died_tx, died_rx) = tokio::sync::mpsc::unbounded_channel();
286 let d = PoolDiscoverer {
287 maker: make_service,
288 making: None,
289 target,
290 load: Level::Normal,
291 services: Slab::new(),
292 died_tx,
293 died_rx,
294 limit: self.limit,
295 };
296
297 Pool {
298 balance: Balance::from_entropy(Box::pin(d)),
299 options: *self,
300 ewma: self.init,
301 }
302 }
303}
304
305pub struct Pool<MS, Target, Request>
307where
308 MS: MakeService<Target, Request>,
309 MS::MakeError: Into<error::Error>,
310 MS::Error: Into<error::Error>,
311 Target: Clone,
312{
313 balance: Balance<Pin<Box<PoolDiscoverer<MS, Target, Request>>>, Request>,
315 options: Builder,
316 ewma: f64,
317}
318
319impl<MS, Target, Request> fmt::Debug for Pool<MS, Target, Request>
320where
321 MS: MakeService<Target, Request> + fmt::Debug,
322 MS::MakeError: Into<error::Error>,
323 MS::Error: Into<error::Error>,
324 Target: Clone + fmt::Debug,
325 MS::Service: fmt::Debug,
326 Request: fmt::Debug,
327{
328 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
329 f.debug_struct("Pool")
330 .field("balance", &self.balance)
331 .field("options", &self.options)
332 .field("ewma", &self.ewma)
333 .finish()
334 }
335}
336
337impl<MS, Target, Request> Pool<MS, Target, Request>
338where
339 MS: MakeService<Target, Request>,
340 MS::Service: Load,
341 <MS::Service as Load>::Metric: std::fmt::Debug,
342 MS::MakeError: Into<error::Error>,
343 MS::Error: Into<error::Error>,
344 Target: Clone,
345{
346 pub fn new(make_service: MS, target: Target) -> Self {
353 Builder::new().build(make_service, target)
354 }
355}
356
357type PinBalance<S, Request> = Balance<Pin<Box<S>>, Request>;
358
359impl<MS, Target, Req> Service<Req> for Pool<MS, Target, Req>
360where
361 MS: MakeService<Target, Req>,
362 MS::Service: Load,
363 <MS::Service as Load>::Metric: std::fmt::Debug,
364 MS::MakeError: Into<error::Error>,
365 MS::Error: Into<error::Error>,
366 Target: Clone,
367{
368 type Response = <PinBalance<PoolDiscoverer<MS, Target, Req>, Req> as Service<Req>>::Response;
369 type Error = <PinBalance<PoolDiscoverer<MS, Target, Req>, Req> as Service<Req>>::Error;
370 type Future = <PinBalance<PoolDiscoverer<MS, Target, Req>, Req> as Service<Req>>::Future;
371
372 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
373 if let Poll::Ready(()) = self.balance.poll_ready(cx)? {
374 self.ewma = (1.0 - self.options.alpha) * self.ewma;
377
378 let discover = self.balance.discover_mut().as_mut().project();
379 if self.ewma < self.options.low {
380 if *discover.load != Level::Low {
381 tracing::trace!({ ewma = %self.ewma }, "pool is over-provisioned");
382 }
383 *discover.load = Level::Low;
384
385 if discover.services.len() > 1 {
386 self.ewma = self.options.init;
388 }
389 } else {
390 if *discover.load != Level::Normal {
391 tracing::trace!({ ewma = %self.ewma }, "pool is appropriately provisioned");
392 }
393 *discover.load = Level::Normal;
394 }
395
396 return Poll::Ready(Ok(()));
397 }
398
399 let discover = self.balance.discover_mut().as_mut().project();
400 if discover.making.is_none() {
401 self.ewma = self.options.alpha + (1.0 - self.options.alpha) * self.ewma;
404
405 if self.ewma > self.options.high {
406 if *discover.load != Level::High {
407 tracing::trace!({ ewma = %self.ewma }, "pool is under-provisioned");
408 }
409 *discover.load = Level::High;
410
411 self.ewma = self.options.high;
415
416 return self.balance.poll_ready(cx);
419 } else {
420 *discover.load = Level::Normal;
421 }
422 }
423
424 Poll::Pending
425 }
426
427 fn call(&mut self, req: Req) -> Self::Future {
428 self.balance.call(req)
429 }
430}
431
432#[doc(hidden)]
433#[derive(Debug)]
434pub struct DropNotifyService<Svc> {
435 svc: Svc,
436 id: usize,
437 notify: tokio::sync::mpsc::UnboundedSender<usize>,
438}
439
440impl<Svc> Drop for DropNotifyService<Svc> {
441 fn drop(&mut self) {
442 let _ = self.notify.send(self.id).is_ok();
443 }
444}
445
446impl<Svc: Load> Load for DropNotifyService<Svc> {
447 type Metric = Svc::Metric;
448 fn load(&self) -> Self::Metric {
449 self.svc.load()
450 }
451}
452
453impl<Request, Svc: Service<Request>> Service<Request> for DropNotifyService<Svc> {
454 type Response = Svc::Response;
455 type Future = Svc::Future;
456 type Error = Svc::Error;
457
458 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
459 self.svc.poll_ready(cx)
460 }
461
462 fn call(&mut self, req: Request) -> Self::Future {
463 self.svc.call(req)
464 }
465}