use std::error::Error as StdError;
use std::io;
use std::sync::Arc;
use bytes::buf::Buf;
use futures::stream::{Stream, StreamExt};
use headers::HeaderMapExt;
use http::{Request, Response, StatusCode};
use crate::body::{Body, InBody};
use crate::davheaders;
use crate::util::{dav_method, notfound, AllowedMethods, Method};
use crate::davpath::DavPath;
use crate::errors::DavError;
use crate::fs::*;
use crate::ls::*;
use crate::DavResult;
#[derive(Clone)]
pub struct DavHandler {
config: Arc<DavConfig>,
}
#[derive(Default)]
pub struct DavConfig {
pub prefix: Option<String>,
pub fs: Option<Box<dyn DavFileSystem>>,
pub ls: Option<Box<dyn DavLockSystem>>,
pub allow: Option<AllowedMethods>,
pub principal: Option<String>,
pub hide_symlinks: Option<bool>,
}
pub(crate) struct DavInner {
pub prefix: String,
pub fs: Box<dyn DavFileSystem>,
pub ls: Option<Box<dyn DavLockSystem>>,
pub allow: Option<AllowedMethods>,
pub principal: Option<String>,
pub hide_symlinks: Option<bool>,
}
impl From<DavConfig> for DavInner {
fn from(cfg: DavConfig) -> Self {
DavInner {
prefix: cfg.prefix.unwrap_or("".to_string()),
fs: cfg.fs.unwrap(),
ls: cfg.ls,
allow: cfg.allow,
principal: cfg.principal,
hide_symlinks: cfg.hide_symlinks,
}
}
}
impl From<&DavConfig> for DavInner {
fn from(cfg: &DavConfig) -> Self {
DavInner {
prefix: cfg
.prefix
.as_ref()
.map(|p| p.to_owned())
.unwrap_or("".to_string()),
fs: cfg.fs.clone().unwrap(),
ls: cfg.ls.clone(),
allow: cfg.allow,
principal: cfg.principal.clone(),
hide_symlinks: cfg.hide_symlinks.clone(),
}
}
}
impl Clone for DavInner {
fn clone(&self) -> Self {
DavInner {
prefix: self.prefix.clone(),
fs: self.fs.clone(),
ls: self.ls.clone(),
allow: self.allow.clone(),
principal: self.principal.clone(),
hide_symlinks: self.hide_symlinks.clone(),
}
}
}
impl DavHandler {
pub fn new(
prefix: Option<&str>,
fs: Box<dyn DavFileSystem>,
ls: Option<Box<dyn DavLockSystem>>,
) -> DavHandler
{
let config = DavConfig {
prefix: prefix.map(|s| s.to_string()),
fs: Some(fs),
ls: ls,
allow: None,
principal: None,
hide_symlinks: None,
};
DavHandler {
config: Arc::new(config),
}
}
pub fn new_with(config: DavConfig) -> DavHandler {
DavHandler {
config: Arc::new(config),
}
}
pub async fn handle<ReqBody, ReqData, ReqError>(
&self,
req: Request<ReqBody>,
) -> io::Result<Response<Body>>
where
ReqData: Buf + Send,
ReqError: StdError + Send + Sync + 'static,
ReqBody: http_body::Body<Data = ReqData, Error = ReqError> + Send,
{
if self.config.fs.is_none() {
return Ok(notfound());
}
let inner = DavInner::from(&*self.config);
inner.handle(req).await
}
pub async fn handle_with<ReqBody, ReqData, ReqError>(
&self,
config: DavConfig,
req: Request<ReqBody>,
) -> io::Result<Response<Body>>
where
ReqData: Buf + Send,
ReqError: StdError + Send + Sync + 'static,
ReqBody: http_body::Body<Data = ReqData, Error = ReqError> + Send,
{
let orig = &*self.config;
let newconf = DavConfig {
prefix: config.prefix.or(orig.prefix.clone()),
fs: config.fs.or(orig.fs.clone()),
ls: config.ls.or(orig.ls.clone()),
allow: config.allow.or(orig.allow.clone()),
principal: config.principal.or(orig.principal.clone()),
hide_symlinks: config.hide_symlinks.or(orig.hide_symlinks.clone()),
};
if newconf.fs.is_none() {
return Ok(notfound());
}
let inner = DavInner::from(newconf);
inner.handle(req).await
}
}
impl DavInner {
pub(crate) async fn has_parent<'a>(&'a self, path: &'a DavPath) -> bool {
let p = path.parent();
self.fs.metadata(&p).await.map(|m| m.is_dir()).unwrap_or(false)
}
pub(crate) fn path(&self, req: &Request<()>) -> DavPath {
DavPath::from_uri(req.uri(), &self.prefix).unwrap()
}
pub(crate) fn fixpath(
&self,
res: &mut Response<Body>,
path: &mut DavPath,
meta: Box<dyn DavMetaData>,
) -> Box<dyn DavMetaData>
{
if meta.is_dir() && !path.is_collection() {
path.add_slash();
let newloc = path.as_url_string_with_prefix();
res.headers_mut()
.typed_insert(davheaders::ContentLocation(newloc));
}
meta
}
pub(crate) async fn read_request<'a, ReqBody, ReqError>(
&'a self,
body: ReqBody,
max_size: usize,
) -> DavResult<Vec<u8>>
where
ReqBody: Stream<Item = Result<bytes::Bytes, ReqError>> + Send + 'a,
ReqError: StdError + Send + Sync + 'static,
{
let mut data = Vec::new();
pin_utils::pin_mut!(body);
while let Some(res) = body.next().await {
let chunk = res.map_err(|_| {
DavError::IoError(io::Error::new(io::ErrorKind::UnexpectedEof, "UnexpectedEof"))
})?;
if data.len() + chunk.len() > max_size {
return Err(StatusCode::PAYLOAD_TOO_LARGE.into());
}
data.extend_from_slice(&chunk);
}
Ok(data)
}
async fn handle<ReqBody, ReqData, ReqError>(self, req: Request<ReqBody>) -> io::Result<Response<Body>>
where
ReqData: Buf + Send,
ReqError: StdError + Send + Sync + 'static,
ReqBody: http_body::Body<Data = ReqData, Error = ReqError> + Send,
{
let (req, body) = {
let (parts, body) = req.into_parts();
(Request::from_parts(parts, ()), InBody::from(body))
};
let is_ms = req
.headers()
.get("user-agent")
.and_then(|s| s.to_str().ok())
.map(|s| s.contains("Microsoft"))
.unwrap_or(false);
match self.handle2(req, body).await {
Ok(resp) => {
debug!("== END REQUEST result OK");
Ok(resp)
},
Err(err) => {
debug!("== END REQUEST result {:?}", err);
let mut resp = Response::builder();
if is_ms && err.statuscode() == StatusCode::NOT_FOUND {
resp.header("Cache-Control", "no-store, no-cache, must-revalidate");
resp.header("Progma", "no-cache");
resp.header("Expires", "0");
resp.header("Vary", "*");
}
resp.header("Content-Length", "0");
resp.status(err.statuscode());
if err.must_close() {
resp.header("connection", "close");
}
let resp = resp.body(Body::empty()).unwrap();
Ok(resp)
},
}
}
async fn handle2<ReqBody, ReqError>(self, req: Request<()>, body: ReqBody) -> DavResult<Response<Body>>
where
ReqBody: Stream<Item = Result<bytes::Bytes, ReqError>> + Send,
ReqError: StdError + Send + Sync + 'static,
{
if log_enabled!(log::Level::Debug) {
if let Some(t) = req.headers().typed_get::<davheaders::XLitmus>() {
debug!("X-Litmus: {:?}", t);
}
}
let method = match dav_method(req.method()) {
Ok(m) => m,
Err(e) => {
debug!("refusing method {} request {}", req.method(), req.uri());
return Err(e);
},
};
if let Some(ref a) = self.allow {
if !a.allowed(method) {
debug!("method {} not allowed on request {}", req.method(), req.uri());
return Err(DavError::StatusClose(StatusCode::METHOD_NOT_ALLOWED));
}
}
let path = DavPath::from_uri(req.uri(), &self.prefix)?;
let (body_strm, body_data) = match method {
Method::Put | Method::Patch => (Some(body), Vec::new()),
_ => (None, self.read_request(body, 65536).await?),
};
match method {
Method::Put | Method::Patch | Method::PropFind | Method::PropPatch | Method::Lock => {},
_ => {
if body_data.len() > 0 {
return Err(StatusCode::UNSUPPORTED_MEDIA_TYPE.into());
}
},
}
debug!("== START REQUEST {:?} {}", method, path);
let res = match method {
Method::Options => self.handle_options(req).await,
Method::PropFind => self.handle_propfind(req, body_data).await,
Method::PropPatch => self.handle_proppatch(req, body_data).await,
Method::MkCol => self.handle_mkcol(req).await,
Method::Delete => self.handle_delete(req).await,
Method::Lock => self.handle_lock(req, body_data).await,
Method::Unlock => self.handle_unlock(req).await,
Method::Head | Method::Get => self.handle_get(req).await,
Method::Put | Method::Patch => self.handle_put(req, body_strm.unwrap()).await,
Method::Copy | Method::Move => self.handle_copymove(req, method).await,
};
res
}
}