warpdrive_proxy/router/
upstream.rs

1//! Upstream backend management with load balancing
2//!
3//! This module provides the Upstream abstraction that always uses Pingora's LoadBalancer,
4//! even for single backends (n=1). This simplifies the architecture with negligible overhead (<100ns).
5
6use anyhow::{Result, anyhow};
7use pingora::prelude::*;
8use pingora_load_balancing::selection::RoundRobin;
9use pingora_load_balancing::{Backend, LoadBalancer};
10use std::collections::BTreeSet;
11use std::sync::Arc;
12use tracing::{debug, info};
13
14use crate::config::toml::{BackendConfig, UpstreamConfig};
15use crate::router::Protocol;
16
17/// Upstream service with load balancing
18///
19/// Always uses Pingora's LoadBalancer (even for single backend) for architectural simplicity.
20/// The <100ns overhead is negligible compared to network latency (~15µs for TCP loopback).
21pub struct Upstream {
22    /// Upstream name (from config)
23    pub name: String,
24
25    /// Load balancer with backends
26    balancer: Arc<LoadBalancer<RoundRobin>>,
27
28    /// Protocol for all backends (they must match)
29    protocol: Protocol,
30
31    /// SNI hostname for TLS connections
32    sni: Option<String>,
33}
34
35impl Upstream {
36    /// Create upstream from TOML configuration
37    pub fn from_config(name: String, config: &UpstreamConfig) -> Result<Self> {
38        info!(
39            "Creating upstream '{}' with protocol {:?}",
40            name, config.protocol
41        );
42
43        // Collect backends
44        let backends = if let Some(backend_configs) = &config.backends {
45            // Multiple backends: load balanced pool
46            info!(
47                "  Load balanced pool with {} backends",
48                backend_configs.len()
49            );
50            backend_configs
51                .iter()
52                .map(Self::backend_from_config)
53                .collect::<Result<BTreeSet<_>>>()?
54        } else {
55            // Single backend: still use LoadBalancer for simplicity
56            info!("  Single backend (still using LoadBalancer)");
57            let backend = Self::single_backend_from_upstream_config(config)?;
58            [backend].into()
59        };
60
61        if backends.is_empty() {
62            return Err(anyhow!("Upstream '{}' has no backends configured", name));
63        }
64
65        // Create LoadBalancer with RoundRobin strategy
66        let balancer = LoadBalancer::try_from_iter(backends)
67            .map_err(|e| anyhow!("Failed to create load balancer for '{}': {}", name, e))?;
68
69        // Optional: Add health checks
70        // balancer.set_health_check(TcpHealthCheck::new());
71
72        Ok(Self {
73            name,
74            balancer: Arc::new(balancer),
75            protocol: config.protocol,
76            sni: config.sni.clone().or_else(|| config.host.clone()),
77        })
78    }
79
80    /// Select a backend and create HttpPeer
81    pub async fn get_peer(&self) -> Result<Box<HttpPeer>> {
82        // Use LoadBalancer to select backend (synchronous)
83        let backend = self
84            .balancer
85            .select(b"", 256)
86            .ok_or_else(|| anyhow!("No healthy backend available for upstream '{}'", self.name))?;
87
88        debug!(
89            "Selected backend for '{}': {} (protocol: {:?})",
90            self.name, backend.addr, self.protocol
91        );
92
93        // Create HttpPeer with protocol-specific settings
94        let (tls, sni) = if self.protocol.requires_tls() {
95            let sni = self.sni.clone().unwrap_or_else(|| {
96                // Extract hostname from backend address
97                backend
98                    .addr
99                    .to_string()
100                    .split(':')
101                    .next()
102                    .unwrap_or("localhost")
103                    .to_string()
104            });
105            (true, sni)
106        } else {
107            (false, String::new())
108        };
109
110        Ok(Box::new(HttpPeer::new(&backend.addr, tls, sni)))
111    }
112
113    /// Create Backend from BackendConfig
114    fn backend_from_config(config: &BackendConfig) -> Result<Backend> {
115        let addr: Box<dyn std::fmt::Display + Send> = if let Some(socket) = &config.socket {
116            // Unix domain socket
117            Box::new(format!("unix:{}", socket.display()))
118        } else if let (Some(host), Some(port)) = (&config.host, &config.port) {
119            // TCP socket
120            Box::new(format!("{}:{}", host, port))
121        } else {
122            return Err(anyhow!(
123                "Backend must have either 'socket' or 'host'+'port' configured"
124            ));
125        };
126
127        Ok(Backend {
128            addr: addr.to_string().parse()?,
129            weight: 1,
130            ext: Default::default(),
131        })
132    }
133
134    /// Create Backend from single UpstreamConfig
135    fn single_backend_from_upstream_config(config: &UpstreamConfig) -> Result<Backend> {
136        let addr: Box<dyn std::fmt::Display + Send> = if let Some(socket) = &config.socket {
137            // Unix domain socket
138            Box::new(format!("unix:{}", socket.display()))
139        } else if let (Some(host), Some(port)) = (&config.host, &config.port) {
140            // TCP socket
141            Box::new(format!("{}:{}", host, port))
142        } else {
143            return Err(anyhow!(
144                "Upstream must have either 'socket' or 'host'+'port' configured"
145            ));
146        };
147
148        Ok(Backend {
149            addr: addr.to_string().parse()?,
150            weight: 1,
151            ext: Default::default(),
152        })
153    }
154}
155
156impl std::fmt::Debug for Upstream {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        f.debug_struct("Upstream")
159            .field("name", &self.name)
160            .field("protocol", &self.protocol)
161            .field("sni", &self.sni)
162            .finish_non_exhaustive()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_single_backend_upstream() {
172        let config = UpstreamConfig {
173            protocol: Protocol::Http,
174            host: Some("127.0.0.1".to_string()),
175            port: Some(3000),
176            socket: None,
177            backends: None,
178            strategy: "round_robin".to_string(),
179            sni: None,
180            process: None,
181        };
182
183        let upstream = Upstream::from_config("rails".to_string(), &config).unwrap();
184        assert_eq!(upstream.name, "rails");
185        assert_eq!(upstream.protocol, Protocol::Http);
186    }
187
188    #[test]
189    fn test_load_balanced_upstream() {
190        let config = UpstreamConfig {
191            protocol: Protocol::Http,
192            host: None,
193            port: None,
194            socket: None,
195            backends: Some(vec![
196                BackendConfig {
197                    protocol: Protocol::Http,
198                    host: Some("127.0.0.1".to_string()),
199                    port: Some(3000),
200                    socket: None,
201                    sni: None,
202                },
203                BackendConfig {
204                    protocol: Protocol::Http,
205                    host: Some("127.0.0.1".to_string()),
206                    port: Some(3001),
207                    socket: None,
208                    sni: None,
209                },
210            ]),
211            strategy: "round_robin".to_string(),
212            sni: None,
213            process: None,
214        };
215
216        let upstream = Upstream::from_config("rails".to_string(), &config).unwrap();
217        assert_eq!(upstream.name, "rails");
218    }
219}