static_web_server/
fallback_page.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2// This file is part of Static Web Server.
3// See https://static-web-server.net/ for more information
4// Copyright (C) 2019-present Jose Quintana <joseluisq.net>
5
6//! Fallback page module useful for a custom page default.
7//!
8
9use headers::{AcceptRanges, ContentLength, ContentType, HeaderMapExt};
10use hyper::{Body, Request, Response, StatusCode};
11use mime_guess::mime;
12use std::path::Path;
13
14use crate::{handler::RequestHandlerOpts, helpers, http_ext::MethodExt, Error};
15
16/// Initializes fallback page processing
17pub(crate) fn init(file_path: &Path, handler_opts: &mut RequestHandlerOpts) {
18    let found = file_path.is_file();
19    if found {
20        handler_opts.page_fallback =
21            String::from_utf8_lossy(&helpers::read_bytes_default(file_path))
22                .trim()
23                .as_bytes()
24                .to_owned();
25    } else {
26        tracing::debug!("fallback page path not found or not a regular file");
27    }
28
29    tracing::info!(
30        "fallback page: enabled={}, value=\"{}\"",
31        found,
32        file_path.display()
33    );
34}
35
36/// Replace 404 Not Found by the configured fallback page
37pub(crate) fn post_process<T>(
38    opts: &RequestHandlerOpts,
39    req: &Request<T>,
40    resp: Response<Body>,
41) -> Result<Response<Body>, Error> {
42    Ok(
43        if req.method().is_get()
44            && resp.status() == StatusCode::NOT_FOUND
45            && !opts.page_fallback.is_empty()
46        {
47            fallback_response(&opts.page_fallback)
48        } else {
49            resp
50        },
51    )
52}
53
54/// Checks if a fallback response can be generated, i.e. if it is a `GET` request
55/// that would result in a `404` error and a fallback page is configured.
56/// If a response can be generated then is returned otherwise `None`.
57pub fn fallback_response(page_fallback: &[u8]) -> Response<Body> {
58    let body = Body::from(page_fallback.to_owned());
59    let len = page_fallback.len() as u64;
60
61    let mut resp = Response::new(body);
62    *resp.status_mut() = StatusCode::OK;
63
64    resp.headers_mut().typed_insert(ContentLength(len));
65    resp.headers_mut()
66        .typed_insert(ContentType::from(mime::TEXT_HTML_UTF_8));
67    resp.headers_mut().typed_insert(AcceptRanges::bytes());
68
69    resp
70}
71
72#[cfg(test)]
73mod tests {
74    use super::post_process;
75    use crate::{error_page, handler::RequestHandlerOpts, Error};
76    use hyper::{Body, Method, Request, Response, StatusCode, Uri};
77    use std::path::PathBuf;
78
79    fn make_request(method: &str) -> Request<Body> {
80        Request::builder()
81            .method(method)
82            .uri("/")
83            .body(Body::empty())
84            .unwrap()
85    }
86
87    fn make_response(status: &StatusCode) -> Response<Body> {
88        error_page::error_response(
89            &Uri::try_from("/").unwrap(),
90            &Method::GET,
91            status,
92            &PathBuf::new(),
93            &PathBuf::new(),
94        )
95        .unwrap()
96    }
97
98    #[test]
99    fn test_success_code() -> Result<(), Error> {
100        let opts = RequestHandlerOpts {
101            page_fallback: vec![1, 2, 3],
102            ..Default::default()
103        };
104        let req = make_request("GET");
105        let resp = make_response(&StatusCode::OK);
106
107        let resp = post_process(&opts, &req, resp)?;
108        assert_eq!(resp.status(), StatusCode::OK);
109        assert_ne!(
110            resp.headers()
111                .get("Content-Length")
112                .map(|v| v.to_str().unwrap())
113                .unwrap_or("3"),
114            "3"
115        );
116
117        Ok(())
118    }
119
120    #[test]
121    fn test_wrong_error() -> Result<(), Error> {
122        let opts = RequestHandlerOpts {
123            page_fallback: vec![1, 2, 3],
124            ..Default::default()
125        };
126        let req = make_request("GET");
127        let resp = make_response(&StatusCode::INTERNAL_SERVER_ERROR);
128
129        let resp = post_process(&opts, &req, resp)?;
130        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
131        assert_ne!(
132            resp.headers()
133                .get("Content-Length")
134                .map(|v| v.to_str().unwrap())
135                .unwrap_or("3"),
136            "3"
137        );
138
139        Ok(())
140    }
141
142    #[test]
143    fn test_wrong_method() -> Result<(), Error> {
144        let opts = RequestHandlerOpts {
145            page_fallback: vec![1, 2, 3],
146            ..Default::default()
147        };
148        let req = make_request("POST");
149        let resp = make_response(&StatusCode::NOT_FOUND);
150
151        let resp = post_process(&opts, &req, resp)?;
152        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
153        assert_ne!(
154            resp.headers()
155                .get("Content-Length")
156                .map(|v| v.to_str().unwrap())
157                .unwrap_or("3"),
158            "3"
159        );
160
161        Ok(())
162    }
163
164    #[test]
165    fn test_unconfigured() -> Result<(), Error> {
166        let opts = RequestHandlerOpts {
167            page_fallback: Vec::new(),
168            ..Default::default()
169        };
170        let req = make_request("GET");
171        let resp = make_response(&StatusCode::NOT_FOUND);
172
173        let resp = post_process(&opts, &req, resp)?;
174        assert_eq!(resp.status(), StatusCode::NOT_FOUND);
175
176        Ok(())
177    }
178
179    #[test]
180    fn test_fallback() -> Result<(), Error> {
181        let opts = RequestHandlerOpts {
182            page_fallback: vec![1, 2, 3],
183            ..Default::default()
184        };
185        let req = make_request("GET");
186        let resp = make_response(&StatusCode::NOT_FOUND);
187
188        let resp = post_process(&opts, &req, resp)?;
189        assert_eq!(resp.status(), StatusCode::OK);
190        assert_eq!(
191            resp.headers()
192                .get("Content-Length")
193                .map(|v| v.to_str().unwrap())
194                .unwrap_or("3"),
195            "3"
196        );
197
198        Ok(())
199    }
200}