use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::{stdin, Read, Write};
pub extern crate http;
pub type Request = http::Request<Vec<u8>>;
pub type Response = http::Response<Vec<u8>>;
pub fn handle<F>(func: F)
where
F: FnOnce(Request) -> Response,
{
let env_vars: HashMap<String, String> = std::env::vars().collect();
let content_length: usize = env_vars
.get("CONTENT_LENGTH")
.and_then(|cl| cl.parse::<usize>().ok())
.unwrap_or(0);
let mut stdin_contents = vec![0; content_length];
stdin().read_exact(&mut stdin_contents).unwrap();
let request = parse_request(env_vars, stdin_contents);
let response = func(request);
let output = serialize_response(response);
std::io::stdout().write_all(&output).unwrap();
}
#[macro_export]
macro_rules! cgi_main {
( $func:expr ) => {
fn main() {
rust_cgi::handle($func);
}
};
}
#[macro_export]
macro_rules! cgi_try_main {
( $func:expr ) => {
fn main() {
rust_cgi::handle(|request: rust_cgi::Request| match $func(request) {
Ok(resp) => resp,
Err(err) => {
eprintln!("{:?}", err);
rust_cgi::empty_response(500)
}
})
}
};
}
pub fn err_to_500<E>(res: Result<Response, E>) -> Response {
res.unwrap_or(empty_response(500))
}
pub fn empty_response<T>(status_code: T) -> Response
where
http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
{
http::response::Builder::new()
.status(status_code)
.body(vec![])
.unwrap()
}
pub fn html_response<T, S>(status_code: T, body: S) -> Response
where
http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
S: Into<String>,
{
let body: Vec<u8> = body.into().into_bytes();
http::response::Builder::new()
.status(status_code)
.header(http::header::CONTENT_TYPE, "text/html; charset=utf-8")
.header(
http::header::CONTENT_LENGTH,
format!("{}", body.len()).as_str(),
)
.body(body)
.unwrap()
}
pub fn string_response<T, S>(status_code: T, body: S) -> Response
where
http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
S: Into<String>,
{
let body: Vec<u8> = body.into().into_bytes();
http::response::Builder::new()
.status(status_code)
.header(
http::header::CONTENT_LENGTH,
format!("{}", body.len()).as_str(),
)
.body(body)
.unwrap()
}
pub fn text_response<T, S>(status_code: T, body: S) -> Response
where
http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
S: Into<String>,
{
let body: Vec<u8> = body.into().into_bytes();
http::response::Builder::new()
.status(status_code)
.header(
http::header::CONTENT_LENGTH,
format!("{}", body.len()).as_str(),
)
.header(http::header::CONTENT_TYPE, "text/plain; charset=utf-8")
.body(body)
.unwrap()
}
pub fn binary_response<'a, T>(
status_code: T,
content_type: impl Into<Option<&'a str>>,
body: Vec<u8>,
) -> Response
where
http::StatusCode: TryFrom<T>,
<http::StatusCode as TryFrom<T>>::Error: Into<http::Error>,
{
let content_type: Option<&str> = content_type.into();
let mut response = http::response::Builder::new().status(status_code).header(
http::header::CONTENT_LENGTH,
format!("{}", body.len()).as_str(),
);
if let Some(ct) = content_type {
response = response.header(http::header::CONTENT_TYPE, ct);
}
response.body(body).unwrap()
}
fn exe_url() -> String {
match std::env::current_exe() {
Ok(p) => p.to_string_lossy().into_owned(),
Err(_) => String::new(),
}
}
fn parse_request(env_vars: HashMap<String, String>, stdin: Vec<u8>) -> Request {
let mut req = http::Request::builder();
req = req.method(env_vars.get("REQUEST_METHOD").map_or("GET", String::as_str));
let mut uri = env_vars
.get("SCRIPT_NAME")
.map_or_else(exe_url, String::clone);
if env_vars.contains_key("QUERY_STRING") {
uri.push_str("?");
uri.push_str(&env_vars["QUERY_STRING"]);
}
req = req.uri(uri.as_str());
if let Some(v) = env_vars.get("SERVER_PROTOCOL") {
if v == "HTTP/0.9" {
req = req.version(http::version::Version::HTTP_09);
} else if v == "HTTP/1.0" {
req = req.version(http::version::Version::HTTP_10);
} else if v == "HTTP/1.1" {
req = req.version(http::version::Version::HTTP_11);
} else if v == "HTTP/2.0" {
req = req.version(http::version::Version::HTTP_2);
} else {
unimplemented!("Unsupport HTTP SERVER_PROTOCOL {:?}", v);
}
}
for key in env_vars.keys().filter(|k| k.starts_with("HTTP_")) {
let header: String = key
.chars()
.skip(5)
.map(|c| if c == '_' { '-' } else { c })
.collect();
req = req.header(header.as_str(), env_vars[key].as_str().trim());
}
req = add_header(req, &env_vars, "AUTH_TYPE", "X-CGI-Auth-Type");
req = add_header(req, &env_vars, "CONTENT_LENGTH", "X-CGI-Content-Length");
req = add_header(req, &env_vars, "CONTENT_TYPE", "X-CGI-Content-Type");
req = add_header(
req,
&env_vars,
"GATEWAY_INTERFACE",
"X-CGI-Gateway-Interface",
);
req = add_header(req, &env_vars, "PATH_INFO", "X-CGI-Path-Info");
req = add_header(req, &env_vars, "PATH_TRANSLATED", "X-CGI-Path-Translated");
req = add_header(req, &env_vars, "QUERY_STRING", "X-CGI-Query-String");
req = add_header(req, &env_vars, "REMOTE_ADDR", "X-CGI-Remote-Addr");
req = add_header(req, &env_vars, "REMOTE_HOST", "X-CGI-Remote-Host");
req = add_header(req, &env_vars, "REMOTE_IDENT", "X-CGI-Remote-Ident");
req = add_header(req, &env_vars, "REMOTE_USER", "X-CGI-Remote-User");
req = add_header(req, &env_vars, "REQUEST_METHOD", "X-CGI-Request-Method");
req = add_header(req, &env_vars, "SCRIPT_NAME", "X-CGI-Script-Name");
req = add_header(req, &env_vars, "SERVER_PORT", "X-CGI-Server-Port");
req = add_header(req, &env_vars, "SERVER_PROTOCOL", "X-CGI-Server-Protocol");
req = add_header(req, &env_vars, "SERVER_SOFTWARE", "X-CGI-Server-Software");
req.body(stdin).unwrap()
}
fn add_header(
req: http::request::Builder,
env_vars: &HashMap<String, String>,
meta_var: &str,
target_header: &str,
) -> http::request::Builder {
if let Some(var) = env_vars.get(meta_var) {
req.header(target_header, var.as_str())
} else {
req
}
}
fn serialize_response(response: Response) -> Vec<u8> {
let mut output = String::new();
output.push_str("Status: ");
output.push_str(response.status().as_str());
if let Some(reason) = response.status().canonical_reason() {
output.push_str(" ");
output.push_str(reason);
}
output.push_str("\n");
{
let headers = response.headers();
let mut keys: Vec<&http::header::HeaderName> = headers.keys().collect();
keys.sort_by_key(|h| h.as_str());
for key in keys {
output.push_str(key.as_str());
output.push_str(": ");
output.push_str(headers.get(key).unwrap().to_str().unwrap());
output.push_str("\n");
}
}
output.push_str("\n");
let mut output = output.into_bytes();
let (_, mut body) = response.into_parts();
output.append(&mut body);
output
}
#[cfg(test)]
mod tests {
use super::*;
fn env(input: Vec<(&str, &str)>) -> HashMap<String, String> {
input
.into_iter()
.map(|(a, b)| (a.to_owned(), b.to_owned()))
.collect()
}
#[test]
fn test_empty() {
let env_vars = env(vec![]);
let stdin = Vec::new();
let req = parse_request(env_vars, stdin);
assert_eq!(req.method(), &http::method::Method::GET);
}
#[test]
fn test_parse_request() {
let env_vars = env(vec![
("REQUEST_METHOD", "GET"),
("SCRIPT_NAME", "/my/path/script"),
("SERVER_PROTOCOL", "HTTP/1.0"),
("HTTP_USER_AGENT", "MyBrowser/1.0"),
("QUERY_STRING", "foo=bar&baz=bop"),
]);
let stdin = Vec::new();
let req = parse_request(env_vars, stdin);
assert_eq!(req.method(), &http::method::Method::GET);
assert_eq!(req.uri(), "/my/path/script?foo=bar&baz=bop");
assert_eq!(req.uri().path(), "/my/path/script");
assert_eq!(req.uri().query(), Some("foo=bar&baz=bop"));
assert_eq!(req.version(), http::version::Version::HTTP_10);
assert_eq!(req.headers()[http::header::USER_AGENT], "MyBrowser/1.0");
assert_eq!(req.body(), &vec![] as &Vec<u8>);
}
fn test_serialized_response(resp: http::response::Builder, body: &str, expected_output: &str) {
let resp: Response = resp.body(String::from(body).into_bytes()).unwrap();
let output = serialize_response(resp);
let expected_output = String::from(expected_output).into_bytes();
if output != expected_output {
println!(
"output: {}\nexptected: {}",
std::str::from_utf8(&output).unwrap(),
std::str::from_utf8(&expected_output).unwrap()
);
}
assert_eq!(output, expected_output);
}
#[test]
fn test_serialized_response1() {
test_serialized_response(
http::Response::builder().status(200),
"Hello World",
"Status: 200 OK\n\nHello World",
);
test_serialized_response(
http::Response::builder().status(200)
.header("Content-Type", "text/html")
.header("Content-Language", "en")
.header("Cache-Control", "max-age=3600"),
"<html><body><h1>Hello</h1></body></html>",
"Status: 200 OK\ncache-control: max-age=3600\ncontent-language: en\ncontent-type: text/html\n\n<html><body><h1>Hello</h1></body></html>"
);
}
#[test]
fn test_shortcuts1() {
assert_eq!(std::str::from_utf8(&serialize_response(html_response(200, "<html><body><h1>Hello World</h1></body></html>"))).unwrap(),
"Status: 200 OK\ncontent-length: 46\ncontent-type: text/html; charset=utf-8\n\n<html><body><h1>Hello World</h1></body></html>"
);
}
#[test]
fn test_shortcuts2() {
assert_eq!(
std::str::from_utf8(&serialize_response(binary_response(
200,
None,
vec![65, 66, 67]
)))
.unwrap(),
"Status: 200 OK\ncontent-length: 3\n\nABC"
);
assert_eq!(
std::str::from_utf8(&serialize_response(binary_response(
200,
"application/octet-stream",
vec![65, 66, 67]
)))
.unwrap(),
"Status: 200 OK\ncontent-length: 3\ncontent-type: application/octet-stream\n\nABC"
);
let ct: String = "image/png".to_string();
assert_eq!(
std::str::from_utf8(&serialize_response(binary_response(
200,
ct.as_str(),
vec![65, 66, 67]
)))
.unwrap(),
"Status: 200 OK\ncontent-length: 3\ncontent-type: image/png\n\nABC"
);
}
}