1#![doc = include_str!("../README.md")]
9
10use std::{
11 ops::Deref,
12 sync::{
13 atomic::{AtomicU64, Ordering},
14 Arc,
15 },
16 time::Duration,
17};
18
19use futures::executor;
20use futures_timer::Delay;
21use rand::RngCore;
22
23pub mod provider;
24
25pub trait TimeProvider {
26 fn timestamp(&self) -> u64;
28}
29
30#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
90#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
91pub struct Snowflake(i64);
92
93pub type SnowflakeId = i64;
95
96impl From<Snowflake> for i64 {
97 fn from(value: Snowflake) -> Self {
98 value.0
99 }
100}
101
102impl Deref for Snowflake {
103 type Target = i64;
104
105 fn deref(&self) -> &Self::Target {
106 &self.0
107 }
108}
109
110impl AsRef<i64> for Snowflake {
111 fn as_ref(&self) -> &i64 {
112 self
113 }
114}
115
116#[derive(Debug)]
117pub struct SnowflakeConfiguration {
118 pub identifier: u64,
124}
125
126impl SnowflakeConfiguration {
127 pub fn with_identifier(identifier: u64) -> Self {
128 Self { identifier }
129 }
130}
131
132impl Default for SnowflakeConfiguration {
133 fn default() -> Self {
134 Self {
135 identifier: rand::thread_rng().next_u64(),
136 }
137 }
138}
139
140unsafe impl Send for SnowflakeConfiguration {}
141
142fn fill_timestamp(sid: u64, timestamp: u64) -> u64 {
144 const MASK: u64 = (1u64 << 41) - 1;
145 let truncated_timestamp = timestamp & MASK; let filled = truncated_timestamp << 22;
147 (sid & !(MASK << 22)) | filled
148}
149
150fn fill_identifier(sid: u64, identifier: u64) -> u64 {
152 const MASK: u64 = (1u64 << 10) - 1; let truncated_identifier = identifier & MASK; let filled = truncated_identifier << 12;
155 (sid & !(MASK << 12)) | filled
156}
157
158fn fill_sequence(sid: u64, sequence: u64) -> u64 {
160 const MASK: u64 = (1u64 << 12) - 1;
161 let truncated_sequence = sequence & MASK; (sid & !MASK) | truncated_sequence
165}
166
167pub fn filling<T0, T1, T2>(dest: u64, timestamp: T0, identifier: T1, sequence: T2) -> u64
168where
169 T0: Into<u64>,
170 T1: Into<u64>,
171 T2: Into<u64>,
172{
173 let sid = fill_timestamp(dest, timestamp.into());
174 let sid = fill_identifier(sid, identifier.into());
175 fill_sequence(sid, sequence.into())
176}
177
178#[derive(Debug, Default)]
186pub struct SnowflakeGenerator {
187 timestamp_sequence: AtomicU64,
188 cfg: SnowflakeConfiguration,
189}
190const MAX_SEQUENCE: u16 = 0xFFF; impl SnowflakeGenerator {
193 pub fn with_cfg(cfg: SnowflakeConfiguration) -> Self {
194 Self {
195 cfg,
196 timestamp_sequence: AtomicU64::new(0),
197 }
198 }
199
200 pub async fn assign<T>(&self, provider: &T) -> Snowflake
202 where
203 T: TimeProvider + Sync + Send,
204 {
205 loop {
206 let timestamp = provider.timestamp();
207 let current = self.timestamp_sequence.load(Ordering::Relaxed);
208 let current_timestamp = current >> 16;
209 let current_sequence = (current & 0xFFFF) as u16;
210
211 match current_timestamp.cmp(×tamp) {
212 std::cmp::Ordering::Less => {
213 let new_value = timestamp << 16;
215
216 if self
217 .timestamp_sequence
218 .compare_exchange(current, new_value, Ordering::SeqCst, Ordering::SeqCst)
219 .is_ok()
220 {
221 let sid = fill_timestamp(0, timestamp);
222 let sid = fill_identifier(sid, self.cfg.identifier);
223 let sid = fill_sequence(sid, 0);
224 return Snowflake(sid as i64);
225 }
226 }
227 std::cmp::Ordering::Equal => {
228 if current_sequence >= MAX_SEQUENCE {
229 Delay::new(Duration::from_millis(1)).await;
231 continue;
232 }
233
234 let new_sequence = current_sequence + 1;
235 let new_value = (timestamp << 16) | new_sequence as u64;
236
237 if self
238 .timestamp_sequence
239 .compare_exchange(current, new_value, Ordering::SeqCst, Ordering::SeqCst)
240 .is_ok()
241 {
242 let sid = fill_timestamp(0, timestamp);
243 let sid = fill_identifier(sid, self.cfg.identifier);
244 let sid = fill_sequence(sid, new_sequence as u64);
245 return Snowflake(sid as i64);
246 }
247 }
248 std::cmp::Ordering::Greater => Delay::new(Duration::from_millis(1)).await,
249 };
250 }
251 }
252
253 #[cfg(feature = "sync")]
255 pub fn assign_sync<T>(&self, provider: &T) -> Snowflake
256 where
257 T: TimeProvider + Sync + Send,
258 {
259 executor::block_on(self.assign(provider))
260 }
261}
262
263#[derive(Debug)]
275pub struct PersistedSnowflakeGenerator<T> {
276 generator: Arc<SnowflakeGenerator>,
277 provider: Arc<T>,
278}
279
280impl<T> PersistedSnowflakeGenerator<T>
281where
282 T: TimeProvider + Send + Sync,
283{
284 pub fn new(generator: Arc<SnowflakeGenerator>, provider: Arc<T>) -> Self {
292 Self {
293 generator,
294 provider,
295 }
296 }
297
298 pub async fn assign(&self) -> Snowflake {
300 self.generator.assign(self.provider.as_ref()).await
301 }
302
303 #[cfg(feature = "sync")]
305 pub fn assign_sync(&self) -> Snowflake {
306 self.generator.assign_sync(self.provider.as_ref())
307 }
308}
309
310impl<T> Clone for PersistedSnowflakeGenerator<T> {
311 fn clone(&self) -> Self {
312 Self {
313 generator: self.generator.clone(),
314 provider: self.provider.clone(),
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use std::{collections::HashSet, sync::Arc};
322
323 use parking_lot::RwLock;
324 use provider::{StdProvider, STD_PROVIDER};
325
326 use super::*;
327
328 #[test]
329 fn test_fill_timestamp() {
330 let sid = 0u64;
332 let timestamp = 0b101010;
333 let expected = 42 << 22;
334 let result = fill_timestamp(sid, timestamp);
335 assert_eq!(result, expected);
336
337 let sid = 0u64;
339 let timestamp = (1u64 << 42) - 1;
340 let expected = ((1u64 << 41) - 1) << 22;
341 let result = fill_timestamp(sid, timestamp);
342 assert_eq!(result, expected);
343 }
344
345 #[test]
346 fn test_fill_identifier() {
347 let sid = 0u64;
349 let identifier = 0b110101;
350 let expected = 53 << 12;
351 let result = fill_identifier(sid, identifier);
352 assert_eq!(result, expected);
353
354 let sid = 0u64;
356 let identifier = (1u64 << 11) - 1;
357 let expected = ((1u64 << 10) - 1) << 12;
358 let result = fill_identifier(sid, identifier);
359 assert_eq!(result, expected);
360 }
361
362 #[test]
363 fn test_fill_sequence() {
364 let sid = 0u64;
366 let sequence = 0b1001;
367 let expected = 9;
368 let result = fill_sequence(sid, sequence);
369 assert_eq!(result, expected);
370
371 let sid = 0u64;
373 let sequence = (1u64 << 13) - 1;
374 let expected = (1u64 << 12) - 1;
375 let result = fill_sequence(sid, sequence);
376 assert_eq!(result, expected);
377 }
378
379 #[test]
380 fn test_filling() {
381 let sid = 0u64;
382 let timestamp = 0b10101010101010101010101010101010101010101u64;
383 let identifier = 0b110101u64;
384 let sequence = 0b1001u64;
385
386 let expected = (timestamp << 22) | (identifier << 12) | sequence;
387
388 let result = filling(sid, timestamp, identifier, sequence);
389 assert_eq!(result, expected);
390 }
391
392 #[tokio::test]
393 async fn test_assign() {
394 let generator = Arc::new(SnowflakeGenerator::default());
395
396 for _ in 0..1024 {
397 generator.assign(&provider::TIME_CRATE_PROVIDER).await;
398 }
399 }
400
401 #[tokio::test]
402 async fn test_assign_multithread() {
403 let generator = Arc::new(SnowflakeGenerator::default());
404
405 let mut handles = vec![];
406 let id_set = Arc::new(RwLock::new(HashSet::new()));
407
408 for _ in 0..1000 {
409 let generator = Arc::clone(&generator);
410 let id_set = Arc::clone(&id_set);
411 let handle = tokio::spawn(async move {
412 for _ in 0..1000 {
413 let id = generator.assign(&STD_PROVIDER).await;
414 let mut set = id_set.write();
415 if set.contains(&id) {
416 panic!("Duplicate `Snowflake` generated!");
417 }
418 set.insert(id);
419 }
420 });
421 handles.push(handle);
422 }
423
424 futures::future::join_all(handles).await;
425
426 assert_eq!(
427 id_set.read().len(),
428 1000 * 1000,
429 "Some `Snowflake` were lost!"
430 );
431 }
432
433 #[test]
434 fn test_persists() {
435 let binding = Arc::new(SnowflakeGenerator::default());
436 let persist = PersistedSnowflakeGenerator::new(binding.clone(), Arc::new(StdProvider));
437
438 let snowflakes = (0..1000)
439 .map(|_| persist.assign_sync())
440 .collect::<HashSet<_>>();
441
442 assert_eq!(snowflakes.len(), 1000);
443 }
444
445 #[tokio::test]
446 async fn test_persists_multithread() {
447 let binding = Arc::new(SnowflakeGenerator::default());
448
449 let persist = Arc::new(PersistedSnowflakeGenerator::new(
450 binding.clone(),
451 Arc::new(StdProvider),
452 ));
453
454 let tasks = (0..1000).map(|_| {
455 let persist = persist.clone();
456 tokio::spawn(async move { persist.assign().await })
457 });
458 let snowflakes = futures::future::join_all(tasks).await;
459
460 assert_eq!(snowflakes.len(), 1000);
461 }
462}