salvo_csrf/
hmac_cipher.rs1use std::fmt::Debug;
2
3use base64::Engine;
4use base64::engine::general_purpose::URL_SAFE_NO_PAD;
5use hmac::{Hmac, Mac};
6use sha2::Sha256;
7
8use super::CsrfCipher;
9
10#[derive(Debug, Clone)]
12pub struct HmacCipher {
13 hmac_key: [u8; 32],
14 token_size: usize,
15}
16
17impl HmacCipher {
18 #[inline]
20 #[must_use]
21 pub fn new(hmac_key: [u8; 32]) -> Self {
22 Self {
23 hmac_key,
24 token_size: 32,
25 }
26 }
27
28 #[inline]
30 #[must_use]
31 pub fn token_size(mut self, token_size: usize) -> Self {
32 assert!(token_size >= 8, "length must be larger than 8");
33 self.token_size = token_size;
34 self
35 }
36
37 #[inline]
38 fn hmac(&self) -> Hmac<Sha256> {
39 Hmac::<Sha256>::new_from_slice(&self.hmac_key).expect("HMAC can take key of any size")
40 }
41}
42
43impl CsrfCipher for HmacCipher {
44 fn verify(&self, token: &str, proof: &str) -> bool {
45 if let (Ok(token), Ok(proof)) = (
46 URL_SAFE_NO_PAD.decode(token.as_bytes()),
47 URL_SAFE_NO_PAD.decode(proof.as_bytes()),
48 ) {
49 if proof.len() != self.token_size {
50 false
51 } else {
52 let mut hmac = self.hmac();
53 hmac.update(&token);
54 hmac.verify((&*proof).into()).is_ok()
55 }
56 } else {
57 false
58 }
59 }
60 fn generate(&self) -> (String, String) {
61 let token = self.random_bytes(self.token_size);
62 let mut hmac = self.hmac();
63 hmac.update(&token);
64 let mac = hmac.finalize();
65 let proof = mac.into_bytes();
66 (URL_SAFE_NO_PAD.encode(token), URL_SAFE_NO_PAD.encode(proof))
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use base64::Engine;
73 use base64::engine::general_purpose::URL_SAFE_NO_PAD;
74
75 use super::*;
76
77 #[test]
78 fn test_new() {
79 let hmac_key = [0u8; 32];
80 let hmac_cipher = HmacCipher::new(hmac_key);
81 assert_eq!(hmac_cipher.hmac_key, hmac_key);
82 assert_eq!(hmac_cipher.token_size, 32);
83 }
84
85 #[test]
86 fn test_with_token_size() {
87 let hmac_key = [0u8; 32];
88 let hmac_cipher = HmacCipher::new(hmac_key).token_size(16);
89 assert_eq!(hmac_cipher.token_size, 16);
90 }
91
92 #[test]
93 fn test_verify() {
94 let hmac_key = [0u8; 32];
95 let hmac_cipher = HmacCipher::new(hmac_key);
96 let (token, proof) = hmac_cipher.generate();
97 assert!(hmac_cipher.verify(&token, &proof));
98 }
99
100 #[test]
101 fn test_verify_invalid() {
102 let hmac_key = [0u8; 32];
103 let hmac_cipher = HmacCipher::new(hmac_key);
104 let (token, _) = hmac_cipher.generate();
105 let invalid_proof = URL_SAFE_NO_PAD.encode(vec![0u8; hmac_cipher.token_size]);
106 assert!(!hmac_cipher.verify(&token, &invalid_proof));
107 }
108
109 #[test]
110 fn test_generate() {
111 let hmac_key = [0u8; 32];
112 let hmac_cipher = HmacCipher::new(hmac_key);
113 let (token, proof) = hmac_cipher.generate();
114 assert!(hmac_cipher.verify(&token, &proof));
115 }
116}