1use std::time::SystemTime;
18
19use crate::builder::Builder;
20use crate::loom::{AtomicU64, Ordering};
21use crate::time::{DefaultTime, Time};
22use crate::{const_panic_new, Components, Snowflake, INSTANCE_MAX};
23
24#[derive(Debug)]
47pub struct Generator {
48 internal: InternalGenerator<SystemTime>,
49}
50
51impl Generator {
52 #[cfg(not(loom))]
58 #[inline]
59 pub const fn new(instance: u16) -> Self {
60 match Self::new_checked(instance) {
61 Some(this) => this,
62 None => const_panic_new(),
63 }
64 }
65
66 #[cfg(not(loom))]
69 #[inline]
70 pub const fn new_checked(instance: u16) -> Option<Self> {
71 if instance > INSTANCE_MAX {
72 None
73 } else {
74 Some(Self::new_unchecked(instance))
75 }
76 }
77
78 #[cfg(not(loom))]
84 #[inline]
85 pub const fn new_unchecked(instance: u16) -> Self {
86 Self {
87 internal: InternalGenerator::new_unchecked(instance),
88 }
89 }
90
91 #[inline]
106 pub const fn builder() -> Builder {
107 Builder::new()
108 }
109
110 #[inline]
122 pub fn instance(&self) -> u16 {
123 self.internal.instance()
124 }
125
126 #[inline]
140 pub fn epoch(&self) -> SystemTime {
141 self.internal.epoch()
142 }
143
144 pub fn generate<T>(&self) -> T
146 where
147 T: Snowflake,
148 {
149 self.internal.generate(std::hint::spin_loop)
150 }
151}
152
153impl From<Builder> for Generator {
154 fn from(builder: Builder) -> Self {
155 let internal = InternalGenerator {
156 components: AtomicU64::new(Components::new(builder.instance as u64).to_bits()),
157 epoch: builder.epoch,
158 };
159
160 Self { internal }
161 }
162}
163
164#[derive(Debug)]
165struct InternalGenerator<T>
166where
167 T: Time,
168{
169 components: AtomicU64,
170 epoch: T,
171}
172
173impl<T> InternalGenerator<T>
174where
175 T: Time,
176{
177 #[cfg(not(loom))]
178 #[inline]
179 const fn new_unchecked(instance: u16) -> Self
180 where
181 T: DefaultTime,
182 {
183 Self {
184 components: AtomicU64::new(Components::new(instance as u64).to_bits()),
185 epoch: T::DEFAULT,
186 }
187 }
188
189 #[cfg(loom)]
192 #[inline]
193 fn new_unchecked(instance: u16) -> Self
194 where
195 T: DefaultTime,
196 {
197 Self {
198 components: AtomicU64::new(Components::new(instance as u64).to_bits()),
199 epoch: T::DEFAULT,
200 }
201 }
202
203 #[cfg(loom)]
204 #[inline]
205 fn new_unchecked_with_epoch(instance: u16, epoch: T) -> Self {
206 Self {
207 components: AtomicU64::new(Components::new(instance as u64).to_bits()),
208 epoch,
209 }
210 }
211
212 #[inline]
213 fn instance(&self) -> u16 {
214 let bits = self.components.load(Ordering::Relaxed);
215 Components::from_bits(bits).instance() as u16
216 }
217
218 #[inline]
219 fn epoch(&self) -> T
220 where
221 T: Copy,
222 {
223 self.epoch
224 }
225
226 fn generate<S, F>(&self, tick_wait: F) -> S
227 where
228 S: Snowflake,
229 F: Fn(),
230 {
231 use std::cmp;
232
233 let mut id = None;
238
239 let _ = self
240 .components
241 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |bits| {
242 let mut components = Components::from_bits(bits);
243 let mut now = self.epoch.elapsed().as_millis() as u64;
244 let instance = components.instance();
245 match now.cmp(&components.timestamp()) {
246 cmp::Ordering::Less => {
247 panic!("Clock has moved backwards! This is not supported");
248 }
249 cmp::Ordering::Greater => {
250 components.reset_sequence();
251 components.set_timestamp(now);
252 id = Some(S::from_parts(now, instance, 0));
253 Some(components.to_bits())
254 }
255 cmp::Ordering::Equal => {
256 let sequence = components.take_sequence();
257 if sequence == 0 {
258 now = Self::wait_until_next_millisecond(&self.epoch, now, &tick_wait);
259 }
260 components.set_timestamp(now);
261 id = Some(S::from_parts(now, instance, sequence));
262 Some(components.to_bits())
263 }
264 }
265 });
266 id.expect("ID should have been set within the fetch_update closure.")
267 }
268
269 fn wait_until_next_millisecond<F>(epoch: &T, last_millisecond: u64, tick_wait: F) -> u64
270 where
271 F: Fn(),
272 {
273 loop {
274 let now = epoch.elapsed().as_millis() as u64;
275 if now > last_millisecond {
276 return now;
277 }
278 tick_wait();
279 }
280 }
281}
282
283#[cfg(all(test, not(loom)))]
284mod tests {
285 use std::sync::mpsc;
286 use std::thread;
287
288 use super::Generator;
289 use crate::Snowflake;
290
291 #[test]
292 fn test_generate() {
293 const INSTANCE: u64 = 0;
294
295 let mut last_id = None;
296 let generator = Generator::new_unchecked(INSTANCE as u16);
297
298 for _ in 0..10_000 {
299 let id: u64 = generator.generate();
300 assert_eq!(id.instance(), INSTANCE);
301 assert!(
302 last_id < Some(id),
303 "expected {:?} to be less than {:?}",
304 last_id,
305 Some(id)
306 );
307 last_id = Some(id);
308 }
309 }
310
311 #[test]
312 fn test_generate_threads() {
313 const INSTANCE: u64 = 0;
314 const THREADS: usize = 4;
315
316 static GENERATOR: Generator = Generator::new_unchecked(INSTANCE as u16);
317
318 let (tx, rx) = mpsc::sync_channel::<Vec<u64>>(THREADS);
319
320 for _ in 0..THREADS {
321 let tx = tx.clone();
322 thread::spawn(move || {
323 let mut ids = Vec::with_capacity(10_000);
324
325 for _ in 0..10_000 {
326 ids.push(GENERATOR.generate());
327 }
328
329 tx.send(ids).unwrap();
330 });
331 }
332
333 let mut ids = Vec::with_capacity(10_000 * THREADS);
334 for _ in 0..THREADS {
335 ids.extend(rx.recv().unwrap());
336 }
337
338 for (index, id) in ids.iter().enumerate() {
339 for (index2, id2) in ids.iter().enumerate() {
340 if index != index2 && id == id2 {
341 panic!(
342 "Found duplicate id {} (SEQ {}, INS {}, TS {}) at index {} and {}",
343 id,
344 id.sequence(),
345 id.instance(),
346 id.timestamp(),
347 index,
348 index2
349 );
350 }
351 }
352 }
353 }
354
355 #[test]
356 fn test_generate_no_duplicates() {
357 let generator = Generator::new_unchecked(0);
358 let mut ids: Vec<u64> = Vec::with_capacity(10_000);
359
360 for _ in 0..ids.capacity() {
361 ids.push(generator.generate());
362 }
363
364 for (index, id) in ids.iter().enumerate() {
365 for (index2, id2) in ids.iter().enumerate() {
366 if index != index2 && id == id2 {
367 panic!(
368 "Found duplicate id {} (SEQ {}, INS {}, TS {}) at index {} and {}",
369 id,
370 id.sequence(),
371 id.instance(),
372 id.timestamp(),
373 index,
374 index2
375 );
376 }
377 }
378 }
379 }
380
381 }
393
394#[cfg(all(test, loom))]
395mod loom_tests {
396 use std::sync::{mpsc, Arc, Mutex};
397 use std::time::Duration;
398
399 use loom::thread;
400
401 use super::InternalGenerator;
402 use crate::loom::Ordering;
403 use crate::time::{DefaultTime, Time};
404 use crate::Components;
405
406 #[derive(Copy, Clone, Debug)]
407 pub struct TestTime(u64);
408
409 impl Time for TestTime {
410 fn elapsed(&self) -> Duration {
411 Duration::from_millis(self.0)
412 }
413 }
414
415 impl DefaultTime for TestTime {
416 const DEFAULT: Self = Self(0);
417 }
418
419 fn panic_on_wait() {
420 panic!("unexpected wait");
421 }
422
423 const THREADS: usize = 2;
424
425 #[test]
426 fn no_duplicates_no_wrap() {
427 loom::model(|| {
428 let generator = Arc::new(InternalGenerator::<TestTime>::new_unchecked(0));
429 let (tx, rx) = mpsc::channel();
430
431 let threads: Vec<_> = (0..THREADS)
432 .map(|_| {
433 let generator = generator.clone();
434 let tx = tx.clone();
435
436 thread::spawn(move || {
437 let id: u64 = generator.generate(panic_on_wait);
438 tx.send(id).unwrap();
439 })
440 })
441 .collect();
442
443 for th in threads {
444 th.join().unwrap();
445 }
446
447 let id1 = rx.recv().unwrap();
448 let id2 = rx.recv().unwrap();
449 assert_ne!(id1, id2);
450 });
451 }
452
453 #[test]
454 fn no_duplicates_wrap() {
455 static DEFAULT_TIME: Mutex<u64> = Mutex::new(0);
456
457 #[derive(Clone, Debug)]
460 struct TestTimeWrap(Arc<Mutex<u64>>);
461
462 impl Time for TestTimeWrap {
463 fn elapsed(&self) -> Duration {
464 let ms = self.0.lock().unwrap();
465 Duration::from_millis(*ms)
466 }
467 }
468
469 loom::model(|| {
470 let ticked = Arc::new(Mutex::new(false));
471 let time = Arc::new(Mutex::new(0));
472
473 let mut generator =
474 InternalGenerator::new_unchecked_with_epoch(0, TestTimeWrap(time.clone()));
475
476 generator.components.with_mut(|bits| {
478 let mut components = Components::from_bits(*bits);
479 components.set_sequence(4095);
480 *bits = components.to_bits();
481 });
482
483 let generator = Arc::new(generator);
484 let (tx, rx) = mpsc::channel();
485
486 let threads: Vec<_> = (0..THREADS)
487 .map(|_| {
488 let ticked = ticked.clone();
489 let time = time.clone();
490
491 let generator = generator.clone();
492 let tx = tx.clone();
493
494 thread::spawn(move || {
495 let id: u64 = generator.generate(move || {
496 let mut ticked = ticked.lock().unwrap();
497
498 if !*ticked {
499 *ticked = true;
500
501 let mut ms = time.lock().unwrap();
502 *ms += 1;
503 }
504 });
505
506 tx.send(id).unwrap();
507 })
508 })
509 .collect();
510
511 for th in threads {
512 th.join().unwrap();
513 }
514
515 let id1 = rx.recv().unwrap();
516 let id2 = rx.recv().unwrap();
517 assert_ne!(id1, id2);
518 });
519 }
520}