rust_web_server/router/
mod.rs1#[cfg(test)]
2mod tests;
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::request::Request;
8use crate::response::Response;
9use crate::server::ConnectionInfo;
10
11pub struct PathParams {
20 params: HashMap<String, String>,
21}
22
23impl PathParams {
24 fn new() -> Self {
25 PathParams { params: HashMap::new() }
26 }
27
28 pub(crate) fn from_map(params: HashMap<String, String>) -> Self {
30 PathParams { params }
31 }
32
33 pub fn get(&self, name: &str) -> Option<&str> {
35 self.params.get(name).map(String::as_str)
36 }
37
38 fn insert(&mut self, key: String, value: String) {
39 self.params.insert(key, value);
40 }
41}
42
43#[derive(Clone)]
44enum Segment {
45 Literal(String),
46 Param(String),
47 Wildcard(String),
48}
49
50type HandlerFn =
51 Arc<dyn Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static>;
52
53#[derive(Clone)]
54struct Route {
55 method: String,
56 segments: Vec<Segment>,
57 handler: HandlerFn,
58}
59
60#[derive(Clone)]
100pub struct RouteInfo {
101 pub method: String,
102 pub pattern: String,
103}
104
105#[derive(Clone)]
106pub struct Router {
107 routes: Vec<Route>,
108 host: Option<String>,
111}
112
113impl Router {
114 pub fn new() -> Self {
115 Router { routes: Vec::new(), host: None }
116 }
117
118 pub fn with_host(mut self, host: &str) -> Self {
121 self.host = Some(host.to_string());
122 self
123 }
124
125 pub fn get<F>(self, pattern: &str, handler: F) -> Self
127 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
128 self.add("GET", pattern, handler)
129 }
130
131 pub fn post<F>(self, pattern: &str, handler: F) -> Self
133 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
134 self.add("POST", pattern, handler)
135 }
136
137 pub fn put<F>(self, pattern: &str, handler: F) -> Self
139 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
140 self.add("PUT", pattern, handler)
141 }
142
143 pub fn patch<F>(self, pattern: &str, handler: F) -> Self
145 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
146 self.add("PATCH", pattern, handler)
147 }
148
149 pub fn delete<F>(self, pattern: &str, handler: F) -> Self
151 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
152 self.add("DELETE", pattern, handler)
153 }
154
155 fn add<F>(mut self, method: &str, pattern: &str, handler: F) -> Self
156 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
157 self.routes.push(Route {
158 method: method.to_string(),
159 segments: Self::parse_pattern(pattern),
160 handler: Arc::new(handler),
161 });
162 self
163 }
164
165 fn parse_pattern(pattern: &str) -> Vec<Segment> {
166 if pattern == "/" {
167 return vec![];
168 }
169 pattern
170 .split('/')
171 .filter(|s| !s.is_empty())
172 .map(|seg| {
173 if let Some(name) = seg.strip_prefix(':') {
174 Segment::Param(name.to_string())
175 } else if let Some(name) = seg.strip_prefix('*') {
176 Segment::Wildcard(name.to_string())
177 } else {
178 Segment::Literal(seg.to_string())
179 }
180 })
181 .collect()
182 }
183
184 pub fn route_entries(&self) -> Vec<RouteInfo> {
189 self.routes.iter().map(|r| RouteInfo {
190 method: r.method.clone(),
191 pattern: Self::segments_to_pattern(&r.segments),
192 }).collect()
193 }
194
195 fn segments_to_pattern(segs: &[Segment]) -> String {
196 if segs.is_empty() {
197 return "/".to_string();
198 }
199 let parts: Vec<String> = segs.iter().map(|s| match s {
200 Segment::Literal(l) => l.clone(),
201 Segment::Param(n) => format!(":{}", n),
202 Segment::Wildcard(n) => format!("*{}", n),
203 }).collect();
204 format!("/{}", parts.join("/"))
205 }
206
207 pub fn handle(&self, request: &Request, connection: &ConnectionInfo) -> Option<Response> {
215 if let Some(required_host) = &self.host {
216 let actual = connection.sni_hostname.as_deref().or_else(|| {
217 request.headers.iter()
218 .find(|h| h.name.eq_ignore_ascii_case("host"))
219 .map(|h| h.value.as_str())
220 });
221 if actual != Some(required_host.as_str()) {
222 return None;
223 }
224 }
225
226 let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
227 let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
228
229 for route in &self.routes {
230 if route.method != request.method {
231 continue;
232 }
233 if let Some(params) = Self::try_match(&route.segments, &path_segs) {
234 return Some((route.handler)(request, ¶ms, connection));
235 }
236 }
237 None
238 }
239
240 fn try_match(pattern: &[Segment], path: &[&str]) -> Option<PathParams> {
241 let mut params = PathParams::new();
242 let mut pi = 0;
243
244 for (si, seg) in pattern.iter().enumerate() {
245 match seg {
246 Segment::Literal(lit) => {
247 if pi >= path.len() || path[pi] != lit.as_str() {
248 return None;
249 }
250 pi += 1;
251 }
252 Segment::Param(name) => {
253 if pi >= path.len() {
254 return None;
255 }
256 params.insert(name.clone(), path[pi].to_string());
257 pi += 1;
258 }
259 Segment::Wildcard(name) => {
260 if si != pattern.len() - 1 {
261 return None; }
263 params.insert(name.clone(), path[pi..].join("/"));
264 pi = path.len();
265 }
266 }
267 }
268
269 if pi == path.len() { Some(params) } else { None }
270 }
271}