1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
use aead::generic_array::GenericArray;
use aead::{Aead, KeyInit};
use aes_gcm::Aes256Gcm;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;

use super::CsrfCipher;

/// CSRF protection implementation that uses AES-GCM.
pub struct AesGcmCipher {
    aead_key: [u8; 32],
    token_size: usize,
}

impl AesGcmCipher {
    /// Given an aead key, return an `AesGcmCipher` instance.
    #[inline]
    pub fn new(aead_key: [u8; 32]) -> Self {
        Self {
            aead_key,
            token_size: 32,
        }
    }

    /// Sets the length of the token.
    #[inline]
    pub fn token_size(mut self, token_size: usize) -> Self {
        assert!(token_size >= 8, "length must be larger than 8");
        self.token_size = token_size;
        self
    }

    #[inline]
    fn aead(&self) -> Aes256Gcm {
        let key = GenericArray::clone_from_slice(&self.aead_key);
        Aes256Gcm::new(&key)
    }
}

impl CsrfCipher for AesGcmCipher {
    fn verify(&self, token: &str, proof: &str) -> bool {
        if let (Ok(token), Ok(proof)) = (
            URL_SAFE_NO_PAD.decode(token.as_bytes()),
            URL_SAFE_NO_PAD.decode(proof.as_bytes()),
        ) {
            if token.len() < 8 || proof.len() < 20 {
                false
            } else {
                let nonce = GenericArray::from_slice(&proof[0..12]);
                let aead = self.aead();
                aead.decrypt(nonce, &proof[12..]).map(|p| p == token).unwrap_or(false)
            }
        } else {
            false
        }
    }
    fn generate(&self) -> (String, String) {
        let token = self.random_bytes(self.token_size);
        let aead = self.aead();
        let mut proof = self.random_bytes(12);
        let nonce = GenericArray::from_slice(&proof);
        proof.append(&mut aead.encrypt(nonce, token.as_slice()).unwrap());
        (URL_SAFE_NO_PAD.encode(token), URL_SAFE_NO_PAD.encode(proof))
    }
}

#[cfg(test)]
mod tests {
    use base64::engine::general_purpose::URL_SAFE_NO_PAD;
    use base64::Engine;

    use super::AesGcmCipher;
    use super::CsrfCipher;

    #[test]
    fn test_aes_gcm_cipher() {
        let aead_key = [0u8; 32];
        let cipher = AesGcmCipher::new(aead_key);

        let (token, proof) = cipher.generate();
        assert!(cipher.verify(&token, &proof));

        let invalid_proof = URL_SAFE_NO_PAD.encode(vec![0u8; proof.len()]);
        assert!(!cipher.verify(&token, &invalid_proof));

        let invalid_token = URL_SAFE_NO_PAD.encode(vec![0u8; token.len()]);
        assert!(!cipher.verify(&invalid_token, &proof));
    }
}