1use base64::engine::general_purpose::STANDARD;
4use base64::Engine;
5use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
6use hkdf::Hkdf;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use sha2::{Digest, Sha256};
10
11use crate::canonicalize;
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
14pub struct TransportBinding {
15 pub binding_version: String,
16 pub kind: String,
17 #[serde(skip_serializing_if = "Option::is_none", default)]
18 pub endpoint: Option<String>,
19 #[serde(skip_serializing_if = "Option::is_none", default)]
20 pub exporter_key: Option<String>,
21 #[serde(skip_serializing_if = "Option::is_none", default)]
22 pub peer_cert_fingerprint: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none", default)]
24 pub tls_alpn: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none", default)]
26 pub established_at: Option<String>,
27 #[serde(skip_serializing_if = "Option::is_none", default)]
28 pub metadata: Option<Value>,
29}
30
31#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
32pub struct SessionMigration {
33 pub migration_version: String,
34 pub session_id: String,
35 pub generation: u64,
36 pub from_binding: TransportBinding,
37 pub to_binding: TransportBinding,
38 #[serde(skip_serializing_if = "Option::is_none", default)]
39 pub preserved_capabilities: Option<Vec<Value>>,
40 #[serde(skip_serializing_if = "Option::is_none", default)]
41 pub rotated_keys: Option<bool>,
42 pub migrated_at: String,
43 #[serde(skip_serializing_if = "Option::is_none", default)]
44 pub reason: Option<String>,
45 pub signer: String,
46 pub signature: SignatureEnvelope,
47}
48
49#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
50pub struct SignatureEnvelope {
51 pub algorithm: String,
52 pub signer: String,
53 pub signature: String,
54}
55
56pub fn migration_signing_bytes(m: &SessionMigration) -> [u8; 32] {
57 let mut value = serde_json::to_value(m).unwrap_or(Value::Null);
58 if let Value::Object(map) = &mut value {
59 map.remove("signature");
60 }
61 let canonical = canonicalize(&value).unwrap_or_default();
62 Sha256::digest(canonical.as_bytes()).into()
63}
64
65#[allow(clippy::too_many_arguments)]
66pub fn migrate_session(
67 session_id: &str,
68 generation: u64,
69 from_binding: TransportBinding,
70 to_binding: TransportBinding,
71 rotated_keys: bool,
72 reason: Option<&str>,
73 signer: &str,
74 private_key: &[u8; 32],
75 migrated_at: Option<&str>,
76) -> SessionMigration {
77 let migrated_at = migrated_at.map(str::to_string).unwrap_or_else(now_iso8601);
78 let mut m = SessionMigration {
79 migration_version: "1".into(),
80 session_id: session_id.into(),
81 generation,
82 from_binding,
83 to_binding,
84 preserved_capabilities: None,
85 rotated_keys: if rotated_keys { Some(true) } else { None },
86 migrated_at,
87 reason: reason.map(str::to_string),
88 signer: signer.into(),
89 signature: SignatureEnvelope {
90 algorithm: "ed25519".into(),
91 signer: signer.into(),
92 signature: String::new(),
93 },
94 };
95 let digest = migration_signing_bytes(&m);
96 let signing = SigningKey::from_bytes(private_key);
97 let sig: Signature = signing.sign(&digest);
98 m.signature.signature = STANDARD.encode(sig.to_bytes());
99 m
100}
101
102#[derive(Debug)]
103pub struct VerifyMigrationResult {
104 pub ok: bool,
105 pub reason: Option<String>,
106}
107
108pub fn verify_session_migration(
109 m: &SessionMigration,
110 public_key: &[u8; 32],
111 last_generation: Option<u64>,
112 expected_session_id: Option<&str>,
113) -> VerifyMigrationResult {
114 let rejected = |r: &str| VerifyMigrationResult {
115 ok: false,
116 reason: Some(r.to_string()),
117 };
118 if m.migration_version != "1" {
119 return rejected(&format!(
120 "unsupported migration_version {}",
121 m.migration_version
122 ));
123 }
124 if m.signature.signer != m.signer {
125 return rejected("signature signer does not match signer");
126 }
127 if m.signature.algorithm != "ed25519" {
128 return rejected(&format!(
129 "unsupported signature algorithm {}",
130 m.signature.algorithm
131 ));
132 }
133 if let Some(expected) = expected_session_id {
134 if m.session_id != expected {
135 return rejected("session_id mismatch");
136 }
137 }
138 if let Some(last) = last_generation {
139 if m.generation <= last {
140 return rejected(&format!(
141 "generation {} <= last seen {} (replay)",
142 m.generation, last
143 ));
144 }
145 }
146 let digest = migration_signing_bytes(m);
147 let sig_bytes = match STANDARD.decode(&m.signature.signature) {
148 Ok(b) => b,
149 Err(e) => return rejected(&format!("signature base64 decode: {}", e)),
150 };
151 let sig = match Signature::from_slice(&sig_bytes) {
152 Ok(s) => s,
153 Err(e) => return rejected(&format!("signature parse: {}", e)),
154 };
155 let vk = match VerifyingKey::from_bytes(public_key) {
156 Ok(v) => v,
157 Err(e) => return rejected(&format!("verifying key: {}", e)),
158 };
159 if vk.verify(&digest, &sig).is_err() {
160 return rejected("migration signature did not verify");
161 }
162 VerifyMigrationResult {
163 ok: true,
164 reason: None,
165 }
166}
167
168const RATCHET_INFO: &[u8] = b"tf-session/ratchet";
169
170#[derive(Debug)]
171pub struct Ratchet {
172 current_key: [u8; 32],
173 rotation_count: u64,
174 messages_since_rotation: u32,
175 max_messages: u32,
176}
177
178impl Ratchet {
179 pub fn new(initial_key: [u8; 32], max_messages: Option<u32>) -> Self {
180 Ratchet {
181 current_key: initial_key,
182 rotation_count: 0,
183 messages_since_rotation: 0,
184 max_messages: max_messages.unwrap_or(1024),
185 }
186 }
187
188 pub fn key(&self) -> [u8; 32] {
189 self.current_key
190 }
191
192 pub fn generation(&self) -> u64 {
193 self.rotation_count
194 }
195
196 pub fn observe_message(&mut self) -> bool {
197 self.messages_since_rotation += 1;
198 if self.messages_since_rotation >= self.max_messages {
199 self.rotate();
200 true
201 } else {
202 false
203 }
204 }
205
206 pub fn rotate(&mut self) {
207 let hk = Hkdf::<Sha256>::new(None, &self.current_key);
208 let mut next = [0u8; 32];
209 hk.expand(RATCHET_INFO, &mut next).expect("hkdf");
210 self.current_key = next;
211 self.rotation_count += 1;
212 self.messages_since_rotation = 0;
213 }
214}
215
216fn now_iso8601() -> String {
217 let secs = std::time::SystemTime::now()
218 .duration_since(std::time::UNIX_EPOCH)
219 .unwrap_or_default()
220 .as_secs() as i64;
221 let (y, m, d, h, mi, s) = secs_to_ymdhms(secs);
222 format!("{:04}-{:02}-{:02}T{:02}:{:02}:{:02}Z", y, m, d, h, mi, s)
223}
224
225fn secs_to_ymdhms(secs: i64) -> (i32, u32, u32, u32, u32, u32) {
226 let days = secs.div_euclid(86_400);
227 let time = secs.rem_euclid(86_400);
228 let hour = (time / 3600) as u32;
229 let minute = ((time % 3600) / 60) as u32;
230 let second = (time % 60) as u32;
231 let z = days + 719_468;
232 let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
233 let doe = (z - era * 146_097) as u64;
234 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
235 let y = yoe as i64 + era * 400;
236 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
237 let mp = (5 * doy + 2) / 153;
238 let d = (doy - (153 * mp + 2) / 5 + 1) as u32;
239 let m = if mp < 10 {
240 (mp + 3) as u32
241 } else {
242 (mp - 9) as u32
243 };
244 let year = if m <= 2 { y + 1 } else { y };
245 (year as i32, m, d, hour, minute, second)
246}