1use std::sync::{Mutex, OnceLock};
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4use crate::RuntimeError;
5
6pub trait InternalIdGenerator: Send + Sync {
7 fn generate_id(&self, entity: &str) -> Result<u64, RuntimeError>;
8}
9
10#[derive(Debug)]
11pub struct SnowflakeIdGenerator {
12 epoch_millis: u64,
13 worker_id: u64,
14 datacenter_id: u64,
15 state: Mutex<SnowflakeState>,
16}
17
18#[derive(Debug, Default)]
19struct SnowflakeState {
20 last_timestamp: u64,
21 sequence: u64,
22}
23
24impl Default for SnowflakeIdGenerator {
25 fn default() -> Self {
26 Self::new(0, 0)
27 }
28}
29
30impl SnowflakeIdGenerator {
31 const DEFAULT_EPOCH_MILLIS: u64 = 1_288_834_974_657;
32 const WORKER_ID_BITS: u64 = 5;
33 const DATACENTER_ID_BITS: u64 = 5;
34 const SEQUENCE_BITS: u64 = 12;
35 const MAX_WORKER_ID: u64 = (1 << Self::WORKER_ID_BITS) - 1;
36 const MAX_DATACENTER_ID: u64 = (1 << Self::DATACENTER_ID_BITS) - 1;
37 const SEQUENCE_MASK: u64 = (1 << Self::SEQUENCE_BITS) - 1;
38 const WORKER_ID_SHIFT: u64 = Self::SEQUENCE_BITS;
39 const DATACENTER_ID_SHIFT: u64 = Self::SEQUENCE_BITS + Self::WORKER_ID_BITS;
40 const TIMESTAMP_SHIFT: u64 =
41 Self::SEQUENCE_BITS + Self::WORKER_ID_BITS + Self::DATACENTER_ID_BITS;
42
43 pub fn new(worker_id: u64, datacenter_id: u64) -> Self {
44 assert!(worker_id <= Self::MAX_WORKER_ID, "worker id out of range");
45 assert!(
46 datacenter_id <= Self::MAX_DATACENTER_ID,
47 "datacenter id out of range"
48 );
49
50 Self {
51 epoch_millis: Self::DEFAULT_EPOCH_MILLIS,
52 worker_id,
53 datacenter_id,
54 state: Mutex::new(SnowflakeState::default()),
55 }
56 }
57
58 fn current_millis() -> Result<u64, RuntimeError> {
59 let now = SystemTime::now()
60 .duration_since(UNIX_EPOCH)
61 .map_err(|err| RuntimeError::IdGeneration(err.to_string()))?;
62 Ok(now.as_millis() as u64)
63 }
64
65 fn wait_until_next_millis(last_timestamp: u64) -> Result<u64, RuntimeError> {
66 loop {
67 let timestamp = Self::current_millis()?;
68 if timestamp > last_timestamp {
69 return Ok(timestamp);
70 }
71 std::thread::sleep(Duration::from_millis(1));
72 }
73 }
74}
75
76impl InternalIdGenerator for SnowflakeIdGenerator {
77 fn generate_id(&self, _entity: &str) -> Result<u64, RuntimeError> {
78 let mut state = self
79 .state
80 .lock()
81 .map_err(|_| RuntimeError::IdGeneration("snowflake state poisoned".to_owned()))?;
82 let mut timestamp = Self::current_millis()?;
83
84 if timestamp < state.last_timestamp {
85 timestamp = Self::wait_until_next_millis(state.last_timestamp)?;
86 }
87
88 if timestamp == state.last_timestamp {
89 state.sequence = (state.sequence + 1) & Self::SEQUENCE_MASK;
90 if state.sequence == 0 {
91 timestamp = Self::wait_until_next_millis(state.last_timestamp)?;
92 }
93 } else {
94 state.sequence = 0;
95 }
96
97 state.last_timestamp = timestamp;
98
99 let relative_timestamp = timestamp.checked_sub(self.epoch_millis).ok_or_else(|| {
100 RuntimeError::IdGeneration("system clock is before snowflake epoch".to_owned())
101 })?;
102
103 Ok((relative_timestamp << Self::TIMESTAMP_SHIFT)
104 | (self.datacenter_id << Self::DATACENTER_ID_SHIFT)
105 | (self.worker_id << Self::WORKER_ID_SHIFT)
106 | state.sequence)
107 }
108}
109
110pub(crate) fn local_id_generator() -> &'static SnowflakeIdGenerator {
111 static LOCAL_ID_GENERATOR: OnceLock<SnowflakeIdGenerator> = OnceLock::new();
112 LOCAL_ID_GENERATOR.get_or_init(SnowflakeIdGenerator::default)
113}