#[cfg(any(feature = "compression", feature = "compression-brotli"))]
use async_compression::tokio::bufread::BrotliEncoder;
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
use async_compression::tokio::bufread::DeflateEncoder;
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
use async_compression::tokio::bufread::GzipEncoder;
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
use async_compression::tokio::bufread::ZstdEncoder;
use bytes::Bytes;
use futures_util::Stream;
use headers::{ContentType, HeaderMap, HeaderMapExt, HeaderValue};
use hyper::{
    header::{CONTENT_ENCODING, CONTENT_LENGTH},
    Body, Method, Request, Response, StatusCode,
};
use lazy_static::lazy_static;
use mime_guess::{mime, Mime};
use pin_project::pin_project;
use std::collections::HashSet;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_util::io::{ReaderStream, StreamReader};
use crate::{
    error_page,
    handler::RequestHandlerOpts,
    headers_ext::{AcceptEncoding, ContentCoding},
    http_ext::MethodExt,
    Error, Result,
};
lazy_static! {
    static ref TEXT_MIME_TYPES: HashSet<&'static str> = [
        "application/rtf",
        "application/javascript",
        "application/json",
        "application/xml",
        "font/ttf",
        "application/font-sfnt",
        "application/vnd.ms-fontobject",
        "application/wasm",
    ].into_iter().collect();
}
const AVAILABLE_ENCODINGS: &[ContentCoding] = &[
    #[cfg(any(feature = "compression", feature = "compression-deflate"))]
    ContentCoding::DEFLATE,
    #[cfg(any(feature = "compression", feature = "compression-gzip"))]
    ContentCoding::GZIP,
    #[cfg(any(feature = "compression", feature = "compression-brotli"))]
    ContentCoding::BROTLI,
    #[cfg(any(feature = "compression", feature = "compression-zstd"))]
    ContentCoding::ZSTD,
];
pub fn init(enabled: bool, handler_opts: &mut RequestHandlerOpts) {
    handler_opts.compression = enabled;
    const FORMATS: &[&str] = &[
        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
        "deflate",
        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
        "gzip",
        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
        "brotli",
        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
        "zstd",
    ];
    server_info!(
        "auto compression: enabled={enabled}, formats={}",
        FORMATS.join(",")
    );
}
pub(crate) fn post_process(
    opts: &RequestHandlerOpts,
    req: &Request<Body>,
    mut resp: Response<Body>,
) -> Result<Response<Body>, Error> {
    if !opts.compression {
        return Ok(resp);
    }
    let is_precompressed = resp.headers().get(CONTENT_ENCODING).is_some();
    if is_precompressed {
        return Ok(resp);
    }
    resp.headers_mut().append(
        hyper::header::VARY,
        HeaderValue::from_name(hyper::header::ACCEPT_ENCODING),
    );
    match auto(req.method(), req.headers(), resp) {
        Ok(resp) => Ok(resp),
        Err(err) => {
            tracing::error!("error during body compression: {:?}", err);
            error_page::error_response(
                req.uri(),
                req.method(),
                &StatusCode::INTERNAL_SERVER_ERROR,
                &opts.page404,
                &opts.page50x,
            )
        }
    }
}
pub fn auto(
    method: &Method,
    headers: &HeaderMap<HeaderValue>,
    resp: Response<Body>,
) -> Result<Response<Body>> {
    if method.is_head() || method.is_options() {
        return Ok(resp);
    }
    if let Some(encoding) = get_preferred_encoding(headers) {
        tracing::trace!(
            "preferred encoding selected from the accept-encoding header: {:?}",
            encoding
        );
        if let Some(content_type) = resp.headers().typed_get::<ContentType>() {
            if !is_text(Mime::from(content_type)) {
                return Ok(resp);
            }
        }
        #[cfg(any(feature = "compression", feature = "compression-gzip"))]
        if encoding == ContentCoding::GZIP {
            let (head, body) = resp.into_parts();
            return Ok(gzip(head, body.into()));
        }
        #[cfg(any(feature = "compression", feature = "compression-deflate"))]
        if encoding == ContentCoding::DEFLATE {
            let (head, body) = resp.into_parts();
            return Ok(deflate(head, body.into()));
        }
        #[cfg(any(feature = "compression", feature = "compression-brotli"))]
        if encoding == ContentCoding::BROTLI {
            let (head, body) = resp.into_parts();
            return Ok(brotli(head, body.into()));
        }
        #[cfg(any(feature = "compression", feature = "compression-zstd"))]
        if encoding == ContentCoding::ZSTD {
            let (head, body) = resp.into_parts();
            return Ok(zstd(head, body.into()));
        }
        tracing::trace!("no compression feature matched the preferred encoding, probably not enabled or unsupported");
    }
    Ok(resp)
}
fn is_text(mime: Mime) -> bool {
    mime.type_() == mime::TEXT
        || mime
            .suffix()
            .is_some_and(|suffix| suffix == mime::XML || suffix == mime::JSON)
        || TEXT_MIME_TYPES.contains(mime.essence_str())
}
#[cfg(any(feature = "compression", feature = "compression-gzip"))]
#[cfg_attr(
    docsrs,
    doc(cfg(any(feature = "compression", feature = "compression-gzip")))
)]
pub fn gzip(
    mut head: http::response::Parts,
    body: CompressableBody<Body, hyper::Error>,
) -> Response<Body> {
    tracing::trace!("compressing response body on the fly using GZIP");
    let body = Body::wrap_stream(ReaderStream::new(GzipEncoder::new(StreamReader::new(body))));
    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::GZIP);
    head.headers.remove(CONTENT_LENGTH);
    head.headers.append(CONTENT_ENCODING, header);
    Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-deflate"))]
