Skip to main content

tf_types/
session_migration.rs

1//! Session-migration helpers (Rust mirror of TS).
2
3use base64::engine::general_purpose::STANDARD;
4use base64::Engine;
5use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
6use hkdf::Hkdf;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use sha2::{Digest, Sha256};
10
11use crate::canonicalize;
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
14pub struct TransportBinding {
15    pub binding_version: String,
16    pub kind: String,
17    #[serde(skip_serializing_if = "Option::is_none", default)]
18    pub endpoint: Option<String>,
19    #[serde(skip_serializing_if = "Option::is_none", default)]
20    pub exporter_key: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none", default)]
22    pub peer_cert_fingerprint: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none", default)]
24    pub tls_alpn: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none", default)]
26    pub established_at: Option<String>,
27    #[serde(skip_serializing_if = "Option::is_none", default)]
28    pub metadata: Option<Value>,
29}
30
31#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
32pub struct SessionMigration {
33    pub migration_version: String,
34    pub session_id: String,
35    pub generation: u64,
36    pub from_binding: TransportBinding,
37    pub to_binding: TransportBinding,
38    #[serde(skip_serializing_if = "Option::is_none", default)]
39    pub preserved_capabilities: Option<Vec<Value>>,
40    #[serde(skip_serializing_if = "Option::is_none", default)]
41    pub rotated_keys: Option<bool>,
42    pub migrated_at: String,
43    #[serde(skip_serializing_if = "Option::is_none", default)]
44    pub reason: Option<String>,
45    pub signer: String,
46    pub signature: SignatureEnvelope,
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
50pub struct SignatureEnvelope {
51    pub algorithm: String,
52    pub signer: String,
53    pub signature: String,
54}
55
56pub fn migration_signing_bytes(m: &SessionMigration) -> [u8; 32] {
57    let mut value = serde_json::to_value(m).unwrap_or(Value::Null);
58    if let Value::Object(map) = &mut value {
59        map.remove("signature");
60    }
61    let canonical = canonicalize(&value).unwrap_or_default();
62    Sha256::digest(canonical.as_bytes()).into()
63}
64
65#[allow(clippy::too_many_arguments)]
66pub fn migrate_session(
67    session_id: &str,
68    generation: u64,
69    from_binding: TransportBinding,
70    to_binding: TransportBinding,
71    rotated_keys: bool,
72    reason: Option<&str>,
73    signer: &str,
74    private_key: &[u8; 32],
75    migrated_at: Option<&str>,
76) -> SessionMigration {
77    let migrated_at = migrated_at.map(str::to_string).unwrap_or_else(now_iso8601);
78    let mut m = SessionMigration {
79        migration_version: "1".into(),
80        session_id: session_id.into(),
81        generation,
82        from_binding,
83        to_binding,
84        preserved_capabilities: None,
85        rotated_keys: if rotated_keys { Some(true) } else { None },
86        migrated_at,
87        reason: reason.map(str::to_string),
88        signer: signer.into(),
89        signature: SignatureEnvelope {
90            algorithm: "ed25519".into(),
91            signer: signer.into(),
92            signature: String::new(),
93        },
94    };
95    let digest = migration_signing_bytes(&m);
96    let signing = SigningKey::from_bytes(private_key);
97    let sig: Signature = signing.sign(&digest);
98    m.signature.signature = STANDARD.encode(sig.to_bytes());
99    m
100}
101
102#[derive(Debug)]
103pub struct VerifyMigrationResult {
104    pub ok: bool,
105    pub reason: Option<String>,
106}
107
108pub fn verify_session_migration(
109    m: &SessionMigration,
110    public_key: &[u8; 32],
111    last_generation: Option<u64>,
112    expected_session_id: Option<&str>,
113) -> VerifyMigrationResult {
114    let rejected = |r: &str| VerifyMigrationResult {
115        ok: false,
116        reason: Some(r.to_string()),
117    };
118    if m.migration_version != "1" {
119        return rejected(&format!(
120            "unsupported migration_version {}",
121            m.migration_version
122        ));
123    }
124    if m.signature.signer != m.signer {
125        return rejected("signature signer does not match signer");
126    }
127    if m.signature.algorithm != "ed25519" {
128        return rejected(&format!(
129            "unsupported signature algorithm {}",
130            m.signature.algorithm
131        ));
132    }
133    if let Some(expected) = expected_session_id {
134        if m.session_id != expected {
135            return rejected("session_id mismatch");
136        }
137    }
138    if let Some(last) = last_generation {
139        if m.generation <= last {
140            return rejected(&format!(
141                "generation {} <= last seen {} (replay)",
142                m.generation, last
143            ));
144        }
145    }
146    let digest = migration_signing_bytes(m);
147    let sig_bytes = match STANDARD.decode(&m.signature.signature) {
148        Ok(b) => b,
149        Err(e) => return rejected(&format!("signature base64 decode: {}", e)),
150    };
151    let sig = match Signature::from_slice(&sig_bytes) {
152        Ok(s) => s,
153        Err(e) => return rejected(&format!("signature parse: {}", e)),
154    };
155    let vk = match VerifyingKey::from_bytes(public_key) {
156        Ok(v) => v,
157        Err(e) => return rejected(&format!("verifying key: {}", e)),
158    };
159    if vk.verify(&digest, &sig).is_err() {
160        return rejected("migration signature did not verify");
161    }
162    VerifyMigrationResult {
163        ok: true,
164        reason: None,
165    }
166}
167
168const RATCHET_INFO: &[u8] = b"tf-session/ratchet";
169
170#[derive(Debug)]
171pub struct Ratchet {
172    current_key: [u8; 32],
173    rotation_count: u64,
174    messages_since_rotation: u32,
175    max_messages: u32,
176}
177
178impl Ratchet {
179    pub fn new(initial_key: [u8; 32], max_messages: Option<u32>) -> Self {
180        Ratchet {
181            current_key: initial_key,
182            rotation_count: 0,
183            messages_since_rotation: 0,
184            max_messages: max_messages.unwrap_or(1024),
185        }
186    }
187
188    pub fn key(&self) -> [u8; 32] {
189        self.current_key
190    }
191
192    pub fn generation(&self) -> u64 {
193        self.rotation_count
194    }
195
196    pub fn observe_message(&mut self) -> bool {
197        self.messages_since_rotation += 1;
198        if self.messages_since_rotation >= self.max_messages {
199            self.rotate();
200            true
201        } else {
202            false
203        }
204    }
205
206    pub fn rotate(&mut self) {
207        let hk = Hkdf::<Sha256>::new(None, &self.current_key);
208        let mut next = [0u8; 32];
209        hk.expand(RATCHET_INFO, &mut next).expect("hkdf");
210        self.current_key = next;
211        self.rotation_count += 1;
212        self.messages_since_rotation = 0;
213    }
214}
215
216fn now_iso8601() -> String {
217    let secs = std::time::SystemTime::now()
218        .duration_since(std::time::UNIX_EPOCH)
219        .unwrap_or_default()
220        .as_secs() as i64;
221    let (y, m, d, h, mi, s) = secs_to_ymdhms(secs);
222    format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, h, mi, s)
223}
224
225fn secs_to_ymdhms(secs: i64) -> (i32, u32, u32, u32, u32, u32) {
226    let days = secs.div_euclid(86_400);
227    let time = secs.rem_euclid(86_400);
228    let hour = (time / 3600) as u32;
229    let minute = ((time % 3600) / 60) as u32;
230    let second = (time % 60) as u32;
231    let z = days + 719_468;
232    let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
233    let doe = (z - era * 146_097) as u64;
234    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
235    let y = yoe as i64 + era * 400;
236    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
237    let mp = (5 * doy + 2) / 153;
238    let d = (doy - (153 * mp + 2) / 5 + 1) as u32;
239    let m = if mp < 10 {
240        (mp + 3) as u32
241    } else {
242        (mp - 9) as u32
243    };
244    let year = if m <= 2 { y + 1 } else { y };
245    (year as i32, m, d, hour, minute, second)
246}