1use std::{
4    collections::Bound,
5    io::{Seek, SeekFrom},
6    path::{Path, PathBuf},
7    str::FromStr,
8    time::SystemTime,
9};
10use tokio::io::AsyncReadExt;
11use tokio_util::io::ReaderStream;
12
13use viz_core::{
14    headers::{
15        AcceptRanges, ContentLength, ContentRange, ContentType, ETag, HeaderMap, HeaderMapExt,
16        IfMatch, IfModifiedSince, IfNoneMatch, IfUnmodifiedSince, LastModified, Range,
17    },
18    Handler, IntoResponse, Method, Request, RequestExt, Response, ResponseExt, Result, StatusCode,
19};
20
21mod directory;
22mod error;
23
24use directory::Directory;
25pub use error::Error;
26
27#[derive(Clone, Debug)]
29pub struct File {
30    path: PathBuf,
31}
32
33impl File {
34    #[must_use]
40    pub fn new(path: impl Into<PathBuf>) -> Self {
41        let path = path.into();
42
43        assert!(path.exists(), "{} not found", path.to_string_lossy());
44
45        Self { path }
46    }
47}
48
49#[viz_core::async_trait]
50impl Handler<Request> for File {
51    type Output = Result<Response>;
52
53    async fn call(&self, req: Request) -> Self::Output {
54        serve(&self.path, req.headers())
55    }
56}
57
58#[derive(Clone, Debug)]
60pub struct Dir {
61    path: PathBuf,
62    listing: bool,
63    unlisted: Option<Vec<&'static str>>,
64}
65
66impl Dir {
67    #[must_use]
73    pub fn new(path: impl Into<PathBuf>) -> Self {
74        let path = path.into();
75
76        assert!(path.exists(), "{} not found", path.to_string_lossy());
77
78        Self {
79            path,
80            listing: false,
81            unlisted: None,
82        }
83    }
84
85    #[must_use]
87    pub const fn listing(mut self) -> Self {
88        self.listing = true;
89        self
90    }
91
92    #[must_use]
94    pub fn unlisted(mut self, unlisted: Vec<&'static str>) -> Self {
95        self.unlisted.replace(unlisted);
96        self
97    }
98}
99
100#[viz_core::async_trait]
101impl Handler<Request> for Dir {
102    type Output = Result<Response>;
103
104    async fn call(&self, req: Request) -> Self::Output {
105        if req.method() != Method::GET {
106            Err(Error::MethodNotAllowed)?;
107        }
108
109        let mut prev = false;
110        let mut path = self.path.clone();
111
112        if let Some(param) = req.route_info().params.first().map(|(_, v)| v) {
113            let p = percent_encoding::percent_decode_str(param)
114                .decode_utf8()
115                .map_err(|_| Error::InvalidPath)?;
116            sanitize_path(&mut path, &p)?;
117            prev = true;
118        }
119
120        if !path.exists() {
121            Err(StatusCode::NOT_FOUND.into_error())?;
122        }
123
124        if path.is_file() {
125            return serve(&path, req.headers());
126        }
127
128        let index = path.join("index.html");
129        if index.exists() {
130            return serve(&index, req.headers());
131        }
132
133        if self.listing {
134            return Directory::new(req.path(), prev, &path, self.unlisted.as_ref())
135                .ok_or_else(|| StatusCode::INTERNAL_SERVER_ERROR.into_error())
136                .map(IntoResponse::into_response);
137        }
138
139        Ok(StatusCode::NOT_FOUND.into_response())
140    }
141}
142
143fn sanitize_path<'a>(path: &'a mut PathBuf, p: &'a str) -> Result<()> {
144    for seg in p.split('/') {
145        if seg.starts_with("..") {
146            return Err(StatusCode::NOT_FOUND.into_error());
147        }
148        if seg.contains('\\') {
149            return Err(StatusCode::NOT_FOUND.into_error());
150        }
151        path.push(seg);
152    }
153    Ok(())
154}
155
156fn extract_etag(mtime: &SystemTime, size: u64) -> Option<ETag> {
157    ETag::from_str(&format!(
158        r#""{}-{}""#,
159        mtime
160            .duration_since(SystemTime::UNIX_EPOCH)
161            .ok()?
162            .as_millis(),
163        size
164    ))
165    .ok()
166}
167
168#[inline]
169fn serve(path: &Path, headers: &HeaderMap) -> Result<Response> {
170    let mut file = std::fs::File::open(path).map_err(Error::Io)?;
171    let metadata = file
172        .metadata()
173        .map_err(|_| StatusCode::NOT_FOUND.into_error())?;
174
175    let mut etag = None;
176    let mut last_modified = None;
177    let mut content_range = None;
178    let mut max = metadata.len();
179
180    if let Ok(modified) = metadata.modified() {
181        etag = extract_etag(&modified, max);
182
183        if matches!((headers.typed_get::<IfMatch>(), &etag), (Some(if_match), Some(etag)) if !if_match.precondition_passes(etag))
184            || matches!(headers.typed_get::<IfUnmodifiedSince>(), Some(if_unmodified_since) if !if_unmodified_since.precondition_passes(modified))
185        {
186            Err(Error::PreconditionFailed)?;
187        }
188
189        if matches!((headers.typed_get::<IfNoneMatch>(), &etag), (Some(if_no_match), Some(etag)) if !if_no_match.precondition_passes(etag))
190            || matches!(headers.typed_get::<IfModifiedSince>(), Some(if_modified_since) if !if_modified_since.is_modified(modified))
191        {
192            return Ok(StatusCode::NOT_MODIFIED.into_response());
193        }
194
195        last_modified.replace(LastModified::from(modified));
196    }
197
198    if let Some((start, end)) = headers
200        .typed_get::<Range>()
201        .and_then(|range| range.satisfiable_ranges(100).next())
202    {
203        let start = match start {
204            Bound::Included(n) => n,
205            Bound::Excluded(n) => n + 1,
206            Bound::Unbounded => 0,
207        };
208        let end = match end {
209            Bound::Included(n) => n + 1,
210            Bound::Excluded(n) => n,
211            Bound::Unbounded => max,
212        };
213
214        if end < start || end > max {
215            Err(Error::RangeUnsatisfied(max))?;
216        }
217
218        if start != 0 || end != max {
219            if let Ok(range) = ContentRange::bytes(start..end, max) {
220                max = end - start;
221                content_range.replace(range);
222                file.seek(SeekFrom::Start(start)).map_err(Error::Io)?;
223            }
224        }
225    }
226
227    let mut res = if content_range.is_some() {
228        Response::stream(ReaderStream::new(tokio::fs::File::from_std(file).take(max)))
230    } else {
231        Response::stream(ReaderStream::new(tokio::fs::File::from_std(file)))
232    };
233
234    let headers = res.headers_mut();
235
236    headers.typed_insert(AcceptRanges::bytes());
237    headers.typed_insert(ContentLength(max));
238    headers.typed_insert(ContentType::from(
239        mime_guess::from_path(path).first_or_octet_stream(),
240    ));
241
242    if let Some(etag) = etag {
243        headers.typed_insert(etag);
244    }
245
246    if let Some(last_modified) = last_modified {
247        headers.typed_insert(last_modified);
248    }
249
250    if let Some(content_range) = content_range {
251        headers.typed_insert(content_range);
252        *res.status_mut() = StatusCode::PARTIAL_CONTENT;
253    };
254
255    Ok(res)
256}
257
258#[cfg(test)]
259mod tests {
260    use super::{Dir, File};
261    use std::sync::Arc;
262    use viz_core::{
263        types::{Params, RouteInfo},
264        Handler, IntoResponse, Request, Result, StatusCode,
265    };
266
267    #[tokio::test]
268    async fn file() -> Result<()> {
269        let serve = File::new("src/serve.rs");
270
271        let mut req: Request = Request::default();
272        req.extensions_mut().insert(Arc::new(RouteInfo {
273            id: 2,
274            pattern: "/*".to_string(),
275            params: Into::<Params>::into(vec![("*1", "serve.rs")]),
276        }));
277        *req.uri_mut() = "/serve.rs".parse().unwrap();
278
279        let result = serve.call(req).await;
280
281        assert_eq!(result.unwrap().status(), StatusCode::OK);
282
283        let mut req: Request = Request::default();
284        req.extensions_mut().insert(Arc::new(RouteInfo {
285            id: 2,
286            pattern: "/*".to_string(),
287            params: Into::<Params>::into(vec![("*1", "serve")]),
288        }));
289        *req.uri_mut() = "/serve".parse().unwrap();
290
291        let result = serve.call(req).await;
292
293        assert_eq!(result.unwrap().status(), StatusCode::OK);
294
295        Ok(())
296    }
297
298    #[tokio::test]
299    async fn dir() -> Result<()> {
300        let serve = Dir::new("src/serve");
301
302        let mut req: Request = Request::default();
303        req.extensions_mut().insert(Arc::new(RouteInfo {
304            id: 2,
305            pattern: "/*".to_string(),
306            params: Into::<Params>::into(vec![("*1", "list.tpl")]),
307        }));
308        *req.uri_mut() = "/list.tpl".parse().unwrap();
309
310        let result = serve.call(req).await;
311
312        assert_eq!(result.unwrap().status(), StatusCode::OK);
313
314        let mut req: Request = Request::default();
315        req.extensions_mut().insert(Arc::new(RouteInfo {
316            id: 2,
317            pattern: "/*".to_string(),
318            params: Into::<Params>::into(vec![("*1", "list")]),
319        }));
320        *req.uri_mut() = "/list".parse().unwrap();
321
322        let result = serve.call(req).await.map_err(IntoResponse::into_response);
323
324        assert_eq!(result.unwrap_err().status(), StatusCode::NOT_FOUND);
325
326        Ok(())
327    }
328}