#[cfg_attr(
    docsrs,
    doc(cfg(any(feature = "compression", feature = "compression-deflate")))
)]
pub fn deflate(
    mut head: http::response::Parts,
    body: CompressableBody<Body, hyper::Error>,
) -> Response<Body> {
    tracing::trace!("compressing response body on the fly using DEFLATE");
    let body = Body::wrap_stream(ReaderStream::new(DeflateEncoder::new(StreamReader::new(
        body,
    ))));
    let header = create_encoding_header(
        head.headers.remove(CONTENT_ENCODING),
        ContentCoding::DEFLATE,
    );
    head.headers.remove(CONTENT_LENGTH);
    head.headers.append(CONTENT_ENCODING, header);
    Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-brotli"))]
#[cfg_attr(
    docsrs,
    doc(cfg(any(feature = "compression", feature = "compression-brotli")))
)]
pub fn brotli(
    mut head: http::response::Parts,
    body: CompressableBody<Body, hyper::Error>,
) -> Response<Body> {
    tracing::trace!("compressing response body on the fly using BROTLI");
    let body = Body::wrap_stream(ReaderStream::new(BrotliEncoder::new(StreamReader::new(
        body,
    ))));
    let header =
        create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::BROTLI);
    head.headers.remove(CONTENT_LENGTH);
    head.headers.append(CONTENT_ENCODING, header);
    Response::from_parts(head, body)
}
#[cfg(any(feature = "compression", feature = "compression-zstd"))]
#[cfg_attr(
    docsrs,
    doc(cfg(any(feature = "compression", feature = "compression-zstd")))
)]
pub fn zstd(
    mut head: http::response::Parts,
    body: CompressableBody<Body, hyper::Error>,
) -> Response<Body> {
    tracing::trace!("compressing response body on the fly using ZSTD");
    let body = Body::wrap_stream(ReaderStream::new(ZstdEncoder::new(StreamReader::new(body))));
    let header = create_encoding_header(head.headers.remove(CONTENT_ENCODING), ContentCoding::ZSTD);
    head.headers.remove(CONTENT_LENGTH);
    head.headers.append(CONTENT_ENCODING, header);
    Response::from_parts(head, body)
}
pub fn create_encoding_header(existing: Option<HeaderValue>, coding: ContentCoding) -> HeaderValue {
    if let Some(val) = existing {
        if let Ok(str_val) = val.to_str() {
            return HeaderValue::from_str(&[str_val, ", ", coding.as_str()].concat())
                .unwrap_or_else(|_| coding.into());
        }
    }
    coding.into()
}
#[inline(always)]
pub fn get_preferred_encoding(headers: &HeaderMap<HeaderValue>) -> Option<ContentCoding> {
    if let Some(ref accept_encoding) = headers.typed_get::<AcceptEncoding>() {
        tracing::trace!("request with accept-encoding header: {:?}", accept_encoding);
        for encoding in accept_encoding.sorted_encodings() {
            if AVAILABLE_ENCODINGS.contains(&encoding) {
                return Some(encoding);
            }
        }
    }
    None
}
#[pin_project]
#[derive(Debug)]
pub struct CompressableBody<S, E>
where
    S: Stream<Item = Result<Bytes, E>>,
    E: std::error::Error,
{
    #[pin]
    body: S,
}
impl<S, E> Stream for CompressableBody<S, E>
where
    S: Stream<Item = Result<Bytes, E>>,
    E: std::error::Error,
{
    type Item = std::io::Result<Bytes>;
    fn poll_next(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        use std::io::{Error, ErrorKind};
        let pin = self.project();
        S::poll_next(pin.body, ctx).map_err(|_| Error::from(ErrorKind::InvalidData))
    }
}
impl From<Body> for CompressableBody<Body, hyper::Error> {
    #[inline(always)]
    fn from(body: Body) -> Self {
        CompressableBody { body }
    }
}