Skip to main content

salvo_csrf/
hmac_cipher.rs

1use std::fmt::Debug;
2
3use base64::Engine;
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use hmac::{Hmac, Mac};
6use sha2::Sha256;
7
8use super::CsrfCipher;
9
10/// A CSRF protection implementation that uses HMAC.
11#[derive(Debug, Clone)]
12pub struct HmacCipher {
13    hmac_key: [u8; 32],
14    token_size: usize,
15}
16
17impl HmacCipher {
18    /// Given an HMAC key, return an `HmacCipher` instance.
19    #[inline]
20    #[must_use]
21    pub fn new(hmac_key: [u8; 32]) -> Self {
22        Self {
23            hmac_key,
24            token_size: 32,
25        }
26    }
27
28    /// Sets the length of the token.
29    #[inline]
30    #[must_use]
31    pub fn token_size(mut self, token_size: usize) -> Self {
32        assert!(token_size >= 8, "length must be larger than 8");
33        self.token_size = token_size;
34        self
35    }
36
37    #[inline]
38    fn hmac(&self) -> Hmac<Sha256> {
39        Hmac::<Sha256>::new_from_slice(&self.hmac_key).expect("HMAC can take key of any size")
40    }
41}
42
43impl CsrfCipher for HmacCipher {
44    fn verify(&self, token: &str, proof: &str) -> bool {
45        if let (Ok(token), Ok(proof)) = (
46            URL_SAFE_NO_PAD.decode(token.as_bytes()),
47            URL_SAFE_NO_PAD.decode(proof.as_bytes()),
48        ) {
49            if proof.len() != self.token_size {
50                false
51            } else {
52                let mut hmac = self.hmac();
53                hmac.update(&token);
54                hmac.verify((&*proof).into()).is_ok()
55            }
56        } else {
57            false
58        }
59    }
60    fn generate(&self) -> (String, String) {
61        let token = self.random_bytes(self.token_size);
62        let mut hmac = self.hmac();
63        hmac.update(&token);
64        let mac = hmac.finalize();
65        let proof = mac.into_bytes();
66        (URL_SAFE_NO_PAD.encode(token), URL_SAFE_NO_PAD.encode(proof))
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use base64::Engine;
73    use base64::engine::general_purpose::URL_SAFE_NO_PAD;
74
75    use super::*;
76
77    #[test]
78    fn test_new() {
79        let hmac_key = [0u8; 32];
80        let hmac_cipher = HmacCipher::new(hmac_key);
81        assert_eq!(hmac_cipher.hmac_key, hmac_key);
82        assert_eq!(hmac_cipher.token_size, 32);
83    }
84
85    #[test]
86    fn test_with_token_size() {
87        let hmac_key = [0u8; 32];
88        let hmac_cipher = HmacCipher::new(hmac_key).token_size(16);
89        assert_eq!(hmac_cipher.token_size, 16);
90    }
91
92    #[test]
93    fn test_verify() {
94        let hmac_key = [0u8; 32];
95        let hmac_cipher = HmacCipher::new(hmac_key);
96        let (token, proof) = hmac_cipher.generate();
97        assert!(hmac_cipher.verify(&token, &proof));
98    }
99
100    #[test]
101    fn test_verify_invalid() {
102        let hmac_key = [0u8; 32];
103        let hmac_cipher = HmacCipher::new(hmac_key);
104        let (token, _) = hmac_cipher.generate();
105        let invalid_proof = URL_SAFE_NO_PAD.encode(vec![0u8; hmac_cipher.token_size]);
106        assert!(!hmac_cipher.verify(&token, &invalid_proof));
107    }
108
109    #[test]
110    fn test_generate() {
111        let hmac_key = [0u8; 32];
112        let hmac_cipher = HmacCipher::new(hmac_key);
113        let (token, proof) = hmac_cipher.generate();
114        assert!(hmac_cipher.verify(&token, &proof));
115    }
116}