1use std::error::Error;
2use std::path::{Path, PathBuf};
3use std::pin::Pin;
4use std::process::Stdio;
5use std::task::{Context, Poll};
6
7use futures::{Future, FutureExt, StreamExt};
8use http::uri::{Authority, Scheme};
9use http::{header, Request, Response, StatusCode};
10use hyper::Body;
11use tokio::io::{self, AsyncBufReadExt, BufReader};
12use tokio::process::Command;
13use tokio_util::io::{ReaderStream, StreamReader};
14use tower::Service;
15
16pub struct Cgi {
17 path: PathBuf,
18 env_clear: bool,
19}
20
21impl Cgi {
22 pub fn new<P: AsRef<Path>>(path: P) -> Self {
23 Cgi {
24 path: path.as_ref().to_path_buf(),
25 env_clear: true,
26 }
27 }
28
29 pub fn env_clear(mut self, clear: bool) -> Self {
30 self.env_clear = clear;
31 self
32 }
33}
34
35type BoxedError = Box<dyn Error + Sync + Send>;
36
37impl Service<Request<Body>> for Cgi {
38 type Response = Response<Body>;
39 type Error = BoxedError;
40
41 #[allow(clippy::type_complexity)]
42 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
43
44 #[inline]
45 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
46 Poll::Ready(Ok(()))
47 }
48
49 fn call(&mut self, req: Request<Body>) -> Self::Future {
50 let script_path = self.path.clone();
51 let env_clear = self.env_clear;
52
53 async move {
54 let script_path = std::fs::canonicalize(script_path)?;
55
56 let mut cmd = Command::new(&script_path);
57 let cmd = if env_clear { cmd.env_clear() } else { &mut cmd };
58
59 let mut child = cmd
60 .env("GATEWAY_INTERFACE", "CGI/1.1")
61 .env("QUERY_STRING", req.uri().query().unwrap_or_default())
62 .env("PATH_INFO", req.uri().path())
63 .env("PATH_TRANSLATED", &script_path)
64 .env("REQUEST_METHOD", req.method().as_str().to_ascii_uppercase())
65 .env("SCRIPT_NAME", req.uri().path())
69 .env(
70 "SERVER_NAME",
71 req.headers()
72 .get(header::HOST)
73 .and_then(|val| val.to_str().ok())
74 .and_then(|host| host.parse::<Authority>().ok())
75 .map(|authority| authority.host().to_owned())
76 .unwrap_or_default(),
77 )
78 .env(
79 "SERVER_PORT",
80 req.uri()
81 .port()
82 .map(|port| port.to_string())
83 .or_else(|| {
84 req.headers().get("x-forwarded-proto").and_then(|val| {
85 match val.to_str() {
86 Ok("http") => Some("80".to_string()),
87 Ok("https") => Some("443".to_string()),
88 _ => None,
89 }
90 })
91 })
92 .or_else(|| match req.uri().scheme() {
93 Some(scheme) if *scheme == Scheme::HTTP => Some("80".to_string()),
94 Some(scheme) if *scheme == Scheme::HTTPS => Some("443".to_string()),
95 _ => None,
96 })
97 .unwrap_or_else(|| "80".to_string()),
98 )
99 .env("SERVER_PROTOCOL", format!("{:?}", req.version()))
100 .env("SERVER_SOFTWARE", "tower-cgi/0.0.1")
101 .env(
102 "CONTENT_TYPE",
103 req.headers()
104 .get(header::CONTENT_TYPE)
105 .and_then(|val| val.to_str().ok())
106 .unwrap_or_default(),
107 )
108 .env(
109 "CONTENT_LENGTH",
110 req.headers()
111 .get(header::CONTENT_LENGTH)
112 .and_then(|val| val.to_str().ok())
113 .unwrap_or_default(),
114 )
115 .envs(
116 req.headers()
117 .into_iter()
118 .map(|(name, value)| {
119 let name = format!("HTTP_{}", name)
120 .replace("-", "_")
121 .to_ascii_uppercase();
122 Ok((name, value.to_str()?))
123 })
124 .collect::<Result<Vec<_>, Self::Error>>()?,
125 )
126 .stdin(Stdio::piped())
127 .stdout(Stdio::piped())
128 .stderr(Stdio::inherit())
129 .spawn()?;
130
131 let mut stdin = child.stdin.take().ok_or("Failed to get process STDIN")?;
132 let stdout = child.stdout.take().ok_or("Failed to get process STDOUT")?;
133
134 tokio::spawn(async move { child.wait().await.unwrap() });
135
136 let write_request_body = async move {
137 let request_body = req
138 .into_body()
139 .map(|chunk| chunk.map_err(|err| io::Error::new(io::ErrorKind::Other, err)));
140 let mut request_body_reader = StreamReader::new(request_body);
141 io::copy(&mut request_body_reader, &mut stdin).await?;
142 Ok::<_, Self::Error>(io::copy(&mut request_body_reader, &mut stdin).await?)
143 };
144
145 let read_response = async move {
146 let mut stdout_reader = BufReader::new(stdout);
147 let mut headers = Vec::new();
148
149 loop {
150 stdout_reader.read_until(b'\n', &mut headers).await?;
151
152 match headers.as_slice() {
153 [.., b'\r', b'\n', b'\r', b'\n'] => break,
154 [.., b'\n', b'\n'] => break,
155 _ => continue,
156 }
157 }
158
159 let mut parsed_headers = [httparse::EMPTY_HEADER; 64];
160 httparse::parse_headers(&headers, &mut parsed_headers)?;
161
162 let response = parsed_headers
163 .into_iter()
164 .filter(|header| *header != httparse::EMPTY_HEADER)
165 .map(|header| (header.name, header.value))
166 .try_fold(
167 Response::builder().status(200),
168 |response, (name, value)| {
169 if name.to_ascii_lowercase() == "status" {
170 Ok::<_, Self::Error>(
171 response.status(StatusCode::from_bytes(&value[0..3])?),
172 )
173 } else {
174 Ok(response.header(name, value))
175 }
176 },
177 )?;
178
179 let body_reader = ReaderStream::new(stdout_reader);
180 let response = response.body(Body::wrap_stream(body_reader))?;
181 Ok::<_, Self::Error>(response)
182 };
183
184 let (_, response) = tokio::try_join!(write_request_body, read_response)?;
185
186 Ok(response)
187 }
188 .boxed()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use std::fs::Permissions;
195 use std::io;
196 use std::io::Write;
197 use std::os::unix::fs::PermissionsExt;
198
199 use http::Request;
200 use hyper::Body;
201 use indoc::indoc;
202 use tempfile::{NamedTempFile, TempPath};
203 use tower::ServiceExt;
204
205 use crate::Cgi;
206
207 async fn temp_cgi_script(program: &str) -> io::Result<TempPath> {
208 let mut file = NamedTempFile::new()?;
209 file.as_file_mut()
210 .set_permissions(Permissions::from_mode(0o755))?;
211 writeln!(file, "{}", program)?;
212 let path = file.into_temp_path();
213 std::thread::sleep(std::time::Duration::from_secs(1));
216 Ok(path)
217 }
218
219 #[tokio::test]
220 async fn test_status_code() {
221 let script = temp_cgi_script(indoc! {r#"
222 #!/bin/sh
223 echo "Status: 201 Created"
224 echo ""
225 "#})
226 .await
227 .unwrap();
228
229 let svc = Cgi::new(&script);
230
231 let req = Request::builder().body(Body::empty()).unwrap();
232 let res = svc.oneshot(req).await.unwrap();
233
234 assert_eq!(res.status(), 201);
235 }
236
237 #[tokio::test]
238 async fn test_response_headers() {
239 let script = temp_cgi_script(indoc! {r#"
240 #!/bin/sh
241 echo "Status: 200"
242 echo "x-some-header: hello"
243 echo "x-other-header: bye"
244 echo ""
245 "#})
246 .await
247 .unwrap();
248
249 let svc = Cgi::new(&script);
250
251 let req = Request::builder().body(Body::empty()).unwrap();
252 let res = svc.oneshot(req).await.unwrap();
253
254 assert_eq!(res.headers()["x-some-header"], "hello");
255 assert_eq!(res.headers()["x-other-header"], "bye");
256 }
257
258 #[tokio::test]
259 async fn test_request_headers() {
260 let script = temp_cgi_script(indoc! {r#"
261 #!/bin/sh
262 echo "Status: 200"
263 echo "x-req-header: ${HTTP_SOME_REQUEST_HEADER}"
264 echo ""
265 "#})
266 .await
267 .unwrap();
268
269 let svc = Cgi::new(&script);
270
271 let req = Request::builder()
272 .header("some-request-header", "hello")
273 .body(Body::empty())
274 .unwrap();
275
276 let res = svc.oneshot(req).await.unwrap();
277
278 assert_eq!(res.headers()["x-req-header"], "hello");
279 }
280
281 #[tokio::test]
282 async fn test_response_body() {
283 let script = temp_cgi_script(indoc! {r#"
284 #!/bin/sh
285 echo "Status: 200"
286 echo ""
287 printf "Hello"
288 "#})
289 .await
290 .unwrap();
291
292 let svc = Cgi::new(&script);
293
294 let req = Request::builder().body(Body::empty()).unwrap();
295
296 let res = svc.oneshot(req).await.unwrap();
297 let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
298
299 assert_eq!(&body[..], b"Hello");
300 }
301
302 #[tokio::test]
303 async fn test_request_body() {
304 let script = temp_cgi_script(indoc! {r#"
305 #!/bin/sh
306 echo "Status: 200"
307 echo ""
308 cat -
309 "#})
310 .await
311 .unwrap();
312
313 let svc = Cgi::new(&script);
314
315 let req = Request::builder().body(Body::from(&b"input"[..])).unwrap();
316
317 let res = svc.oneshot(req).await.unwrap();
318 let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
319
320 assert_eq!(&body[..], b"input");
321 }
322}