Skip to main content

tf_types/
session_migration.rs

1//! Session-migration helpers (Rust mirror of TS).
2
3use crate::encoding::STANDARD;
4use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
5use hkdf::Hkdf;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use sha2::{Digest, Sha256};
9
10use crate::canonicalize;
11
12#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
13pub struct TransportBinding {
14    pub binding_version: String,
15    pub kind: String,
16    #[serde(skip_serializing_if = "Option::is_none", default)]
17    pub endpoint: Option<String>,
18    #[serde(skip_serializing_if = "Option::is_none", default)]
19    pub exporter_key: Option<String>,
20    #[serde(skip_serializing_if = "Option::is_none", default)]
21    pub peer_cert_fingerprint: Option<String>,
22    #[serde(skip_serializing_if = "Option::is_none", default)]
23    pub tls_alpn: Option<String>,
24    #[serde(skip_serializing_if = "Option::is_none", default)]
25    pub established_at: Option<String>,
26    #[serde(skip_serializing_if = "Option::is_none", default)]
27    pub metadata: Option<Value>,
28}
29
30#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
31pub struct SessionMigration {
32    pub migration_version: String,
33    pub session_id: String,
34    pub generation: u64,
35    pub from_binding: TransportBinding,
36    pub to_binding: TransportBinding,
37    #[serde(skip_serializing_if = "Option::is_none", default)]
38    pub preserved_capabilities: Option<Vec<Value>>,
39    #[serde(skip_serializing_if = "Option::is_none", default)]
40    pub rotated_keys: Option<bool>,
41    pub migrated_at: String,
42    #[serde(skip_serializing_if = "Option::is_none", default)]
43    pub reason: Option<String>,
44    pub signer: String,
45    pub signature: SignatureEnvelope,
46}
47
48#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
49pub struct SignatureEnvelope {
50    pub algorithm: String,
51    pub signer: String,
52    pub signature: String,
53}
54
55pub fn migration_signing_bytes(m: &SessionMigration) -> [u8; 32] {
56    let mut value = serde_json::to_value(m).unwrap_or(Value::Null);
57    if let Value::Object(map) = &mut value {
58        map.remove("signature");
59    }
60    let canonical = canonicalize(&value).unwrap_or_default();
61    Sha256::digest(canonical.as_bytes()).into()
62}
63
64#[allow(clippy::too_many_arguments)]
65pub fn migrate_session(
66    session_id: &str,
67    generation: u64,
68    from_binding: TransportBinding,
69    to_binding: TransportBinding,
70    rotated_keys: bool,
71    reason: Option<&str>,
72    signer: &str,
73    private_key: &[u8; 32],
74    migrated_at: Option<&str>,
75) -> SessionMigration {
76    let migrated_at = migrated_at.map(str::to_string).unwrap_or_else(now_iso8601);
77    let mut m = SessionMigration {
78        migration_version: "1".into(),
79        session_id: session_id.into(),
80        generation,
81        from_binding,
82        to_binding,
83        preserved_capabilities: None,
84        rotated_keys: if rotated_keys { Some(true) } else { None },
85        migrated_at,
86        reason: reason.map(str::to_string),
87        signer: signer.into(),
88        signature: SignatureEnvelope {
89            algorithm: "ed25519".into(),
90            signer: signer.into(),
91            signature: String::new(),
92        },
93    };
94    let digest = migration_signing_bytes(&m);
95    let signing = SigningKey::from_bytes(private_key);
96    let sig: Signature = signing.sign(&digest);
97    m.signature.signature = STANDARD.encode(sig.to_bytes());
98    m
99}
100
101#[derive(Debug)]
102pub struct VerifyMigrationResult {
103    pub ok: bool,
104    pub reason: Option<String>,
105}
106
107pub fn verify_session_migration(
108    m: &SessionMigration,
109    public_key: &[u8; 32],
110    last_generation: Option<u64>,
111    expected_session_id: Option<&str>,
112) -> VerifyMigrationResult {
113    let rejected = |r: &str| VerifyMigrationResult {
114        ok: false,
115        reason: Some(r.to_string()),
116    };
117    if m.migration_version != "1" {
118        return rejected(&format!(
119            "unsupported migration_version {}",
120            m.migration_version
121        ));
122    }
123    if m.signature.signer != m.signer {
124        return rejected("signature signer does not match signer");
125    }
126    if m.signature.algorithm != "ed25519" {
127        return rejected(&format!(
128            "unsupported signature algorithm {}",
129            m.signature.algorithm
130        ));
131    }
132    if let Some(expected) = expected_session_id {
133        if m.session_id != expected {
134            return rejected("session_id mismatch");
135        }
136    }
137    if let Some(last) = last_generation {
138        if m.generation <= last {
139            return rejected(&format!(
140                "generation {} <= last seen {} (replay)",
141                m.generation, last
142            ));
143        }
144    }
145    let digest = migration_signing_bytes(m);
146    let sig_bytes = match STANDARD.decode(&m.signature.signature) {
147        Ok(b) => b,
148        Err(e) => return rejected(&format!("signature base64 decode: {}", e)),
149    };
150    let sig = match Signature::from_slice(&sig_bytes) {
151        Ok(s) => s,
152        Err(e) => return rejected(&format!("signature parse: {}", e)),
153    };
154    let vk = match VerifyingKey::from_bytes(public_key) {
155        Ok(v) => v,
156        Err(e) => return rejected(&format!("verifying key: {}", e)),
157    };
158    if vk.verify(&digest, &sig).is_err() {
159        return rejected("migration signature did not verify");
160    }
161    VerifyMigrationResult {
162        ok: true,
163        reason: None,
164    }
165}
166
167const RATCHET_INFO: &[u8] = b"tf-session/ratchet";
168
169#[derive(Debug)]
170pub struct Ratchet {
171    current_key: [u8; 32],
172    rotation_count: u64,
173    messages_since_rotation: u32,
174    max_messages: u32,
175}
176
177impl Ratchet {
178    pub fn new(initial_key: [u8; 32], max_messages: Option<u32>) -> Self {
179        Ratchet {
180            current_key: initial_key,
181            rotation_count: 0,
182            messages_since_rotation: 0,
183            max_messages: max_messages.unwrap_or(1024),
184        }
185    }
186
187    pub fn key(&self) -> [u8; 32] {
188        self.current_key
189    }
190
191    pub fn generation(&self) -> u64 {
192        self.rotation_count
193    }
194
195    pub fn observe_message(&mut self) -> bool {
196        self.messages_since_rotation += 1;
197        if self.messages_since_rotation >= self.max_messages {
198            self.rotate();
199            true
200        } else {
201            false
202        }
203    }
204
205    pub fn rotate(&mut self) {
206        let hk = Hkdf::<Sha256>::new(None, &self.current_key);
207        let mut next = [0u8; 32];
208        hk.expand(RATCHET_INFO, &mut next).expect("hkdf");
209        self.current_key = next;
210        self.rotation_count += 1;
211        self.messages_since_rotation = 0;
212    }
213}
214
215fn now_iso8601() -> String {
216    let secs = std::time::SystemTime::now()
217        .duration_since(std::time::UNIX_EPOCH)
218        .unwrap_or_default()
219        .as_secs() as i64;
220    let (y, m, d, h, mi, s) = secs_to_ymdhms(secs);
221    format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, h, mi, s)
222}
223
224fn secs_to_ymdhms(secs: i64) -> (i32, u32, u32, u32, u32, u32) {
225    let days = secs.div_euclid(86_400);
226    let time = secs.rem_euclid(86_400);
227    let hour = (time / 3600) as u32;
228    let minute = ((time % 3600) / 60) as u32;
229    let second = (time % 60) as u32;
230    let z = days + 719_468;
231    let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
232    let doe = (z - era * 146_097) as u64;
233    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
234    let y = yoe as i64 + era * 400;
235    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
236    let mp = (5 * doy + 2) / 153;
237    let d = (doy - (153 * mp + 2) / 5 + 1) as u32;
238    let m = if mp < 10 {
239        (mp + 3) as u32
240    } else {
241        (mp - 9) as u32
242    };
243    let year = if m <= 2 { y + 1 } else { y };
244    (year as i32, m, d, hour, minute, second)
245}