use crate::{
    header,
    types::{PayloadError, RealIp},
    Body, BodyState, Bytes, FromRequest, Future, Request, Result,
};
use headers::HeaderMapExt;
use http_body_util::{BodyExt, Collected};
#[cfg(any(feature = "params", feature = "multipart"))]
use std::sync::Arc;
#[cfg(feature = "limits")]
use crate::types::Limits;
#[cfg(feature = "limits")]
use http_body_util::{LengthLimitError, Limited};
#[cfg(any(feature = "form", feature = "json", feature = "multipart"))]
use crate::types::Payload;
#[cfg(feature = "form")]
use crate::types::Form;
#[cfg(feature = "json")]
use crate::types::Json;
#[cfg(feature = "multipart")]
use crate::types::Multipart;
#[cfg(feature = "cookie")]
use crate::types::{Cookie, Cookies, CookiesError};
#[cfg(feature = "session")]
use crate::types::Session;
#[cfg(feature = "params")]
use crate::types::{ParamsError, PathDeserializer, RouteInfo};
pub trait RequestExt: private::Sealed + Sized {
    fn schema(&self) -> Option<&http::uri::Scheme>;
    fn path(&self) -> &str;
    fn query_string(&self) -> Option<&str>;
    #[cfg(feature = "query")]
    fn query<T>(&self) -> Result<T, PayloadError>
    where
        T: serde::de::DeserializeOwned;
    fn header<K, T>(&self, key: K) -> Option<T>
    where
        K: header::AsHeaderName,
        T: std::str::FromStr;
    fn header_typed<H>(&self) -> Option<H>
    where
        H: headers::Header;
    fn content_length(&self) -> Option<u64>;
    fn content_type(&self) -> Option<mime::Mime>;
    fn extract<T>(&mut self) -> impl Future<Output = Result<T, T::Error>> + Send
    where
        T: FromRequest;
    fn incoming(&mut self) -> Result<Body, PayloadError>;
    fn bytes(&mut self) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
    fn text(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
    #[cfg(feature = "form")]
    fn form<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
    where
        T: serde::de::DeserializeOwned;
    #[cfg(feature = "json")]
    fn json<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
    where
        T: serde::de::DeserializeOwned;
    #[cfg(feature = "multipart")]
    fn multipart(&mut self) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
    #[cfg(feature = "state")]
    fn state<T>(&self) -> Option<T>
    where
        T: Clone + Send + Sync + 'static;
    #[cfg(feature = "state")]
    fn set_state<T>(&mut self, t: T) -> Option<T>
    where
        T: Clone + Send + Sync + 'static;
    #[cfg(feature = "cookie")]
    fn cookies(&self) -> Result<Cookies, CookiesError>;
    #[cfg(feature = "cookie")]
    fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
    where
        S: AsRef<str>;
    #[cfg(feature = "session")]
    fn session(&self) -> &Session;
    #[cfg(feature = "params")]
    fn params<T>(&self) -> Result<T, ParamsError>
    where
        T: serde::de::DeserializeOwned;
    #[cfg(feature = "params")]
    fn param<T>(&self, name: &str) -> Result<T, ParamsError>
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Display;
    #[cfg(feature = "params")]
    fn route_info(&self) -> &Arc<RouteInfo>;
    fn remote_addr(&self) -> Option<&std::net::SocketAddr>;
    fn realip(&self) -> Option<RealIp>;
}
impl RequestExt for Request {
    fn schema(&self) -> Option<&http::uri::Scheme> {
        self.uri().scheme()
    }
    fn path(&self) -> &str {
        self.uri().path()
    }
    fn query_string(&self) -> Option<&str> {
        self.uri().query()
    }
    #[cfg(feature = "query")]
    fn query<T>(&self) -> Result<T, PayloadError>
    where
        T: serde::de::DeserializeOwned,
    {
        serde_urlencoded::from_str(self.query_string().unwrap_or_default())
            .map_err(PayloadError::UrlDecode)
    }
    fn header<K, T>(&self, key: K) -> Option<T>
    where
        K: header::AsHeaderName,
        T: std::str::FromStr,
    {
        self.headers()
            .get(key)
            .map(header::HeaderValue::to_str)
            .and_then(Result::ok)
            .map(str::parse)
            .and_then(Result::ok)
    }
    fn header_typed<H>(&self) -> Option<H>
    where
        H: headers::Header,
    {
        self.headers().typed_get()
    }
    fn content_length(&self) -> Option<u64> {
        self.header(header::CONTENT_LENGTH)
    }
    fn content_type(&self) -> Option<mime::Mime> {
        self.header(header::CONTENT_TYPE)
    }
    async fn extract<T>(&mut self) -> Result<T, T::Error>
    where
        T: FromRequest,
    {
        T::extract(self).await
    }
    fn incoming(&mut self) -> Result<Body, PayloadError> {
        if let Some(state) = self.extensions().get::<BodyState>() {
            match state {
                BodyState::Empty => Err(PayloadError::Empty)?,
                BodyState::Used => Err(PayloadError::Used)?,
                BodyState::Normal => {}
            }
        }
        let (state, result) = match std::mem::replace(self.body_mut(), Body::Empty) {
            Body::Empty => (BodyState::Empty, Err(PayloadError::Empty)),
            body => (BodyState::Used, Ok(body)),
        };
        self.extensions_mut().insert(state);
        result
    }
    async fn bytes(&mut self) -> Result<Bytes, PayloadError> {
        self.incoming()?
            .collect()
            .await
            .map_err(|err| {
                #[cfg(feature = "limits")]
                if err.is::<LengthLimitError>() {
                    return PayloadError::TooLarge;
                }
                if let Ok(err) = err.downcast::<hyper::Error>() {
                    return PayloadError::Hyper(err);
                }
                PayloadError::Read
            })
            .map(Collected::to_bytes)
    }
    async fn text(&mut self) -> Result<String, PayloadError> {
        let bytes = self.bytes().await?;
        String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
    }
    #[cfg(feature = "form")]
    async fn form<T>(&mut self) -> Result<T, PayloadError>
    where
        T: serde::de::DeserializeOwned,
    {
        <Form as Payload>::check_type(self.content_type())?;
        let bytes = self.bytes().await?;
        serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
    }
    #[cfg(feature = "json")]
    async fn json<T>(&mut self) -> Result<T, PayloadError>
    where
        T: serde::de::DeserializeOwned,
    {
        <Json as Payload>::check_type(self.content_type())?;
        let bytes = self.bytes().await?;
        serde_json::from_slice(&bytes).map_err(PayloadError::Json)
    }
    #[cfg(feature = "multipart")]
    async fn multipart(&mut self) -> Result<Multipart, PayloadError> {
        let m = <Multipart as Payload>::check_type(self.content_type())?;
        let boundary = m
            .get_param(mime::BOUNDARY)
            .ok_or(PayloadError::MissingBoundary)?
            .as_str();
        Ok(Multipart::new(self.incoming()?, boundary))
    }
    #[cfg(feature = "state")]
    fn state<T>(&self) -> Option<T>
    where
        T: Clone + Send + Sync + 'static,
    {
        self.extensions().get().cloned()
    }
    #[cfg(feature = "state")]
    fn set_state<T>(&mut self, t: T) -> Option<T>
    where
        T: Clone + Send + Sync + 'static,
    {
        self.extensions_mut().insert(t)
    }
    #[cfg(feature = "cookie")]
    fn cookies(&self) -> Result<Cookies, CookiesError> {
        self.extensions()
            .get::<Cookies>()
            .cloned()
            .ok_or(CookiesError::Read)
    }
    #[cfg(feature = "cookie")]
    fn cookie<S>(&self, name: S) -> Option<Cookie<'_>>
    where
        S: AsRef<str>,
    {
        self.extensions().get::<Cookies>()?.get(name.as_ref())
    }
    #[cfg(feature = "session")]
    fn session(&self) -> &Session {
        self.extensions().get().expect("should get a session")
    }
    #[cfg(feature = "params")]
    fn params<T>(&self) -> Result<T, ParamsError>
    where
        T: serde::de::DeserializeOwned,
    {
        T::deserialize(PathDeserializer::new(&self.route_info().params)).map_err(ParamsError::Parse)
    }
    #[cfg(feature = "params")]
    fn param<T>(&self, name: &str) -> Result<T, ParamsError>
    where
        T: std::str::FromStr,
        T::Err: std::fmt::Display,
    {
        self.route_info().params.find(name)
    }
    fn remote_addr(&self) -> Option<&std::net::SocketAddr> {
        self.extensions().get()
    }
    #[cfg(feature = "params")]
    fn route_info(&self) -> &Arc<RouteInfo> {
        self.extensions().get().expect("should get current route")
    }
    fn realip(&self) -> Option<RealIp> {
        RealIp::parse(self)
    }
}
#[cfg(feature = "limits")]
pub trait RequestLimitsExt: private::Sealed + Sized {
    fn limits(&self) -> &Limits;
    fn bytes_with(
        &mut self,
        limit: Option<u64>,
        max: u64,
    ) -> impl Future<Output = Result<Bytes, PayloadError>> + Send;
    fn text_with_limit(&mut self) -> impl Future<Output = Result<String, PayloadError>> + Send;
    #[cfg(feature = "form")]
    fn form_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
    where
        T: serde::de::DeserializeOwned;
    #[cfg(feature = "json")]
    fn json_with_limit<T>(&mut self) -> impl Future<Output = Result<T, PayloadError>> + Send
    where
        T: serde::de::DeserializeOwned;
    #[cfg(feature = "multipart")]
    fn multipart_with_limit(
        &mut self,
    ) -> impl Future<Output = Result<Multipart, PayloadError>> + Send;
}
#[cfg(feature = "limits")]
impl RequestLimitsExt for Request {
    fn limits(&self) -> &Limits {
        self.extensions()
            .get::<Limits>()
            .expect("Limits middleware is required")
    }
    async fn bytes_with(&mut self, limit: Option<u64>, max: u64) -> Result<Bytes, PayloadError> {
        Limited::new(
            self.incoming()?,
            usize::try_from(limit.unwrap_or(max)).unwrap_or(usize::MAX),
        )
        .collect()
        .await
        .map_err(|err| {
            if err.is::<LengthLimitError>() {
                return PayloadError::TooLarge;
            }
            if let Ok(err) = err.downcast::<hyper::Error>() {
                return PayloadError::Hyper(*err);
            }
            PayloadError::Read
        })
        .map(Collected::to_bytes)
    }
    async fn text_with_limit(&mut self) -> Result<String, PayloadError> {
        let bytes = self
            .bytes_with(self.limits().get("text"), Limits::NORMAL)
            .await?;
        String::from_utf8(bytes.to_vec()).map_err(PayloadError::Utf8)
    }
    #[cfg(feature = "form")]
    async fn form_with_limit<T>(&mut self) -> Result<T, PayloadError>
    where
        T: serde::de::DeserializeOwned,
    {
        let limit = self.limits().get(<Form as Payload>::NAME);
        <Form as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
        let bytes = self.bytes_with(limit, <Form as Payload>::LIMIT).await?;
        serde_urlencoded::from_reader(bytes::Buf::reader(bytes)).map_err(PayloadError::UrlDecode)
    }
    #[cfg(feature = "json")]
    async fn json_with_limit<T>(&mut self) -> Result<T, PayloadError>
    where
        T: serde::de::DeserializeOwned,
    {
        let limit = self.limits().get(<Json as Payload>::NAME);
        <Json as Payload>::check_header(self.content_type(), self.content_length(), limit)?;
        let bytes = self.bytes_with(limit, <Json as Payload>::LIMIT).await?;
        serde_json::from_slice(&bytes).map_err(PayloadError::Json)
    }
    #[cfg(feature = "multipart")]
    async fn multipart_with_limit(&mut self) -> Result<Multipart, PayloadError> {
        let limit = self.limits().get(<Multipart as Payload>::NAME);
        let m = <Multipart as Payload>::check_header(
            self.content_type(),
            self.content_length(),
            limit,
        )?;
        let boundary = m
            .get_param(mime::BOUNDARY)
            .ok_or(PayloadError::MissingBoundary)?
            .as_str();
        Ok(Multipart::with_limits(
            self.incoming()?,
            boundary,
            self.extensions()
                .get::<std::sync::Arc<crate::types::MultipartLimits>>()
                .map(AsRef::as_ref)
                .cloned()
                .unwrap_or_default(),
        ))
    }
}
mod private {
    pub trait Sealed {}
    impl Sealed for super::Request {}
}