1use std::collections::HashMap;
54use std::convert::TryFrom;
55use std::fmt::Debug;
56use std::io::{Read, Write};
57
58pub extern crate http;
59
60pub type Request = http::Request<Vec<u8>>;
62
63pub type Response = http::Response<Vec<u8>>;
65
66fn handle_with_io<F, R, W>(func: F, mut stdin: R, mut stdout: W)
67where
68 F: FnOnce(Request) -> Response,
69 R: Read,
70 W: Write,
71{
72 let env_vars: HashMap<String, String> = std::env::vars().collect();
73
74 let content_length: usize = env_vars
77 .get("CONTENT_LENGTH")
78 .and_then(|cl| cl.parse::<usize>().ok())
79 .unwrap_or(0);
80
81 let mut stdin_contents = vec![0; content_length];
82 stdin.read_exact(&mut stdin_contents).unwrap();
83
84 let request = parse_request(env_vars, stdin_contents);
85
86 let response = func(request);
87
88 let output = serialize_response(response);
89
90 stdout.write_all(&output).unwrap();
91}
92
93fn try_handle_with_io<E, F, R, W, X>(func: F, stdin: R, stdout: W, mut stderr: X)
94where
95 E: Debug,
96 F: FnOnce(Request) -> Result<Response, E>,
97 R: Read,
98 W: Write,
99 X: Write,
100{
101 handle_with_io(
102 |request: Request| match func(request) {
103 Ok(resp) => resp,
104 Err(err) => {
105 writeln!(stderr, "{:?}", err).unwrap_or_else(|_| eprintln!("{:?}", err));
106 empty_response(500)
107 }
108 },
109 stdin,
110 stdout,
111 )
112}
113
114pub fn handle<F>(func: F)
121where
122 F: FnOnce(Request) -> Response,
123{
124 handle_with_io(func, std::io::stdin(), std::io::stdout())
125}
126
127pub fn try_handle<E, F>(func: F)
133where
134 E: Debug,
135 F: FnOnce(Request) -> Result<Response, E>,
136{
137 try_handle_with_io(func, std::io::stdin(), std::io::stdout(), std::io::stderr())
138}
139
140#[macro_export]
141macro_rules! cgi_main {
154 ( $func:expr ) => {
155 fn main() {
156 rust_cgi::handle($func);
157 }
158 };
159}
160
161#[macro_export]
162macro_rules! cgi_try_main {
180 ( $func:expr ) => {
181 fn main() {
182 rust_cgi::try_handle($func);
183 }
184 };
185}
186
187pub fn err_to_500<E>(res: Result<Response, E>) -> Response {
188 res.unwrap_or(empty_response(500))
189}
190
191pub fn empty_response<T>(status_code: T) -> Response
194where
195 http::StatusCode: TryFrom<T>,
196 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
197{
198 http::response::Builder::new()
199 .status(status_code)
200 .body(vec![])
201 .unwrap()
202}
203
204pub fn html_response<T, S>(status_code: T, body: S) -> Response
207where
208 http::StatusCode: TryFrom<T>,
209 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
210 S: Into<String>,
211{
212 let body: Vec<u8> = body.into().into_bytes();
213 http::response::Builder::new()
214 .status(status_code)
215 .header(http::header::CONTENT_TYPE, "text/html; charset=utf-8")
216 .header(
217 http::header::CONTENT_LENGTH,
218 format!("{}", body.len()).as_str(),
219 )
220 .body(body)
221 .unwrap()
222}
223
224pub fn string_response<T, S>(status_code: T, body: S) -> Response
226where
227 http::StatusCode: TryFrom<T>,
228 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
229 S: Into<String>,
230{
231 let body: Vec<u8> = body.into().into_bytes();
232 http::response::Builder::new()
233 .status(status_code)
234 .header(
235 http::header::CONTENT_LENGTH,
236 format!("{}", body.len()).as_str(),
237 )
238 .body(body)
239 .unwrap()
240}
241
242pub fn text_response<T, S>(status_code: T, body: S) -> Response
252where
253 http::StatusCode: TryFrom<T>,
254 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
255 S: Into<String>,
256{
257 let body: Vec<u8> = body.into().into_bytes();
258 http::response::Builder::new()
259 .status(status_code)
260 .header(
261 http::header::CONTENT_LENGTH,
262 format!("{}", body.len()).as_str(),
263 )
264 .header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
265 .body(body)
266 .unwrap()
267}
268
269pub fn binary_response<'a, T>(
290 status_code: T,
291 content_type: impl Into<Option<&'a str>>,
292 body: Vec<u8>,
293) -> Response
294where
295 http::StatusCode: TryFrom<T>,
296 <http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
297{
298 let content_type: Option<&str> = content_type.into();
299
300 let mut response = http::response::Builder::new().status(status_code).header(
301 http::header::CONTENT_LENGTH,
302 format!("{}", body.len()).as_str(),
303 );
304
305 if let Some(ct) = content_type {
306 response = response.header(http::header::CONTENT_TYPE, ct);
307 }
308
309 response.body(body).unwrap()
310}
311
312fn exe_url() -> String {
313 match std::env::current_exe() {
317 Ok(p) => p.to_string_lossy().into_owned(),
318 Err(_) => String::new(),
319 }
320}
321
322fn parse_request(env_vars: HashMap<String, String>, stdin: Vec<u8>) -> Request {
323 let mut req = http::Request::builder();
324
325 req = req.method(env_vars.get("REQUEST_METHOD").map_or("GET", String::as_str));
326 let mut uri = env_vars
327 .get("SCRIPT_NAME")
328 .map_or_else(exe_url, String::clone);
329
330 if env_vars.contains_key("QUERY_STRING") {
331 uri.push_str("?");
332 uri.push_str(&env_vars["QUERY_STRING"]);
333 }
334 req = req.uri(uri.as_str());
335
336 if let Some(v) = env_vars.get("SERVER_PROTOCOL") {
337 if v == "HTTP/0.9" {
338 req = req.version(http::version::Version::HTTP_09);
339 } else if v == "HTTP/1.0" {
340 req = req.version(http::version::Version::HTTP_10);
341 } else if v == "HTTP/1.1" {
342 req = req.version(http::version::Version::HTTP_11);
343 } else if v == "HTTP/2.0" {
344 req = req.version(http::version::Version::HTTP_2);
345 } else {
346 unimplemented!("Unsupport HTTP SERVER_PROTOCOL {:?}", v);
347 }
348 }
349
350 for key in env_vars.keys().filter(|k| k.starts_with("HTTP_")) {
351 let header: String = key
352 .chars()
353 .skip(5)
354 .map(|c| if c == '_' { '-' } else { c })
355 .collect();
356 req = req.header(header.as_str(), env_vars[key].as_str().trim());
357 }
358
359 req = add_header(req, &env_vars, "AUTH_TYPE", "X-CGI-Auth-Type");
360 req = add_header(req, &env_vars, "CONTENT_LENGTH", "X-CGI-Content-Length");
361 req = add_header(req, &env_vars, "CONTENT_TYPE", "X-CGI-Content-Type");
362 req = add_header(
363 req,
364 &env_vars,
365 "GATEWAY_INTERFACE",
366 "X-CGI-Gateway-Interface",
367 );
368 req = add_header(req, &env_vars, "PATH_INFO", "X-CGI-Path-Info");
369 req = add_header(req, &env_vars, "PATH_TRANSLATED", "X-CGI-Path-Translated");
370 req = add_header(req, &env_vars, "QUERY_STRING", "X-CGI-Query-String");
371 req = add_header(req, &env_vars, "REMOTE_ADDR", "X-CGI-Remote-Addr");
372 req = add_header(req, &env_vars, "REMOTE_HOST", "X-CGI-Remote-Host");
373 req = add_header(req, &env_vars, "REMOTE_IDENT", "X-CGI-Remote-Ident");
374 req = add_header(req, &env_vars, "REMOTE_USER", "X-CGI-Remote-User");
375 req = add_header(req, &env_vars, "REQUEST_METHOD", "X-CGI-Request-Method");
376 req = add_header(req, &env_vars, "SCRIPT_NAME", "X-CGI-Script-Name");
377 req = add_header(req, &env_vars, "SERVER_PORT", "X-CGI-Server-Port");
378 req = add_header(req, &env_vars, "SERVER_PROTOCOL", "X-CGI-Server-Protocol");
379 req = add_header(req, &env_vars, "SERVER_SOFTWARE", "X-CGI-Server-Software");
380
381 req.body(stdin).unwrap()
382}
383
384fn add_header(
386 req: http::request::Builder,
387 env_vars: &HashMap<String, String>,
388 meta_var: &str,
389 target_header: &str,
390) -> http::request::Builder {
391 if let Some(var) = env_vars.get(meta_var) {
392 req.header(target_header, var.as_str())
393 } else {
394 req
395 }
396}
397
398fn serialize_response(response: Response) -> Vec<u8> {
400 let mut output = String::new();
401 output.push_str("Status: ");
402 output.push_str(response.status().as_str());
403 if let Some(reason) = response.status().canonical_reason() {
404 output.push_str(" ");
405 output.push_str(reason);
406 }
407 output.push_str("\n");
408
409 {
410 let headers = response.headers();
411 let mut keys: Vec<&http::header::HeaderName> = headers.keys().collect();
412 keys.sort_by_key(|h| h.as_str());
413 for key in keys {
414 output.push_str(key.as_str());
415 output.push_str(": ");
416 output.push_str(headers.get(key).unwrap().to_str().unwrap());
417 output.push_str("\n");
418 }
419 }
420
421 output.push_str("\n");
422
423 let mut output = output.into_bytes();
424
425 let (_, mut body) = response.into_parts();
426
427 output.append(&mut body);
428
429 output
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 fn env(input: Vec<(&str, &str)>) -> HashMap<String, String> {
437 input
438 .into_iter()
439 .map(|(a, b)| (a.to_owned(), b.to_owned()))
440 .collect()
441 }
442
443 #[test]
444 fn test_empty() {
445 let env_vars = env(vec![]);
446 let stdin = Vec::new();
447 let req = parse_request(env_vars, stdin);
448 assert_eq!(req.method(), &http::method::Method::GET);
449 }
453
454 #[test]
455 fn test_parse_request() {
456 let env_vars = env(vec![
457 ("REQUEST_METHOD", "GET"),
458 ("SCRIPT_NAME", "/my/path/script"),
459 ("SERVER_PROTOCOL", "HTTP/1.0"),
460 ("HTTP_USER_AGENT", "MyBrowser/1.0"),
461 ("QUERY_STRING", "foo=bar&baz=bop"),
462 ]);
463 let stdin = Vec::new();
464 let req = parse_request(env_vars, stdin);
465 assert_eq!(req.method(), &http::method::Method::GET);
466 assert_eq!(req.uri(), "/my/path/script?foo=bar&baz=bop");
467 assert_eq!(req.uri().path(), "/my/path/script");
468 assert_eq!(req.uri().query(), Some("foo=bar&baz=bop"));
469 assert_eq!(req.version(), http::version::Version::HTTP_10);
470 assert_eq!(req.headers()[http::header::USER_AGENT], "MyBrowser/1.0");
471 assert_eq!(req.body(), &vec![] as &Vec<u8>);
472 }
473
474 fn test_serialized_response(resp: http::response::Builder, body: &str, expected_output: &str) {
475 let resp: Response = resp.body(String::from(body).into_bytes()).unwrap();
476 let output = serialize_response(resp);
477 let expected_output = String::from(expected_output).into_bytes();
478
479 if output != expected_output {
480 println!(
481 "output: {}\nexptected: {}",
482 std::str::from_utf8(&output).unwrap(),
483 std::str::from_utf8(&expected_output).unwrap()
484 );
485 }
486
487 assert_eq!(output, expected_output);
488 }
489
490 #[test]
491 fn test_serialized_response1() {
492 test_serialized_response(
493 http::Response::builder().status(200),
494 "Hello World",
495 "Status: 200 OK\n\nHello World",
496 );
497
498 test_serialized_response(
499 http::Response::builder().status(200)
500 .header("Content-Type", "text/html")
501 .header("Content-Language", "en")
502 .header("Cache-Control", "max-age=3600"),
503 "<html><body><h1>Hello</h1></body></html>",
504 "Status: 200 OK\ncache-control: max-age=3600\ncontent-language: en\ncontent-type: text/html\n\n<html><body><h1>Hello</h1></body></html>"
505 );
506 }
507
508 #[test]
509 fn test_shortcuts1() {
510 assert_eq!(std::str::from_utf8(&serialize_response(html_response(200, "<html><body><h1>Hello World</h1></body></html>"))).unwrap(),
511 "Status: 200 OK\ncontent-length: 46\ncontent-type: text/html; charset=utf-8\n\n<html><body><h1>Hello World</h1></body></html>"
512 );
513 }
514
515 #[test]
516 fn test_shortcuts2() {
517 assert_eq!(
518 std::str::from_utf8(&serialize_response(binary_response(
519 200,
520 None,
521 vec![65, 66, 67]
522 )))
523 .unwrap(),
524 "Status: 200 OK\ncontent-length: 3\n\nABC"
525 );
526
527 assert_eq!(
528 std::str::from_utf8(&serialize_response(binary_response(
529 200,
530 "application/octet-stream",
531 vec![65, 66, 67]
532 )))
533 .unwrap(),
534 "Status: 200 OK\ncontent-length: 3\ncontent-type: application/octet-stream\n\nABC"
535 );
536
537 let ct: String = "image/png".to_string();
538 assert_eq!(
539 std::str::from_utf8(&serialize_response(binary_response(
540 200,
541 ct.as_str(),
542 vec![65, 66, 67]
543 )))
544 .unwrap(),
545 "Status: 200 OK\ncontent-length: 3\ncontent-type: image/png\n\nABC"
546 );
547 }
548
549 #[test]
550 fn test_handle_success() {
551 let input = std::io::Cursor::new(vec![]);
552 let mut output = std::io::BufWriter::new(Vec::new());
553 let mut error = std::io::BufWriter::new(Vec::new());
554
555 try_handle_with_io(
556 |_req: Request| Ok::<http::Response<Vec<u8>>, String>(text_response(200, "All good")),
557 input,
558 &mut output,
559 &mut error,
560 );
561
562 let written = output.into_inner().unwrap();
563 assert_eq!(String::from_utf8(written).unwrap(), "Status: 200 OK\ncontent-length: 8\ncontent-type: text/plain; charset=utf-8\n\nAll good");
564 assert_eq!(error.into_inner().unwrap().len(), 0);
565 }
566
567 #[test]
568 fn test_handle_error() {
569 let input = std::io::Cursor::new(vec![]);
570 let mut output = std::io::BufWriter::new(Vec::new());
571 let mut error = std::io::BufWriter::new(Vec::new());
572
573 try_handle_with_io(
574 |_req: Request| Err("Not good"),
575 input,
576 &mut output,
577 &mut error,
578 );
579
580 let written = output.into_inner().unwrap();
581 assert_eq!(
582 String::from_utf8(written).unwrap(),
583 "Status: 500 Internal Server Error\n\n"
584 );
585 assert_eq!(
586 String::from_utf8(error.into_inner().unwrap()).unwrap(),
587 "\"Not good\"\n"
588 );
589 }
590}