snowflake_ng/
lib.rs

1// Copyright 2024 Krysztal Huang
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8#![doc = include_str!("../README.md")]
9
10use std::{
11    ops::Deref,
12    sync::{
13        atomic::{AtomicU64, Ordering},
14        Arc,
15    },
16    time::Duration,
17};
18
19use futures::executor;
20use futures_timer::Delay;
21use rand::RngCore;
22
23pub mod provider;
24
25pub trait TimeProvider {
26    /// Timestamp fetcher.
27    fn timestamp(&self) -> u64;
28}
29
30/// Generated [`Snowflake`](Snowflake)
31///
32/// # Implementation
33///
34/// Let me describe snowflake ID (*SID in below*) in simple words.
35///
36/// Firstly, we have to know this structure of SID.
37///
38/// SIDs are actually [i64] types. It's length 64bit and 1bit for sign.
39///
40/// So it looks like this:
41///
42/// ```text
43/// | sign |   data    | # sign not used.
44/// | 1bit |   63bit   |
45/// ```
46///
47/// Next, I'll introduce standard SID design to you. Why STANDARD? Because there are some variant, just ignore them use Twitter's(formally X) design only.
48///
49/// The standard SID contains these content:
50///
51/// - Timestamp: 41bit
52/// - Identifier(or Machine ID?): 10bit
53/// - Sequence Number: 12bit
54///
55/// Our SID structure looks like this
56/// ```text
57/// | sign |                data                      |
58/// |   0  | Timestamp | Identifier | Sequence Number |
59/// | 1bit |   41bit   |    10bit   |     12bit       |
60/// ```
61///
62/// ✨ So cool, you in just understood the SID structure!
63///
64/// Ok, let's deep in **_DARK_**.
65///
66/// ## Timestamp
67///
68/// In standard design, timestamp can start at any time.
69///
70/// But here, the precision we need for the timestamp is to the millisecond, so exactly 41bits.
71///
72/// ## Identifier
73///
74/// Base the design of distributed systems, we will have many machine(or instance) running at same time.
75///
76/// So we must distinguish between them. Based identifier have 10bit, we can have 1024 instance at same time, thats so cool!
77///
78/// ## Sequence Number
79///
80/// Have you just noticed the `Sequence Number`? It have 12bit, means it can process at most 4096 message(or other things if you want) in one millisecond.
81///
82/// Above all, we can know: the entire system can produce at most `1024 * 4096 = 4194304` pieces of message at one millisecond!
83///
84/// ## Out of assigned
85///
86/// But there is always the possibility that we will encounter a situation: all the SIDs for this millisecond have been assigned!
87///
88/// At this time, the instance must waiting for next millisecond. At next millisecond, we will have new 4096 SID can be assigned.
89#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
90#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
91pub struct Snowflake(i64);
92
93/// Type alias for [`i64`](i64)
94pub type SnowflakeId = i64;
95
96impl From<Snowflake> for i64 {
97    fn from(value: Snowflake) -> Self {
98        value.0
99    }
100}
101
102impl Deref for Snowflake {
103    type Target = i64;
104
105    fn deref(&self) -> &Self::Target {
106        &self.0
107    }
108}
109
110impl AsRef<i64> for Snowflake {
111    fn as_ref(&self) -> &i64 {
112        self
113    }
114}
115
116#[derive(Debug)]
117pub struct SnowflakeConfiguration {
118    /// Identifier ID
119    ///
120    /// [`SnowflakeGenerator`](SnowflakeGenerator) will use **_10bit_**
121    ///
122    /// By default, `identifier_id` set to the number generated by `rand` crate.
123    pub identifier: u64,
124}
125
126impl SnowflakeConfiguration {
127    pub fn with_identifier(identifier: u64) -> Self {
128        Self { identifier }
129    }
130}
131
132impl Default for SnowflakeConfiguration {
133    fn default() -> Self {
134        Self {
135            identifier: rand::thread_rng().next_u64(),
136        }
137    }
138}
139
140unsafe impl Send for SnowflakeConfiguration {}
141
142/// Filling timestamp by mask  
143fn fill_timestamp(sid: u64, timestamp: u64) -> u64 {
144    const MASK: u64 = (1u64 << 41) - 1;
145    let truncated_timestamp = timestamp & MASK; // Make sure `timestamp` up to 41bit
146    let filled = truncated_timestamp << 22;
147    (sid & !(MASK << 22)) | filled
148}
149
150/// Filling identifier by mask
151fn fill_identifier(sid: u64, identifier: u64) -> u64 {
152    const MASK: u64 = (1u64 << 10) - 1; // 限定为10位
153    let truncated_identifier = identifier & MASK; // Make sure `identifier` up to 10bit
154    let filled = truncated_identifier << 12;
155    (sid & !(MASK << 12)) | filled
156}
157
158/// Filling sequence by mask
159fn fill_sequence(sid: u64, sequence: u64) -> u64 {
160    const MASK: u64 = (1u64 << 12) - 1;
161    let truncated_sequence = sequence & MASK; // // Make sure `sequence` up to 12bit
162
163    // Does not need to shift
164    (sid & !MASK) | truncated_sequence
165}
166
167pub fn filling<T0, T1, T2>(dest: u64, timestamp: T0, identifier: T1, sequence: T2) -> u64
168where
169    T0: Into<u64>,
170    T1: Into<u64>,
171    T2: Into<u64>,
172{
173    let sid = fill_timestamp(dest, timestamp.into());
174    let sid = fill_identifier(sid, identifier.into());
175    fill_sequence(sid, sequence.into())
176}
177
178/// Generating [`Snowflake`](Snowflake)
179///
180/// Recommended keep this generator single-instance for one instance's SID generation.
181///
182/// # Thread safety
183///
184/// You can use [`::std::sync::Arc`](::std::sync::Arc) sharing ownership between thread.
185#[derive(Debug, Default)]
186pub struct SnowflakeGenerator {
187    timestamp_sequence: AtomicU64,
188    cfg: SnowflakeConfiguration,
189}
190const MAX_SEQUENCE: u16 = 0xFFF; // 12bit sequence
191
192impl SnowflakeGenerator {
193    pub fn with_cfg(cfg: SnowflakeConfiguration) -> Self {
194        Self {
195            cfg,
196            timestamp_sequence: AtomicU64::new(0),
197        }
198    }
199
200    /// Assign a [`Snowflake`](Snowflake) with [`TimeProvider`](TimeProvider)
201    pub async fn assign<T>(&self, provider: &T) -> Snowflake
202    where
203        T: TimeProvider + Sync + Send,
204    {
205        loop {
206            let timestamp = provider.timestamp();
207            let current = self.timestamp_sequence.load(Ordering::Relaxed);
208            let current_timestamp = current >> 16;
209            let current_sequence = (current & 0xFFFF) as u16;
210
211            match current_timestamp.cmp(&timestamp) {
212                std::cmp::Ordering::Less => {
213                    // update timestamp
214                    let new_value = timestamp << 16;
215
216                    if self
217                        .timestamp_sequence
218                        .compare_exchange(current, new_value, Ordering::SeqCst, Ordering::SeqCst)
219                        .is_ok()
220                    {
221                        let sid = fill_timestamp(0, timestamp);
222                        let sid = fill_identifier(sid, self.cfg.identifier);
223                        let sid = fill_sequence(sid, 0);
224                        return Snowflake(sid as i64);
225                    }
226                }
227                std::cmp::Ordering::Equal => {
228                    if current_sequence >= MAX_SEQUENCE {
229                        // Sequence reached MAX, waiting for next millisecond
230                        Delay::new(Duration::from_millis(1)).await;
231                        continue;
232                    }
233
234                    let new_sequence = current_sequence + 1;
235                    let new_value = (timestamp << 16) | new_sequence as u64;
236
237                    if self
238                        .timestamp_sequence
239                        .compare_exchange(current, new_value, Ordering::SeqCst, Ordering::SeqCst)
240                        .is_ok()
241                    {
242                        let sid = fill_timestamp(0, timestamp);
243                        let sid = fill_identifier(sid, self.cfg.identifier);
244                        let sid = fill_sequence(sid, new_sequence as u64);
245                        return Snowflake(sid as i64);
246                    }
247                }
248                std::cmp::Ordering::Greater => Delay::new(Duration::from_millis(1)).await,
249            };
250        }
251    }
252
253    /// Assign a new [`Snowflake`](Snowflake) but in synchronous way.
254    #[cfg(feature = "sync")]
255    pub fn assign_sync<T>(&self, provider: &T) -> Snowflake
256    where
257        T: TimeProvider + Sync + Send,
258    {
259        executor::block_on(self.assign(provider))
260    }
261}
262
263/// Persisted [`SnowflakeGenerator`](SnowflakeGenerator).
264///
265/// Designed for easier contextualization.
266///
267/// # Thread safety
268///
269/// YES. You can use [`::std::sync::Arc`](::std::sync::Arc) to send data between threads safety.
270///
271/// # Clone
272///
273/// Clone is cheap. If you clone it, it equals invoke [`Arc::clone`](Arc::clone) three times.
274#[derive(Debug)]
275pub struct PersistedSnowflakeGenerator<T> {
276    generator: Arc<SnowflakeGenerator>,
277    provider: Arc<T>,
278}
279
280impl<T> PersistedSnowflakeGenerator<T>
281where
282    T: TimeProvider + Send + Sync,
283{
284    /// Constructing new [`PersistedSnowflakeGenerator`](PersistedSnowflakeGenerator) from already instanced [`SnowflakeGenerator`](SnowflakeGenerator) and [`TimeProvider`](TimeProvider)
285    ///
286    /// The cost very low, please relax to constructing your [`PersistedSnowflakeGenerator`](PersistedSnowflakeGenerator).
287    ///
288    /// # Thread safety
289    ///
290    /// Yes, `time_provider` must be send and sync between threads and [`SnowflakeGenerator`](SnowflakeGenerator) are already thread safe.
291    pub fn new(generator: Arc<SnowflakeGenerator>, provider: Arc<T>) -> Self {
292        Self {
293            generator,
294            provider,
295        }
296    }
297
298    /// Assign a new [`Snowflake`](Snowflake)
299    pub async fn assign(&self) -> Snowflake {
300        self.generator.assign(self.provider.as_ref()).await
301    }
302
303    /// Assign a new [`Snowflake`](Snowflake) but in synchronous way.
304    #[cfg(feature = "sync")]
305    pub fn assign_sync(&self) -> Snowflake {
306        self.generator.assign_sync(self.provider.as_ref())
307    }
308}
309
310impl<T> Clone for PersistedSnowflakeGenerator<T> {
311    fn clone(&self) -> Self {
312        Self {
313            generator: self.generator.clone(),
314            provider: self.provider.clone(),
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use std::{collections::HashSet, sync::Arc};
322
323    use parking_lot::RwLock;
324    use provider::{StdProvider, STD_PROVIDER};
325
326    use super::*;
327
328    #[test]
329    fn test_fill_timestamp() {
330        // Case1
331        let sid = 0u64;
332        let timestamp = 0b101010;
333        let expected = 42 << 22;
334        let result = fill_timestamp(sid, timestamp);
335        assert_eq!(result, expected);
336
337        // Case2
338        let sid = 0u64;
339        let timestamp = (1u64 << 42) - 1;
340        let expected = ((1u64 << 41) - 1) << 22;
341        let result = fill_timestamp(sid, timestamp);
342        assert_eq!(result, expected);
343    }
344
345    #[test]
346    fn test_fill_identifier() {
347        // Case1
348        let sid = 0u64;
349        let identifier = 0b110101;
350        let expected = 53 << 12;
351        let result = fill_identifier(sid, identifier);
352        assert_eq!(result, expected);
353
354        // Case2
355        let sid = 0u64;
356        let identifier = (1u64 << 11) - 1;
357        let expected = ((1u64 << 10) - 1) << 12;
358        let result = fill_identifier(sid, identifier);
359        assert_eq!(result, expected);
360    }
361
362    #[test]
363    fn test_fill_sequence() {
364        // Case1
365        let sid = 0u64;
366        let sequence = 0b1001;
367        let expected = 9;
368        let result = fill_sequence(sid, sequence);
369        assert_eq!(result, expected);
370
371        // Case2
372        let sid = 0u64;
373        let sequence = (1u64 << 13) - 1;
374        let expected = (1u64 << 12) - 1;
375        let result = fill_sequence(sid, sequence);
376        assert_eq!(result, expected);
377    }
378
379    #[test]
380    fn test_filling() {
381        let sid = 0u64;
382        let timestamp = 0b10101010101010101010101010101010101010101u64;
383        let identifier = 0b110101u64;
384        let sequence = 0b1001u64;
385
386        let expected = (timestamp << 22) | (identifier << 12) | sequence;
387
388        let result = filling(sid, timestamp, identifier, sequence);
389        assert_eq!(result, expected);
390    }
391
392    #[tokio::test]
393    async fn test_assign() {
394        let generator = Arc::new(SnowflakeGenerator::default());
395
396        for _ in 0..1024 {
397            generator.assign(&provider::TIME_CRATE_PROVIDER).await;
398        }
399    }
400
401    #[tokio::test]
402    async fn test_assign_multithread() {
403        let generator = Arc::new(SnowflakeGenerator::default());
404
405        let mut handles = vec![];
406        let id_set = Arc::new(RwLock::new(HashSet::new()));
407
408        for _ in 0..1000 {
409            let generator = Arc::clone(&generator);
410            let id_set = Arc::clone(&id_set);
411            let handle = tokio::spawn(async move {
412                for _ in 0..1000 {
413                    let id = generator.assign(&STD_PROVIDER).await;
414                    let mut set = id_set.write();
415                    if set.contains(&id) {
416                        panic!("Duplicate `Snowflake` generated!");
417                    }
418                    set.insert(id);
419                }
420            });
421            handles.push(handle);
422        }
423
424        futures::future::join_all(handles).await;
425
426        assert_eq!(
427            id_set.read().len(),
428            1000 * 1000,
429            "Some `Snowflake` were lost!"
430        );
431    }
432
433    #[test]
434    fn test_persists() {
435        let binding = Arc::new(SnowflakeGenerator::default());
436        let persist = PersistedSnowflakeGenerator::new(binding.clone(), Arc::new(StdProvider));
437
438        let snowflakes = (0..1000)
439            .map(|_| persist.assign_sync())
440            .collect::<HashSet<_>>();
441
442        assert_eq!(snowflakes.len(), 1000);
443    }
444
445    #[tokio::test]
446    async fn test_persists_multithread() {
447        let binding = Arc::new(SnowflakeGenerator::default());
448
449        let persist = Arc::new(PersistedSnowflakeGenerator::new(
450            binding.clone(),
451            Arc::new(StdProvider),
452        ));
453
454        let tasks = (0..1000).map(|_| {
455            let persist = persist.clone();
456            tokio::spawn(async move { persist.assign().await })
457        });
458        let snowflakes = futures::future::join_all(tasks).await;
459
460        assert_eq!(snowflakes.len(), 1000);
461    }
462}