shadowforge_lib/domain/timelock/
mod.rs1use chrono::{DateTime, Utc};
12use num_bigint::BigUint;
13use num_traits::One;
14use sha2::{Digest, Sha256};
15
16use crate::domain::crypto::{decrypt_aes_gcm, encrypt_aes_gcm};
17use crate::domain::errors::TimeLockError;
18use crate::domain::types::{Payload, TimeLockPuzzle};
19
20const DEFAULT_SQUARINGS_PER_SEC: u64 = 10_000_000;
23
24const MODULUS_BITS: u64 = 256;
26
27const NONCE_LEN: usize = 12;
29
30fn derive_key_and_nonce(solution: &BigUint) -> ([u8; 32], [u8; NONCE_LEN]) {
32 let solution_bytes = solution.to_bytes_be();
33
34 let key: [u8; 32] = Sha256::digest(&solution_bytes).into();
36
37 let mut nonce_input = b"nonce".to_vec();
39 nonce_input.extend_from_slice(&solution_bytes);
40 let nonce_hash = Sha256::digest(&nonce_input);
41 let mut nonce = [0u8; NONCE_LEN];
42 nonce.copy_from_slice(nonce_hash.get(..NONCE_LEN).unwrap_or(&[0u8; NONCE_LEN]));
43
44 (key, nonce)
45}
46
47fn sequential_square(g: &BigUint, n: &BigUint, t: u64) -> BigUint {
49 let mut result = g.clone();
50 for _ in 0..t {
51 result = (&result * &result) % n;
52 }
53 result
54}
55
56pub fn create_puzzle(
61 payload: &Payload,
62 unlock_at: DateTime<Utc>,
63 squarings_per_sec: u64,
64) -> Result<TimeLockPuzzle, TimeLockError> {
65 let now = Utc::now();
66 let duration_secs = (unlock_at - now).num_seconds().max(0).cast_unsigned();
67 let squarings_required = duration_secs.saturating_mul(squarings_per_sec);
68
69 let mut rng = rand::rng();
71 let half_bits = MODULUS_BITS / 2;
72 let p = generate_random_prime(&mut rng, half_bits);
73 let q = generate_random_prime(&mut rng, half_bits);
74 let n = &p * &q;
75
76 let g = {
78 let two = BigUint::from(2u32);
79 let upper = &n - &two;
80 let rand_val = random_biguint(&mut rng, MODULUS_BITS);
81 (rand_val % &upper) + &two
82 };
83
84 let phi = (&p - BigUint::one()) * (&q - BigUint::one());
87 let two = BigUint::from(2u32);
88 let exponent = two.modpow(&BigUint::from(squarings_required), &phi);
89 let solution = g.modpow(&exponent, &n);
90
91 let (key, nonce) = derive_key_and_nonce(&solution);
93
94 let ciphertext = encrypt_aes_gcm(&key, &nonce, payload.as_bytes()).map_err(|e| {
96 TimeLockError::ComputationFailed {
97 reason: format!("encryption failed: {e}"),
98 }
99 })?;
100
101 Ok(TimeLockPuzzle {
102 ciphertext,
103 modulus: n.to_bytes_be(),
104 start_value: g.to_bytes_be(),
105 squarings_required,
106 created_at: now,
107 unlock_at,
108 })
109}
110
111pub fn solve_puzzle(puzzle: &TimeLockPuzzle) -> Result<Payload, TimeLockError> {
118 let n = BigUint::from_bytes_be(&puzzle.modulus);
119 let g = BigUint::from_bytes_be(&puzzle.start_value);
120
121 let solution = sequential_square(&g, &n, puzzle.squarings_required);
123
124 let (key, nonce) = derive_key_and_nonce(&solution);
126
127 let plaintext = decrypt_aes_gcm(&key, &nonce, &puzzle.ciphertext)
128 .map_err(|source| TimeLockError::DecryptFailed { source })?;
129
130 Ok(Payload::from_bytes(plaintext.to_vec()))
131}
132
133pub fn try_solve_puzzle(
139 puzzle: &TimeLockPuzzle,
140 squarings_per_sec: u64,
141) -> Result<Option<Payload>, TimeLockError> {
142 let now = Utc::now();
144 let elapsed_secs = (now - puzzle.created_at)
145 .num_seconds()
146 .max(0)
147 .cast_unsigned();
148 let estimated_solvable_squarings = elapsed_secs.saturating_mul(squarings_per_sec);
149
150 if estimated_solvable_squarings < puzzle.squarings_required {
151 return Ok(None);
152 }
153
154 solve_puzzle(puzzle).map(Some)
156}
157
158fn generate_random_prime(rng: &mut impl rand::Rng, bits: u64) -> BigUint {
163 loop {
164 let mut candidate = random_biguint(rng, bits);
165 candidate |= BigUint::one();
167 candidate |= BigUint::one() << (bits - 1);
169
170 if is_probably_prime(&candidate) {
171 return candidate;
172 }
173 }
174}
175
176fn is_probably_prime(n: &BigUint) -> bool {
180 let two = BigUint::from(2u32);
181 if n < &two {
182 return false;
183 }
184 if n == &two || n == &BigUint::from(3u32) {
185 return true;
186 }
187 if n % &two == BigUint::ZERO {
188 return false;
189 }
190
191 let mut i = BigUint::from(3u32);
193 let limit = BigUint::from(10_000u32);
194 while &i * &i <= *n && i < limit {
195 if n % &i == BigUint::ZERO {
196 return false;
197 }
198 i += &two;
199 }
200 true
201}
202
203fn random_biguint(rng: &mut impl rand::Rng, bits: u64) -> BigUint {
205 #[expect(
206 clippy::cast_possible_truncation,
207 reason = "bits <= 256, always fits usize"
208 )]
209 let byte_count = bits.div_ceil(8) as usize;
210 let mut buf = vec![0u8; byte_count];
211 rng.fill_bytes(&mut buf);
212 let excess_bits = (byte_count * 8) as u64 - bits;
214 if excess_bits > 0
215 && let Some(first) = buf.first_mut()
216 {
217 *first &= 0xFF >> excess_bits;
218 }
219 BigUint::from_bytes_be(&buf)
220}
221
222#[must_use]
224pub const fn default_squarings_per_sec() -> u64 {
225 DEFAULT_SQUARINGS_PER_SEC
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 type TestResult = Result<(), Box<dyn std::error::Error>>;
233
234 #[test]
235 fn roundtrip_lock_then_unlock() -> TestResult {
236 let payload = Payload::from_bytes(b"secret message".to_vec());
237 let unlock_at = Utc::now();
238
239 let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
240
241 let recovered = solve_puzzle(&puzzle)?;
243 assert_eq!(recovered.as_bytes(), payload.as_bytes());
244 Ok(())
245 }
246
247 #[test]
248 fn roundtrip_with_small_delay() -> TestResult {
249 let payload = Payload::from_bytes(b"timed secret".to_vec());
250 let unlock_at = Utc::now();
251
252 let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
253
254 let recovered = solve_puzzle(&puzzle)?;
255 assert_eq!(recovered.as_bytes(), payload.as_bytes());
256 Ok(())
257 }
258
259 #[test]
260 fn try_solve_returns_none_for_future_puzzle() -> TestResult {
261 let payload = Payload::from_bytes(b"future secret".to_vec());
262 let unlock_at = Utc::now() + chrono::Duration::hours(1);
263
264 let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
265
266 let result = try_solve_puzzle(&puzzle, DEFAULT_SQUARINGS_PER_SEC)?;
268 assert!(result.is_none());
269 Ok(())
270 }
271
272 #[test]
273 fn puzzle_serialises_to_json() -> TestResult {
274 let payload = Payload::from_bytes(b"json test".to_vec());
275 let unlock_at = Utc::now();
276
277 let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
278
279 let json = serde_json::to_string_pretty(&puzzle)?;
280 assert!(json.contains("squarings_required"));
281 assert!(json.contains("modulus"));
282
283 let recovered: TimeLockPuzzle = serde_json::from_str(&json)?;
284 assert_eq!(recovered.squarings_required, puzzle.squarings_required);
285 Ok(())
286 }
287
288 #[test]
289 fn different_payloads_produce_different_puzzles() -> TestResult {
290 let p1 = Payload::from_bytes(b"payload one".to_vec());
291 let p2 = Payload::from_bytes(b"payload two".to_vec());
292 let unlock_at = Utc::now();
293
294 let puzzle1 = create_puzzle(&p1, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
295 let puzzle2 = create_puzzle(&p2, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
296
297 assert_ne!(puzzle1.modulus, puzzle2.modulus);
299 Ok(())
300 }
301
302 #[test]
303 fn wrong_solution_fails_decrypt() -> TestResult {
304 let payload = Payload::from_bytes(b"fail test".to_vec());
305 let unlock_at = Utc::now();
306
307 let mut puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
308
309 if let Some(byte) = puzzle.start_value.first_mut() {
311 *byte ^= 0xFF;
312 }
313
314 let result = solve_puzzle(&puzzle);
315 assert!(result.is_err());
316 Ok(())
317 }
318
319 #[test]
320 fn derive_key_and_nonce_is_deterministic() {
321 let solution = BigUint::from(42u32);
322 let (k1, n1) = derive_key_and_nonce(&solution);
323 let (k2, n2) = derive_key_and_nonce(&solution);
324 assert_eq!(k1, k2);
325 assert_eq!(n1, n2);
326 }
327
328 #[test]
329 fn sequential_square_identity() {
330 let n = BigUint::from(143u32); let g = BigUint::from(2u32);
332
333 let r0 = sequential_square(&g, &n, 0);
335 assert_eq!(r0, g);
336
337 let r1 = sequential_square(&g, &n, 1);
339 assert_eq!(r1, BigUint::from(4u32));
340
341 let r2 = sequential_square(&g, &n, 2);
343 assert_eq!(r2, BigUint::from(16u32));
344 }
345}