1#![allow(clippy::cast_possible_truncation)]
7
8use ring::{hmac, pbkdf2, rand::{SecureRandom, SystemRandom}};
9use subtle::ConstantTimeEq;
10use std::num::NonZeroU32;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13use crate::error::{Result, ShieldError};
14
15const PBKDF2_ITERATIONS: u32 = 100_000;
17
18const NONCE_SIZE: usize = 16;
20
21const MAC_SIZE: usize = 16;
23
24const MIN_CIPHERTEXT_SIZE: usize = NONCE_SIZE + 8 + MAC_SIZE;
26
27#[derive(Zeroize, ZeroizeOnDrop)]
35pub struct Shield {
36 key: [u8; 32],
37 #[zeroize(skip)]
38 #[allow(dead_code)]
39 counter: u64,
40}
41
42impl Shield {
43 #[must_use]
55 pub fn new(password: &str, service: &str) -> Self {
56 let salt = ring::digest::digest(&ring::digest::SHA256, service.as_bytes());
58
59 let mut key = [0u8; 32];
61 pbkdf2::derive(
62 pbkdf2::PBKDF2_HMAC_SHA256,
63 NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
64 salt.as_ref(),
65 password.as_bytes(),
66 &mut key,
67 );
68
69 Self { key, counter: 0 }
70 }
71
72 #[must_use]
74 pub fn with_key(key: [u8; 32]) -> Self {
75 Self { key, counter: 0 }
76 }
77
78 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
85 Self::encrypt_with_key(&self.key, plaintext)
86 }
87
88 pub fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
90 let rng = SystemRandom::new();
91
92 let mut nonce = [0u8; NONCE_SIZE];
94 rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
95
96 let counter_bytes = 0u64.to_le_bytes();
98
99 let mut data_to_encrypt = Vec::with_capacity(8 + plaintext.len());
101 data_to_encrypt.extend_from_slice(&counter_bytes);
102 data_to_encrypt.extend_from_slice(plaintext);
103
104 let keystream = generate_keystream(key, &nonce, data_to_encrypt.len());
106 let ciphertext: Vec<u8> = data_to_encrypt
107 .iter()
108 .zip(keystream.iter())
109 .map(|(p, k)| p ^ k)
110 .collect();
111
112 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
114 let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
115 hmac_data.extend_from_slice(&nonce);
116 hmac_data.extend_from_slice(&ciphertext);
117 let tag = hmac::sign(&hmac_key, &hmac_data);
118
119 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len() + MAC_SIZE);
121 result.extend_from_slice(&nonce);
122 result.extend_from_slice(&ciphertext);
123 result.extend_from_slice(&tag.as_ref()[..MAC_SIZE]);
124
125 Ok(result)
126 }
127
128 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
133 Self::decrypt_with_key(&self.key, encrypted)
134 }
135
136 pub fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
138 if encrypted.len() < MIN_CIPHERTEXT_SIZE {
139 return Err(ShieldError::CiphertextTooShort {
140 expected: MIN_CIPHERTEXT_SIZE,
141 actual: encrypted.len(),
142 });
143 }
144
145 let nonce = &encrypted[..NONCE_SIZE];
147 let ciphertext = &encrypted[NONCE_SIZE..encrypted.len() - MAC_SIZE];
148 let mac = &encrypted[encrypted.len() - MAC_SIZE..];
149
150 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
152 let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
153 hmac_data.extend_from_slice(nonce);
154 hmac_data.extend_from_slice(ciphertext);
155 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
156
157 if mac.ct_eq(&expected_tag.as_ref()[..MAC_SIZE]).unwrap_u8() != 1 {
159 return Err(ShieldError::AuthenticationFailed);
160 }
161
162 let keystream = generate_keystream(key, nonce, ciphertext.len());
164 let decrypted: Vec<u8> = ciphertext
165 .iter()
166 .zip(keystream.iter())
167 .map(|(c, k)| c ^ k)
168 .collect();
169
170 Ok(decrypted[8..].to_vec())
172 }
173
174 #[must_use]
176 pub fn key(&self) -> &[u8; 32] {
177 &self.key
178 }
179}
180
181fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
183 let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
184 let num_blocks = length.div_ceil(32);
185
186 for i in 0..num_blocks {
187 let counter = (i as u32).to_le_bytes();
188
189 let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
191 data.extend_from_slice(key);
192 data.extend_from_slice(nonce);
193 data.extend_from_slice(&counter);
194
195 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
196 keystream.extend_from_slice(hash.as_ref());
197 }
198
199 keystream.truncate(length);
200 keystream
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_keystream_deterministic() {
209 let key = [1u8; 32];
210 let nonce = [2u8; 16];
211
212 let ks1 = generate_keystream(&key, &nonce, 64);
213 let ks2 = generate_keystream(&key, &nonce, 64);
214
215 assert_eq!(ks1, ks2);
216 }
217
218 #[test]
219 fn test_keystream_different_nonce() {
220 let key = [1u8; 32];
221 let nonce1 = [2u8; 16];
222 let nonce2 = [3u8; 16];
223
224 let ks1 = generate_keystream(&key, &nonce1, 32);
225 let ks2 = generate_keystream(&key, &nonce2, 32);
226
227 assert_ne!(ks1, ks2);
228 }
229
230 #[test]
231 fn test_encrypt_format() {
232 let shield = Shield::new("password", "service");
233 let encrypted = shield.encrypt(b"test").unwrap();
234
235 assert_eq!(encrypted.len(), 16 + 8 + 4 + 16);
237 }
238}