snowflake_rs_impl/
snowflake.rs

1use std::time::Instant;
2use std::sync::atomic::{AtomicI64, Ordering};
3use std::error::Error;
4use std::fmt;
5
6use log::{debug, warn};
7use serde::{Deserialize, Serialize};
8
9/// Bit allocation for different parts of the Snowflake ID
10const NODE_BITS: u8 = 10;
11const STEP_BITS: u8 = 12;
12const TIMESTAMP_BITS: u8 = 41;
13
14/// Maximum values for node and step
15const NODE_MAX: u16 = (1 << NODE_BITS) - 1;
16const STEP_MAX: u16 = (1 << STEP_BITS) - 1;
17
18/// Bit shifting constants
19const TIMESTAMP_SHIFT: u8 = NODE_BITS + STEP_BITS;
20const NODE_SHIFT: u8 = STEP_BITS;
21
22/// Default epoch (2021-01-01T00:00:00Z in milliseconds since Unix epoch)
23const DEFAULT_EPOCH: i64 = 1609459200000;
24
25/// Errors that can occur during Snowflake ID generation
26#[derive(Debug,Serialize,Deserialize)]
27pub enum SnowflakeError {
28    /// Indicates that the system clock has moved backwards
29    ClockMovedBackwards,
30    /// Indicates that the provided machine ID is out of the valid range
31    MachineIdOutOfRange,
32    /// Indicates that the sequence number has overflowed
33    SequenceOverflow,
34}
35
36impl fmt::Display for SnowflakeError {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        match self {
39            SnowflakeError::ClockMovedBackwards => write!(f, "Clock moved backwards"),
40            SnowflakeError::MachineIdOutOfRange => write!(f, "Machine ID is out of range"),
41            SnowflakeError::SequenceOverflow => write!(f, "Sequence overflow"),
42        }
43    }
44}
45
46impl Error for SnowflakeError {}
47
48/// Snowflake ID generator
49///
50/// This struct implements the Snowflake algorithm for generating unique IDs.
51/// Each ID is composed of:
52/// - Timestamp (41 bits)
53/// - Node ID (10 bits)
54/// - Sequence number (12 bits)
55pub struct Snowflake {
56    node: u16,
57    epoch_ms: i64,
58    last_timestamp_and_sequence: AtomicI64,
59    start: Instant,
60}
61
62impl Snowflake {
63    /// Creates a new Snowflake instance
64    ///
65    /// # Arguments
66    ///
67    /// * `node` - A unique identifier for the node generating the IDs (0-1023)
68    /// * `epoch` - An optional custom epoch in milliseconds. If None, DEFAULT_EPOCH is used.
69    ///
70    /// # Returns
71    ///
72    /// A Result containing the new Snowflake instance or a SnowflakeError
73    ///
74    /// # Errors
75    ///
76    /// Returns SnowflakeError::MachineIdOutOfRange if the node ID is greater than 1023
77    pub fn new(node: u16, epoch: Option<i64>) -> Result<Self, SnowflakeError> {
78        if node > NODE_MAX {
79            return Err(SnowflakeError::MachineIdOutOfRange);
80        }
81        let epoch_ms = epoch.unwrap_or(DEFAULT_EPOCH);
82        Ok(Snowflake {
83            node,
84            epoch_ms,
85            last_timestamp_and_sequence: AtomicI64::new(0),
86            start: Instant::now(),
87        })
88    }
89
90    /// Generates a new Snowflake ID
91    ///
92    /// # Returns
93    ///
94    /// A Result containing the generated ID as a u64 or a SnowflakeError
95    ///
96    /// # Errors
97    ///
98    /// - SnowflakeError::ClockMovedBackwards if the system time moves backwards
99    /// - SnowflakeError::SequenceOverflow if unable to generate a unique ID within 5 seconds
100    pub fn generate(&self) -> Result<u64, SnowflakeError> {
101        let current_timestamp = self.current_time_millis();
102        let mut last_timestamp_and_sequence = self.last_timestamp_and_sequence.load(Ordering::Acquire);
103
104        loop {
105            let (last_timestamp, last_sequence) = decode_timestamp_and_sequence(last_timestamp_and_sequence);
106            if current_timestamp < last_timestamp {
107                return Err(SnowflakeError::ClockMovedBackwards);
108            }
109            let (new_timestamp, new_sequence) = if current_timestamp == last_timestamp {
110                let new_sequence = (last_sequence + 1) & STEP_MAX as i64;
111                if new_sequence == 0 {
112                    (self.wait_next_millis(last_timestamp)?, 0)
113                } else {
114                    (current_timestamp, new_sequence)
115                }
116            } else {
117                (current_timestamp, 0)
118            };
119            let new_timestamp_and_sequence = encode_timestamp_and_sequence(new_timestamp, new_sequence);
120            match self.last_timestamp_and_sequence.compare_exchange_weak(
121                last_timestamp_and_sequence,
122                new_timestamp_and_sequence,
123                Ordering::AcqRel,
124                Ordering::Acquire,
125            ) {
126                Ok(_) => {
127                    let id = self.create_id(new_timestamp, new_sequence as u16);
128                    return Ok(id);
129                }
130                Err(actual) => {
131                    last_timestamp_and_sequence = actual;
132                }
133            }
134        }
135    }
136    /// Parses a Snowflake ID into its components
137    /// # Arguments
138    /// * `id` - The Snowflake ID to parse
139    /// # Returns
140    /// A tuple containing the timestamp, node ID, and sequence number
141    /// # Example
142    /// ```
143    /// let (timestamp, node, sequence) = Snowflake::parse_id(1234567890);
144    /// println!("Timestamp: {}, Node: {}, Sequence: {}", timestamp, node, sequence);
145    /// ```
146    pub fn parse_id(id: u64) -> (u64, u16, u16) {
147        let timestamp = (id >> TIMESTAMP_SHIFT) & ((1 << TIMESTAMP_BITS) - 1);
148        let node = ((id >> NODE_SHIFT) & ((1 << NODE_BITS) - 1)) as u16;
149        let sequence = (id & ((1 << STEP_BITS) - 1)) as u16;
150        (timestamp, node, sequence)
151    }
152    // Waits until the next millisecond
153    fn wait_next_millis(&self, last_timestamp: i64) -> Result<i64, SnowflakeError> {
154        let start = Instant::now();
155        loop {
156            let current_timestamp = self.current_time_millis();
157            if current_timestamp > last_timestamp {
158                return Ok(current_timestamp);
159            }
160            if start.elapsed().as_millis() > 5000 { // 5 seconds max wait
161                return Err(SnowflakeError::SequenceOverflow);
162            }
163            std::thread::yield_now();
164        }
165    }
166
167    // Creates the final ID by combining timestamp, node ID, and sequence
168    fn create_id(&self, timestamp: i64, sequence: u16) -> u64 {
169        (((timestamp - self.epoch_ms) as u64) << TIMESTAMP_SHIFT)
170            | ((self.node as u64) << NODE_SHIFT)
171            | sequence as u64
172    }
173
174    // Returns the current timestamp in milliseconds
175    fn current_time_millis(&self) -> i64 {
176        use std::time::{SystemTime, UNIX_EPOCH};
177        SystemTime::now()
178            .duration_since(UNIX_EPOCH)
179            .expect("Time went backwards")
180            .as_millis() as i64
181    }
182}
183
184// Encodes timestamp and sequence into a single i64 value
185fn encode_timestamp_and_sequence(timestamp: i64, sequence: i64) -> i64 {
186    (timestamp << STEP_BITS) | sequence
187}
188
189// Decodes timestamp and sequence from a single i64 value
190fn decode_timestamp_and_sequence(value: i64) -> (i64, i64) {
191    let timestamp = value >> STEP_BITS;
192    let sequence = value & STEP_MAX as i64;
193    (timestamp, sequence)
194}