warpdrive_proxy/router/
upstream.rs1use anyhow::{Result, anyhow};
7use pingora::prelude::*;
8use pingora_load_balancing::selection::RoundRobin;
9use pingora_load_balancing::{Backend, LoadBalancer};
10use std::collections::BTreeSet;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14use crate::config::toml::{BackendConfig, UpstreamConfig};
15use crate::router::Protocol;
16
17pub struct Upstream {
22 pub name: String,
24
25 balancer: Arc<LoadBalancer<RoundRobin>>,
27
28 protocol: Protocol,
30
31 sni: Option<String>,
33}
34
35impl Upstream {
36 pub fn from_config(name: String, config: &UpstreamConfig) -> Result<Self> {
38 info!(
39 "Creating upstream '{}' with protocol {:?}",
40 name, config.protocol
41 );
42
43 let backends = if let Some(backend_configs) = &config.backends {
45 info!(
47 " Load balanced pool with {} backends",
48 backend_configs.len()
49 );
50 backend_configs
51 .iter()
52 .map(Self::backend_from_config)
53 .collect::<Result<BTreeSet<_>>>()?
54 } else {
55 info!(" Single backend (still using LoadBalancer)");
57 let backend = Self::single_backend_from_upstream_config(config)?;
58 [backend].into()
59 };
60
61 if backends.is_empty() {
62 return Err(anyhow!("Upstream '{}' has no backends configured", name));
63 }
64
65 let balancer = LoadBalancer::try_from_iter(backends)
67 .map_err(|e| anyhow!("Failed to create load balancer for '{}': {}", name, e))?;
68
69 Ok(Self {
73 name,
74 balancer: Arc::new(balancer),
75 protocol: config.protocol,
76 sni: config.sni.clone().or_else(|| config.host.clone()),
77 })
78 }
79
80 pub async fn get_peer(&self) -> Result<Box<HttpPeer>> {
82 let backend = self
84 .balancer
85 .select(b"", 256)
86 .ok_or_else(|| anyhow!("No healthy backend available for upstream '{}'", self.name))?;
87
88 debug!(
89 "Selected backend for '{}': {} (protocol: {:?})",
90 self.name, backend.addr, self.protocol
91 );
92
93 let (tls, sni) = if self.protocol.requires_tls() {
95 let sni = self.sni.clone().unwrap_or_else(|| {
96 backend
98 .addr
99 .to_string()
100 .split(':')
101 .next()
102 .unwrap_or("localhost")
103 .to_string()
104 });
105 (true, sni)
106 } else {
107 (false, String::new())
108 };
109
110 Ok(Box::new(HttpPeer::new(&backend.addr, tls, sni)))
111 }
112
113 fn backend_from_config(config: &BackendConfig) -> Result<Backend> {
115 let addr: Box<dyn std::fmt::Display + Send> = if let Some(socket) = &config.socket {
116 Box::new(format!("unix:{}", socket.display()))
118 } else if let (Some(host), Some(port)) = (&config.host, &config.port) {
119 Box::new(format!("{}:{}", host, port))
121 } else {
122 return Err(anyhow!(
123 "Backend must have either 'socket' or 'host'+'port' configured"
124 ));
125 };
126
127 Ok(Backend {
128 addr: addr.to_string().parse()?,
129 weight: 1,
130 ext: Default::default(),
131 })
132 }
133
134 fn single_backend_from_upstream_config(config: &UpstreamConfig) -> Result<Backend> {
136 let addr: Box<dyn std::fmt::Display + Send> = if let Some(socket) = &config.socket {
137 Box::new(format!("unix:{}", socket.display()))
139 } else if let (Some(host), Some(port)) = (&config.host, &config.port) {
140 Box::new(format!("{}:{}", host, port))
142 } else {
143 return Err(anyhow!(
144 "Upstream must have either 'socket' or 'host'+'port' configured"
145 ));
146 };
147
148 Ok(Backend {
149 addr: addr.to_string().parse()?,
150 weight: 1,
151 ext: Default::default(),
152 })
153 }
154}
155
156impl std::fmt::Debug for Upstream {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 f.debug_struct("Upstream")
159 .field("name", &self.name)
160 .field("protocol", &self.protocol)
161 .field("sni", &self.sni)
162 .finish_non_exhaustive()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_single_backend_upstream() {
172 let config = UpstreamConfig {
173 protocol: Protocol::Http,
174 host: Some("127.0.0.1".to_string()),
175 port: Some(3000),
176 socket: None,
177 backends: None,
178 strategy: "round_robin".to_string(),
179 sni: None,
180 process: None,
181 };
182
183 let upstream = Upstream::from_config("rails".to_string(), &config).unwrap();
184 assert_eq!(upstream.name, "rails");
185 assert_eq!(upstream.protocol, Protocol::Http);
186 }
187
188 #[test]
189 fn test_load_balanced_upstream() {
190 let config = UpstreamConfig {
191 protocol: Protocol::Http,
192 host: None,
193 port: None,
194 socket: None,
195 backends: Some(vec![
196 BackendConfig {
197 protocol: Protocol::Http,
198 host: Some("127.0.0.1".to_string()),
199 port: Some(3000),
200 socket: None,
201 sni: None,
202 },
203 BackendConfig {
204 protocol: Protocol::Http,
205 host: Some("127.0.0.1".to_string()),
206 port: Some(3001),
207 socket: None,
208 sni: None,
209 },
210 ]),
211 strategy: "round_robin".to_string(),
212 sni: None,
213 process: None,
214 };
215
216 let upstream = Upstream::from_config("rails".to_string(), &config).unwrap();
217 assert_eq!(upstream.name, "rails");
218 }
219}