1use std::io::{self, BufRead, BufReader, Write};
9use std::net::{TcpListener, TcpStream};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{mpsc, Arc, Mutex};
12use std::thread;
13use std::time::Duration;
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub enum Method {
18 Get,
19 Post,
20 Put,
21 Patch,
22 Delete,
23 Head,
24 Options,
25 Other,
26}
27
28impl Method {
29 pub fn parse(s: &str) -> Method {
30 match s {
31 "GET" => Method::Get,
32 "POST" => Method::Post,
33 "PUT" => Method::Put,
34 "PATCH" => Method::Patch,
35 "DELETE" => Method::Delete,
36 "HEAD" => Method::Head,
37 "OPTIONS" => Method::Options,
38 _ => Method::Other,
39 }
40 }
41
42 pub fn as_str(&self) -> &'static str {
43 match self {
44 Method::Get => "GET",
45 Method::Post => "POST",
46 Method::Put => "PUT",
47 Method::Patch => "PATCH",
48 Method::Delete => "DELETE",
49 Method::Head => "HEAD",
50 Method::Options => "OPTIONS",
51 Method::Other => "OTHER",
52 }
53 }
54}
55
56#[derive(Clone, Debug)]
58pub struct Request {
59 pub method: Method,
60 pub path: String,
62 pub query: String,
64 pub version: String,
65 pub headers: Vec<(String, String)>,
66 pub body: Vec<u8>,
67 pub peer: Option<String>,
69}
70
71impl Request {
72 pub fn peer_ip(&self) -> Option<String> {
74 let p = self.peer.as_ref()?;
75 if let Some(end) = p.find(']') {
76 Some(p[..=end].to_string()) } else if let Some(i) = p.rfind(':') {
78 Some(p[..i].to_string()) } else {
80 Some(p.clone())
81 }
82 }
83
84 pub fn header(&self, name: &str) -> Option<&str> {
86 self.headers
87 .iter()
88 .find(|(k, _)| k.eq_ignore_ascii_case(name))
89 .map(|(_, v)| v.as_str())
90 }
91
92 pub fn content_type(&self) -> Option<&str> {
94 self.header("content-type")
95 }
96
97 pub fn is_json(&self) -> bool {
99 self.content_type()
100 .map(|ct| ct.contains("application/json"))
101 .unwrap_or(false)
102 }
103
104 pub fn cookie(&self, name: &str) -> Option<String> {
106 let header = self.header("cookie")?;
107 for pair in header.split(';') {
108 let pair = pair.trim();
109 if let Some((k, v)) = pair.split_once('=') {
110 if k.trim() == name {
111 return Some(v.trim().to_string());
112 }
113 }
114 }
115 None
116 }
117}
118
119pub type StreamProducer = Box<dyn FnOnce(&mut dyn Write) -> io::Result<()> + Send + 'static>;
129
130pub enum Body {
131 Full(Vec<u8>),
132 Stream(StreamProducer),
133}
134
135pub struct Response {
137 pub status: u16,
138 pub headers: Vec<(String, String)>,
139 pub body: Body,
140}
141
142impl Response {
143 pub fn new(status: u16) -> Response {
144 Response {
145 status,
146 headers: Vec::new(),
147 body: Body::Full(Vec::new()),
148 }
149 }
150
151 pub fn with_header(mut self, name: &str, value: &str) -> Response {
152 self.headers.push((name.to_string(), value.to_string()));
153 self
154 }
155
156 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Response {
157 self.body = Body::Full(body.into());
158 self
159 }
160
161 pub fn with_stream(
165 mut self,
166 producer: impl FnOnce(&mut dyn Write) -> io::Result<()> + Send + 'static,
167 ) -> Response {
168 self.body = Body::Stream(Box::new(producer));
169 self
170 }
171
172 pub fn is_stream(&self) -> bool {
174 matches!(self.body, Body::Stream(_))
175 }
176}
177
178pub struct StreamSink<'a> {
181 w: &'a mut dyn Write,
182}
183
184impl<'a> StreamSink<'a> {
185 pub fn new(w: &'a mut dyn Write) -> StreamSink<'a> {
186 StreamSink { w }
187 }
188
189 pub fn write(&mut self, bytes: &[u8]) -> io::Result<()> {
190 self.w.write_all(bytes)?;
191 self.w.flush()
192 }
193
194 pub fn write_str(&mut self, s: &str) -> io::Result<()> {
195 self.write(s.as_bytes())
196 }
197}
198
199pub struct SseSink<'a> {
203 w: &'a mut dyn Write,
204}
205
206impl<'a> SseSink<'a> {
207 pub fn new(w: &'a mut dyn Write) -> SseSink<'a> {
208 SseSink { w }
209 }
210
211 pub fn data(&mut self, data: &str) -> io::Result<()> {
213 for line in data.split('\n') {
214 writeln!(self.w, "data: {}", line)?;
215 }
216 self.w.write_all(b"\n")?;
217 self.w.flush()
218 }
219
220 pub fn event(&mut self, event: &str, data: &str) -> io::Result<()> {
222 writeln!(self.w, "event: {}", event)?;
223 for line in data.split('\n') {
224 writeln!(self.w, "data: {}", line)?;
225 }
226 self.w.write_all(b"\n")?;
227 self.w.flush()
228 }
229
230 pub fn comment(&mut self, text: &str) -> io::Result<()> {
232 write!(self.w, ": {}\n\n", text)?;
233 self.w.flush()
234 }
235
236 pub fn retry(&mut self, millis: u64) -> io::Result<()> {
238 write!(self.w, "retry: {}\n\n", millis)?;
239 self.w.flush()
240 }
241}
242
243#[derive(Clone, Copy, Debug)]
245pub struct Limits {
246 pub max_body: usize,
248 pub max_header_bytes: usize,
250 pub timeout: Option<Duration>,
252}
253
254impl Default for Limits {
255 fn default() -> Limits {
256 Limits {
257 max_body: 2 * 1024 * 1024,
258 max_header_bytes: 64 * 1024,
259 timeout: Some(Duration::from_secs(30)),
260 }
261 }
262}
263
264pub enum Incoming {
266 Request(Request),
267 TooLarge,
268}
269
270pub fn parse_request<R: BufRead>(reader: &mut R, limits: &Limits) -> io::Result<Option<Incoming>> {
275 let mut request_line = String::new();
276 if reader.read_line(&mut request_line)? == 0 {
277 return Ok(None);
278 }
279 let mut parts = request_line.split_whitespace();
280 let method = Method::parse(parts.next().unwrap_or(""));
281 let target = parts.next().unwrap_or("/").to_string();
282 let version = parts.next().unwrap_or("HTTP/1.1").to_string();
283
284 let (path, query) = match target.split_once('?') {
285 Some((p, q)) => (p.to_string(), q.to_string()),
286 None => (target, String::new()),
287 };
288
289 let mut headers = Vec::new();
290 let mut content_length = 0usize;
291 let mut header_bytes = request_line.len();
292 loop {
293 let mut line = String::new();
294 let n = reader.read_line(&mut line)?;
295 if n == 0 {
296 break;
297 }
298 header_bytes += n;
299 if header_bytes > limits.max_header_bytes {
300 return Ok(Some(Incoming::TooLarge));
301 }
302 let line = line.trim_end_matches(['\r', '\n']);
303 if line.is_empty() {
304 break;
305 }
306 if let Some((k, v)) = line.split_once(':') {
307 let k = k.trim().to_string();
308 let v = v.trim().to_string();
309 if k.eq_ignore_ascii_case("content-length") {
310 content_length = v.parse().unwrap_or(0);
311 }
312 headers.push((k, v));
313 }
314 }
315
316 if content_length > limits.max_body {
318 return Ok(Some(Incoming::TooLarge));
319 }
320 let mut body = vec![0u8; content_length];
321 if content_length > 0 {
322 reader.read_exact(&mut body)?;
323 }
324
325 Ok(Some(Incoming::Request(Request {
326 method,
327 path,
328 query,
329 version,
330 headers,
331 body,
332 peer: None,
333 })))
334}
335
336pub fn write_response<W: Write>(w: &mut W, resp: Response) -> io::Result<()> {
340 let reason = status_reason(resp.status);
341 let mut head = format!("HTTP/1.1 {} {}\r\n", resp.status, reason);
342 let mut has_content_type = false;
343 for (k, v) in &resp.headers {
344 if k.eq_ignore_ascii_case("content-type") {
345 has_content_type = true;
346 }
347 head.push_str(&format!("{}: {}\r\n", k, v));
348 }
349 if !has_content_type {
350 head.push_str("content-type: text/plain; charset=utf-8\r\n");
351 }
352
353 match resp.body {
354 Body::Full(bytes) => {
355 head.push_str(&format!("content-length: {}\r\n", bytes.len()));
356 head.push_str("connection: close\r\n\r\n");
357 w.write_all(head.as_bytes())?;
358 w.write_all(&bytes)?;
359 w.flush()
360 }
361 Body::Stream(producer) => {
362 head.push_str("connection: close\r\n\r\n");
364 w.write_all(head.as_bytes())?;
365 w.flush()?;
366 producer(w)
367 }
368 }
369}
370
371pub fn status_reason(status: u16) -> &'static str {
373 match status {
374 200 => "OK",
375 201 => "Created",
376 202 => "Accepted",
377 204 => "No Content",
378 301 => "Moved Permanently",
379 302 => "Found",
380 303 => "See Other",
381 304 => "Not Modified",
382 307 => "Temporary Redirect",
383 308 => "Permanent Redirect",
384 400 => "Bad Request",
385 401 => "Unauthorized",
386 403 => "Forbidden",
387 404 => "Not Found",
388 405 => "Method Not Allowed",
389 409 => "Conflict",
390 422 => "Unprocessable Entity",
391 429 => "Too Many Requests",
392 500 => "Internal Server Error",
393 501 => "Not Implemented",
394 502 => "Bad Gateway",
395 503 => "Service Unavailable",
396 504 => "Gateway Timeout",
397 s if (200..300).contains(&s) => "OK",
399 s if (300..400).contains(&s) => "Redirect",
400 s if (400..500).contains(&s) => "Client Error",
401 _ => "Server Error",
402 }
403}
404
405pub fn serve<H>(addr: &str, workers: usize, limits: Limits, handler: H) -> io::Result<()>
408where
409 H: Fn(Request) -> Response + Send + Sync + 'static,
410{
411 let listener = TcpListener::bind(addr)?;
412 let handler = Arc::new(handler);
413 let pool = ThreadPool::new(workers.max(1));
414
415 for stream in listener.incoming() {
416 let stream = match stream {
417 Ok(s) => s,
418 Err(_) => continue,
419 };
420 let handler = Arc::clone(&handler);
421 pool.execute(move || {
422 let _ = handle_connection(stream, &*handler, &limits);
423 });
424 }
425 Ok(())
426}
427
428pub fn serve_until<H>(
434 addr: &str,
435 workers: usize,
436 limits: Limits,
437 handler: H,
438 shutdown: Arc<AtomicBool>,
439) -> io::Result<()>
440where
441 H: Fn(Request) -> Response + Send + Sync + 'static,
442{
443 let listener = TcpListener::bind(addr)?;
444 listener.set_nonblocking(true)?;
445 let handler = Arc::new(handler);
446 let pool = ThreadPool::new(workers.max(1));
447
448 while !shutdown.load(Ordering::Relaxed) {
449 match listener.accept() {
450 Ok((stream, _addr)) => {
451 let _ = stream.set_nonblocking(false);
453 let handler = Arc::clone(&handler);
454 pool.execute(move || {
455 let _ = handle_connection(stream, &*handler, &limits);
456 });
457 }
458 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
459 thread::sleep(Duration::from_millis(50));
460 }
461 Err(_) => {}
462 }
463 }
464
465 drop(pool);
468 Ok(())
469}
470
471fn handle_connection<H>(stream: TcpStream, handler: &H, limits: &Limits) -> io::Result<()>
472where
473 H: Fn(Request) -> Response,
474{
475 if let Some(t) = limits.timeout {
477 let _ = stream.set_read_timeout(Some(t));
478 let _ = stream.set_write_timeout(Some(t));
479 }
480 let peer = stream.peer_addr().ok().map(|a| a.to_string());
481 let mut reader = BufReader::new(stream.try_clone()?);
482
483 match parse_request(&mut reader, limits)? {
484 Some(Incoming::Request(mut req)) => {
485 req.peer = peer;
486 let resp = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| handler(req)))
489 .unwrap_or_else(|_| {
490 Response::new(500).with_body(&b"500 Internal Server Error"[..])
491 });
492 let mut writer = stream;
493 write_response(&mut writer, resp)?;
494 }
495 Some(Incoming::TooLarge) => {
496 let mut writer = stream;
497 write_response(
498 &mut writer,
499 Response::new(413).with_body(&b"413 Payload Too Large"[..]),
500 )?;
501 }
502 None => {}
503 }
504 Ok(())
505}
506
507type Job = Box<dyn FnOnce() + Send + 'static>;
510
511pub struct ThreadPool {
513 sender: Option<mpsc::Sender<Job>>,
514 workers: Vec<thread::JoinHandle<()>>,
515}
516
517impl ThreadPool {
518 pub fn new(size: usize) -> ThreadPool {
519 let (sender, receiver) = mpsc::channel::<Job>();
520 let receiver = Arc::new(Mutex::new(receiver));
521 let mut workers = Vec::with_capacity(size);
522 for _ in 0..size {
523 let receiver = Arc::clone(&receiver);
524 workers.push(thread::spawn(move || loop {
525 let job = receiver.lock().unwrap().recv();
526 match job {
527 Ok(job) => job(),
528 Err(_) => break, }
530 }));
531 }
532 ThreadPool {
533 sender: Some(sender),
534 workers,
535 }
536 }
537
538 pub fn execute<F>(&self, f: F)
539 where
540 F: FnOnce() + Send + 'static,
541 {
542 if let Some(sender) = &self.sender {
543 let _ = sender.send(Box::new(f));
544 }
545 }
546}
547
548impl Drop for ThreadPool {
549 fn drop(&mut self) {
550 drop(self.sender.take());
552 for worker in self.workers.drain(..) {
553 let _ = worker.join();
554 }
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
563 fn parses_request_with_body() {
564 let raw = "POST /todos?x=1 HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nhello";
565 let mut reader = BufReader::new(raw.as_bytes());
566 let req = match parse_request(&mut reader, &Limits::default())
567 .unwrap()
568 .unwrap()
569 {
570 Incoming::Request(r) => r,
571 Incoming::TooLarge => panic!("unexpected 413"),
572 };
573 assert_eq!(req.method, Method::Post);
574 assert_eq!(req.path, "/todos");
575 assert_eq!(req.query, "x=1");
576 assert_eq!(req.body, b"hello");
577 assert_eq!(req.header("host"), Some("localhost"));
578 }
579
580 #[test]
581 fn writes_response_with_default_content_type() {
582 let resp = Response::new(200).with_body("hi");
583 let mut buf = Vec::new();
584 write_response(&mut buf, resp).unwrap();
585 let s = String::from_utf8(buf).unwrap();
586 assert!(s.starts_with("HTTP/1.1 200 OK\r\n"));
587 assert!(s.contains("content-length: 2\r\n"));
588 assert!(s.ends_with("\r\n\r\nhi"));
589 }
590
591 #[test]
592 fn rejects_oversized_body() {
593 let raw = "POST / HTTP/1.1\r\nContent-Length: 5000\r\n\r\n";
594 let mut reader = BufReader::new(raw.as_bytes());
595 let limits = Limits {
596 max_body: 100,
597 ..Limits::default()
598 };
599 assert!(matches!(
600 parse_request(&mut reader, &limits).unwrap(),
601 Some(Incoming::TooLarge)
602 ));
603 }
604
605 #[test]
606 fn request_content_type_and_cookies() {
607 let raw = "GET / HTTP/1.1\r\nContent-Type: application/json\r\nCookie: sid=abc; theme=dark\r\n\r\n";
608 let mut reader = BufReader::new(raw.as_bytes());
609 let req = match parse_request(&mut reader, &Limits::default())
610 .unwrap()
611 .unwrap()
612 {
613 Incoming::Request(r) => r,
614 Incoming::TooLarge => panic!("unexpected 413"),
615 };
616 assert!(req.is_json());
617 assert_eq!(req.cookie("sid").as_deref(), Some("abc"));
618 assert_eq!(req.cookie("theme").as_deref(), Some("dark"));
619 assert_eq!(req.cookie("missing"), None);
620 }
621
622 #[test]
623 fn streams_without_content_length() {
624 let resp = Response::new(200)
625 .with_header("content-type", "text/event-stream")
626 .with_stream(|w| {
627 let mut sink = SseSink::new(w);
628 sink.data("one")?;
629 sink.data("two")?;
630 Ok(())
631 });
632 let mut buf = Vec::new();
633 write_response(&mut buf, resp).unwrap();
634 let s = String::from_utf8(buf).unwrap();
635 assert!(s.contains("content-type: text/event-stream\r\n"));
636 assert!(!s.to_lowercase().contains("content-length"));
637 assert!(s.contains("data: one\n\n"));
638 assert!(s.contains("data: two\n\n"));
639 }
640
641 #[test]
642 fn method_parse_and_as_str_roundtrip() {
643 for (s, m) in [
644 ("GET", Method::Get),
645 ("POST", Method::Post),
646 ("PUT", Method::Put),
647 ("PATCH", Method::Patch),
648 ("DELETE", Method::Delete),
649 ("HEAD", Method::Head),
650 ("OPTIONS", Method::Options),
651 ] {
652 assert_eq!(Method::parse(s), m);
653 assert_eq!(m.as_str(), s);
654 }
655 assert_eq!(Method::parse("BREW"), Method::Other);
656 }
657
658 #[test]
659 fn status_reason_known_and_fallbacks() {
660 assert_eq!(status_reason(200), "OK");
661 assert_eq!(status_reason(404), "Not Found");
662 assert_eq!(status_reason(422), "Unprocessable Entity");
663 assert_eq!(status_reason(299), "OK");
665 assert_eq!(status_reason(399), "Redirect");
666 assert_eq!(status_reason(418), "Client Error");
667 assert_eq!(status_reason(599), "Server Error");
668 }
669
670 #[test]
671 fn peer_ip_handles_ipv4_and_ipv6() {
672 let mk = |peer: &str| Request {
673 method: Method::Get,
674 path: "/".into(),
675 query: String::new(),
676 version: "HTTP/1.1".into(),
677 headers: vec![],
678 body: vec![],
679 peer: Some(peer.into()),
680 };
681 assert_eq!(mk("1.2.3.4:55000").peer_ip().as_deref(), Some("1.2.3.4"));
682 assert_eq!(mk("[::1]:8080").peer_ip().as_deref(), Some("[::1]"));
683 let mut r = mk("x");
685 r.peer = None;
686 assert_eq!(r.peer_ip(), None);
687 }
688
689 #[test]
690 fn rejects_oversized_headers() {
691 let big = "X-Pad: ".to_string() + &"a".repeat(500) + "\r\n";
692 let raw = format!("GET / HTTP/1.1\r\n{}\r\n", big);
693 let mut reader = BufReader::new(raw.as_bytes());
694 let limits = Limits {
695 max_header_bytes: 100,
696 ..Limits::default()
697 };
698 assert!(matches!(
699 parse_request(&mut reader, &limits).unwrap(),
700 Some(Incoming::TooLarge)
701 ));
702 }
703
704 #[test]
705 fn empty_stream_yields_none() {
706 let mut reader = BufReader::new(&b""[..]);
707 assert!(parse_request(&mut reader, &Limits::default())
708 .unwrap()
709 .is_none());
710 }
711
712 #[test]
713 fn explicit_content_type_is_not_overridden() {
714 let resp = Response::new(200)
715 .with_header("content-type", "application/json")
716 .with_body("{}");
717 let mut buf = Vec::new();
718 write_response(&mut buf, resp).unwrap();
719 let s = String::from_utf8(buf).unwrap();
720 assert!(s.contains("content-type: application/json\r\n"));
721 assert!(!s.contains("text/plain"));
723 }
724
725 #[test]
726 fn stream_sink_writes_and_flushes() {
727 let mut buf = Vec::new();
728 {
729 let mut sink = StreamSink::new(&mut buf);
730 sink.write_str("chunk-").unwrap();
731 sink.write(b"bytes").unwrap();
732 }
733 assert_eq!(buf, b"chunk-bytes");
734 }
735
736 #[test]
737 fn sse_named_event_and_comment() {
738 let mut buf = Vec::new();
739 {
740 let mut sink = SseSink::new(&mut buf);
741 sink.event("tick", "1\n2").unwrap(); sink.comment("keep-alive").unwrap();
743 sink.retry(3000).unwrap();
744 }
745 let s = String::from_utf8(buf).unwrap();
746 assert!(s.contains("event: tick\n"));
747 assert!(s.contains("data: 1\ndata: 2\n\n"));
748 assert!(s.contains(": keep-alive\n\n"));
749 assert!(s.contains("retry: 3000\n\n"));
750 }
751
752 #[test]
753 fn thread_pool_runs_jobs() {
754 use std::sync::atomic::{AtomicU32, Ordering};
755 let count = Arc::new(AtomicU32::new(0));
756 {
757 let pool = ThreadPool::new(3);
758 for _ in 0..30 {
759 let c = Arc::clone(&count);
760 pool.execute(move || {
761 c.fetch_add(1, Ordering::Relaxed);
762 });
763 }
764 }
766 assert_eq!(count.load(Ordering::Relaxed), 30);
767 }
768}