warpdrive_proxy/router/
mod.rs1mod protocol;
7mod route;
8mod upstream;
9
10pub use protocol::Protocol;
11pub use route::Route;
12pub use upstream::Upstream;
13
14use anyhow::{Result, anyhow};
15use pingora::prelude::*;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tracing::{debug, info};
19
20use crate::config::toml::TomlConfig;
21
22#[derive(Debug)]
27pub struct Router {
28 upstreams: HashMap<String, Arc<Upstream>>,
30
31 routes: Vec<Route>,
33}
34
35impl Router {
36 pub fn from_config(config: &TomlConfig) -> Result<Self> {
38 info!(
39 "Initializing router with {} upstreams and {} routes",
40 config.upstreams.len(),
41 config.routes.len()
42 );
43
44 let mut upstreams = HashMap::new();
46 for (name, upstream_config) in &config.upstreams {
47 let upstream = Upstream::from_config(name.clone(), upstream_config)?;
48 upstreams.insert(name.clone(), Arc::new(upstream));
49 }
50
51 let routes = config
53 .routes
54 .iter()
55 .map(Route::from_config)
56 .collect::<Result<Vec<_>>>()?;
57
58 for route in &routes {
60 if !upstreams.contains_key(&route.upstream_name) {
61 return Err(anyhow!(
62 "Route references non-existent upstream '{}'",
63 route.upstream_name
64 ));
65 }
66 }
67
68 info!("Router initialized successfully");
69 Ok(Self { upstreams, routes })
70 }
71
72 pub fn select_upstream(&self, session: &Session) -> Result<&Arc<Upstream>> {
76 for route in &self.routes {
77 if route.matches(session) {
78 let upstream = self
79 .upstreams
80 .get(&route.upstream_name)
81 .ok_or_else(|| anyhow!("Upstream '{}' not found", route.upstream_name))?;
82
83 debug!(
84 "Router: selected upstream '{}' for {} {}",
85 route.upstream_name,
86 session.req_header().method,
87 session.req_header().uri.path()
88 );
89
90 return Ok(upstream);
91 }
92 }
93
94 Err(Error::explain(
95 ErrorType::HTTPStatus(502),
96 "No matching route found for request",
97 )
98 .into())
99 }
100
101 pub fn find_matching_route(&self, session: &Session) -> Option<&Route> {
105 self.routes.iter().find(|r| r.matches(session))
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn test_router_creation() {
115 let toml_str = r#"
116 [upstreams.rails]
117 protocol = "http"
118 host = "127.0.0.1"
119 port = 3000
120
121 [upstreams.cable]
122 protocol = "wss"
123 host = "127.0.0.1"
124 port = 3001
125
126 [[routes]]
127 path_prefix = "/cable"
128 upstream = "cable"
129
130 [[routes]]
131 path_prefix = "/"
132 upstream = "rails"
133 "#;
134
135 let config: TomlConfig = toml::from_str(toml_str).unwrap();
136 let router = Router::from_config(&config).unwrap();
137
138 assert_eq!(router.upstreams.len(), 2);
139 assert_eq!(router.routes.len(), 2);
140 }
141
142 #[test]
143 fn test_router_validates_upstream_refs() {
144 let toml_str = r#"
145 [upstreams.rails]
146 protocol = "http"
147 host = "127.0.0.1"
148 port = 3000
149
150 [[routes]]
151 path_prefix = "/"
152 upstream = "nonexistent"
153 "#;
154
155 let config: TomlConfig = toml::from_str(toml_str).unwrap();
156 let result = Router::from_config(&config);
157
158 assert!(result.is_err());
159 assert!(result.unwrap_err().to_string().contains("nonexistent"));
160 }
161}