#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(unreachable_pub)]
#![forbid(unsafe_code)]
#![warn(missing_docs)]
#![warn(clippy::future_not_send)]
#![warn(rustdoc::broken_intra_doc_links)]
use std::error::Error as StdError;
mod finder;
pub use finder::{CsrfTokenFinder, FormFinder, HeaderFinder, JsonFinder};
use rand::distributions::Standard;
use rand::Rng;
use salvo_core::handler::Skipper;
use salvo_core::http::{Method, StatusCode};
use salvo_core::{async_trait, Depot, FlowCtrl, Handler, Request, Response};
#[macro_use]
mod cfg;
cfg_feature! {
    #![feature = "cookie-store"]
    mod cookie_store;
    pub use cookie_store::CookieStore;
    pub fn cookie_store<>() -> CookieStore {
        CookieStore::new()
    }
}
cfg_feature! {
    #![feature = "session-store"]
    mod session_store;
    pub use session_store::SessionStore;
    pub fn session_store() -> SessionStore {
        SessionStore::new()
    }
}
cfg_feature! {
    #![feature = "bcrypt-cipher"]
    mod bcrypt_cipher;
    pub use bcrypt_cipher::BcryptCipher;
    pub fn bcrypt_csrf<S>(store: S, finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, S> where S: CsrfStore {
        Csrf::new(BcryptCipher::new(), store, finder)
    }
}
cfg_feature! {
    #![all(feature = "bcrypt-cipher", feature = "cookie-store")]
    pub fn bcrypt_cookie_csrf(finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, CookieStore> {
        Csrf::new(BcryptCipher::new(), CookieStore::new(), finder)
    }
}
cfg_feature! {
    #![all(feature = "bcrypt-cipher", feature = "session-store")]
    pub fn bcrypt_session_csrf(finder: impl CsrfTokenFinder ) -> Csrf<BcryptCipher, SessionStore> {
        Csrf::new(BcryptCipher::new(), SessionStore::new(), finder)
    }
}
cfg_feature! {
    #![feature = "hmac-cipher"]
    mod hmac_cipher;
    pub use hmac_cipher::HmacCipher;
    pub fn hmac_csrf<S>(hmac_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, S> where S: CsrfStore {
        Csrf::new(HmacCipher::new(hmac_key), store, finder)
    }
}
cfg_feature! {
    #![all(feature = "hmac-cipher", feature = "cookie-store")]
    pub fn hmac_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, CookieStore> {
        Csrf::new(HmacCipher::new(aead_key), CookieStore::new(), finder)
    }
}
cfg_feature! {
    #![all(feature = "hmac-cipher", feature = "session-store")]
    pub fn hmac_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<HmacCipher, SessionStore> {
        Csrf::new(HmacCipher::new(aead_key), SessionStore::new(), finder)
    }
}
cfg_feature! {
    #![feature = "aes-gcm-cipher"]
    mod aes_gcm_cipher;
    pub use aes_gcm_cipher::AesGcmCipher;
    pub fn aes_gcm_csrf<S>(aead_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, S> where S: CsrfStore {
        Csrf::new(AesGcmCipher::new(aead_key), store, finder)
    }
}
cfg_feature! {
    #![all(feature = "aes-gcm-cipher", feature = "cookie-store")]
    pub fn aes_gcm_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, CookieStore> {
        Csrf::new(AesGcmCipher::new(aead_key), CookieStore::new(), finder)
    }
}
cfg_feature! {
    #![all(feature = "aes-gcm-cipher", feature = "session-store")]
    pub fn aes_gcm_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<AesGcmCipher, SessionStore> {
        Csrf::new(AesGcmCipher::new(aead_key), SessionStore::new(), finder)
    }
}
cfg_feature! {
    #![feature = "ccp-cipher"]
    mod ccp_cipher;
    pub use ccp_cipher::CcpCipher;
    pub fn ccp_csrf<S>(aead_key: [u8; 32], store: S, finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, S> where S: CsrfStore {
        Csrf::new(CcpCipher::new(aead_key), store, finder)
    }
}
cfg_feature! {
    #![all(feature = "ccp-cipher", feature = "cookie-store")]
    pub fn ccp_cookie_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, CookieStore> {
        Csrf::new(CcpCipher::new(aead_key), CookieStore::new(), finder)
    }
}
cfg_feature! {
    #![all(feature = "ccp-cipher", feature = "session-store")]
    pub fn ccp_session_csrf(aead_key: [u8; 32], finder: impl CsrfTokenFinder ) -> Csrf<CcpCipher, SessionStore> {
        Csrf::new(CcpCipher::new(aead_key), SessionStore::new(), finder)
    }
}
pub const CSRF_TOKEN_KEY: &str = "salvo.csrf.token";
fn default_skipper(req: &mut Request, _depot: &Depot) -> bool {
    ![Method::POST, Method::PATCH, Method::DELETE, Method::PUT].contains(req.method())
}
#[async_trait]
pub trait CsrfStore: Send + Sync + 'static {
    type Error: StdError + Send + Sync + 'static;
    async fn load<C: CsrfCipher>(&self, req: &mut Request, depot: &mut Depot, cipher: &C) -> Option<(String, String)>;
    async fn save(
        &self,
        req: &mut Request,
        depot: &mut Depot,
        res: &mut Response,
        token: &str,
        proof: &str,
    ) -> Result<(), Self::Error>;
}
pub trait CsrfCipher: Send + Sync + 'static {
    fn verify(&self, token: &str, proof: &str) -> bool;
    fn generate(&self) -> (String, String);
    fn random_bytes(&self, len: usize) -> Vec<u8> {
        rand::thread_rng().sample_iter(Standard).take(len).collect()
    }
}
pub trait CsrfDepotExt {
    fn csrf_token(&self) -> Option<&String>;
}
impl CsrfDepotExt for Depot {
    #[inline]
    fn csrf_token(&self) -> Option<&String> {
        self.get(CSRF_TOKEN_KEY).ok()
    }
}
pub struct Csrf<C, S> {
    cipher: C,
    store: S,
    skipper: Box<dyn Skipper>,
    finders: Vec<Box<dyn CsrfTokenFinder>>,
}
impl<C: CsrfCipher, S: CsrfStore> Csrf<C, S> {
    #[inline]
    pub fn new(cipher: C, store: S, finder: impl CsrfTokenFinder) -> Self {
        Self {
            cipher,
            store,
            skipper: Box::new(default_skipper),
            finders: vec![Box::new(finder)],
        }
    }
    #[inline]
    pub fn add_finder(mut self, finder: impl CsrfTokenFinder) -> Self {
        self.finders.push(Box::new(finder));
        self
    }
    async fn find_token(&self, req: &mut Request) -> Option<String> {
        for finder in self.finders.iter() {
            if let Some(token) = finder.find_token(req).await {
                return Some(token);
            }
        }
        None
    }
}
#[async_trait]
impl<C: CsrfCipher, S: CsrfStore> Handler for Csrf<C, S> {
    async fn handle(&self, req: &mut Request, depot: &mut Depot, res: &mut Response, ctrl: &mut FlowCtrl) {
        match self.store.load(req, depot, &self.cipher).await {
            Some((token, proof)) => {
                depot.insert(CSRF_TOKEN_KEY, token);
                if !self.skipper.skipped(req, depot) {
                    if let Some(token) = &self.find_token(req).await {
                        tracing::debug!("csrf token: {token}");
                        if !self.cipher.verify(token, &proof) {
                            tracing::debug!("rejecting request due to invalid or expired CSRF token");
                            res.status_code(StatusCode::FORBIDDEN);
                            ctrl.skip_rest();
                            return;
                        } else {
                            tracing::debug!("cipher verify CSRF token success");
                        }
                    } else {
                        tracing::debug!("rejecting request due to missing CSRF token",);
                        res.status_code(StatusCode::FORBIDDEN);
                        ctrl.skip_rest();
                        return;
                    }
                }
                ctrl.call_next(req, depot, res).await;
            }
            None => {
                if !self.skipper.skipped(req, depot) {
                    tracing::debug!("rejecting request due to missing CSRF token",);
                    res.status_code(StatusCode::FORBIDDEN);
                    ctrl.skip_rest();
                } else {
                    let (token, proof) = self.cipher.generate();
                    if let Err(e) = self.store.save(req, depot, res, &token, &proof).await {
                        tracing::error!(error = ?e, "salvo csrf token failed");
                    }
                    tracing::debug!("new token: {:?}", token);
                    depot.insert(CSRF_TOKEN_KEY, token);
                    ctrl.call_next(req, depot, res).await;
                }
            }
        }
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use salvo_core::prelude::*;
    use salvo_core::test::{ResponseExt, TestClient};
    #[handler]
    async fn get_index(depot: &mut Depot) -> String {
        depot.csrf_token().unwrap().to_owned()
    }
    #[handler]
    async fn post_index() -> &'static str {
        "POST"
    }
    #[tokio::test]
    async fn test_exposes_csrf_request_extensions() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            HeaderFinder::new("x-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index);
        let res = TestClient::get("http://127.0.0.1:5801").send(router).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
    }
    #[tokio::test]
    async fn test_adds_csrf_cookie_sets_request_token() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            HeaderFinder::new("x-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(router).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        assert_ne!(res.take_string().await.unwrap(), "");
        assert_ne!(res.cookie("salvo.csrf"), None);
    }
    #[tokio::test]
    async fn test_validates_token_in_header() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            HeaderFinder::new("x-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let csrf_token = res.take_string().await.unwrap();
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let mut res = TestClient::post("http://127.0.0.1:5801")
            .add_header("x-csrf-token", csrf_token, true)
            .add_header("cookie", cookie.to_string(), true)
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        assert_eq!(res.take_string().await.unwrap(), "POST");
    }
    #[tokio::test]
    async fn test_validates_token_in_custom_header() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            HeaderFinder::new("x-mycsrf-header"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let csrf_token = res.take_string().await.unwrap();
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let mut res = TestClient::post("http://127.0.0.1:5801")
            .add_header("x-mycsrf-header", csrf_token, true)
            .add_header("cookie", cookie.to_string(), true)
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        assert_eq!(res.take_string().await.unwrap(), "POST");
    }
    #[tokio::test]
    async fn test_validates_token_in_query() {
        let csrf = Csrf::new(BcryptCipher::new(), CookieStore::new(), HeaderFinder::new("csrf-token"));
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let csrf_token = res.take_string().await.unwrap();
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let mut res = TestClient::post("http://127.0.0.1:5801?a=1&b=2")
            .add_header("csrf-token", csrf_token, true)
            .add_header("cookie", cookie.to_string(), true)
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        assert_eq!(res.take_string().await.unwrap(), "POST");
    }
    #[cfg(feadture = "hmac-cipher")]
    #[tokio::test]
    async fn test_validates_token_in_alternate_query() {
        let csrf = Csrf::new(
            HmacCipher::new(*b"01234567012345670123456701234567"),
            CookieStore::new(),
            HeaderFinder::new("my-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let csrf_token = res.take_string().await.unwrap();
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let mut res = TestClient::post("http://127.0.0.1:5801?a=1&b=2")
            .add_header("my-csrf-token", csrf_token, true)
            .add_header("cookie", cookie.to_string(), true)
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        assert_eq!(res.take_string().await.unwrap(), "POST");
    }
    #[cfg(feature = "hmac-cipher")]
    #[tokio::test]
    async fn test_validates_token_in_form() {
        let csrf = Csrf::new(
            HmacCipher::new(*b"01234567012345670123456701234567"),
            CookieStore::new(),
            FormFinder::new("csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let csrf_token = res.take_string().await.unwrap();
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let mut res = TestClient::post("http://127.0.0.1:5801")
            .add_header("cookie", cookie.to_string(), true)
            .form(&[("a", "1"), ("csrf-token", &*csrf_token), ("b", "2")])
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        assert_eq!(res.take_string().await.unwrap(), "POST");
    }
    #[tokio::test]
    async fn test_validates_token_in_alternate_form() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            FormFinder::new("my-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let csrf_token = res.take_string().await.unwrap();
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let mut res = TestClient::post("http://127.0.0.1:5801")
            .add_header("cookie", cookie.to_string(), true)
            .form(&[("a", "1"), ("my-csrf-token", &*csrf_token), ("b", "2")])
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        assert_eq!(res.take_string().await.unwrap(), "POST");
    }
    #[tokio::test]
    async fn test_rejects_short_token() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            HeaderFinder::new("x-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let res = TestClient::post("http://127.0.0.1:5801")
            .add_header("x-csrf-token", "aGVsbG8=", true)
            .add_header("cookie", cookie.to_string().split_once('.').unwrap().0, true)
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
    }
    #[tokio::test]
    async fn test_rejects_invalid_base64_token() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            HeaderFinder::new("x-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let res = TestClient::post("http://127.0.0.1:5801")
            .add_header("x-csrf-token", "aGVsbG8", true)
            .add_header("cookie", cookie.to_string().split_once('.').unwrap().0, true)
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
    }
    #[tokio::test]
    async fn test_rejects_mismatched_token() {
        let csrf = Csrf::new(
            BcryptCipher::new(),
            CookieStore::new(),
            HeaderFinder::new("x-csrf-token"),
        );
        let router = Router::new().hoop(csrf).get(get_index).post(post_index);
        let service = Service::new(router);
        let mut res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let csrf_token = res.take_string().await.unwrap();
        let res = TestClient::get("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::OK);
        let cookie = res.cookie("salvo.csrf").unwrap();
        let res = TestClient::post("http://127.0.0.1:5801").send(&service).await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
        let res = TestClient::post("http://127.0.0.1:5801")
            .add_header("x-csrf-token", csrf_token, true)
            .add_header("cookie", cookie.to_string().split_once('.').unwrap().0, true)
            .send(&service)
            .await;
        assert_eq!(res.status_code.unwrap(), StatusCode::FORBIDDEN);
    }
}