snapcast_server/
crypto.rs1use chacha20poly1305::aead::{Aead, KeyInit};
7use chacha20poly1305::{ChaCha20Poly1305, Nonce};
8use hkdf::Hkdf;
9use sha2::Sha256;
10
11const NONCE_SIZE: usize = 12;
13
14fn derive_key(psk: &[u8], salt: &[u8]) -> [u8; 32] {
16 let hk = Hkdf::<Sha256>::new(Some(salt), psk);
17 let mut key = [0u8; 32];
18 hk.expand(b"snapcast-f32lz4e", &mut key)
19 .expect("32 bytes is a valid HKDF-SHA256 output length");
20 key
21}
22
23pub struct ChunkEncryptor {
25 cipher: ChaCha20Poly1305,
26 counter: u64,
27}
28
29impl ChunkEncryptor {
30 pub fn new(psk: &str, salt: &[u8]) -> Self {
32 let key = derive_key(psk.as_bytes(), salt);
33 Self {
34 cipher: ChaCha20Poly1305::new(&key.into()),
35 counter: 0,
36 }
37 }
38
39 pub fn encrypt(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, chacha20poly1305::Error> {
41 let mut nonce_bytes = [0u8; NONCE_SIZE];
42 nonce_bytes[..8].copy_from_slice(&self.counter.to_le_bytes());
43 self.counter += 1;
44
45 let nonce = Nonce::from(nonce_bytes);
46 let ciphertext = self.cipher.encrypt(&nonce, plaintext)?;
47
48 let mut out = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
49 out.extend_from_slice(&nonce_bytes);
50 out.extend_from_slice(&ciphertext);
51 Ok(out)
52 }
53}
54
55#[cfg(test)]
58pub struct ChunkDecryptor {
59 cipher: ChaCha20Poly1305,
60}
61
62#[cfg(test)]
63impl ChunkDecryptor {
64 pub fn new(psk: &str, salt: &[u8]) -> Self {
66 let key = derive_key(psk.as_bytes(), salt);
67 Self {
68 cipher: ChaCha20Poly1305::new(&key.into()),
69 }
70 }
71
72 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, chacha20poly1305::Error> {
74 if data.len() < NONCE_SIZE + 16 {
75 return Err(chacha20poly1305::Error);
76 }
77 let nonce = Nonce::from_slice(&data[..NONCE_SIZE]);
78 self.cipher.decrypt(nonce, &data[NONCE_SIZE..])
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn encrypt_decrypt_roundtrip() {
88 let salt = b"test-session-salt";
89 let mut enc = ChunkEncryptor::new("my-secret", salt);
90 let dec = ChunkDecryptor::new("my-secret", salt);
91
92 let plaintext = b"hello audio data";
93 let encrypted = enc.encrypt(plaintext).unwrap();
94
95 assert_eq!(encrypted.len(), NONCE_SIZE + plaintext.len() + 16);
97
98 let decrypted = dec.decrypt(&encrypted).unwrap();
99 assert_eq!(decrypted, plaintext);
100 }
101
102 #[test]
103 fn wrong_key_fails() {
104 let salt = b"test-salt";
105 let mut enc = ChunkEncryptor::new("correct-key", salt);
106 let dec = ChunkDecryptor::new("wrong-key", salt);
107
108 let encrypted = enc.encrypt(b"secret audio").unwrap();
109 assert!(dec.decrypt(&encrypted).is_err());
110 }
111
112 #[test]
113 fn nonce_increments() {
114 let salt = b"nonce-test";
115 let mut enc = ChunkEncryptor::new("key", salt);
116
117 let a = enc.encrypt(b"chunk1").unwrap();
118 let b = enc.encrypt(b"chunk2").unwrap();
119
120 assert_ne!(&a[..NONCE_SIZE], &b[..NONCE_SIZE]);
122 }
123}