snowflake_rs/
snowflake.rs

1use time;
2use std::sync::{Arc, Mutex};
3
4// temp var for test
5#[allow(dead_code)]
6pub const STANDARD_EPOCH: u64 = 1514736000_000u64;
7
8// machine id's bit
9#[allow(dead_code)]
10const WORKER_ID_BITS: u8 = 10;
11#[allow(dead_code)]
12const SEQUENCE_ID_BITS: u8 = 12;
13#[allow(dead_code)]
14const TIMESTAMP_ID_BITS: u8 = 42;
15
16// shift 
17const WORKER_ID_SHIFT : u8 =12;
18const TIMESTAMP_LEFT_SHIFT: u8 = 22;
19
20// mask
21#[allow(dead_code)]
22const SEQUENCE_MASK : u16 = 0xFFF;
23
24#[derive(Debug, Default)]
25pub struct SnowFlakeId {
26    // system begin running time, micro-second
27    standard_epoch: u64,
28    worker_id: u16,
29    sequence: u16,
30    last_timestamp: u64,
31}
32
33#[allow(dead_code)]
34impl SnowFlakeId {
35    pub fn new(worker_id: u16, standard: u64) -> Self {
36        SnowFlakeId {
37            standard_epoch: standard,
38            worker_id : worker_id,
39            sequence : 0,
40            last_timestamp: 0,
41        }
42    }
43
44    pub fn new_multi_thread(worker_id: u16, standard: u64) -> Arc<Mutex<Self>> {
45        Arc::new(Mutex::new(SnowFlakeId {
46            standard_epoch: standard,
47            worker_id : worker_id,
48            sequence : 0,
49            last_timestamp: 0,
50        }))
51    }
52
53    pub fn generate_id(&mut self) -> Result<u64, String> {
54        let mut curr_timestamp = SnowFlakeId::curr_time_millisec();
55
56        if curr_timestamp < self.last_timestamp {
57            return  Err(format!("Clock moved backwards.  Refusing to generate id for {} milliseconds", self.last_timestamp));
58        }
59
60        if curr_timestamp == self.last_timestamp {
61            self.sequence = (self.sequence + 1) & SEQUENCE_MASK;
62            if self.sequence == 0 {
63                if curr_timestamp == self.last_timestamp {
64                    curr_timestamp = self.wait_for_next_milli_sec();
65                }
66            }
67        } else {
68            self.sequence = 0u16;
69        }
70
71        self.last_timestamp = curr_timestamp;
72        let uid: u64 = (self.last_timestamp - self.standard_epoch) << TIMESTAMP_LEFT_SHIFT |
73            (self.worker_id as u64) << WORKER_ID_SHIFT |
74            (self.sequence as u64);
75
76        Ok(uid)
77    }
78
79    fn wait_for_next_milli_sec(&self) -> u64 {
80        let mut curr_timestamp = SnowFlakeId::curr_time_millisec();
81
82        while self.last_timestamp >= curr_timestamp {
83            curr_timestamp = SnowFlakeId::curr_time_millisec();
84        }
85
86        curr_timestamp
87    }
88
89    fn curr_time_millisec() -> u64 {
90        let time_spec = time::get_time();
91        let mut milli_sec = (time_spec.sec as u64) * 1000u64;
92        milli_sec += (time_spec.nsec as u64) / 1000_000u64;
93        milli_sec
94    }
95}
96
97#[allow(unused_imports)]
98mod test {
99    use crate::snowflake::{self, SnowFlakeId};
100    use std::thread;
101    use std::time::Instant;
102
103    #[test]
104    fn loop_test(){
105        let mut id_gen = SnowFlakeId::new(1, snowflake::STANDARD_EPOCH);
106        println!("{:?}",&id_gen);
107        let now = Instant::now();
108        for _ in 1..1000 {
109            let t  = &mut id_gen;
110            assert!(t.generate_id().is_ok());
111        }
112        let elapsed = now.elapsed();
113        println!("single thread generate 1000 ids cost {}.{:09} s",elapsed.as_secs(), elapsed.subsec_nanos());
114    }
115
116    #[test]
117    fn multi_thread(){
118        let id_gen = SnowFlakeId::new_multi_thread(2, snowflake::STANDARD_EPOCH);
119        let mut ths = Vec::new();
120        for i in 1 .. 10{
121            let t = id_gen.clone();
122            ths.push(thread::spawn(move || {
123                let now = Instant::now();
124                for _ in 1..1000 {
125                    let mut gen = t.lock().unwrap();
126                    let id = gen.generate_id();
127                    assert!(id.is_ok());
128                    //println!("{:?}",id.unwrap());
129                }
130                let elapsed = now.elapsed();
131                println!("multi thread:[{}] generate 1000 ids cost {}.{:09} s", i, elapsed.as_secs(), elapsed.subsec_nanos());
132            }));
133        }
134
135        for t in ths {
136            t.join();
137        }
138
139    }
140}