1use std::io::{self, BufRead, BufReader, Write};
9use std::net::{TcpListener, TcpStream};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{mpsc, Arc, Mutex};
12use std::thread;
13use std::time::Duration;
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub enum Method {
18 Get,
19 Post,
20 Put,
21 Patch,
22 Delete,
23 Head,
24 Options,
25 Other,
26}
27
28impl Method {
29 pub fn parse(s: &str) -> Method {
30 match s {
31 "GET" => Method::Get,
32 "POST" => Method::Post,
33 "PUT" => Method::Put,
34 "PATCH" => Method::Patch,
35 "DELETE" => Method::Delete,
36 "HEAD" => Method::Head,
37 "OPTIONS" => Method::Options,
38 _ => Method::Other,
39 }
40 }
41
42 pub fn as_str(&self) -> &'static str {
43 match self {
44 Method::Get => "GET",
45 Method::Post => "POST",
46 Method::Put => "PUT",
47 Method::Patch => "PATCH",
48 Method::Delete => "DELETE",
49 Method::Head => "HEAD",
50 Method::Options => "OPTIONS",
51 Method::Other => "OTHER",
52 }
53 }
54}
55
56#[derive(Clone, Debug)]
58pub struct Request {
59 pub method: Method,
60 pub path: String,
62 pub query: String,
64 pub version: String,
65 pub headers: Vec<(String, String)>,
66 pub body: Vec<u8>,
67 pub peer: Option<String>,
69}
70
71impl Request {
72 pub fn peer_ip(&self) -> Option<String> {
74 let p = self.peer.as_ref()?;
75 if let Some(end) = p.find(']') {
76 Some(p[..=end].to_string()) } else if let Some(i) = p.rfind(':') {
78 Some(p[..i].to_string()) } else {
80 Some(p.clone())
81 }
82 }
83
84 pub fn header(&self, name: &str) -> Option<&str> {
86 self.headers
87 .iter()
88 .find(|(k, _)| k.eq_ignore_ascii_case(name))
89 .map(|(_, v)| v.as_str())
90 }
91
92 pub fn content_type(&self) -> Option<&str> {
94 self.header("content-type")
95 }
96
97 pub fn is_json(&self) -> bool {
99 self.content_type()
100 .map(|ct| ct.contains("application/json"))
101 .unwrap_or(false)
102 }
103
104 pub fn cookie(&self, name: &str) -> Option<String> {
106 let header = self.header("cookie")?;
107 for pair in header.split(';') {
108 let pair = pair.trim();
109 if let Some((k, v)) = pair.split_once('=') {
110 if k.trim() == name {
111 return Some(v.trim().to_string());
112 }
113 }
114 }
115 None
116 }
117}
118
119pub enum Body {
126 Full(Vec<u8>),
127 Stream(Box<dyn FnOnce(&mut dyn Write) -> io::Result<()> + Send + 'static>),
128}
129
130pub struct Response {
132 pub status: u16,
133 pub headers: Vec<(String, String)>,
134 pub body: Body,
135}
136
137impl Response {
138 pub fn new(status: u16) -> Response {
139 Response {
140 status,
141 headers: Vec::new(),
142 body: Body::Full(Vec::new()),
143 }
144 }
145
146 pub fn with_header(mut self, name: &str, value: &str) -> Response {
147 self.headers.push((name.to_string(), value.to_string()));
148 self
149 }
150
151 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Response {
152 self.body = Body::Full(body.into());
153 self
154 }
155
156 pub fn with_stream(
160 mut self,
161 producer: impl FnOnce(&mut dyn Write) -> io::Result<()> + Send + 'static,
162 ) -> Response {
163 self.body = Body::Stream(Box::new(producer));
164 self
165 }
166
167 pub fn is_stream(&self) -> bool {
169 matches!(self.body, Body::Stream(_))
170 }
171}
172
173pub struct StreamSink<'a> {
176 w: &'a mut dyn Write,
177}
178
179impl<'a> StreamSink<'a> {
180 pub fn new(w: &'a mut dyn Write) -> StreamSink<'a> {
181 StreamSink { w }
182 }
183
184 pub fn write(&mut self, bytes: &[u8]) -> io::Result<()> {
185 self.w.write_all(bytes)?;
186 self.w.flush()
187 }
188
189 pub fn write_str(&mut self, s: &str) -> io::Result<()> {
190 self.write(s.as_bytes())
191 }
192}
193
194pub struct SseSink<'a> {
198 w: &'a mut dyn Write,
199}
200
201impl<'a> SseSink<'a> {
202 pub fn new(w: &'a mut dyn Write) -> SseSink<'a> {
203 SseSink { w }
204 }
205
206 pub fn data(&mut self, data: &str) -> io::Result<()> {
208 for line in data.split('\n') {
209 write!(self.w, "data: {}\n", line)?;
210 }
211 self.w.write_all(b"\n")?;
212 self.w.flush()
213 }
214
215 pub fn event(&mut self, event: &str, data: &str) -> io::Result<()> {
217 write!(self.w, "event: {}\n", event)?;
218 for line in data.split('\n') {
219 write!(self.w, "data: {}\n", line)?;
220 }
221 self.w.write_all(b"\n")?;
222 self.w.flush()
223 }
224
225 pub fn comment(&mut self, text: &str) -> io::Result<()> {
227 write!(self.w, ": {}\n\n", text)?;
228 self.w.flush()
229 }
230
231 pub fn retry(&mut self, millis: u64) -> io::Result<()> {
233 write!(self.w, "retry: {}\n\n", millis)?;
234 self.w.flush()
235 }
236}
237
238#[derive(Clone, Copy, Debug)]
240pub struct Limits {
241 pub max_body: usize,
243 pub max_header_bytes: usize,
245 pub timeout: Option<Duration>,
247}
248
249impl Default for Limits {
250 fn default() -> Limits {
251 Limits {
252 max_body: 2 * 1024 * 1024,
253 max_header_bytes: 64 * 1024,
254 timeout: Some(Duration::from_secs(30)),
255 }
256 }
257}
258
259pub enum Incoming {
261 Request(Request),
262 TooLarge,
263}
264
265pub fn parse_request<R: BufRead>(reader: &mut R, limits: &Limits) -> io::Result<Option<Incoming>> {
270 let mut request_line = String::new();
271 if reader.read_line(&mut request_line)? == 0 {
272 return Ok(None);
273 }
274 let mut parts = request_line.trim_end().split_whitespace();
275 let method = Method::parse(parts.next().unwrap_or(""));
276 let target = parts.next().unwrap_or("/").to_string();
277 let version = parts.next().unwrap_or("HTTP/1.1").to_string();
278
279 let (path, query) = match target.split_once('?') {
280 Some((p, q)) => (p.to_string(), q.to_string()),
281 None => (target, String::new()),
282 };
283
284 let mut headers = Vec::new();
285 let mut content_length = 0usize;
286 let mut header_bytes = request_line.len();
287 loop {
288 let mut line = String::new();
289 let n = reader.read_line(&mut line)?;
290 if n == 0 {
291 break;
292 }
293 header_bytes += n;
294 if header_bytes > limits.max_header_bytes {
295 return Ok(Some(Incoming::TooLarge));
296 }
297 let line = line.trim_end_matches(['\r', '\n']);
298 if line.is_empty() {
299 break;
300 }
301 if let Some((k, v)) = line.split_once(':') {
302 let k = k.trim().to_string();
303 let v = v.trim().to_string();
304 if k.eq_ignore_ascii_case("content-length") {
305 content_length = v.parse().unwrap_or(0);
306 }
307 headers.push((k, v));
308 }
309 }
310
311 if content_length > limits.max_body {
313 return Ok(Some(Incoming::TooLarge));
314 }
315 let mut body = vec![0u8; content_length];
316 if content_length > 0 {
317 reader.read_exact(&mut body)?;
318 }
319
320 Ok(Some(Incoming::Request(Request {
321 method,
322 path,
323 query,
324 version,
325 headers,
326 body,
327 peer: None,
328 })))
329}
330
331pub fn write_response<W: Write>(w: &mut W, resp: Response) -> io::Result<()> {
335 let reason = status_reason(resp.status);
336 let mut head = format!("HTTP/1.1 {} {}\r\n", resp.status, reason);
337 let mut has_content_type = false;
338 for (k, v) in &resp.headers {
339 if k.eq_ignore_ascii_case("content-type") {
340 has_content_type = true;
341 }
342 head.push_str(&format!("{}: {}\r\n", k, v));
343 }
344 if !has_content_type {
345 head.push_str("content-type: text/plain; charset=utf-8\r\n");
346 }
347
348 match resp.body {
349 Body::Full(bytes) => {
350 head.push_str(&format!("content-length: {}\r\n", bytes.len()));
351 head.push_str("connection: close\r\n\r\n");
352 w.write_all(head.as_bytes())?;
353 w.write_all(&bytes)?;
354 w.flush()
355 }
356 Body::Stream(producer) => {
357 head.push_str("connection: close\r\n\r\n");
359 w.write_all(head.as_bytes())?;
360 w.flush()?;
361 producer(w)
362 }
363 }
364}
365
366pub fn status_reason(status: u16) -> &'static str {
368 match status {
369 200 => "OK",
370 201 => "Created",
371 202 => "Accepted",
372 204 => "No Content",
373 301 => "Moved Permanently",
374 302 => "Found",
375 303 => "See Other",
376 304 => "Not Modified",
377 307 => "Temporary Redirect",
378 308 => "Permanent Redirect",
379 400 => "Bad Request",
380 401 => "Unauthorized",
381 403 => "Forbidden",
382 404 => "Not Found",
383 405 => "Method Not Allowed",
384 409 => "Conflict",
385 422 => "Unprocessable Entity",
386 429 => "Too Many Requests",
387 500 => "Internal Server Error",
388 501 => "Not Implemented",
389 502 => "Bad Gateway",
390 503 => "Service Unavailable",
391 504 => "Gateway Timeout",
392 s if (200..300).contains(&s) => "OK",
394 s if (300..400).contains(&s) => "Redirect",
395 s if (400..500).contains(&s) => "Client Error",
396 _ => "Server Error",
397 }
398}
399
400pub fn serve<H>(addr: &str, workers: usize, limits: Limits, handler: H) -> io::Result<()>
403where
404 H: Fn(Request) -> Response + Send + Sync + 'static,
405{
406 let listener = TcpListener::bind(addr)?;
407 let handler = Arc::new(handler);
408 let pool = ThreadPool::new(workers.max(1));
409
410 for stream in listener.incoming() {
411 let stream = match stream {
412 Ok(s) => s,
413 Err(_) => continue,
414 };
415 let handler = Arc::clone(&handler);
416 pool.execute(move || {
417 let _ = handle_connection(stream, &*handler, &limits);
418 });
419 }
420 Ok(())
421}
422
423pub fn serve_until<H>(
429 addr: &str,
430 workers: usize,
431 limits: Limits,
432 handler: H,
433 shutdown: Arc<AtomicBool>,
434) -> io::Result<()>
435where
436 H: Fn(Request) -> Response + Send + Sync + 'static,
437{
438 let listener = TcpListener::bind(addr)?;
439 listener.set_nonblocking(true)?;
440 let handler = Arc::new(handler);
441 let pool = ThreadPool::new(workers.max(1));
442
443 while !shutdown.load(Ordering::Relaxed) {
444 match listener.accept() {
445 Ok((stream, _addr)) => {
446 let _ = stream.set_nonblocking(false);
448 let handler = Arc::clone(&handler);
449 pool.execute(move || {
450 let _ = handle_connection(stream, &*handler, &limits);
451 });
452 }
453 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
454 thread::sleep(Duration::from_millis(50));
455 }
456 Err(_) => {}
457 }
458 }
459
460 drop(pool);
463 Ok(())
464}
465
466fn handle_connection<H>(stream: TcpStream, handler: &H, limits: &Limits) -> io::Result<()>
467where
468 H: Fn(Request) -> Response,
469{
470 if let Some(t) = limits.timeout {
472 let _ = stream.set_read_timeout(Some(t));
473 let _ = stream.set_write_timeout(Some(t));
474 }
475 let peer = stream.peer_addr().ok().map(|a| a.to_string());
476 let mut reader = BufReader::new(stream.try_clone()?);
477
478 match parse_request(&mut reader, limits)? {
479 Some(Incoming::Request(mut req)) => {
480 req.peer = peer;
481 let resp = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| handler(req)))
484 .unwrap_or_else(|_| Response::new(500).with_body(&b"500 Internal Server Error"[..]));
485 let mut writer = stream;
486 write_response(&mut writer, resp)?;
487 }
488 Some(Incoming::TooLarge) => {
489 let mut writer = stream;
490 write_response(&mut writer, Response::new(413).with_body(&b"413 Payload Too Large"[..]))?;
491 }
492 None => {}
493 }
494 Ok(())
495}
496
497type Job = Box<dyn FnOnce() + Send + 'static>;
500
501pub struct ThreadPool {
503 sender: Option<mpsc::Sender<Job>>,
504 workers: Vec<thread::JoinHandle<()>>,
505}
506
507impl ThreadPool {
508 pub fn new(size: usize) -> ThreadPool {
509 let (sender, receiver) = mpsc::channel::<Job>();
510 let receiver = Arc::new(Mutex::new(receiver));
511 let mut workers = Vec::with_capacity(size);
512 for _ in 0..size {
513 let receiver = Arc::clone(&receiver);
514 workers.push(thread::spawn(move || loop {
515 let job = receiver.lock().unwrap().recv();
516 match job {
517 Ok(job) => job(),
518 Err(_) => break, }
520 }));
521 }
522 ThreadPool {
523 sender: Some(sender),
524 workers,
525 }
526 }
527
528 pub fn execute<F>(&self, f: F)
529 where
530 F: FnOnce() + Send + 'static,
531 {
532 if let Some(sender) = &self.sender {
533 let _ = sender.send(Box::new(f));
534 }
535 }
536}
537
538impl Drop for ThreadPool {
539 fn drop(&mut self) {
540 drop(self.sender.take());
542 for worker in self.workers.drain(..) {
543 let _ = worker.join();
544 }
545 }
546}
547
548#[cfg(test)]
549mod tests {
550 use super::*;
551
552 #[test]
553 fn parses_request_with_body() {
554 let raw = "POST /todos?x=1 HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nhello";
555 let mut reader = BufReader::new(raw.as_bytes());
556 let req = match parse_request(&mut reader, &Limits::default()).unwrap().unwrap() {
557 Incoming::Request(r) => r,
558 Incoming::TooLarge => panic!("unexpected 413"),
559 };
560 assert_eq!(req.method, Method::Post);
561 assert_eq!(req.path, "/todos");
562 assert_eq!(req.query, "x=1");
563 assert_eq!(req.body, b"hello");
564 assert_eq!(req.header("host"), Some("localhost"));
565 }
566
567 #[test]
568 fn writes_response_with_default_content_type() {
569 let resp = Response::new(200).with_body("hi");
570 let mut buf = Vec::new();
571 write_response(&mut buf, resp).unwrap();
572 let s = String::from_utf8(buf).unwrap();
573 assert!(s.starts_with("HTTP/1.1 200 OK\r\n"));
574 assert!(s.contains("content-length: 2\r\n"));
575 assert!(s.ends_with("\r\n\r\nhi"));
576 }
577
578 #[test]
579 fn rejects_oversized_body() {
580 let raw = "POST / HTTP/1.1\r\nContent-Length: 5000\r\n\r\n";
581 let mut reader = BufReader::new(raw.as_bytes());
582 let limits = Limits { max_body: 100, ..Limits::default() };
583 assert!(matches!(
584 parse_request(&mut reader, &limits).unwrap(),
585 Some(Incoming::TooLarge)
586 ));
587 }
588
589 #[test]
590 fn request_content_type_and_cookies() {
591 let raw = "GET / HTTP/1.1\r\nContent-Type: application/json\r\nCookie: sid=abc; theme=dark\r\n\r\n";
592 let mut reader = BufReader::new(raw.as_bytes());
593 let req = match parse_request(&mut reader, &Limits::default()).unwrap().unwrap() {
594 Incoming::Request(r) => r,
595 Incoming::TooLarge => panic!("unexpected 413"),
596 };
597 assert!(req.is_json());
598 assert_eq!(req.cookie("sid").as_deref(), Some("abc"));
599 assert_eq!(req.cookie("theme").as_deref(), Some("dark"));
600 assert_eq!(req.cookie("missing"), None);
601 }
602
603 #[test]
604 fn streams_without_content_length() {
605 let resp = Response::new(200)
606 .with_header("content-type", "text/event-stream")
607 .with_stream(|w| {
608 let mut sink = SseSink::new(w);
609 sink.data("one")?;
610 sink.data("two")?;
611 Ok(())
612 });
613 let mut buf = Vec::new();
614 write_response(&mut buf, resp).unwrap();
615 let s = String::from_utf8(buf).unwrap();
616 assert!(s.contains("content-type: text/event-stream\r\n"));
617 assert!(!s.to_lowercase().contains("content-length"));
618 assert!(s.contains("data: one\n\n"));
619 assert!(s.contains("data: two\n\n"));
620 }
621}