use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use std::fs::File;
use std::future::Future;
use std::io::{self, BufReader, Cursor, Read};
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_rustls::rustls::{
    server::{AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, NoClientAuth},
    Certificate, Error as TlsError, PrivateKey, RootCertStore, ServerConfig,
};
use crate::transport::Transport;
#[derive(Debug)]
pub enum TlsConfigError {
    Io(io::Error),
    CertParseError,
    InvalidIdentityPem,
    EmptyKey,
    UnknownPrivateKeyFormat,
    InvalidKey(TlsError),
}
impl std::fmt::Display for TlsConfigError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            TlsConfigError::Io(err) => err.fmt(f),
            TlsConfigError::CertParseError => write!(f, "certificate parse error"),
            TlsConfigError::InvalidIdentityPem => write!(f, "identity PEM is invalid"),
            TlsConfigError::UnknownPrivateKeyFormat => write!(f, "unknown private key format"),
            TlsConfigError::EmptyKey => write!(f, "key contains no private key"),
            TlsConfigError::InvalidKey(err) => write!(f, "key contains an invalid key, {err}"),
        }
    }
}
impl std::error::Error for TlsConfigError {}
pub enum TlsClientAuth {
    Off,
    Optional(Box<dyn Read + Send + Sync>),
    Required(Box<dyn Read + Send + Sync>),
}
pub struct TlsConfigBuilder {
    cert: Box<dyn Read + Send + Sync>,
    key: Box<dyn Read + Send + Sync>,
    client_auth: TlsClientAuth,
    ocsp_resp: Vec<u8>,
}
impl std::fmt::Debug for TlsConfigBuilder {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> ::std::fmt::Result {
        f.debug_struct("TlsConfigBuilder").finish()
    }
}
impl TlsConfigBuilder {
    pub fn new() -> TlsConfigBuilder {
        TlsConfigBuilder {
            key: Box::new(io::empty()),
            cert: Box::new(io::empty()),
            client_auth: TlsClientAuth::Off,
            ocsp_resp: Vec::new(),
        }
    }
    pub fn key_path(mut self, path: impl AsRef<Path>) -> Self {
        self.key = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self
    }
    pub fn key(mut self, key: &[u8]) -> Self {
        self.key = Box::new(Cursor::new(Vec::from(key)));
        self
    }
    pub fn cert_path(mut self, path: impl AsRef<Path>) -> Self {
        self.cert = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self
    }
    pub fn cert(mut self, cert: &[u8]) -> Self {
        self.cert = Box::new(Cursor::new(Vec::from(cert)));
        self
    }
    pub fn client_auth_optional_path(mut self, path: impl AsRef<Path>) -> Self {
        let file = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self.client_auth = TlsClientAuth::Optional(file);
        self
    }
    pub fn client_auth_optional(mut self, trust_anchor: &[u8]) -> Self {
        let cursor = Box::new(Cursor::new(Vec::from(trust_anchor)));
        self.client_auth = TlsClientAuth::Optional(cursor);
        self
    }
    pub fn client_auth_required_path(mut self, path: impl AsRef<Path>) -> Self {
        let file = Box::new(LazyFile {
            path: path.as_ref().into(),
            file: None,
        });
        self.client_auth = TlsClientAuth::Required(file);
        self
    }
    pub fn client_auth_required(mut self, trust_anchor: &[u8]) -> Self {
        let cursor = Box::new(Cursor::new(Vec::from(trust_anchor)));
        self.client_auth = TlsClientAuth::Required(cursor);
        self
    }
    pub fn ocsp_resp(mut self, ocsp_resp: &[u8]) -> Self {
        self.ocsp_resp = Vec::from(ocsp_resp);
        self
    }
    pub fn build(mut self) -> Result<ServerConfig, TlsConfigError> {
        let mut cert_rdr = BufReader::new(self.cert);
        let cert = rustls_pemfile::certs(&mut cert_rdr)
            .map_err(|_e| TlsConfigError::CertParseError)?
            .into_iter()
            .map(Certificate)
            .collect();
        let mut key_vec = Vec::new();
        self.key
            .read_to_end(&mut key_vec)
            .map_err(TlsConfigError::Io)?;
        if key_vec.is_empty() {
            return Err(TlsConfigError::EmptyKey);
        }
        let mut key = None;
        let mut reader = std::io::Cursor::new(key_vec);
        for item in rustls_pemfile::read_all(&mut reader)
            .map_err(|_e| TlsConfigError::InvalidIdentityPem)?
        {
            match item {
                rustls_pemfile::Item::RSAKey(k) => key = Some(PrivateKey(k)),
                rustls_pemfile::Item::PKCS8Key(k) => key = Some(PrivateKey(k)),
                rustls_pemfile::Item::ECKey(k) => key = Some(PrivateKey(k)),
                _ => return Err(TlsConfigError::UnknownPrivateKeyFormat),
            }
        }
        let key = match key {
            Some(k) => k,
            _ => return Err(TlsConfigError::EmptyKey),
        };
        fn read_trust_anchor(
            trust_anchor: Box<dyn Read + Send + Sync>,
        ) -> Result<RootCertStore, TlsConfigError> {
            let trust_anchors = {
                let mut reader = BufReader::new(trust_anchor);
                rustls_pemfile::certs(&mut reader).map_err(TlsConfigError::Io)?
            };
            let mut store = RootCertStore::empty();
            let (added, _skipped) = store.add_parsable_certificates(&trust_anchors);
            if added == 0 {
                return Err(TlsConfigError::CertParseError);
            }
            Ok(store)
        }
        let client_auth = match self.client_auth {
            TlsClientAuth::Off => NoClientAuth::boxed(),
            TlsClientAuth::Optional(trust_anchor) => {
                AllowAnyAnonymousOrAuthenticatedClient::new(read_trust_anchor(trust_anchor)?)
                    .boxed()
            }
            TlsClientAuth::Required(trust_anchor) => {
                AllowAnyAuthenticatedClient::new(read_trust_anchor(trust_anchor)?).boxed()
            }
        };
        let mut config = ServerConfig::builder()
            .with_safe_defaults()
            .with_client_cert_verifier(client_auth)
            .with_single_cert_with_ocsp_and_sct(cert, key, self.ocsp_resp, Vec::new())
            .map_err(TlsConfigError::InvalidKey)?;
        config.alpn_protocols = vec!["h2".into(), "http/1.1".into()];
        Ok(config)
    }
}
impl Default for TlsConfigBuilder {
    fn default() -> Self {
        Self::new()
    }
}
struct LazyFile {
    path: PathBuf,
    file: Option<File>,
}
impl LazyFile {
    fn lazy_read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        if self.file.is_none() {
            self.file = Some(File::open(&self.path)?);
        }
        self.file.as_mut().unwrap().read(buf)
    }
}
impl Read for LazyFile {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.lazy_read(buf).map_err(|err| {
            let kind = err.kind();
            io::Error::new(
                kind,
                format!("error reading file ({:?}): {}", self.path.display(), err),
            )
        })
    }
}
impl Transport for TlsStream {
    fn remote_addr(&self) -> Option<SocketAddr> {
        Some(self.remote_addr)
    }
}
enum State {
    Handshaking(tokio_rustls::Accept<AddrStream>),
    Streaming(tokio_rustls::server::TlsStream<AddrStream>),
}
pub struct TlsStream {
    state: State,
    remote_addr: SocketAddr,
}
impl TlsStream {
    fn new(stream: AddrStream, config: Arc<ServerConfig>) -> TlsStream {
        let remote_addr = stream.remote_addr();
        let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
        TlsStream {
            state: State::Handshaking(accept),
            remote_addr,
        }
    }
}
impl AsyncRead for TlsStream {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        let pin = self.get_mut();
        match pin.state {
            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
                Ok(mut stream) => {
                    let result = Pin::new(&mut stream).poll_read(cx, buf);
                    pin.state = State::Streaming(stream);
                    result
                }
                Err(err) => Poll::Ready(Err(err)),
            },
            State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
        }
    }
}
impl AsyncWrite for TlsStream {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let pin = self.get_mut();
        match pin.state {
            State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
                Ok(mut stream) => {
                    let result = Pin::new(&mut stream).poll_write(cx, buf);
                    pin.state = State::Streaming(stream);
                    result
                }
                Err(err) => Poll::Ready(Err(err)),
            },
            State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
        }
    }
    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.state {
            State::Handshaking(_) => Poll::Ready(Ok(())),
            State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
        }
    }
    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.state {
            State::Handshaking(_) => Poll::Ready(Ok(())),
            State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
        }
    }
}
pub struct TlsAcceptor {
    config: Arc<ServerConfig>,
    incoming: AddrIncoming,
}
impl TlsAcceptor {
    pub fn new(config: ServerConfig, incoming: AddrIncoming) -> TlsAcceptor {
        TlsAcceptor {
            config: Arc::new(config),
            incoming,
        }
    }
}
impl Accept for TlsAcceptor {
    type Conn = TlsStream;
    type Error = io::Error;
    fn poll_accept(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
        let pin = self.get_mut();
        match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) {
            Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))),
            Some(Err(e)) => Poll::Ready(Some(Err(e))),
            None => Poll::Ready(None),
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn file_cert_key() {
        TlsConfigBuilder::new()
            .cert_path("tests/tls/local.dev_cert.pem")
            .key_path("tests/tls/local.dev_key.pem")
            .build()
            .unwrap();
    }
    #[test]
    fn bytes_cert_key() {
        let cert = include_str!("../tests/tls/local.dev_cert.pem");
        let key = include_str!("../tests/tls/local.dev_key.pem");
        TlsConfigBuilder::new()
            .key(key.as_bytes())
            .cert(cert.as_bytes())
            .build()
            .unwrap();
    }
    #[test]
    fn file_cert_key_ecc() {
        TlsConfigBuilder::new()
            .cert_path("tests/tls/local.dev_cert.ecc.pem")
            .key_path("tests/tls/local.dev_key.ecc.pem")
            .build()
            .unwrap();
    }
    #[test]
    fn bytes_cert_key_ecc() {
        let cert = include_str!("../tests/tls/local.dev_cert.ecc.pem");
        let key = include_str!("../tests/tls/local.dev_key.ecc.pem");
        TlsConfigBuilder::new()
            .key(key.as_bytes())
            .cert(cert.as_bytes())
            .build()
            .unwrap();
    }
}