Skip to main content

teaql_runtime/
id.rs

1use std::sync::{Mutex, OnceLock};
2use std::time::{Duration, SystemTime, UNIX_EPOCH};
3
4use crate::RuntimeError;
5
6pub trait InternalIdGenerator: Send + Sync {
7    fn generate_id(&self, entity: &str) -> Result<u64, RuntimeError>;
8}
9
10#[derive(Debug)]
11pub struct SnowflakeIdGenerator {
12    epoch_millis: u64,
13    worker_id: u64,
14    datacenter_id: u64,
15    state: Mutex<SnowflakeState>,
16}
17
18#[derive(Debug, Default)]
19struct SnowflakeState {
20    last_timestamp: u64,
21    sequence: u64,
22}
23
24impl Default for SnowflakeIdGenerator {
25    fn default() -> Self {
26        Self::new(0, 0)
27    }
28}
29
30impl SnowflakeIdGenerator {
31    const DEFAULT_EPOCH_MILLIS: u64 = 1_288_834_974_657;
32    const WORKER_ID_BITS: u64 = 5;
33    const DATACENTER_ID_BITS: u64 = 5;
34    const SEQUENCE_BITS: u64 = 12;
35    const MAX_WORKER_ID: u64 = (1 << Self::WORKER_ID_BITS) - 1;
36    const MAX_DATACENTER_ID: u64 = (1 << Self::DATACENTER_ID_BITS) - 1;
37    const SEQUENCE_MASK: u64 = (1 << Self::SEQUENCE_BITS) - 1;
38    const WORKER_ID_SHIFT: u64 = Self::SEQUENCE_BITS;
39    const DATACENTER_ID_SHIFT: u64 = Self::SEQUENCE_BITS + Self::WORKER_ID_BITS;
40    const TIMESTAMP_SHIFT: u64 =
41        Self::SEQUENCE_BITS + Self::WORKER_ID_BITS + Self::DATACENTER_ID_BITS;
42
43    pub fn new(worker_id: u64, datacenter_id: u64) -> Self {
44        assert!(worker_id <= Self::MAX_WORKER_ID, "worker id out of range");
45        assert!(
46            datacenter_id <= Self::MAX_DATACENTER_ID,
47            "datacenter id out of range"
48        );
49
50        Self {
51            epoch_millis: Self::DEFAULT_EPOCH_MILLIS,
52            worker_id,
53            datacenter_id,
54            state: Mutex::new(SnowflakeState::default()),
55        }
56    }
57
58    fn current_millis() -> Result<u64, RuntimeError> {
59        let now = SystemTime::now()
60            .duration_since(UNIX_EPOCH)
61            .map_err(|err| RuntimeError::IdGeneration(err.to_string()))?;
62        Ok(now.as_millis() as u64)
63    }
64
65    fn wait_until_next_millis(last_timestamp: u64) -> Result<u64, RuntimeError> {
66        loop {
67            let timestamp = Self::current_millis()?;
68            if timestamp > last_timestamp {
69                return Ok(timestamp);
70            }
71            std::thread::sleep(Duration::from_millis(1));
72        }
73    }
74}
75
76impl InternalIdGenerator for SnowflakeIdGenerator {
77    fn generate_id(&self, _entity: &str) -> Result<u64, RuntimeError> {
78        let mut state = self
79            .state
80            .lock()
81            .map_err(|_| RuntimeError::IdGeneration("snowflake state poisoned".to_owned()))?;
82        let mut timestamp = Self::current_millis()?;
83
84        if timestamp < state.last_timestamp {
85            timestamp = Self::wait_until_next_millis(state.last_timestamp)?;
86        }
87
88        if timestamp == state.last_timestamp {
89            state.sequence = (state.sequence + 1) & Self::SEQUENCE_MASK;
90            if state.sequence == 0 {
91                timestamp = Self::wait_until_next_millis(state.last_timestamp)?;
92            }
93        } else {
94            state.sequence = 0;
95        }
96
97        state.last_timestamp = timestamp;
98
99        let relative_timestamp = timestamp.checked_sub(self.epoch_millis).ok_or_else(|| {
100            RuntimeError::IdGeneration("system clock is before snowflake epoch".to_owned())
101        })?;
102
103        Ok((relative_timestamp << Self::TIMESTAMP_SHIFT)
104            | (self.datacenter_id << Self::DATACENTER_ID_SHIFT)
105            | (self.worker_id << Self::WORKER_ID_SHIFT)
106            | state.sequence)
107    }
108}
109
110pub(crate) fn local_id_generator() -> &'static SnowflakeIdGenerator {
111    static LOCAL_ID_GENERATOR: OnceLock<SnowflakeIdGenerator> = OnceLock::new();
112    LOCAL_ID_GENERATOR.get_or_init(SnowflakeIdGenerator::default)
113}