1use super::defs::MAX_SEQ;
2use super::utils::ms_since_epoch;
3use super::Id;
4use std::sync::{Arc, Mutex};
5
6struct State {
7 seq: u16, ts: u64, }
10
11#[derive(Clone)]
12pub struct Context {
13 epoch: u128, worker_id: u16, state: Arc<Mutex<State>>, }
17
18impl Context {
19 pub fn new(epoch: u128, worker_id: u16) -> Result<Context, String> {
20 let now = ms_since_epoch()?;
21
22 if now < epoch {
23 return Err(format!("bad epoch {}", epoch));
24 }
25
26 if now == epoch {
28 loop {
29 if ms_since_epoch()? != epoch {
30 break;
31 }
32 }
33 }
34
35 Ok(Context {
36 epoch,
37 worker_id,
38 state: Arc::new(Mutex::new(State { seq: 0, ts: 0 })),
39 })
40 }
41
42 pub fn next(&self) -> Result<Id, String> {
43 let mut state = self.state.lock().unwrap();
44 let ts = self.get_ts()?;
45
46 if ts == state.ts {
47 (*state).seq += 1;
48 } else {
49 (*state).seq = 0;
50 (*state).ts = ts;
51 };
52
53 let seq = state.seq;
54
55 if seq >= MAX_SEQ {
56 return Err(format!("bad seq {}", seq));
57 }
58
59 Ok(Id::new(ts, self.worker_id, seq))
60 }
61
62 pub fn next_id(&self) -> Id {
63 loop {
64 if let Ok(id) = self.next() {
65 return id;
66 }
67 }
68 }
69
70 fn get_ts(&self) -> Result<u64, String> {
71 let now = ms_since_epoch()?;
72 Ok((now - self.epoch) as u64)
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::Context;
79
80 #[test]
81 fn test_create_context_fail() {
82 let context = Context::new(!0, 0);
83
84 assert!(context.is_err());
85 }
86
87 #[test]
88 fn test_create_context_success() {
89 let context = Context::new(1_234_567_891_011, 0);
90
91 assert!(context.is_ok());
92
93 let context = Context::new(0, 0);
94
95 assert!(context.is_ok());
96 }
97
98 #[test]
99 fn test_generate_id_success() {
100 let context = Context::new(1_548_067_209_841, 0).unwrap();
101 let id = context.next_id();
102
103 assert_ne!(id, 0.into());
104 }
105}