tower_cgi/
lib.rs

1use std::error::Error;
2use std::path::{Path, PathBuf};
3use std::pin::Pin;
4use std::process::Stdio;
5use std::task::{Context, Poll};
6
7use futures::{Future, FutureExt, StreamExt};
8use http::uri::{Authority, Scheme};
9use http::{header, Request, Response, StatusCode};
10use hyper::Body;
11use tokio::io::{self, AsyncBufReadExt, BufReader};
12use tokio::process::Command;
13use tokio_util::io::{ReaderStream, StreamReader};
14use tower::Service;
15
16pub struct Cgi {
17    path: PathBuf,
18    env_clear: bool,
19}
20
21impl Cgi {
22    pub fn new<P: AsRef<Path>>(path: P) -> Self {
23        Cgi {
24            path: path.as_ref().to_path_buf(),
25            env_clear: true,
26        }
27    }
28
29    pub fn env_clear(mut self, clear: bool) -> Self {
30        self.env_clear = clear;
31        self
32    }
33}
34
35type BoxedError = Box<dyn Error + Sync + Send>;
36
37impl Service<Request<Body>> for Cgi {
38    type Response = Response<Body>;
39    type Error = BoxedError;
40
41    #[allow(clippy::type_complexity)]
42    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
43
44    #[inline]
45    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
46        Poll::Ready(Ok(()))
47    }
48
49    fn call(&mut self, req: Request<Body>) -> Self::Future {
50        let script_path = self.path.clone();
51        let env_clear = self.env_clear;
52
53        async move {
54            let script_path = std::fs::canonicalize(script_path)?;
55
56            let mut cmd = Command::new(&script_path);
57            let cmd = if env_clear { cmd.env_clear() } else { &mut cmd };
58
59            let mut child = cmd
60                .env("GATEWAY_INTERFACE", "CGI/1.1")
61                .env("QUERY_STRING", req.uri().query().unwrap_or_default())
62                .env("PATH_INFO", req.uri().path())
63                .env("PATH_TRANSLATED", &script_path)
64                .env("REQUEST_METHOD", req.method().as_str().to_ascii_uppercase())
65                // TODO: should we use request extensions to get the remote address here?
66                // .env("REMOTE_ADDR", remote_addr.ip().to_string())
67                // .env("REMOTE_PORT", remote_addr.port().to_string())
68                .env("SCRIPT_NAME", req.uri().path())
69                .env(
70                    "SERVER_NAME",
71                    req.headers()
72                        .get(header::HOST)
73                        .and_then(|val| val.to_str().ok())
74                        .and_then(|host| host.parse::<Authority>().ok())
75                        .map(|authority| authority.host().to_owned())
76                        .unwrap_or_default(),
77                )
78                .env(
79                    "SERVER_PORT",
80                    req.uri()
81                        .port()
82                        .map(|port| port.to_string())
83                        .or_else(|| {
84                            req.headers().get("x-forwarded-proto").and_then(|val| {
85                                match val.to_str() {
86                                    Ok("http") => Some("80".to_string()),
87                                    Ok("https") => Some("443".to_string()),
88                                    _ => None,
89                                }
90                            })
91                        })
92                        .or_else(|| match req.uri().scheme() {
93                            Some(scheme) if *scheme == Scheme::HTTP => Some("80".to_string()),
94                            Some(scheme) if *scheme == Scheme::HTTPS => Some("443".to_string()),
95                            _ => None,
96                        })
97                        .unwrap_or_else(|| "80".to_string()),
98                )
99                .env("SERVER_PROTOCOL", format!("{:?}", req.version()))
100                .env("SERVER_SOFTWARE", "tower-cgi/0.0.1")
101                .env(
102                    "CONTENT_TYPE",
103                    req.headers()
104                        .get(header::CONTENT_TYPE)
105                        .and_then(|val| val.to_str().ok())
106                        .unwrap_or_default(),
107                )
108                .env(
109                    "CONTENT_LENGTH",
110                    req.headers()
111                        .get(header::CONTENT_LENGTH)
112                        .and_then(|val| val.to_str().ok())
113                        .unwrap_or_default(),
114                )
115                .envs(
116                    req.headers()
117                        .into_iter()
118                        .map(|(name, value)| {
119                            let name = format!("HTTP_{}", name)
120                                .replace("-", "_")
121                                .to_ascii_uppercase();
122                            Ok((name, value.to_str()?))
123                        })
124                        .collect::<Result<Vec<_>, Self::Error>>()?,
125                )
126                .stdin(Stdio::piped())
127                .stdout(Stdio::piped())
128                .stderr(Stdio::inherit())
129                .spawn()?;
130
131            let mut stdin = child.stdin.take().ok_or("Failed to get process STDIN")?;
132            let stdout = child.stdout.take().ok_or("Failed to get process STDOUT")?;
133
134            tokio::spawn(async move { child.wait().await.unwrap() });
135
136            let write_request_body = async move {
137                let request_body = req
138                    .into_body()
139                    .map(|chunk| chunk.map_err(|err| io::Error::new(io::ErrorKind::Other, err)));
140                let mut request_body_reader = StreamReader::new(request_body);
141                io::copy(&mut request_body_reader, &mut stdin).await?;
142                Ok::<_, Self::Error>(io::copy(&mut request_body_reader, &mut stdin).await?)
143            };
144
145            let read_response = async move {
146                let mut stdout_reader = BufReader::new(stdout);
147                let mut headers = Vec::new();
148
149                loop {
150                    stdout_reader.read_until(b'\n', &mut headers).await?;
151
152                    match headers.as_slice() {
153                        [.., b'\r', b'\n', b'\r', b'\n'] => break,
154                        [.., b'\n', b'\n'] => break,
155                        _ => continue,
156                    }
157                }
158
159                let mut parsed_headers = [httparse::EMPTY_HEADER; 64];
160                httparse::parse_headers(&headers, &mut parsed_headers)?;
161
162                let response = parsed_headers
163                    .into_iter()
164                    .filter(|header| *header != httparse::EMPTY_HEADER)
165                    .map(|header| (header.name, header.value))
166                    .try_fold(
167                        Response::builder().status(200),
168                        |response, (name, value)| {
169                            if name.to_ascii_lowercase() == "status" {
170                                Ok::<_, Self::Error>(
171                                    response.status(StatusCode::from_bytes(&value[0..3])?),
172                                )
173                            } else {
174                                Ok(response.header(name, value))
175                            }
176                        },
177                    )?;
178
179                let body_reader = ReaderStream::new(stdout_reader);
180                let response = response.body(Body::wrap_stream(body_reader))?;
181                Ok::<_, Self::Error>(response)
182            };
183
184            let (_, response) = tokio::try_join!(write_request_body, read_response)?;
185
186            Ok(response)
187        }
188        .boxed()
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use std::fs::Permissions;
195    use std::io;
196    use std::io::Write;
197    use std::os::unix::fs::PermissionsExt;
198
199    use http::Request;
200    use hyper::Body;
201    use indoc::indoc;
202    use tempfile::{NamedTempFile, TempPath};
203    use tower::ServiceExt;
204
205    use crate::Cgi;
206
207    async fn temp_cgi_script(program: &str) -> io::Result<TempPath> {
208        let mut file = NamedTempFile::new()?;
209        file.as_file_mut()
210            .set_permissions(Permissions::from_mode(0o755))?;
211        writeln!(file, "{}", program)?;
212        let path = file.into_temp_path();
213        // Just to eliminate possible heisen tests when the OS didn't
214        // close the file and we tried to execute it.
215        std::thread::sleep(std::time::Duration::from_secs(1));
216        Ok(path)
217    }
218
219    #[tokio::test]
220    async fn test_status_code() {
221        let script = temp_cgi_script(indoc! {r#"
222            #!/bin/sh
223            echo "Status: 201 Created"
224            echo ""
225        "#})
226        .await
227        .unwrap();
228
229        let svc = Cgi::new(&script);
230
231        let req = Request::builder().body(Body::empty()).unwrap();
232        let res = svc.oneshot(req).await.unwrap();
233
234        assert_eq!(res.status(), 201);
235    }
236
237    #[tokio::test]
238    async fn test_response_headers() {
239        let script = temp_cgi_script(indoc! {r#"
240            #!/bin/sh
241            echo "Status: 200"
242            echo "x-some-header: hello"
243            echo "x-other-header: bye"
244            echo ""
245        "#})
246        .await
247        .unwrap();
248
249        let svc = Cgi::new(&script);
250
251        let req = Request::builder().body(Body::empty()).unwrap();
252        let res = svc.oneshot(req).await.unwrap();
253
254        assert_eq!(res.headers()["x-some-header"], "hello");
255        assert_eq!(res.headers()["x-other-header"], "bye");
256    }
257
258    #[tokio::test]
259    async fn test_request_headers() {
260        let script = temp_cgi_script(indoc! {r#"
261            #!/bin/sh
262            echo "Status: 200"
263            echo "x-req-header: ${HTTP_SOME_REQUEST_HEADER}"
264            echo ""
265        "#})
266        .await
267        .unwrap();
268
269        let svc = Cgi::new(&script);
270
271        let req = Request::builder()
272            .header("some-request-header", "hello")
273            .body(Body::empty())
274            .unwrap();
275
276        let res = svc.oneshot(req).await.unwrap();
277
278        assert_eq!(res.headers()["x-req-header"], "hello");
279    }
280
281    #[tokio::test]
282    async fn test_response_body() {
283        let script = temp_cgi_script(indoc! {r#"
284            #!/bin/sh
285            echo "Status: 200"
286            echo ""
287            printf "Hello"
288        "#})
289        .await
290        .unwrap();
291
292        let svc = Cgi::new(&script);
293
294        let req = Request::builder().body(Body::empty()).unwrap();
295
296        let res = svc.oneshot(req).await.unwrap();
297        let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
298
299        assert_eq!(&body[..], b"Hello");
300    }
301
302    #[tokio::test]
303    async fn test_request_body() {
304        let script = temp_cgi_script(indoc! {r#"
305            #!/bin/sh
306            echo "Status: 200"
307            echo ""
308            cat -
309        "#})
310        .await
311        .unwrap();
312
313        let svc = Cgi::new(&script);
314
315        let req = Request::builder().body(Body::from(&b"input"[..])).unwrap();
316
317        let res = svc.oneshot(req).await.unwrap();
318        let body = hyper::body::to_bytes(res.into_body()).await.unwrap();
319
320        assert_eq!(&body[..], b"input");
321    }
322}