1use chacha20poly1305::{
16 aead::{Aead, KeyInit, Payload},
17 XChaCha20Poly1305, Key, XNonce,
18};
19use hkdf::Hkdf;
20use sha2::{Digest, Sha256, Sha512};
21
22use crate::{Error, Result};
23
24const AUTH_KEY_INFO: &[u8] = b"totalreclaw-auth-key-v1";
29const ENCRYPTION_KEY_INFO: &[u8] = b"totalreclaw-encryption-key-v1";
30const DEDUP_KEY_INFO: &[u8] = b"openmemory-dedup-v1";
31const LSH_SEED_INFO: &[u8] = b"openmemory-lsh-seed-v1";
32
33const NONCE_LENGTH: usize = 24;
34const TAG_LENGTH: usize = 16;
35
36#[derive(Clone, Debug)]
42pub struct DerivedKeys {
43 pub auth_key: [u8; 32],
44 pub encryption_key: [u8; 32],
45 pub dedup_key: [u8; 32],
46 pub salt: [u8; 32],
47}
48
49fn mnemonic_to_seed(mnemonic: &str) -> Result<[u8; 64]> {
54 let trimmed = mnemonic.trim();
56 bip39::Mnemonic::parse(trimmed).map_err(|e| {
57 Error::InvalidMnemonic(format!("invalid BIP-39 mnemonic: {}", e))
58 })?;
59
60 pbkdf2_seed(trimmed)
61}
62
63fn mnemonic_to_seed_lenient(mnemonic: &str) -> Result<[u8; 64]> {
69 let trimmed = mnemonic.trim();
70 let words: Vec<&str> = trimmed.split_whitespace().collect();
71
72 if words.len() != 12 && words.len() != 24 {
74 return Err(Error::InvalidMnemonic(format!(
75 "expected 12 or 24 words, got {}",
76 words.len()
77 )));
78 }
79
80 let lang = bip39::Language::English;
82 for word in &words {
83 if lang.find_word(word).is_none() {
84 return Err(Error::InvalidMnemonic(format!(
85 "word '{}' not in BIP-39 English wordlist",
86 word
87 )));
88 }
89 }
90
91 pbkdf2_seed(trimmed)
93}
94
95fn pbkdf2_seed(mnemonic: &str) -> Result<[u8; 64]> {
97 let salt = b"mnemonic";
98 let mut seed = [0u8; 64];
99 pbkdf2::pbkdf2_hmac::<Sha512>(mnemonic.as_bytes(), salt, 2048, &mut seed);
100 Ok(seed)
101}
102
103pub fn derive_keys_from_mnemonic(mnemonic: &str) -> Result<DerivedKeys> {
107 let seed = mnemonic_to_seed(mnemonic)?;
108 derive_keys_from_seed(&seed)
109}
110
111pub fn derive_keys_from_mnemonic_lenient(mnemonic: &str) -> Result<DerivedKeys> {
116 let seed = mnemonic_to_seed_lenient(mnemonic)?;
117 derive_keys_from_seed(&seed)
118}
119
120fn derive_keys_from_seed(seed: &[u8; 64]) -> Result<DerivedKeys> {
122 let mut salt = [0u8; 32];
123 salt.copy_from_slice(&seed[..32]);
124
125 let auth_key = hkdf_sha256(seed, &salt, AUTH_KEY_INFO)?;
126 let encryption_key = hkdf_sha256(seed, &salt, ENCRYPTION_KEY_INFO)?;
127 let dedup_key = hkdf_sha256(seed, &salt, DEDUP_KEY_INFO)?;
128
129 Ok(DerivedKeys {
130 auth_key,
131 encryption_key,
132 dedup_key,
133 salt,
134 })
135}
136
137pub fn mnemonic_to_seed_bytes(mnemonic: &str) -> Result<[u8; 64]> {
140 mnemonic_to_seed(mnemonic)
141}
142
143pub fn derive_lsh_seed(mnemonic: &str, salt: &[u8; 32]) -> Result<[u8; 32]> {
147 let seed = mnemonic_to_seed(mnemonic)?;
148 hkdf_sha256(&seed, salt, LSH_SEED_INFO)
149}
150
151pub fn compute_auth_key_hash(auth_key: &[u8; 32]) -> String {
155 let hash = Sha256::digest(auth_key);
156 hex::encode(hash)
157}
158
159fn hkdf_sha256(ikm: &[u8], salt: &[u8], info: &[u8]) -> Result<[u8; 32]> {
161 let hk = Hkdf::<Sha256>::new(Some(salt), ikm);
162 let mut okm = [0u8; 32];
163 hk.expand(info, &mut okm)
164 .map_err(|e| Error::Crypto(format!("HKDF expand failed: {}", e)))?;
165 Ok(okm)
166}
167
168pub fn encrypt(plaintext: &str, encryption_key: &[u8; 32]) -> Result<String> {
178 let nonce_bytes: [u8; NONCE_LENGTH] = rand::random();
179 encrypt_with_nonce(plaintext, encryption_key, &nonce_bytes)
180}
181
182pub fn encrypt_with_nonce(
184 plaintext: &str,
185 encryption_key: &[u8; 32],
186 nonce: &[u8; NONCE_LENGTH],
187) -> Result<String> {
188 let key = Key::from_slice(encryption_key);
189 let cipher = XChaCha20Poly1305::new(key);
190 let xnonce = XNonce::from_slice(nonce);
191
192 let ciphertext_with_tag = cipher
193 .encrypt(xnonce, Payload { msg: plaintext.as_bytes(), aad: b"" })
194 .map_err(|e| Error::Crypto(format!("XChaCha20-Poly1305 encrypt failed: {}", e)))?;
195
196 let ct_len = ciphertext_with_tag.len() - TAG_LENGTH;
199 let ciphertext = &ciphertext_with_tag[..ct_len];
200 let tag = &ciphertext_with_tag[ct_len..];
201
202 let mut combined = Vec::with_capacity(NONCE_LENGTH + TAG_LENGTH + ct_len);
203 combined.extend_from_slice(nonce);
204 combined.extend_from_slice(tag);
205 combined.extend_from_slice(ciphertext);
206
207 use base64::Engine;
208 Ok(base64::engine::general_purpose::STANDARD.encode(&combined))
209}
210
211pub fn decrypt(encrypted_base64: &str, encryption_key: &[u8; 32]) -> Result<String> {
215 use base64::Engine;
216 let combined = base64::engine::general_purpose::STANDARD
217 .decode(encrypted_base64)
218 .map_err(|e| Error::Crypto(format!("base64 decode failed: {}", e)))?;
219
220 if combined.len() < NONCE_LENGTH + TAG_LENGTH {
221 return Err(Error::Crypto("Encrypted data too short".into()));
222 }
223
224 let nonce = &combined[..NONCE_LENGTH];
225 let tag = &combined[NONCE_LENGTH..NONCE_LENGTH + TAG_LENGTH];
226 let ciphertext = &combined[NONCE_LENGTH + TAG_LENGTH..];
227
228 let mut ct_with_tag = Vec::with_capacity(ciphertext.len() + TAG_LENGTH);
230 ct_with_tag.extend_from_slice(ciphertext);
231 ct_with_tag.extend_from_slice(tag);
232
233 let key = Key::from_slice(encryption_key);
234 let cipher = XChaCha20Poly1305::new(key);
235 let xnonce = XNonce::from_slice(nonce);
236
237 let plaintext_bytes = cipher
238 .decrypt(xnonce, Payload { msg: &ct_with_tag, aad: b"" })
239 .map_err(|e| Error::Crypto(format!("XChaCha20-Poly1305 decrypt failed: {}", e)))?;
240
241 String::from_utf8(plaintext_bytes)
242 .map_err(|e| Error::Crypto(format!("UTF-8 decode failed: {}", e)))
243}
244
245pub fn derive_random_bytes(seed: &[u8], base_info: &str, length: usize) -> Result<Vec<u8>> {
256 const MAX_HKDF_OUTPUT: usize = 255 * 32;
257 let mut result = vec![0u8; length];
258 let mut offset = 0;
259 let mut block_index = 0;
260
261 while offset < length {
262 let remaining = length - offset;
263 let chunk_len = remaining.min(MAX_HKDF_OUTPUT);
264 let info = format!("{}_block_{}", base_info, block_index);
265
266 let hk = Hkdf::<Sha256>::new(Some(&[]), seed);
268 hk.expand(info.as_bytes(), &mut result[offset..offset + chunk_len])
269 .map_err(|e| Error::Crypto(format!("HKDF expand failed: {}", e)))?;
270
271 offset += chunk_len;
272 block_index += 1;
273 }
274
275 Ok(result)
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn test_key_derivation_parity() {
284 let fixture: serde_json::Value = serde_json::from_str(
285 include_str!("../tests/fixtures/crypto_vectors.json"),
286 )
287 .unwrap();
288
289 let kd = &fixture["key_derivation"];
290 let mnemonic = kd["mnemonic"].as_str().unwrap();
291
292 let keys = derive_keys_from_mnemonic(mnemonic).unwrap();
293
294 assert_eq!(hex::encode(keys.salt), kd["salt_hex"].as_str().unwrap());
295 assert_eq!(hex::encode(keys.auth_key), kd["auth_key_hex"].as_str().unwrap());
296 assert_eq!(
297 hex::encode(keys.encryption_key),
298 kd["encryption_key_hex"].as_str().unwrap()
299 );
300 assert_eq!(hex::encode(keys.dedup_key), kd["dedup_key_hex"].as_str().unwrap());
301
302 let hash = compute_auth_key_hash(&keys.auth_key);
304 assert_eq!(hash, kd["auth_key_hash"].as_str().unwrap());
305 }
306
307 #[test]
308 fn test_bip39_seed_parity() {
309 let fixture: serde_json::Value = serde_json::from_str(
310 include_str!("../tests/fixtures/crypto_vectors.json"),
311 )
312 .unwrap();
313
314 let mnemonic = fixture["key_derivation"]["mnemonic"].as_str().unwrap();
315 let expected_seed_hex = fixture["key_derivation"]["bip39_seed_hex"].as_str().unwrap();
316
317 let seed = mnemonic_to_seed(mnemonic).unwrap();
318 assert_eq!(hex::encode(seed), expected_seed_hex);
319 }
320
321 #[test]
322 fn test_lsh_seed_parity() {
323 let fixture: serde_json::Value = serde_json::from_str(
324 include_str!("../tests/fixtures/crypto_vectors.json"),
325 )
326 .unwrap();
327
328 let mnemonic = fixture["key_derivation"]["mnemonic"].as_str().unwrap();
329 let keys = derive_keys_from_mnemonic(mnemonic).unwrap();
330 let lsh_seed = derive_lsh_seed(mnemonic, &keys.salt).unwrap();
331
332 assert_eq!(
333 hex::encode(lsh_seed),
334 fixture["lsh"]["lsh_seed_hex"].as_str().unwrap()
335 );
336 }
337
338 #[test]
339 fn test_xchacha_fixed_nonce_parity() {
340 let fixture: serde_json::Value = serde_json::from_str(
341 include_str!("../tests/fixtures/crypto_vectors.json"),
342 )
343 .unwrap();
344
345 let xc = &fixture["xchacha20"];
346 let key_hex = xc["encryption_key_hex"].as_str().unwrap();
347 let key_bytes = hex::decode(key_hex).unwrap();
348 let mut key = [0u8; 32];
349 key.copy_from_slice(&key_bytes);
350
351 let nonce_hex = xc["fixed_nonce_hex"].as_str().unwrap();
352 let nonce_bytes = hex::decode(nonce_hex).unwrap();
353 let mut nonce = [0u8; 24];
354 nonce.copy_from_slice(&nonce_bytes);
355
356 let plaintext = xc["plaintext"].as_str().unwrap();
357 let expected_b64 = xc["fixed_nonce_encrypted_base64"].as_str().unwrap();
358
359 let encrypted = encrypt_with_nonce(plaintext, &key, &nonce).unwrap();
360 assert_eq!(encrypted, expected_b64);
361
362 let decrypted = decrypt(&encrypted, &key).unwrap();
363 assert_eq!(decrypted, plaintext);
364 }
365
366 #[test]
367 fn test_xchacha_round_trip() {
368 let keys = derive_keys_from_mnemonic(
369 "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
370 )
371 .unwrap();
372
373 let plaintext = "Hello, TotalReclaw!";
374 let encrypted = encrypt(plaintext, &keys.encryption_key).unwrap();
375 let decrypted = decrypt(&encrypted, &keys.encryption_key).unwrap();
376 assert_eq!(decrypted, plaintext);
377 }
378
379 #[test]
380 fn test_lenient_accepts_valid_mnemonic() {
381 let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
383 let strict = derive_keys_from_mnemonic(mnemonic).unwrap();
384 let lenient = derive_keys_from_mnemonic_lenient(mnemonic).unwrap();
385
386 assert_eq!(strict.auth_key, lenient.auth_key);
387 assert_eq!(strict.encryption_key, lenient.encryption_key);
388 assert_eq!(strict.dedup_key, lenient.dedup_key);
389 assert_eq!(strict.salt, lenient.salt);
390 }
391
392 #[test]
393 fn test_lenient_rejects_invalid_words() {
394 let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon xyzzy";
395 let result = derive_keys_from_mnemonic_lenient(mnemonic);
396 assert!(result.is_err());
397 let err = result.unwrap_err().to_string();
398 assert!(err.contains("xyzzy"));
399 }
400
401 #[test]
402 fn test_lenient_rejects_wrong_word_count() {
403 let mnemonic = "abandon abandon abandon"; let result = derive_keys_from_mnemonic_lenient(mnemonic);
405 assert!(result.is_err());
406 let err = result.unwrap_err().to_string();
407 assert!(err.contains("expected 12 or 24"));
408 }
409
410 #[test]
411 fn test_strict_rejects_bad_checksum() {
412 let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
414 let result = derive_keys_from_mnemonic(mnemonic);
415 assert!(result.is_err());
416 }
417
418 #[test]
419 fn test_lenient_accepts_bad_checksum() {
420 let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
422 let result = derive_keys_from_mnemonic_lenient(mnemonic);
423 assert!(result.is_ok());
424 }
425
426}