Skip to main content

snapcast_server/
crypto.rs

1//! ChaCha20-Poly1305 encryption for audio chunks.
2//!
3//! Derives a 256-bit key from a pre-shared key via HKDF-SHA256.
4//! Each chunk is encrypted with a unique nonce (counter-based).
5
6use chacha20poly1305::aead::{Aead, KeyInit};
7use chacha20poly1305::{ChaCha20Poly1305, Nonce};
8use hkdf::Hkdf;
9use sha2::Sha256;
10
11/// Nonce size (12 bytes).
12const NONCE_SIZE: usize = 12;
13
14/// Derives a 256-bit encryption key from a PSK and salt via HKDF-SHA256.
15fn derive_key(psk: &[u8], salt: &[u8]) -> [u8; 32] {
16    let hk = Hkdf::<Sha256>::new(Some(salt), psk);
17    let mut key = [0u8; 32];
18    hk.expand(b"snapcast-f32lz4e", &mut key)
19        .expect("32 bytes is a valid HKDF-SHA256 output length");
20    key
21}
22
23/// Audio chunk encryptor.
24pub struct ChunkEncryptor {
25    cipher: ChaCha20Poly1305,
26    counter: u64,
27}
28
29impl ChunkEncryptor {
30    /// Create from PSK and session salt.
31    pub fn new(psk: &str, salt: &[u8]) -> Self {
32        let key = derive_key(psk.as_bytes(), salt);
33        Self {
34            cipher: ChaCha20Poly1305::new(&key.into()),
35            counter: 0,
36        }
37    }
38
39    /// Encrypt a chunk. Returns `[12-byte nonce][ciphertext + 16-byte tag]`.
40    pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, chacha20poly1305::Error> {
41        let mut nonce_bytes = [0u8; NONCE_SIZE];
42        nonce_bytes[..8].copy_from_slice(&self.counter.to_le_bytes());
43        self.counter += 1;
44
45        let nonce = Nonce::from(nonce_bytes);
46        let ciphertext = self.cipher.encrypt(&nonce, plaintext)?;
47
48        let mut out = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
49        out.extend_from_slice(&nonce_bytes);
50        out.extend_from_slice(&ciphertext);
51        Ok(out)
52    }
53}
54
55/// Audio chunk decryptor.
56pub struct ChunkDecryptor {
57    cipher: ChaCha20Poly1305,
58}
59
60impl ChunkDecryptor {
61    /// Create from PSK and session salt.
62    pub fn new(psk: &str, salt: &[u8]) -> Self {
63        let key = derive_key(psk.as_bytes(), salt);
64        Self {
65            cipher: ChaCha20Poly1305::new(&key.into()),
66        }
67    }
68
69    /// Decrypt a chunk. Input: `[12-byte nonce][ciphertext + 16-byte tag]`.
70    pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, chacha20poly1305::Error> {
71        if data.len() < NONCE_SIZE + 16 {
72            return Err(chacha20poly1305::Error);
73        }
74        let nonce = Nonce::from_slice(&data[..NONCE_SIZE]);
75        self.cipher.decrypt(nonce, &data[NONCE_SIZE..])
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn encrypt_decrypt_roundtrip() {
85        let salt = b"test-session-salt";
86        let mut enc = ChunkEncryptor::new("my-secret", salt);
87        let dec = ChunkDecryptor::new("my-secret", salt);
88
89        let plaintext = b"hello audio data";
90        let encrypted = enc.encrypt(plaintext).unwrap();
91
92        // 12 nonce + 16 plaintext + 16 tag = 44
93        assert_eq!(encrypted.len(), NONCE_SIZE + plaintext.len() + 16);
94
95        let decrypted = dec.decrypt(&encrypted).unwrap();
96        assert_eq!(decrypted, plaintext);
97    }
98
99    #[test]
100    fn wrong_key_fails() {
101        let salt = b"test-salt";
102        let mut enc = ChunkEncryptor::new("correct-key", salt);
103        let dec = ChunkDecryptor::new("wrong-key", salt);
104
105        let encrypted = enc.encrypt(b"secret audio").unwrap();
106        assert!(dec.decrypt(&encrypted).is_err());
107    }
108
109    #[test]
110    fn nonce_increments() {
111        let salt = b"nonce-test";
112        let mut enc = ChunkEncryptor::new("key", salt);
113
114        let a = enc.encrypt(b"chunk1").unwrap();
115        let b = enc.encrypt(b"chunk2").unwrap();
116
117        // Nonces should differ (first 12 bytes)
118        assert_ne!(&a[..NONCE_SIZE], &b[..NONCE_SIZE]);
119    }
120}