volo_http/client/
loadbalance.rs

1//! HTTP Loadbalance Layer
2//!
3//! This is a copy of `volo::loadbalance::layer` without the retry logic. Because retry needs the
4//! `Req` has `Clone` trait, but HTTP body may be a stream, which cannot be cloned. So we remove
5//! the retry related codes here.
6//!
7//! In addition, HTTP service can use DNS as service discover, so the default load balance uses a
8//! DNS resolver for pick a target address (the DNS resolver picks only one because it does not
9//! need load balance).
10
11use std::{fmt::Debug, sync::Arc};
12
13use async_broadcast::RecvError;
14use motore::{layer::Layer, service::Service};
15use volo::{
16    context::Context,
17    discovery::Discover,
18    loadbalance::{LoadBalance, MkLbLayer, random::WeightedRandomBalance},
19};
20
21use super::dns::{DiscoverKey, DnsResolver};
22use crate::{
23    context::ClientContext,
24    error::{
25        ClientError,
26        client::{lb_error, no_available_endpoint},
27    },
28    request::Request,
29};
30
31/// Default load balance with [`DnsResolver`]
32pub type DefaultLb = LbConfig<WeightedRandomBalance<DiscoverKey>, DnsResolver>;
33/// Default load balance service generated by [`DefaultLb`]
34pub type DefaultLbService<S> =
35    LoadBalanceService<WeightedRandomBalance<DiscoverKey>, DnsResolver, S>;
36
37/// Load balance layer generator with a [`LoadBalance`] and a [`Discover`]
38pub struct LbConfig<L, D> {
39    load_balance: L,
40    discover: D,
41}
42
43impl Default for DefaultLb {
44    fn default() -> Self {
45        LbConfig::new(WeightedRandomBalance::new(), DnsResolver::default())
46    }
47}
48
49impl<L, D> LbConfig<L, D> {
50    /// Create a new [`LbConfig`] using a [`LoadBalance`] and a [`Discover`]
51    pub fn new(load_balance: L, discover: D) -> Self {
52        LbConfig {
53            load_balance,
54            discover,
55        }
56    }
57
58    /// Set a [`LoadBalance`] to the [`LbConfig`] and replace the previous one
59    pub fn load_balance<NL>(self, load_balance: NL) -> LbConfig<NL, D> {
60        LbConfig {
61            load_balance,
62            discover: self.discover,
63        }
64    }
65
66    /// Set a [`Discover`] to the [`LbConfig`] and replace the previous one
67    pub fn discover<ND>(self, discover: ND) -> LbConfig<L, ND> {
68        LbConfig {
69            load_balance: self.load_balance,
70            discover,
71        }
72    }
73}
74
75impl<LB, D> MkLbLayer for LbConfig<LB, D> {
76    type Layer = LoadBalanceLayer<LB, D>;
77
78    fn make(self) -> Self::Layer {
79        LoadBalanceLayer::new(self.load_balance, self.discover)
80    }
81}
82
83/// [`Layer`] for load balance generated by [`LbConfig`]
84#[derive(Clone, Default, Copy)]
85pub struct LoadBalanceLayer<LB, D> {
86    load_balance: LB,
87    discover: D,
88}
89
90impl<LB, D> LoadBalanceLayer<LB, D> {
91    fn new(load_balance: LB, discover: D) -> Self {
92        LoadBalanceLayer {
93            load_balance,
94            discover,
95        }
96    }
97}
98
99impl<LB, D, S> Layer<S> for LoadBalanceLayer<LB, D>
100where
101    LB: LoadBalance<D>,
102    D: Discover,
103{
104    type Service = LoadBalanceService<LB, D, S>;
105
106    fn layer(self, inner: S) -> Self::Service {
107        LoadBalanceService::new(self.load_balance, self.discover, inner)
108    }
109}
110
111/// [`Service`] for load balance generated by [`LoadBalanceLayer`]
112#[derive(Clone)]
113pub struct LoadBalanceService<LB, D, S> {
114    load_balance: Arc<LB>,
115    discover: D,
116    service: S,
117}
118
119impl<LB, D, S> LoadBalanceService<LB, D, S>
120where
121    LB: LoadBalance<D>,
122    D: Discover,
123{
124    fn new(load_balance: LB, discover: D, service: S) -> Self {
125        let lb = Arc::new(load_balance);
126
127        let service = Self {
128            load_balance: lb.clone(),
129            discover,
130            service,
131        };
132
133        let Some(mut channel) = service.discover.watch(None) else {
134            return service;
135        };
136
137        tokio::spawn(async move {
138            loop {
139                match channel.recv().await {
140                    Ok(recv) => lb.rebalance(recv),
141                    Err(err) => match err {
142                        RecvError::Closed => break,
143                        _ => tracing::warn!("[Volo-HTTP] discovering subscription error: {err}"),
144                    },
145                }
146            }
147        });
148
149        service
150    }
151}
152
153impl<LB, D, S, B> Service<ClientContext, Request<B>> for LoadBalanceService<LB, D, S>
154where
155    LB: LoadBalance<D>,
156    D: Discover,
157    S: Service<ClientContext, Request<B>, Error = ClientError> + Send + Sync,
158    B: Send,
159{
160    type Response = S::Response;
161    type Error = S::Error;
162
163    async fn call(
164        &self,
165        cx: &mut ClientContext,
166        req: Request<B>,
167    ) -> Result<Self::Response, Self::Error> {
168        let callee = cx.rpc_info().callee();
169
170        let mut picker = match &callee.address {
171            None => self
172                .load_balance
173                .get_picker(callee, &self.discover)
174                .await
175                .map_err(lb_error)?,
176            _ => {
177                return self.service.call(cx, req).await;
178            }
179        };
180
181        let addr = picker.next().ok_or_else(no_available_endpoint)?;
182        cx.rpc_info_mut().callee_mut().set_address(addr);
183
184        self.service.call(cx, req).await
185    }
186}
187
188impl<LB, D, S> Debug for LoadBalanceService<LB, D, S>
189where
190    LB: Debug,
191    D: Debug,
192    S: Debug,
193{
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        f.debug_struct("LBService")
196            .field("load_balancer", &self.load_balance)
197            .field("discover", &self.discover)
198            .finish()
199    }
200}