Skip to main content

shard_core/
keychain.rs

1use anyhow::Result;
2use ed25519_dalek::{Signer, SigningKey, Verifier, VerifyingKey};
3use serde::{Deserialize, Serialize};
4use std::fs;
5use std::path::Path;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8const RECORDS_DIR: &str = "records";
9const ROTATIONS_DIR: &str = "rotations";
10const ARCHIVE_DIR: &str = "archive";
11const CURRENT_REF: &str = "current";
12
13pub fn key_id_from_public_key(pk: &VerifyingKey) -> String {
14    blake3::hash(&pk.to_bytes()).to_hex().to_string()
15}
16
17#[derive(Serialize, Deserialize, Debug, Clone)]
18pub struct KeyRecord {
19    pub key_id: String,
20    pub public_key_hex: String,
21    pub created_at: u64,
22    pub previous_key_id: Option<String>,
23}
24
25#[derive(Serialize, Deserialize, Debug, Clone)]
26pub struct KeyRotation {
27    pub rotation_id: String,
28    pub old_key_id: String,
29    pub new_key_id: String,
30    pub new_public_key_hex: String,
31    pub timestamp: u64,
32    pub signature_hex: String,
33}
34
35impl KeyRotation {
36    /// Verify that this rotation was signed by the old key's private key.
37    pub fn verify(&self, old_public_key: &VerifyingKey) -> Result<()> {
38        let payload = serde_json::json!({
39            "old_key_id": self.old_key_id,
40            "new_key_id": self.new_key_id,
41            "new_public_key_hex": self.new_public_key_hex,
42            "timestamp": self.timestamp,
43        });
44        let payload_bytes = serde_json::to_vec(&payload)?;
45        let sig_bytes = hex::decode(&self.signature_hex)?;
46        let signature = ed25519_dalek::Signature::from_bytes(sig_bytes.as_slice().try_into()?);
47        old_public_key.verify(&payload_bytes, &signature)?;
48        Ok(())
49    }
50}
51
52/// Initialize the keychain with the current key as the genesis root.
53/// Must be called after the initial keypair is saved to `keys_dir`.
54pub fn init_keychain(keys_dir: &Path) -> Result<String> {
55    let pub_bytes = fs::read(keys_dir.join("public.key"))?;
56    let pk = VerifyingKey::from_bytes(pub_bytes.as_slice().try_into()?)?;
57    let key_id = key_id_from_public_key(&pk);
58
59    fs::create_dir_all(keys_dir.join(RECORDS_DIR))?;
60    fs::create_dir_all(keys_dir.join(ROTATIONS_DIR))?;
61    fs::create_dir_all(keys_dir.join(ARCHIVE_DIR))?;
62
63    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis() as u64;
64    let record = KeyRecord {
65        key_id: key_id.clone(),
66        public_key_hex: hex::encode(pk.to_bytes()),
67        created_at: now,
68        previous_key_id: None,
69    };
70    let record_path = keys_dir.join(RECORDS_DIR).join(format!("{}.json", key_id));
71    fs::write(&record_path, serde_json::to_string_pretty(&record)?)?;
72
73    set_current_key(keys_dir, &key_id)?;
74    Ok(key_id)
75}
76
77/// Return the key_id of the currently active key.
78pub fn get_current_key_id(keys_dir: &Path) -> Result<String> {
79    let current_path = keys_dir.join(CURRENT_REF);
80    let key_id = fs::read_to_string(&current_path)?.trim().to_string();
81    if key_id.is_empty() {
82        anyhow::bail!("current key ref is empty");
83    }
84    Ok(key_id)
85}
86
87fn set_current_key(keys_dir: &Path, key_id: &str) -> Result<()> {
88    fs::write(keys_dir.join(CURRENT_REF), key_id)?;
89    Ok(())
90}
91
92/// Generate a new ed25519 signing keypair, archive the old one, and
93/// persist a signed rotation record.
94pub fn rotate_signing_key(keys_dir: &Path) -> Result<KeyRotation> {
95    let old_secret = fs::read(keys_dir.join("secret.key"))?;
96    let old_signing_key = SigningKey::from_bytes(old_secret.as_slice().try_into()?);
97    let old_verifying_key = old_signing_key.verifying_key();
98    let old_key_id = get_current_key_id(keys_dir)?;
99
100    // Generate new keypair
101    use rand::RngCore;
102    let mut bytes = [0u8; 32];
103    let mut csprng = rand::rngs::OsRng;
104    csprng.fill_bytes(&mut bytes);
105    let new_signing_key = SigningKey::from_bytes(&bytes);
106    let new_verifying_key = new_signing_key.verifying_key();
107    let new_key_id = key_id_from_public_key(&new_verifying_key);
108
109    // Archive old key
110    let archive_dir = keys_dir.join(ARCHIVE_DIR).join(&old_key_id);
111    fs::create_dir_all(&archive_dir)?;
112    fs::write(archive_dir.join("secret.key"), &old_secret)?;
113    fs::write(archive_dir.join("public.key"), old_verifying_key.to_bytes())?;
114
115    // Create and sign the rotation
116    let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_millis() as u64;
117    let new_pk_hex = hex::encode(new_verifying_key.to_bytes());
118    let payload = serde_json::json!({
119        "old_key_id": old_key_id,
120        "new_key_id": new_key_id,
121        "new_public_key_hex": new_pk_hex,
122        "timestamp": now,
123    });
124    let payload_bytes = serde_json::to_vec(&payload)?;
125    let signature = old_signing_key.sign(&payload_bytes);
126
127    let rotation = KeyRotation {
128        rotation_id: blake3::hash(&payload_bytes).to_hex().to_string(),
129        old_key_id,
130        new_key_id: new_key_id.clone(),
131        new_public_key_hex: new_pk_hex,
132        timestamp: now,
133        signature_hex: hex::encode(signature.to_bytes()),
134    };
135
136    // Save rotation record
137    let rotation_path = keys_dir
138        .join(ROTATIONS_DIR)
139        .join(format!("{}.json", rotation.rotation_id));
140    fs::write(&rotation_path, serde_json::to_string_pretty(&rotation)?)?;
141
142    // Create a record for the new key
143    let new_record = KeyRecord {
144        key_id: new_key_id.clone(),
145        public_key_hex: hex::encode(new_verifying_key.to_bytes()),
146        created_at: now,
147        previous_key_id: Some(rotation.old_key_id.clone()),
148    };
149    let record_path = keys_dir
150        .join(RECORDS_DIR)
151        .join(format!("{}.json", new_key_id));
152    fs::write(&record_path, serde_json::to_string_pretty(&new_record)?)?;
153
154    // Save new keypair as current
155    fs::write(keys_dir.join("secret.key"), new_signing_key.to_bytes())?;
156    fs::write(keys_dir.join("public.key"), new_verifying_key.to_bytes())?;
157    set_current_key(keys_dir, &new_key_id)?;
158
159    Ok(rotation)
160}
161
162/// Walk the key rotation chain for a given key_id, returning all rotation
163/// records in order (newest first). Starts from the rotation whose new_key_id
164/// matches key_id, then follows previous_key_id backward to genesis.
165pub fn collect_rotation_chain(keys_dir: &Path, key_id: &str) -> Result<Vec<KeyRotation>> {
166    let rotations = load_rotations(keys_dir)?;
167    // Build index: new_key_id -> rotation
168    let new_to_old: std::collections::HashMap<&str, &KeyRotation> = rotations
169        .iter()
170        .map(|r| (r.new_key_id.as_str(), r))
171        .collect();
172    let mut chain = Vec::new();
173    let mut current = key_id;
174    while let Some(rot) = new_to_old.get(current) {
175        chain.push((*rot).clone());
176        current = &rot.old_key_id;
177    }
178    Ok(chain)
179}
180
181/// Load all rotation records sorted by timestamp.
182pub fn load_rotations(keys_dir: &Path) -> Result<Vec<KeyRotation>> {
183    let rot_dir = keys_dir.join(ROTATIONS_DIR);
184    if !rot_dir.exists() {
185        return Ok(Vec::new());
186    }
187    let mut rotations = Vec::new();
188    for entry in fs::read_dir(&rot_dir)? {
189        let entry = entry?;
190        if entry.file_type()?.is_file() {
191            let data = fs::read(entry.path())?;
192            if let Ok(rot) = serde_json::from_slice::<KeyRotation>(&data) {
193                rotations.push(rot);
194            }
195        }
196    }
197    rotations.sort_by_key(|a| a.timestamp);
198    Ok(rotations)
199}
200
201/// Load all key records sorted by creation time.
202pub fn load_records(keys_dir: &Path) -> Result<Vec<KeyRecord>> {
203    let rec_dir = keys_dir.join(RECORDS_DIR);
204    if !rec_dir.exists() {
205        return Ok(Vec::new());
206    }
207    let mut records = Vec::new();
208    for entry in fs::read_dir(&rec_dir)? {
209        let entry = entry?;
210        if entry.file_type()?.is_file() {
211            let data = fs::read(entry.path())?;
212            if let Ok(record) = serde_json::from_slice::<KeyRecord>(&data) {
213                records.push(record);
214            }
215        }
216    }
217    records.sort_by_key(|a| a.created_at);
218    Ok(records)
219}
220
221/// Walk every rotation and verify its signature. Returns a list of errors.
222pub fn verify_keychain(keys_dir: &Path) -> Result<Vec<String>> {
223    let rotations = load_rotations(keys_dir)?;
224    let mut errors = Vec::new();
225    for rotation in &rotations {
226        let old_pk = resolve_public_key(keys_dir, &rotation.old_key_id)?;
227        if let Err(e) = rotation.verify(&old_pk) {
228            errors.push(format!("rotation {}: {}", rotation.rotation_id, e));
229        }
230    }
231    Ok(errors)
232}
233
234/// Find the ed25519 public key for a given key_id by searching:
235/// current key, archived keys, and key records.
236pub fn resolve_public_key(keys_dir: &Path, key_id: &str) -> Result<VerifyingKey> {
237    if let Ok(current_id) = get_current_key_id(keys_dir) {
238        if current_id == key_id {
239            let pub_bytes = fs::read(keys_dir.join("public.key"))?;
240            return Ok(VerifyingKey::from_bytes(pub_bytes.as_slice().try_into()?)?);
241        }
242    }
243
244    let archive_pub = keys_dir.join(ARCHIVE_DIR).join(key_id).join("public.key");
245    if archive_pub.exists() {
246        let pub_bytes = fs::read(&archive_pub)?;
247        return Ok(VerifyingKey::from_bytes(pub_bytes.as_slice().try_into()?)?);
248    }
249
250    let records = load_records(keys_dir)?;
251    for record in &records {
252        if record.key_id == key_id {
253            let pk_bytes = hex::decode(&record.public_key_hex)?;
254            return Ok(VerifyingKey::from_bytes(pk_bytes.as_slice().try_into()?)?);
255        }
256    }
257
258    anyhow::bail!("key_id {} not found in keychain", key_id)
259}
260
261/// Verify that `key_id` was an active (non-expired) key at the given
262/// Unix timestamp (seconds).  Keychain timestamps are stored in
263/// milliseconds, so we compare at second precision.
264pub fn key_was_valid_at(keys_dir: &Path, key_id: &str, timestamp_secs: u64) -> Result<()> {
265    let records = load_records(keys_dir)?;
266    let record = records
267        .iter()
268        .find(|r| r.key_id == key_id)
269        .ok_or_else(|| anyhow::anyhow!("key_id {} not found in keychain", key_id))?;
270
271    // Compare at second precision (div 1000) to align with commit timestamps.
272    let created_secs = record.created_at / 1000;
273    if created_secs > timestamp_secs {
274        anyhow::bail!(
275            "key {} created at {} (secs) but commit is at {} — key not yet valid",
276            key_id,
277            created_secs,
278            timestamp_secs
279        );
280    }
281
282    for next in &records {
283        if next.previous_key_id.as_deref() == Some(key_id) {
284            let next_secs = next.created_at / 1000;
285            // A key is valid for the entire second in which it rotates. Only
286            // reject if the rotation finished *before* the commit second.
287            if next_secs < timestamp_secs {
288                anyhow::bail!(
289                    "key {} rotated at {} (secs) but commit is at {} — key was already stale",
290                    key_id,
291                    next_secs,
292                    timestamp_secs
293                );
294            }
295            break;
296        }
297    }
298
299    Ok(())
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use tempfile::tempdir;
306
307    fn create_initial_keypair(keys_dir: &Path) {
308        use rand::RngCore;
309        let mut bytes = [0u8; 32];
310        rand::rngs::OsRng.fill_bytes(&mut bytes);
311        let sk = SigningKey::from_bytes(&bytes);
312        let pk = sk.verifying_key();
313        fs::write(keys_dir.join("secret.key"), sk.to_bytes()).unwrap();
314        fs::write(keys_dir.join("public.key"), pk.to_bytes()).unwrap();
315    }
316
317    #[test]
318    fn test_key_id_deterministic() {
319        use rand::RngCore;
320        let mut bytes = [0u8; 32];
321        rand::rngs::OsRng.fill_bytes(&mut bytes);
322        let sk = SigningKey::from_bytes(&bytes);
323        let pk = sk.verifying_key();
324        assert_eq!(key_id_from_public_key(&pk), key_id_from_public_key(&pk));
325    }
326
327    #[test]
328    fn test_init_keychain_creates_record_and_ref() {
329        let dir = tempdir().unwrap();
330        let keys_dir = dir.path().join("keys");
331        fs::create_dir_all(&keys_dir).unwrap();
332        create_initial_keypair(&keys_dir);
333
334        let pub_bytes = fs::read(keys_dir.join("public.key")).unwrap();
335        let pk = VerifyingKey::from_bytes(pub_bytes.as_slice().try_into().unwrap()).unwrap();
336
337        let key_id = init_keychain(&keys_dir).unwrap();
338        assert_eq!(key_id, key_id_from_public_key(&pk));
339
340        let stored = fs::read_to_string(keys_dir.join("current")).unwrap();
341        assert_eq!(stored.trim(), key_id);
342
343        let rec_path = keys_dir.join(RECORDS_DIR).join(format!("{}.json", key_id));
344        assert!(rec_path.exists());
345    }
346
347    #[test]
348    fn test_rotate_signing_key_creates_rotation_and_updates_current() {
349        let dir = tempdir().unwrap();
350        let keys_dir = dir.path().join("keys");
351        fs::create_dir_all(&keys_dir).unwrap();
352        create_initial_keypair(&keys_dir);
353
354        let old_key_id = init_keychain(&keys_dir).unwrap();
355
356        let rotation = rotate_signing_key(&keys_dir).unwrap();
357        assert_eq!(rotation.old_key_id, old_key_id);
358        assert_ne!(rotation.new_key_id, old_key_id);
359
360        let current_id = get_current_key_id(&keys_dir).unwrap();
361        assert_eq!(current_id, rotation.new_key_id);
362
363        let rot_path = keys_dir
364            .join(ROTATIONS_DIR)
365            .join(format!("{}.json", rotation.rotation_id));
366        assert!(rot_path.exists());
367    }
368
369    #[test]
370    fn test_rotation_verifies_with_old_key() {
371        let dir = tempdir().unwrap();
372        let keys_dir = dir.path().join("keys");
373        fs::create_dir_all(&keys_dir).unwrap();
374        create_initial_keypair(&keys_dir);
375
376        init_keychain(&keys_dir).unwrap();
377        let rotation = rotate_signing_key(&keys_dir).unwrap();
378
379        let old_pk = resolve_public_key(&keys_dir, &rotation.old_key_id).unwrap();
380        assert!(rotation.verify(&old_pk).is_ok());
381
382        let errors = verify_keychain(&keys_dir).unwrap();
383        assert!(errors.is_empty(), "{:?}", errors);
384    }
385
386    #[test]
387    fn test_key_was_valid_at() {
388        let dir = tempdir().unwrap();
389        let keys_dir = dir.path().join("keys");
390        fs::create_dir_all(&keys_dir).unwrap();
391        create_initial_keypair(&keys_dir);
392
393        let old_key_id = init_keychain(&keys_dir).unwrap();
394
395        // Sleep to ensure rotation falls in a different second than init.
396        std::thread::sleep(std::time::Duration::from_millis(1500));
397        let rotation = rotate_signing_key(&keys_dir).unwrap();
398
399        let rot_secs = rotation.timestamp / 1000;
400        let old_created_secs = load_records(&keys_dir)
401            .unwrap()
402            .iter()
403            .find(|r| r.key_id == old_key_id)
404            .unwrap()
405            .created_at
406            / 1000;
407
408        // Old key not valid before its creation second
409        assert!(key_was_valid_at(&keys_dir, &old_key_id, old_created_secs - 1).is_err());
410
411        // Old key valid at its creation second (which is < rot_secs with 1.5s sleep)
412        assert!(key_was_valid_at(&keys_dir, &old_key_id, old_created_secs).is_ok());
413
414        // Old key valid throughout the rotation second (grace window)
415        assert!(key_was_valid_at(&keys_dir, &old_key_id, rot_secs).is_ok());
416
417        // Old key invalid starting the second AFTER rotation
418        assert!(key_was_valid_at(&keys_dir, &old_key_id, rot_secs + 1).is_err());
419
420        // New key valid at rotation second onward
421        assert!(key_was_valid_at(&keys_dir, &rotation.new_key_id, rot_secs).is_ok());
422
423        // New key not valid before rotation second
424        assert!(key_was_valid_at(&keys_dir, &rotation.new_key_id, rot_secs - 1).is_err());
425    }
426
427    #[test]
428    fn test_resolve_public_key_after_rotation() {
429        let dir = tempdir().unwrap();
430        let keys_dir = dir.path().join("keys");
431        fs::create_dir_all(&keys_dir).unwrap();
432        create_initial_keypair(&keys_dir);
433
434        let old_key_id = init_keychain(&keys_dir).unwrap();
435        let rotation = rotate_signing_key(&keys_dir).unwrap();
436
437        // Old key resolvable from archive
438        let old_pk = resolve_public_key(&keys_dir, &old_key_id).unwrap();
439        assert_eq!(key_id_from_public_key(&old_pk), old_key_id);
440
441        // New key resolvable from current
442        let new_pk = resolve_public_key(&keys_dir, &rotation.new_key_id).unwrap();
443        assert_eq!(key_id_from_public_key(&new_pk), rotation.new_key_id);
444    }
445
446    #[test]
447    fn test_tampered_rotation_is_detected() {
448        let dir = tempdir().unwrap();
449        let keys_dir = dir.path().join("keys");
450        fs::create_dir_all(&keys_dir).unwrap();
451        create_initial_keypair(&keys_dir);
452
453        init_keychain(&keys_dir).unwrap();
454        rotate_signing_key(&keys_dir).unwrap();
455
456        // Tamper every rotation file
457        let rot_dir = keys_dir.join(ROTATIONS_DIR);
458        for entry in fs::read_dir(&rot_dir).unwrap() {
459            let entry = entry.unwrap();
460            if entry.file_type().unwrap().is_file() {
461                let data = fs::read(entry.path()).unwrap();
462                if let Ok(mut rot) = serde_json::from_slice::<KeyRotation>(&data) {
463                    rot.signature_hex = hex::encode([0u8; 64]);
464                    fs::write(entry.path(), serde_json::to_string_pretty(&rot).unwrap()).unwrap();
465                }
466            }
467        }
468
469        let errors = verify_keychain(&keys_dir).unwrap();
470        assert!(!errors.is_empty(), "tampered rotation must fail");
471    }
472
473    #[test]
474    fn test_double_rotation() {
475        let dir = tempdir().unwrap();
476        let keys_dir = dir.path().join("keys");
477        fs::create_dir_all(&keys_dir).unwrap();
478        create_initial_keypair(&keys_dir);
479
480        let key1 = init_keychain(&keys_dir).unwrap();
481
482        std::thread::sleep(std::time::Duration::from_millis(1500));
483        let rot1 = rotate_signing_key(&keys_dir).unwrap();
484        let key2 = rot1.new_key_id.clone();
485
486        std::thread::sleep(std::time::Duration::from_millis(1500));
487        let rot2 = rotate_signing_key(&keys_dir).unwrap();
488        let key3 = rot2.new_key_id.clone();
489
490        assert_ne!(key1, key2);
491        assert_ne!(key2, key3);
492        assert_ne!(key1, key3);
493
494        let current = get_current_key_id(&keys_dir).unwrap();
495        assert_eq!(current, key3);
496
497        let errors = verify_keychain(&keys_dir).unwrap();
498        assert!(errors.is_empty(), "{:?}", errors);
499
500        let r1s = rot1.timestamp / 1000;
501        let r2s = rot2.timestamp / 1000;
502
503        // key1 valid throughout rot1 second (grace window)
504        assert!(key_was_valid_at(&keys_dir, &key1, r1s).is_ok());
505        // key1 invalid the second after rot1
506        assert!(key_was_valid_at(&keys_dir, &key1, r1s + 1).is_err());
507        // key2 valid during rot2 second
508        assert!(key_was_valid_at(&keys_dir, &key2, r2s).is_ok());
509        // key2 invalid the second after rot2
510        assert!(key_was_valid_at(&keys_dir, &key2, r2s + 1).is_err());
511        // key3 valid at rot2 second onward
512        assert!(key_was_valid_at(&keys_dir, &key3, r2s).is_ok());
513    }
514}