snowflaked/
sync.rs

1//! Thread-safe Snowflake Generator
2//!
3//! This module provides [`Generator`] which can safely be shared between threads. Its constructor
4//! is also const allowing to use it in a `static` context.
5//!
6//! # Example
7//! ```
8//! use snowflaked::sync::Generator;
9//!
10//! static GENERATOR: Generator = Generator::new(0);
11//!
12//! fn generate_id() -> u64 {
13//!     GENERATOR.generate()
14//! }
15//! ```
16
17use std::time::SystemTime;
18
19use crate::builder::Builder;
20use crate::loom::{AtomicU64, Ordering};
21use crate::time::{DefaultTime, Time};
22use crate::{const_panic_new, Components, Snowflake, INSTANCE_MAX};
23
24/// A generator for unique snowflake ids. Since [`generate`] accepts a `&self` reference this can
25/// be used in a `static` context.
26///
27/// # Cloning
28///
29/// Cloning a `Generator` will create a second one with the same state as the original one. The
30/// internal state is copied and not shared. If you need to share a single `Generator` you need to
31/// manually wrap it in an [`Arc`] (or similar).
32///
33/// # Example
34/// ```
35/// use snowflaked::sync::Generator;
36///
37/// static GENERATOR: Generator = Generator::new(0);
38///
39/// fn generate_id() -> u64 {
40///     GENERATOR.generate()
41/// }
42/// ```
43///
44/// [`generate`]: Self::generate
45/// [`Arc`]: std::sync::Arc
46#[derive(Debug)]
47pub struct Generator {
48    internal: InternalGenerator<SystemTime>,
49}
50
51impl Generator {
52    /// Creates a new `Generator` using the given `instance`.
53    ///
54    /// # Panics
55    ///
56    /// Panics if `instance` exceeds the maximum value (2^10 - 1).
57    #[cfg(not(loom))]
58    #[inline]
59    pub const fn new(instance: u16) -> Self {
60        match Self::new_checked(instance) {
61            Some(this) => this,
62            None => const_panic_new(),
63        }
64    }
65
66    /// Creates a new `Generator` using the given `instance`. Returns `None` if the instance
67    /// exceeds the maximum instance value (2^10 - 1).
68    #[cfg(not(loom))]
69    #[inline]
70    pub const fn new_checked(instance: u16) -> Option<Self> {
71        if instance > INSTANCE_MAX {
72            None
73        } else {
74            Some(Self::new_unchecked(instance))
75        }
76    }
77
78    /// Creates a new `Generator` using the given `instance` without checking if it exceeds the
79    /// maximum value (2^10 - 1).
80    ///
81    /// Note: If `instance` exceeds the maximum value the `Generator` will create incorrect
82    /// snowflakes.
83    #[cfg(not(loom))]
84    #[inline]
85    pub const fn new_unchecked(instance: u16) -> Self {
86        Self {
87            internal: InternalGenerator::new_unchecked(instance),
88        }
89    }
90
91    /// Creates a new `Builder` used to configure a `Generator`.
92    ///
93    /// # Examples
94    ///
95    /// ```
96    /// # use snowflaked::sync::Generator;
97    /// use std::time::SystemTime;
98    ///
99    /// let epoch = SystemTime::now();
100    /// let generator: Generator = Generator::builder().instance(123).epoch(epoch).build();
101    ///
102    /// assert_eq!(generator.instance(), 123);
103    /// assert_eq!(generator.epoch(), epoch);
104    /// ```
105    #[inline]
106    pub const fn builder() -> Builder {
107        Builder::new()
108    }
109
110    /// Returns the configured instance component of this `Generator`.
111    ///
112    /// # Examples
113    ///
114    /// ```
115    /// # use snowflaked::sync::Generator;
116    /// #
117    /// let mut generator = Generator::new(123);
118    ///
119    /// assert_eq!(generator.instance(), 123);
120    /// ```
121    #[inline]
122    pub fn instance(&self) -> u16 {
123        self.internal.instance()
124    }
125
126    /// Returns the configured epoch of this `Generator`. By default this is [`UNIX_EPOCH`].
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// # use snowflaked::sync::Generator;
132    /// use std::time::UNIX_EPOCH;
133    ///
134    /// let generator = Generator::new(123);
135    /// assert_eq!(generator.epoch(), UNIX_EPOCH);
136    /// ```
137    ///
138    /// [`UNIX_EPOCH`]: std::time::UNIX_EPOCH
139    #[inline]
140    pub fn epoch(&self) -> SystemTime {
141        self.internal.epoch()
142    }
143
144    /// Generate a new unique snowflake id.
145    pub fn generate<T>(&self) -> T
146    where
147        T: Snowflake,
148    {
149        self.internal.generate(std::hint::spin_loop)
150    }
151}
152
153impl From<Builder> for Generator {
154    fn from(builder: Builder) -> Self {
155        let internal = InternalGenerator {
156            components: AtomicU64::new(Components::new(builder.instance as u64).to_bits()),
157            epoch: builder.epoch,
158        };
159
160        Self { internal }
161    }
162}
163
164#[derive(Debug)]
165struct InternalGenerator<T>
166where
167    T: Time,
168{
169    components: AtomicU64,
170    epoch: T,
171}
172
173impl<T> InternalGenerator<T>
174where
175    T: Time,
176{
177    #[cfg(not(loom))]
178    #[inline]
179    const fn new_unchecked(instance: u16) -> Self
180    where
181        T: DefaultTime,
182    {
183        Self {
184            components: AtomicU64::new(Components::new(instance as u64).to_bits()),
185            epoch: T::DEFAULT,
186        }
187    }
188
189    // AtomicU64 is not const, we have to choose a different code path
190    // than the regular `new_unchecked`.
191    #[cfg(loom)]
192    #[inline]
193    fn new_unchecked(instance: u16) -> Self
194    where
195        T: DefaultTime,
196    {
197        Self {
198            components: AtomicU64::new(Components::new(instance as u64).to_bits()),
199            epoch: T::DEFAULT,
200        }
201    }
202
203    #[cfg(loom)]
204    #[inline]
205    fn new_unchecked_with_epoch(instance: u16, epoch: T) -> Self {
206        Self {
207            components: AtomicU64::new(Components::new(instance as u64).to_bits()),
208            epoch,
209        }
210    }
211
212    #[inline]
213    fn instance(&self) -> u16 {
214        let bits = self.components.load(Ordering::Relaxed);
215        Components::from_bits(bits).instance() as u16
216    }
217
218    #[inline]
219    fn epoch(&self) -> T
220    where
221        T: Copy,
222    {
223        self.epoch
224    }
225
226    fn generate<S, F>(&self, tick_wait: F) -> S
227    where
228        S: Snowflake,
229        F: Fn(),
230    {
231        use std::cmp;
232
233        // Since `fetch_update` doesn't return a result,
234        // we store the result in this mutable variable.
235        // Note that using MaybeUninit is not necessary
236        // as the compiler is smart enough to elide the Option for us.
237        let mut id = None;
238
239        let _ = self
240            .components
241            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |bits| {
242                let mut components = Components::from_bits(bits);
243                let mut now = self.epoch.elapsed().as_millis() as u64;
244                let instance = components.instance();
245                match now.cmp(&components.timestamp()) {
246                    cmp::Ordering::Less => {
247                        panic!("Clock has moved backwards! This is not supported");
248                    }
249                    cmp::Ordering::Greater => {
250                        components.reset_sequence();
251                        components.set_timestamp(now);
252                        id = Some(S::from_parts(now, instance, 0));
253                        Some(components.to_bits())
254                    }
255                    cmp::Ordering::Equal => {
256                        let sequence = components.take_sequence();
257                        if sequence == 0 {
258                            now = Self::wait_until_next_millisecond(&self.epoch, now, &tick_wait);
259                        }
260                        components.set_timestamp(now);
261                        id = Some(S::from_parts(now, instance, sequence));
262                        Some(components.to_bits())
263                    }
264                }
265            });
266        id.expect("ID should have been set within the fetch_update closure.")
267    }
268
269    fn wait_until_next_millisecond<F>(epoch: &T, last_millisecond: u64, tick_wait: F) -> u64
270    where
271        F: Fn(),
272    {
273        loop {
274            let now = epoch.elapsed().as_millis() as u64;
275            if now > last_millisecond {
276                return now;
277            }
278            tick_wait();
279        }
280    }
281}
282
283#[cfg(all(test, not(loom)))]
284mod tests {
285    use std::sync::mpsc;
286    use std::thread;
287
288    use super::Generator;
289    use crate::Snowflake;
290
291    #[test]
292    fn test_generate() {
293        const INSTANCE: u64 = 0;
294
295        let mut last_id = None;
296        let generator = Generator::new_unchecked(INSTANCE as u16);
297
298        for _ in 0..10_000 {
299            let id: u64 = generator.generate();
300            assert_eq!(id.instance(), INSTANCE);
301            assert!(
302                last_id < Some(id),
303                "expected {:?} to be less than {:?}",
304                last_id,
305                Some(id)
306            );
307            last_id = Some(id);
308        }
309    }
310
311    #[test]
312    fn test_generate_threads() {
313        const INSTANCE: u64 = 0;
314        const THREADS: usize = 4;
315
316        static GENERATOR: Generator = Generator::new_unchecked(INSTANCE as u16);
317
318        let (tx, rx) = mpsc::sync_channel::<Vec<u64>>(THREADS);
319
320        for _ in 0..THREADS {
321            let tx = tx.clone();
322            thread::spawn(move || {
323                let mut ids = Vec::with_capacity(10_000);
324
325                for _ in 0..10_000 {
326                    ids.push(GENERATOR.generate());
327                }
328
329                tx.send(ids).unwrap();
330            });
331        }
332
333        let mut ids = Vec::with_capacity(10_000 * THREADS);
334        for _ in 0..THREADS {
335            ids.extend(rx.recv().unwrap());
336        }
337
338        for (index, id) in ids.iter().enumerate() {
339            for (index2, id2) in ids.iter().enumerate() {
340                if index != index2 && id == id2 {
341                    panic!(
342                        "Found duplicate id {} (SEQ {}, INS {}, TS {}) at index {} and {}",
343                        id,
344                        id.sequence(),
345                        id.instance(),
346                        id.timestamp(),
347                        index,
348                        index2
349                    );
350                }
351            }
352        }
353    }
354
355    #[test]
356    fn test_generate_no_duplicates() {
357        let generator = Generator::new_unchecked(0);
358        let mut ids: Vec<u64> = Vec::with_capacity(10_000);
359
360        for _ in 0..ids.capacity() {
361            ids.push(generator.generate());
362        }
363
364        for (index, id) in ids.iter().enumerate() {
365            for (index2, id2) in ids.iter().enumerate() {
366                if index != index2 && id == id2 {
367                    panic!(
368                        "Found duplicate id {} (SEQ {}, INS {}, TS {}) at index {} and {}",
369                        id,
370                        id.sequence(),
371                        id.instance(),
372                        id.timestamp(),
373                        index,
374                        index2
375                    );
376                }
377            }
378        }
379    }
380
381    // #[test]
382    // fn test_generator_clone() {
383    //     let orig = Generator::new_unchecked(0);
384
385    //     let cloned = orig.clone();
386
387    //     let orig_bits = Components::from_bits(orig.components.load(Ordering::Relaxed));
388    //     let cloned_bits = Components::from_bits(cloned.components.load(Ordering::Relaxed));
389
390    //     assert_eq!(orig_bits, cloned_bits);
391    // }
392}
393
394#[cfg(all(test, loom))]
395mod loom_tests {
396    use std::sync::{mpsc, Arc, Mutex};
397    use std::time::Duration;
398
399    use loom::thread;
400
401    use super::InternalGenerator;
402    use crate::loom::Ordering;
403    use crate::time::{DefaultTime, Time};
404    use crate::Components;
405
406    #[derive(Copy, Clone, Debug)]
407    pub struct TestTime(u64);
408
409    impl Time for TestTime {
410        fn elapsed(&self) -> Duration {
411            Duration::from_millis(self.0)
412        }
413    }
414
415    impl DefaultTime for TestTime {
416        const DEFAULT: Self = Self(0);
417    }
418
419    fn panic_on_wait() {
420        panic!("unexpected wait");
421    }
422
423    const THREADS: usize = 2;
424
425    #[test]
426    fn no_duplicates_no_wrap() {
427        loom::model(|| {
428            let generator = Arc::new(InternalGenerator::<TestTime>::new_unchecked(0));
429            let (tx, rx) = mpsc::channel();
430
431            let threads: Vec<_> = (0..THREADS)
432                .map(|_| {
433                    let generator = generator.clone();
434                    let tx = tx.clone();
435
436                    thread::spawn(move || {
437                        let id: u64 = generator.generate(panic_on_wait);
438                        tx.send(id).unwrap();
439                    })
440                })
441                .collect();
442
443            for th in threads {
444                th.join().unwrap();
445            }
446
447            let id1 = rx.recv().unwrap();
448            let id2 = rx.recv().unwrap();
449            assert_ne!(id1, id2);
450        });
451    }
452
453    #[test]
454    fn no_duplicates_wrap() {
455        static DEFAULT_TIME: Mutex<u64> = Mutex::new(0);
456
457        // FIXME: Using raw pointers here is not optimal, but
458        // required to get DEFAULT working. Maybe
459        #[derive(Clone, Debug)]
460        struct TestTimeWrap(Arc<Mutex<u64>>);
461
462        impl Time for TestTimeWrap {
463            fn elapsed(&self) -> Duration {
464                let ms = self.0.lock().unwrap();
465                Duration::from_millis(*ms)
466            }
467        }
468
469        loom::model(|| {
470            let ticked = Arc::new(Mutex::new(false));
471            let time = Arc::new(Mutex::new(0));
472
473            let mut generator =
474                InternalGenerator::new_unchecked_with_epoch(0, TestTimeWrap(time.clone()));
475
476            // Move the generator into a wrapping state.
477            generator.components.with_mut(|bits| {
478                let mut components = Components::from_bits(*bits);
479                components.set_sequence(4095);
480                *bits = components.to_bits();
481            });
482
483            let generator = Arc::new(generator);
484            let (tx, rx) = mpsc::channel();
485
486            let threads: Vec<_> = (0..THREADS)
487                .map(|_| {
488                    let ticked = ticked.clone();
489                    let time = time.clone();
490
491                    let generator = generator.clone();
492                    let tx = tx.clone();
493
494                    thread::spawn(move || {
495                        let id: u64 = generator.generate(move || {
496                            let mut ticked = ticked.lock().unwrap();
497
498                            if !*ticked {
499                                *ticked = true;
500
501                                let mut ms = time.lock().unwrap();
502                                *ms += 1;
503                            }
504                        });
505
506                        tx.send(id).unwrap();
507                    })
508                })
509                .collect();
510
511            for th in threads {
512                th.join().unwrap();
513            }
514
515            let id1 = rx.recv().unwrap();
516            let id2 = rx.recv().unwrap();
517            assert_ne!(id1, id2);
518        });
519    }
520}