1use crate::constants::{DST_TIME_ADVANCE_MS_MAX, TIME_MS_PER_SEC};
7use chrono::{DateTime, Duration, Utc};
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use tokio::sync::Notify;
11
12#[derive(Debug, Clone)]
22pub struct SimClock {
23 current_ms: Arc<AtomicU64>,
25 notify: Arc<Notify>,
27}
28
29impl SimClock {
30 #[must_use]
39 pub fn new() -> Self {
40 Self {
41 current_ms: Arc::new(AtomicU64::new(0)),
42 notify: Arc::new(Notify::new()),
43 }
44 }
45
46 #[must_use]
48 pub fn at_ms(start_ms: u64) -> Self {
49 Self {
50 current_ms: Arc::new(AtomicU64::new(start_ms)),
51 notify: Arc::new(Notify::new()),
52 }
53 }
54
55 #[must_use]
57 pub fn at_datetime(dt: DateTime<Utc>) -> Self {
58 let ms = dt.timestamp_millis() as u64;
59 Self::at_ms(ms)
60 }
61
62 #[must_use]
64 pub fn from_epoch() -> Self {
65 Self::new()
66 }
67
68 #[must_use]
70 pub fn now_ms(&self) -> u64 {
71 self.current_ms.load(Ordering::SeqCst)
72 }
73
74 #[must_use]
76 pub fn now_secs(&self) -> u64 {
77 self.now_ms() / TIME_MS_PER_SEC
78 }
79
80 #[must_use]
82 pub fn now(&self) -> DateTime<Utc> {
83 let ms = self.now_ms() as i64;
84 DateTime::from_timestamp_millis(ms).unwrap_or_else(|| {
85 DateTime::from_timestamp(0, 0).unwrap()
87 })
88 }
89
90 pub fn advance_ms(&self, ms: u64) -> u64 {
98 assert!(
100 ms <= DST_TIME_ADVANCE_MS_MAX,
101 "advance_ms({}) exceeds max ({})",
102 ms,
103 DST_TIME_ADVANCE_MS_MAX
104 );
105
106 let old_time = self.current_ms.fetch_add(ms, Ordering::SeqCst);
107 let new_time = old_time.saturating_add(ms);
108
109 self.notify.notify_waiters();
111
112 assert!(new_time >= old_time, "time must not go backwards");
114
115 new_time
116 }
117
118 pub fn advance_secs(&self, secs: f64) -> u64 {
123 assert!(secs >= 0.0, "secs must be non-negative, got {}", secs);
125
126 let ms = (secs * 1000.0) as u64;
127 self.advance_ms(ms)
128 }
129
130 pub fn advance(&self, duration: Duration) {
132 debug_assert!(duration >= Duration::zero(), "cannot go back in time");
133
134 let delta_ms = duration.num_milliseconds() as u64;
135 self.advance_ms(delta_ms);
136 }
137
138 pub fn set_ms(&self, ms: u64) {
143 let current = self.now_ms();
144 assert!(
146 ms >= current,
147 "cannot set time backwards: {} < {}",
148 ms,
149 current
150 );
151
152 self.current_ms.store(ms, Ordering::SeqCst);
153 self.notify.notify_waiters();
154
155 assert_eq!(self.now_ms(), ms, "time must be set correctly");
157 }
158
159 pub fn set(&self, time: DateTime<Utc>) {
161 let ms = time.timestamp_millis() as u64;
162 self.set_ms(ms);
163 }
164
165 #[must_use]
170 pub fn elapsed_since(&self, since: u64) -> u64 {
171 let current = self.now_ms();
172 assert!(
174 since <= current,
175 "elapsed_since({}) is in the future (now={})",
176 since,
177 current
178 );
179
180 current - since
181 }
182
183 #[must_use]
185 pub fn has_elapsed(&self, since: u64, duration_ms: u64) -> bool {
186 self.elapsed_since(since) >= duration_ms
187 }
188
189 #[must_use]
191 pub fn is_past_ms(&self, deadline_ms: u64) -> bool {
192 self.now_ms() >= deadline_ms
193 }
194
195 #[must_use]
197 pub fn is_past(&self, deadline: DateTime<Utc>) -> bool {
198 self.now() >= deadline
199 }
200
201 #[must_use]
203 pub fn timestamp(&self) -> u64 {
204 self.now_ms()
205 }
206
207 pub async fn sleep_ms(&self, duration_ms: u64) {
212 let target_ms = self.now_ms() + duration_ms;
213
214 while self.now_ms() < target_ms {
215 self.notify.notified().await;
216 }
217 }
218
219 pub async fn sleep(&self, duration: Duration) {
221 let ms = duration.num_milliseconds() as u64;
222 self.sleep_ms(ms).await;
223 }
224
225 pub async fn sleep_until_ms(&self, deadline_ms: u64) {
227 while self.now_ms() < deadline_ms {
228 self.notify.notified().await;
229 }
230 }
231}
232
233impl Default for SimClock {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_initial_time() {
245 let clock = SimClock::new();
246 assert_eq!(clock.now_ms(), 0);
247 assert_eq!(clock.now_secs(), 0);
248 }
249
250 #[test]
251 fn test_at_ms() {
252 let clock = SimClock::at_ms(5000);
253 assert_eq!(clock.now_ms(), 5000);
254 assert_eq!(clock.now_secs(), 5);
255 }
256
257 #[test]
258 fn test_at_datetime() {
259 let dt = DateTime::parse_from_rfc3339("2024-01-01T00:00:00Z")
260 .unwrap()
261 .to_utc();
262 let clock = SimClock::at_datetime(dt);
263 assert_eq!(clock.now(), dt);
264 }
265
266 #[test]
267 fn test_advance_ms() {
268 let clock = SimClock::new();
269
270 let new_time = clock.advance_ms(1000);
271
272 assert_eq!(new_time, 1000);
273 assert_eq!(clock.now_ms(), 1000);
274 }
275
276 #[test]
277 fn test_advance_secs() {
278 let clock = SimClock::new();
279
280 let new_time = clock.advance_secs(1.5);
281
282 assert_eq!(new_time, 1500);
283 assert_eq!(clock.now_ms(), 1500);
284 }
285
286 #[test]
287 fn test_advance_duration() {
288 let clock = SimClock::new();
289
290 clock.advance(Duration::seconds(10));
291
292 assert_eq!(clock.now_ms(), 10_000);
293 }
294
295 #[test]
296 fn test_multiple_advances() {
297 let clock = SimClock::new();
298
299 clock.advance_ms(100);
300 clock.advance_ms(200);
301 clock.advance_ms(300);
302
303 assert_eq!(clock.now_ms(), 600);
304 }
305
306 #[test]
307 #[should_panic(expected = "advance_ms")]
308 fn test_advance_exceeds_max() {
309 let clock = SimClock::new();
310 clock.advance_ms(DST_TIME_ADVANCE_MS_MAX + 1);
311 }
312
313 #[test]
314 fn test_set_ms() {
315 let clock = SimClock::new();
316
317 clock.set_ms(5000);
318
319 assert_eq!(clock.now_ms(), 5000);
320 }
321
322 #[test]
323 #[should_panic(expected = "cannot set time backwards")]
324 fn test_set_ms_backwards() {
325 let clock = SimClock::new();
326 clock.advance_ms(1000);
327 clock.set_ms(500);
328 }
329
330 #[test]
331 fn test_elapsed_since() {
332 let clock = SimClock::new();
333 let start = clock.now_ms();
334 clock.advance_ms(500);
335
336 let elapsed = clock.elapsed_since(start);
337
338 assert_eq!(elapsed, 500);
339 }
340
341 #[test]
342 fn test_has_elapsed() {
343 let clock = SimClock::new();
344 let start = clock.now_ms();
345
346 assert!(!clock.has_elapsed(start, 1000));
347
348 clock.advance_ms(500);
349 assert!(!clock.has_elapsed(start, 1000));
350
351 clock.advance_ms(500);
352 assert!(clock.has_elapsed(start, 1000));
353
354 clock.advance_ms(100);
355 assert!(clock.has_elapsed(start, 1000));
356 }
357
358 #[test]
359 #[should_panic(expected = "is in the future")]
360 fn test_elapsed_since_future() {
361 let clock = SimClock::new();
362 let _ = clock.elapsed_since(1000);
363 }
364
365 #[test]
366 fn test_timestamp() {
367 let clock = SimClock::new();
368 clock.advance_ms(12345);
369 assert_eq!(clock.timestamp(), 12345);
370 }
371
372 #[test]
373 fn test_is_past_ms() {
374 let clock = SimClock::at_ms(1000);
375
376 assert!(clock.is_past_ms(500));
377 assert!(clock.is_past_ms(1000));
378 assert!(!clock.is_past_ms(1500));
379 }
380
381 #[test]
382 fn test_now_datetime() {
383 let clock = SimClock::at_ms(0);
384 let epoch = DateTime::from_timestamp(0, 0).unwrap();
385 assert_eq!(clock.now(), epoch);
386 }
387
388 #[test]
389 fn test_clone_shares_time() {
390 let clock1 = SimClock::new();
391 let clock2 = clock1.clone();
392
393 clock1.advance_ms(1000);
394
395 assert_eq!(clock1.now_ms(), 1000);
397 assert_eq!(clock2.now_ms(), 1000);
398 }
399
400 #[tokio::test]
401 async fn test_sleep_ms() {
402 let clock = SimClock::new();
403 let clock_clone = clock.clone();
404
405 let handle = tokio::spawn(async move {
407 clock_clone.sleep_ms(100).await;
408 clock_clone.now_ms()
409 });
410
411 tokio::task::yield_now().await;
413 clock.advance_ms(50);
414 tokio::task::yield_now().await;
415 clock.advance_ms(50);
416 tokio::task::yield_now().await;
417
418 let result = handle.await.unwrap();
419 assert!(result >= 100);
420 }
421
422 #[tokio::test]
423 async fn test_sleep_duration() {
424 let clock = SimClock::new();
425 let clock_clone = clock.clone();
426
427 let handle = tokio::spawn(async move {
428 clock_clone.sleep(Duration::milliseconds(200)).await;
429 clock_clone.now_ms()
430 });
431
432 tokio::task::yield_now().await;
433 clock.advance_ms(200);
434 tokio::task::yield_now().await;
435
436 let result = handle.await.unwrap();
437 assert!(result >= 200);
438 }
439}