1#![allow(clippy::cast_possible_truncation)]
7
8use ring::{
9 hmac, pbkdf2,
10 rand::{SecureRandom, SystemRandom},
11};
12use std::num::NonZeroU32;
13use subtle::ConstantTimeEq;
14use zeroize::{Zeroize, ZeroizeOnDrop};
15
16use crate::error::{Result, ShieldError};
17
18const PBKDF2_ITERATIONS: u32 = 100_000;
20
21const NONCE_SIZE: usize = 16;
23
24const MAC_SIZE: usize = 16;
26
27const MIN_CIPHERTEXT_SIZE: usize = NONCE_SIZE + 8 + MAC_SIZE;
29
30#[derive(Zeroize, ZeroizeOnDrop)]
38pub struct Shield {
39 key: [u8; 32],
40 #[zeroize(skip)]
41 #[allow(dead_code)]
42 counter: u64,
43}
44
45impl Shield {
46 #[must_use]
58 pub fn new(password: &str, service: &str) -> Self {
59 let salt = ring::digest::digest(&ring::digest::SHA256, service.as_bytes());
61
62 let mut key = [0u8; 32];
64 pbkdf2::derive(
65 pbkdf2::PBKDF2_HMAC_SHA256,
66 NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
67 salt.as_ref(),
68 password.as_bytes(),
69 &mut key,
70 );
71
72 Self { key, counter: 0 }
73 }
74
75 #[must_use]
77 pub fn with_key(key: [u8; 32]) -> Self {
78 Self { key, counter: 0 }
79 }
80
81 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
88 Self::encrypt_with_key(&self.key, plaintext)
89 }
90
91 pub fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
93 let rng = SystemRandom::new();
94
95 let mut nonce = [0u8; NONCE_SIZE];
97 rng.fill(&mut nonce)
98 .map_err(|_| ShieldError::RandomFailed)?;
99
100 let counter_bytes = 0u64.to_le_bytes();
102
103 let mut data_to_encrypt = Vec::with_capacity(8 + plaintext.len());
105 data_to_encrypt.extend_from_slice(&counter_bytes);
106 data_to_encrypt.extend_from_slice(plaintext);
107
108 let keystream = generate_keystream(key, &nonce, data_to_encrypt.len());
110 let ciphertext: Vec<u8> = data_to_encrypt
111 .iter()
112 .zip(keystream.iter())
113 .map(|(p, k)| p ^ k)
114 .collect();
115
116 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
118 let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
119 hmac_data.extend_from_slice(&nonce);
120 hmac_data.extend_from_slice(&ciphertext);
121 let tag = hmac::sign(&hmac_key, &hmac_data);
122
123 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len() + MAC_SIZE);
125 result.extend_from_slice(&nonce);
126 result.extend_from_slice(&ciphertext);
127 result.extend_from_slice(&tag.as_ref()[..MAC_SIZE]);
128
129 Ok(result)
130 }
131
132 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
137 Self::decrypt_with_key(&self.key, encrypted)
138 }
139
140 pub fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
142 if encrypted.len() < MIN_CIPHERTEXT_SIZE {
143 return Err(ShieldError::CiphertextTooShort {
144 expected: MIN_CIPHERTEXT_SIZE,
145 actual: encrypted.len(),
146 });
147 }
148
149 let nonce = &encrypted[..NONCE_SIZE];
151 let ciphertext = &encrypted[NONCE_SIZE..encrypted.len() - MAC_SIZE];
152 let mac = &encrypted[encrypted.len() - MAC_SIZE..];
153
154 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
156 let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
157 hmac_data.extend_from_slice(nonce);
158 hmac_data.extend_from_slice(ciphertext);
159 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
160
161 if mac.ct_eq(&expected_tag.as_ref()[..MAC_SIZE]).unwrap_u8() != 1 {
163 return Err(ShieldError::AuthenticationFailed);
164 }
165
166 let keystream = generate_keystream(key, nonce, ciphertext.len());
168 let decrypted: Vec<u8> = ciphertext
169 .iter()
170 .zip(keystream.iter())
171 .map(|(c, k)| c ^ k)
172 .collect();
173
174 Ok(decrypted[8..].to_vec())
176 }
177
178 #[must_use]
180 pub fn key(&self) -> &[u8; 32] {
181 &self.key
182 }
183}
184
185fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
187 let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
188 let num_blocks = length.div_ceil(32);
189
190 for i in 0..num_blocks {
191 let counter = (i as u32).to_le_bytes();
192
193 let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
195 data.extend_from_slice(key);
196 data.extend_from_slice(nonce);
197 data.extend_from_slice(&counter);
198
199 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
200 keystream.extend_from_slice(hash.as_ref());
201 }
202
203 keystream.truncate(length);
204 keystream
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_keystream_deterministic() {
213 let key = [1u8; 32];
214 let nonce = [2u8; 16];
215
216 let ks1 = generate_keystream(&key, &nonce, 64);
217 let ks2 = generate_keystream(&key, &nonce, 64);
218
219 assert_eq!(ks1, ks2);
220 }
221
222 #[test]
223 fn test_keystream_different_nonce() {
224 let key = [1u8; 32];
225 let nonce1 = [2u8; 16];
226 let nonce2 = [3u8; 16];
227
228 let ks1 = generate_keystream(&key, &nonce1, 32);
229 let ks2 = generate_keystream(&key, &nonce2, 32);
230
231 assert_ne!(ks1, ks2);
232 }
233
234 #[test]
235 fn test_encrypt_format() {
236 let shield = Shield::new("password", "service");
237 let encrypted = shield.encrypt(b"test").unwrap();
238
239 assert_eq!(encrypted.len(), 16 + 8 + 4 + 16);
241 }
242}