Skip to main content

totalreclaw_core/
crypto.rs

1//! Key derivation and XChaCha20-Poly1305 encryption.
2//!
3//! Matches the TypeScript implementation in `mcp/src/subgraph/crypto.ts` byte-for-byte.
4//!
5//! Key derivation chain (BIP-39 path):
6//!   mnemonic -> PBKDF2-HMAC-SHA512(mnemonic, "mnemonic", 2048) -> 64-byte seed
7//!   salt = seed[0..32]
8//!   HKDF-SHA256(seed, salt, "totalreclaw-auth-key-v1", 32)       -> authKey
9//!   HKDF-SHA256(seed, salt, "totalreclaw-encryption-key-v1", 32) -> encryptionKey
10//!   HKDF-SHA256(seed, salt, "openmemory-dedup-v1", 32)           -> dedupKey
11//!   HKDF-SHA256(seed, salt, "openmemory-lsh-seed-v1", 32)        -> lshSeed
12//!
13//! XChaCha20-Poly1305 wire format: nonce(24) || tag(16) || ciphertext -> base64
14
15use chacha20poly1305::{
16    aead::{Aead, KeyInit, Payload},
17    XChaCha20Poly1305, Key, XNonce,
18};
19use hkdf::Hkdf;
20use sha2::{Digest, Sha256, Sha512};
21
22use crate::{Error, Result};
23
24// ---------------------------------------------------------------------------
25// Constants
26// ---------------------------------------------------------------------------
27
28const AUTH_KEY_INFO: &[u8] = b"totalreclaw-auth-key-v1";
29const ENCRYPTION_KEY_INFO: &[u8] = b"totalreclaw-encryption-key-v1";
30const DEDUP_KEY_INFO: &[u8] = b"openmemory-dedup-v1";
31const LSH_SEED_INFO: &[u8] = b"openmemory-lsh-seed-v1";
32
33const NONCE_LENGTH: usize = 24;
34const TAG_LENGTH: usize = 16;
35
36// ---------------------------------------------------------------------------
37// Key material
38// ---------------------------------------------------------------------------
39
40/// Derived key material from a BIP-39 mnemonic.
41#[derive(Clone, Debug)]
42pub struct DerivedKeys {
43    pub auth_key: [u8; 32],
44    pub encryption_key: [u8; 32],
45    pub dedup_key: [u8; 32],
46    pub salt: [u8; 32],
47}
48
49/// Derive the 64-byte BIP-39 seed from a mnemonic phrase (strict mode).
50///
51/// Uses PBKDF2-HMAC-SHA512 with passphrase="mnemonic" and 2048 iterations,
52/// matching the BIP-39 spec and `@scure/bip39`'s `mnemonicToSeedSync`.
53fn mnemonic_to_seed(mnemonic: &str) -> Result<[u8; 64]> {
54    // Validate BIP-39 mnemonic (strict: checksum must be valid).
55    let trimmed = mnemonic.trim();
56    bip39::Mnemonic::parse(trimmed).map_err(|e| {
57        Error::InvalidMnemonic(format!("invalid BIP-39 mnemonic: {}", e))
58    })?;
59
60    pbkdf2_seed(trimmed)
61}
62
63/// Derive the 64-byte BIP-39 seed from a mnemonic phrase (lenient mode).
64///
65/// Validates that all words are in the BIP-39 English wordlist but does NOT
66/// check the checksum. This allows LLM-generated mnemonics where the words
67/// are valid but the checksum is wrong.
68fn mnemonic_to_seed_lenient(mnemonic: &str) -> Result<[u8; 64]> {
69    let trimmed = mnemonic.trim();
70    let words: Vec<&str> = trimmed.split_whitespace().collect();
71
72    // Must be 12 or 24 words
73    if words.len() != 12 && words.len() != 24 {
74        return Err(Error::InvalidMnemonic(format!(
75            "expected 12 or 24 words, got {}",
76            words.len()
77        )));
78    }
79
80    // Validate each word is in the BIP-39 English wordlist
81    let lang = bip39::Language::English;
82    for word in &words {
83        if lang.find_word(word).is_none() {
84            return Err(Error::InvalidMnemonic(format!(
85                "word '{}' not in BIP-39 English wordlist",
86                word
87            )));
88        }
89    }
90
91    // Derive seed (skip checksum validation)
92    pbkdf2_seed(trimmed)
93}
94
95/// PBKDF2-HMAC-SHA512 seed derivation (shared between strict and lenient).
96fn pbkdf2_seed(mnemonic: &str) -> Result<[u8; 64]> {
97    let salt = b"mnemonic";
98    let mut seed = [0u8; 64];
99    pbkdf2::pbkdf2_hmac::<Sha512>(mnemonic.as_bytes(), salt, 2048, &mut seed);
100    Ok(seed)
101}
102
103/// Derive encryption keys from a BIP-39 mnemonic (strict checksum validation).
104///
105/// Matches `deriveKeysFromMnemonic()` in `mcp/src/subgraph/crypto.ts`.
106pub fn derive_keys_from_mnemonic(mnemonic: &str) -> Result<DerivedKeys> {
107    let seed = mnemonic_to_seed(mnemonic)?;
108    derive_keys_from_seed(&seed)
109}
110
111/// Derive encryption keys from a BIP-39 mnemonic (lenient — skips checksum).
112///
113/// Validates words are in the BIP-39 wordlist but accepts invalid checksums.
114/// Use this for LLM-generated mnemonics.
115pub fn derive_keys_from_mnemonic_lenient(mnemonic: &str) -> Result<DerivedKeys> {
116    let seed = mnemonic_to_seed_lenient(mnemonic)?;
117    derive_keys_from_seed(&seed)
118}
119
120/// Internal: derive keys from a 64-byte seed.
121fn derive_keys_from_seed(seed: &[u8; 64]) -> Result<DerivedKeys> {
122    let mut salt = [0u8; 32];
123    salt.copy_from_slice(&seed[..32]);
124
125    let auth_key = hkdf_sha256(seed, &salt, AUTH_KEY_INFO)?;
126    let encryption_key = hkdf_sha256(seed, &salt, ENCRYPTION_KEY_INFO)?;
127    let dedup_key = hkdf_sha256(seed, &salt, DEDUP_KEY_INFO)?;
128
129    Ok(DerivedKeys {
130        auth_key,
131        encryption_key,
132        dedup_key,
133        salt,
134    })
135}
136
137/// Public access to the raw BIP-39 seed bytes (64 bytes).
138/// Used by wallet derivation (BIP-32/BIP-44).
139pub fn mnemonic_to_seed_bytes(mnemonic: &str) -> Result<[u8; 64]> {
140    mnemonic_to_seed(mnemonic)
141}
142
143/// Derive the 32-byte LSH seed from a BIP-39 mnemonic.
144///
145/// Matches `deriveLshSeed()` in `mcp/src/subgraph/crypto.ts` (BIP-39 path).
146pub fn derive_lsh_seed(mnemonic: &str, salt: &[u8; 32]) -> Result<[u8; 32]> {
147    let seed = mnemonic_to_seed(mnemonic)?;
148    hkdf_sha256(&seed, salt, LSH_SEED_INFO)
149}
150
151/// Compute SHA-256(authKey) as a hex string.
152///
153/// Matches `computeAuthKeyHash()` in `mcp/src/subgraph/crypto.ts`.
154pub fn compute_auth_key_hash(auth_key: &[u8; 32]) -> String {
155    let hash = Sha256::digest(auth_key);
156    hex::encode(hash)
157}
158
159/// Single HKDF-SHA256 expand producing 32 bytes.
160fn hkdf_sha256(ikm: &[u8], salt: &[u8], info: &[u8]) -> Result<[u8; 32]> {
161    let hk = Hkdf::<Sha256>::new(Some(salt), ikm);
162    let mut okm = [0u8; 32];
163    hk.expand(info, &mut okm)
164        .map_err(|e| Error::Crypto(format!("HKDF expand failed: {}", e)))?;
165    Ok(okm)
166}
167
168// ---------------------------------------------------------------------------
169// XChaCha20-Poly1305
170// ---------------------------------------------------------------------------
171
172/// Encrypt a UTF-8 plaintext string with XChaCha20-Poly1305.
173///
174/// Wire format (base64-encoded): nonce(24) || tag(16) || ciphertext
175///
176/// Uses a random 24-byte nonce.
177pub fn encrypt(plaintext: &str, encryption_key: &[u8; 32]) -> Result<String> {
178    let nonce_bytes: [u8; NONCE_LENGTH] = rand::random();
179    encrypt_with_nonce(plaintext, encryption_key, &nonce_bytes)
180}
181
182/// Encrypt with a specific nonce (for deterministic testing).
183pub fn encrypt_with_nonce(
184    plaintext: &str,
185    encryption_key: &[u8; 32],
186    nonce: &[u8; NONCE_LENGTH],
187) -> Result<String> {
188    let key = Key::from_slice(encryption_key);
189    let cipher = XChaCha20Poly1305::new(key);
190    let xnonce = XNonce::from_slice(nonce);
191
192    let ciphertext_with_tag = cipher
193        .encrypt(xnonce, Payload { msg: plaintext.as_bytes(), aad: b"" })
194        .map_err(|e| Error::Crypto(format!("XChaCha20-Poly1305 encrypt failed: {}", e)))?;
195
196    // chacha20poly1305 appends the tag at the end: ciphertext || tag
197    // We need wire format: nonce || tag || ciphertext
198    let ct_len = ciphertext_with_tag.len() - TAG_LENGTH;
199    let ciphertext = &ciphertext_with_tag[..ct_len];
200    let tag = &ciphertext_with_tag[ct_len..];
201
202    let mut combined = Vec::with_capacity(NONCE_LENGTH + TAG_LENGTH + ct_len);
203    combined.extend_from_slice(nonce);
204    combined.extend_from_slice(tag);
205    combined.extend_from_slice(ciphertext);
206
207    use base64::Engine;
208    Ok(base64::engine::general_purpose::STANDARD.encode(&combined))
209}
210
211/// Decrypt a base64-encoded XChaCha20-Poly1305 blob back to a UTF-8 string.
212///
213/// Expects wire format: nonce(24) || tag(16) || ciphertext
214pub fn decrypt(encrypted_base64: &str, encryption_key: &[u8; 32]) -> Result<String> {
215    use base64::Engine;
216    let combined = base64::engine::general_purpose::STANDARD
217        .decode(encrypted_base64)
218        .map_err(|e| Error::Crypto(format!("base64 decode failed: {}", e)))?;
219
220    if combined.len() < NONCE_LENGTH + TAG_LENGTH {
221        return Err(Error::Crypto("Encrypted data too short".into()));
222    }
223
224    let nonce = &combined[..NONCE_LENGTH];
225    let tag = &combined[NONCE_LENGTH..NONCE_LENGTH + TAG_LENGTH];
226    let ciphertext = &combined[NONCE_LENGTH + TAG_LENGTH..];
227
228    // chacha20poly1305 expects: ciphertext || tag
229    let mut ct_with_tag = Vec::with_capacity(ciphertext.len() + TAG_LENGTH);
230    ct_with_tag.extend_from_slice(ciphertext);
231    ct_with_tag.extend_from_slice(tag);
232
233    let key = Key::from_slice(encryption_key);
234    let cipher = XChaCha20Poly1305::new(key);
235    let xnonce = XNonce::from_slice(nonce);
236
237    let plaintext_bytes = cipher
238        .decrypt(xnonce, Payload { msg: &ct_with_tag, aad: b"" })
239        .map_err(|e| Error::Crypto(format!("XChaCha20-Poly1305 decrypt failed: {}", e)))?;
240
241    String::from_utf8(plaintext_bytes)
242        .map_err(|e| Error::Crypto(format!("UTF-8 decode failed: {}", e)))
243}
244
245// ---------------------------------------------------------------------------
246// HKDF utilities for LSH (chunked output)
247// ---------------------------------------------------------------------------
248
249/// Derive `length` pseudo-random bytes using chunked HKDF-SHA256.
250///
251/// A single HKDF-SHA256 call can output at most 255 * 32 = 8160 bytes.
252/// For larger outputs we iterate over sub-block indices in the info string.
253///
254/// Matches `deriveRandomBytes()` in `mcp/src/subgraph/lsh.ts`.
255pub fn derive_random_bytes(seed: &[u8], base_info: &str, length: usize) -> Result<Vec<u8>> {
256    const MAX_HKDF_OUTPUT: usize = 255 * 32;
257    let mut result = vec![0u8; length];
258    let mut offset = 0;
259    let mut block_index = 0;
260
261    while offset < length {
262        let remaining = length - offset;
263        let chunk_len = remaining.min(MAX_HKDF_OUTPUT);
264        let info = format!("{}_block_{}", base_info, block_index);
265
266        // Empty salt (matches TypeScript: `new Uint8Array(0)`)
267        let hk = Hkdf::<Sha256>::new(Some(&[]), seed);
268        hk.expand(info.as_bytes(), &mut result[offset..offset + chunk_len])
269            .map_err(|e| Error::Crypto(format!("HKDF expand failed: {}", e)))?;
270
271        offset += chunk_len;
272        block_index += 1;
273    }
274
275    Ok(result)
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn test_key_derivation_parity() {
284        let fixture: serde_json::Value = serde_json::from_str(
285            include_str!("../tests/fixtures/crypto_vectors.json"),
286        )
287        .unwrap();
288
289        let kd = &fixture["key_derivation"];
290        let mnemonic = kd["mnemonic"].as_str().unwrap();
291
292        let keys = derive_keys_from_mnemonic(mnemonic).unwrap();
293
294        assert_eq!(hex::encode(keys.salt), kd["salt_hex"].as_str().unwrap());
295        assert_eq!(hex::encode(keys.auth_key), kd["auth_key_hex"].as_str().unwrap());
296        assert_eq!(
297            hex::encode(keys.encryption_key),
298            kd["encryption_key_hex"].as_str().unwrap()
299        );
300        assert_eq!(hex::encode(keys.dedup_key), kd["dedup_key_hex"].as_str().unwrap());
301
302        // Auth key hash
303        let hash = compute_auth_key_hash(&keys.auth_key);
304        assert_eq!(hash, kd["auth_key_hash"].as_str().unwrap());
305    }
306
307    #[test]
308    fn test_bip39_seed_parity() {
309        let fixture: serde_json::Value = serde_json::from_str(
310            include_str!("../tests/fixtures/crypto_vectors.json"),
311        )
312        .unwrap();
313
314        let mnemonic = fixture["key_derivation"]["mnemonic"].as_str().unwrap();
315        let expected_seed_hex = fixture["key_derivation"]["bip39_seed_hex"].as_str().unwrap();
316
317        let seed = mnemonic_to_seed(mnemonic).unwrap();
318        assert_eq!(hex::encode(seed), expected_seed_hex);
319    }
320
321    #[test]
322    fn test_lsh_seed_parity() {
323        let fixture: serde_json::Value = serde_json::from_str(
324            include_str!("../tests/fixtures/crypto_vectors.json"),
325        )
326        .unwrap();
327
328        let mnemonic = fixture["key_derivation"]["mnemonic"].as_str().unwrap();
329        let keys = derive_keys_from_mnemonic(mnemonic).unwrap();
330        let lsh_seed = derive_lsh_seed(mnemonic, &keys.salt).unwrap();
331
332        assert_eq!(
333            hex::encode(lsh_seed),
334            fixture["lsh"]["lsh_seed_hex"].as_str().unwrap()
335        );
336    }
337
338    #[test]
339    fn test_xchacha_fixed_nonce_parity() {
340        let fixture: serde_json::Value = serde_json::from_str(
341            include_str!("../tests/fixtures/crypto_vectors.json"),
342        )
343        .unwrap();
344
345        let xc = &fixture["xchacha20"];
346        let key_hex = xc["encryption_key_hex"].as_str().unwrap();
347        let key_bytes = hex::decode(key_hex).unwrap();
348        let mut key = [0u8; 32];
349        key.copy_from_slice(&key_bytes);
350
351        let nonce_hex = xc["fixed_nonce_hex"].as_str().unwrap();
352        let nonce_bytes = hex::decode(nonce_hex).unwrap();
353        let mut nonce = [0u8; 24];
354        nonce.copy_from_slice(&nonce_bytes);
355
356        let plaintext = xc["plaintext"].as_str().unwrap();
357        let expected_b64 = xc["fixed_nonce_encrypted_base64"].as_str().unwrap();
358
359        let encrypted = encrypt_with_nonce(plaintext, &key, &nonce).unwrap();
360        assert_eq!(encrypted, expected_b64);
361
362        let decrypted = decrypt(&encrypted, &key).unwrap();
363        assert_eq!(decrypted, plaintext);
364    }
365
366    #[test]
367    fn test_xchacha_round_trip() {
368        let keys = derive_keys_from_mnemonic(
369            "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about",
370        )
371        .unwrap();
372
373        let plaintext = "Hello, TotalReclaw!";
374        let encrypted = encrypt(plaintext, &keys.encryption_key).unwrap();
375        let decrypted = decrypt(&encrypted, &keys.encryption_key).unwrap();
376        assert_eq!(decrypted, plaintext);
377    }
378
379    #[test]
380    fn test_lenient_accepts_valid_mnemonic() {
381        // A valid BIP-39 mnemonic should work in lenient mode too
382        let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon about";
383        let strict = derive_keys_from_mnemonic(mnemonic).unwrap();
384        let lenient = derive_keys_from_mnemonic_lenient(mnemonic).unwrap();
385
386        assert_eq!(strict.auth_key, lenient.auth_key);
387        assert_eq!(strict.encryption_key, lenient.encryption_key);
388        assert_eq!(strict.dedup_key, lenient.dedup_key);
389        assert_eq!(strict.salt, lenient.salt);
390    }
391
392    #[test]
393    fn test_lenient_rejects_invalid_words() {
394        let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon xyzzy";
395        let result = derive_keys_from_mnemonic_lenient(mnemonic);
396        assert!(result.is_err());
397        let err = result.unwrap_err().to_string();
398        assert!(err.contains("xyzzy"));
399    }
400
401    #[test]
402    fn test_lenient_rejects_wrong_word_count() {
403        let mnemonic = "abandon abandon abandon"; // only 3 words
404        let result = derive_keys_from_mnemonic_lenient(mnemonic);
405        assert!(result.is_err());
406        let err = result.unwrap_err().to_string();
407        assert!(err.contains("expected 12 or 24"));
408    }
409
410    #[test]
411    fn test_strict_rejects_bad_checksum() {
412        // All valid BIP-39 words but wrong checksum
413        let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
414        let result = derive_keys_from_mnemonic(mnemonic);
415        assert!(result.is_err());
416    }
417
418    #[test]
419    fn test_lenient_accepts_bad_checksum() {
420        // All valid BIP-39 words but wrong checksum — lenient should accept
421        let mnemonic = "abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon abandon";
422        let result = derive_keys_from_mnemonic_lenient(mnemonic);
423        assert!(result.is_ok());
424    }
425
426}