rust_web_server/router/
mod.rs1#[cfg(test)]
2mod tests;
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use crate::request::Request;
8use crate::response::Response;
9use crate::server::ConnectionInfo;
10
11pub struct PathParams {
20 params: HashMap<String, String>,
21}
22
23impl PathParams {
24 fn new() -> Self {
25 PathParams { params: HashMap::new() }
26 }
27
28 pub(crate) fn from_map(params: HashMap<String, String>) -> Self {
30 PathParams { params }
31 }
32
33 pub fn get(&self, name: &str) -> Option<&str> {
35 self.params.get(name).map(String::as_str)
36 }
37
38 fn insert(&mut self, key: String, value: String) {
39 self.params.insert(key, value);
40 }
41}
42
43#[derive(Clone)]
44enum Segment {
45 Literal(String),
46 Param(String),
47 Wildcard(String),
48}
49
50type HandlerFn =
51 Arc<dyn Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static>;
52
53#[derive(Clone)]
54struct Route {
55 method: String,
56 segments: Vec<Segment>,
57 handler: HandlerFn,
58}
59
60#[derive(Clone)]
99pub struct Router {
100 routes: Vec<Route>,
101}
102
103impl Router {
104 pub fn new() -> Self {
105 Router { routes: Vec::new() }
106 }
107
108 pub fn get<F>(self, pattern: &str, handler: F) -> Self
110 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
111 self.add("GET", pattern, handler)
112 }
113
114 pub fn post<F>(self, pattern: &str, handler: F) -> Self
116 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
117 self.add("POST", pattern, handler)
118 }
119
120 pub fn put<F>(self, pattern: &str, handler: F) -> Self
122 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
123 self.add("PUT", pattern, handler)
124 }
125
126 pub fn patch<F>(self, pattern: &str, handler: F) -> Self
128 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
129 self.add("PATCH", pattern, handler)
130 }
131
132 pub fn delete<F>(self, pattern: &str, handler: F) -> Self
134 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
135 self.add("DELETE", pattern, handler)
136 }
137
138 fn add<F>(mut self, method: &str, pattern: &str, handler: F) -> Self
139 where F: Fn(&Request, &PathParams, &ConnectionInfo) -> Response + Send + Sync + 'static {
140 self.routes.push(Route {
141 method: method.to_string(),
142 segments: Self::parse_pattern(pattern),
143 handler: Arc::new(handler),
144 });
145 self
146 }
147
148 fn parse_pattern(pattern: &str) -> Vec<Segment> {
149 if pattern == "/" {
150 return vec![];
151 }
152 pattern
153 .split('/')
154 .filter(|s| !s.is_empty())
155 .map(|seg| {
156 if let Some(name) = seg.strip_prefix(':') {
157 Segment::Param(name.to_string())
158 } else if let Some(name) = seg.strip_prefix('*') {
159 Segment::Wildcard(name.to_string())
160 } else {
161 Segment::Literal(seg.to_string())
162 }
163 })
164 .collect()
165 }
166
167 pub fn handle(&self, request: &Request, connection: &ConnectionInfo) -> Option<Response> {
172 let path = request.request_uri.split('?').next().unwrap_or(&request.request_uri);
173 let path_segs: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
174
175 for route in &self.routes {
176 if route.method != request.method {
177 continue;
178 }
179 if let Some(params) = Self::try_match(&route.segments, &path_segs) {
180 return Some((route.handler)(request, ¶ms, connection));
181 }
182 }
183 None
184 }
185
186 fn try_match(pattern: &[Segment], path: &[&str]) -> Option<PathParams> {
187 let mut params = PathParams::new();
188 let mut pi = 0;
189
190 for (si, seg) in pattern.iter().enumerate() {
191 match seg {
192 Segment::Literal(lit) => {
193 if pi >= path.len() || path[pi] != lit.as_str() {
194 return None;
195 }
196 pi += 1;
197 }
198 Segment::Param(name) => {
199 if pi >= path.len() {
200 return None;
201 }
202 params.insert(name.clone(), path[pi].to_string());
203 pi += 1;
204 }
205 Segment::Wildcard(name) => {
206 if si != pattern.len() - 1 {
207 return None; }
209 params.insert(name.clone(), path[pi..].join("/"));
210 pi = path.len();
211 }
212 }
213 }
214
215 if pi == path.len() { Some(params) } else { None }
216 }
217}