skiplist_rust/
lib.rs

1pub mod arena;
2
3use std::cell::UnsafeCell;
4use std::fmt::Debug;
5use std::iter::Iterator;
6use std::ptr;
7use std::ptr::{null_mut, NonNull};
8use std::sync::atomic::{AtomicPtr, Ordering};
9use std::sync::{Arc, Mutex};
10use rand::{Rng, SeedableRng};
11use rand::rngs::StdRng;
12use crate::arena::Arena;
13
14const MAX_HEIGHT: usize = 12;
15const K_BRANCHING: usize = 4;
16
17pub struct Node<K> {
18    key: K,
19    next: Vec<AtomicPtr<Node<K>>>,
20}
21
22impl<K> Node<K> {
23    fn new(key: K, height: usize) -> Self {
24        let mut next = Vec::with_capacity(height);
25        for _ in 0..height {
26            next.push(AtomicPtr::new(ptr::null_mut()));
27        }
28        Node { key, next }
29    }
30
31    fn next(&self, level: usize) -> *mut Node<K> {
32        self.next[level].load(Ordering::Acquire)
33    }
34
35    fn set_next(&self, level: usize, node: *mut Node<K>) {
36        self.next[level].store(node, Ordering::Release);
37    }
38
39    fn no_barrier_next(&self, level: usize) -> *mut Node<K> {
40        self.next[level].load(Ordering::Relaxed)
41    }
42
43    fn no_barrier_set_next(&self, level: usize, node: *mut Node<K>) {
44        self.next[level].store(node, Ordering::Relaxed);
45    }
46}
47
48
49pub struct SkipListIterator<'a, K: Ord + Debug + Default> {
50    node: *mut Node<K>,
51    list: &'a SkipListImpl<K>,
52}
53
54impl<'a, K: Ord + Debug + Default> SkipListIterator<'a, K> {
55    pub fn new(list: &'a SkipListImpl<K>) -> Self {
56        SkipListIterator { node: null_mut(), list }
57    }
58
59    pub fn valid(&self) -> bool {
60        !self.node.is_null()
61    }
62
63    pub fn key(&self) -> &K {
64        assert!(self.valid());
65        unsafe { &self.node.as_ref().unwrap().key }
66    }
67
68    pub fn next(&mut self) {
69        assert!(self.valid());
70        self.node = unsafe { self.node.as_ref().unwrap().next(0) };
71    }
72
73    pub fn prev(&mut self) {
74        assert!(self.valid());
75        self.node = self.list.find_less_than(self.key()).as_ptr();
76        if self.node == self.list.head.as_ptr() {
77            self.node = null_mut();
78        }
79    }
80
81    pub fn seek(&mut self, target: &K) {
82        self.node = self.list.find_greater_or_equal(target, &mut None);
83    }
84
85    pub fn seek_to_first(&mut self) {
86        self.node = unsafe { self.list.head.as_ref().next(0) };
87    }
88
89    pub fn seek_to_last(&mut self) {
90        self.node = self.list.find_last().as_ptr();
91        if self.node == self.list.head.as_ptr() {
92            self.node = null_mut();
93        }
94    }
95}
96
97pub struct SkipListImpl<K: Ord + Debug + Default> {
98    head: NonNull<Node<K>>,
99    max_height: std::sync::atomic::AtomicUsize,
100    rnd: StdRng,
101    arena: Arena,
102}
103
104unsafe impl<K: Ord + Debug + Default + Send> Send for SkipListImpl<K> {}
105unsafe impl<K: Ord + Debug + Default + Sync> Sync for SkipListImpl<K> {}
106
107impl<K: Ord + Debug + Default> SkipListImpl<K> {
108    pub fn new(mut arena: Arena) -> SkipListImpl<K> {
109        let head = unsafe {
110            let layout = std::alloc::Layout::new::<Node<K>>();
111            let ptr = arena.allocate(layout.size()) as *mut Node<K>;
112            ptr::write(ptr, Node::new(K::default(), MAX_HEIGHT));
113            NonNull::new_unchecked(ptr)
114        };
115        let mut s = SkipListImpl {
116            head,
117            max_height: std::sync::atomic::AtomicUsize::new(1),
118            rnd: StdRng::seed_from_u64(0xdeadbeef),
119            arena,
120        };
121
122        for i in 0..MAX_HEIGHT {
123            unsafe {
124                s.head.as_mut().set_next(i, ptr::null_mut());
125            }
126        }
127        s
128    }
129
130    /// # Safety
131    ///
132    /// This function should not be called before data ready.
133    pub unsafe fn key_is_after_node(&self, key: &K, node: *const Node<K>) -> bool {
134        unsafe {
135            node.as_ref().map(|n| &n.key)
136                .map_or(false, |node_key| node_key < key)
137        }
138    }
139
140    pub fn find_greater_or_equal(&self, key: &K, prev: &mut Option<&mut Vec<*mut Node<K>>>) -> *mut Node<K> {
141        let mut x = self.head.as_ptr();
142        let mut level = self.get_max_height() - 1;
143        loop {
144            let next = unsafe { x.as_ref().unwrap().next(level) };
145            if unsafe { self.key_is_after_node(key, next) } {
146                x = next;
147            } else {
148                if let Some(prev_node) = prev {
149                    prev_node[level] = x;
150                }
151                if level == 0 {
152                    return next;
153                } else {
154                    level -= 1;
155                }
156            }
157        }
158    }
159
160    pub fn find_less_than(&self, key: &K) -> NonNull<Node<K>> {
161        let mut x = self.head;
162        let mut level = self.get_max_height() - 1;
163        loop {
164            let next = unsafe { x.as_ref().next(level) };
165            if next.is_null() || unsafe { next.as_ref().unwrap().key >= *key } {
166                if level == 0 {
167                    return x;
168                } else {
169                    level -= 1;
170                }
171            } else {
172                x = unsafe { NonNull::new_unchecked(next) };
173            }
174        }
175    }
176
177    pub fn find_last(&self) -> NonNull<Node<K>> {
178        let mut x = self.head;
179        let mut level = self.get_max_height() - 1;
180        loop {
181            let next = unsafe { x.as_ref().next(level) };
182            if next.is_null() {
183                if level == 0 {
184                    return x;
185                } else {
186                    level -= 1;
187                }
188            } else {
189                x = unsafe { NonNull::new_unchecked(next) };
190            }
191        }
192    }
193
194    pub fn contains(&self, key: &K) -> bool {
195        let x = self.find_greater_or_equal(key, &mut None);
196        let x_ref = unsafe { x.as_ref() };
197        match x_ref {
198            None => false,
199            Some(x_ref) => x_ref.key == *key,
200        }
201    }
202
203    pub fn random_height(&mut self) -> usize {
204        let mut height = 1;
205        while height < MAX_HEIGHT && self.rnd.gen_range(0..K_BRANCHING) == 0 {
206            height += 1;
207        }
208        height
209    }
210
211    #[inline]
212    fn get_max_height(&self) -> usize {
213        self.max_height.load(Ordering::Relaxed)
214    }
215
216    pub fn insert(&mut self, key: K) {
217        let mut prev = vec![ptr::null_mut(); MAX_HEIGHT];
218        let x = self.find_greater_or_equal(&key, &mut Some(&mut prev));
219        assert!(x.is_null() || unsafe { x.as_ref().unwrap().key != key });
220
221        let height = self.random_height();
222        if height > self.get_max_height() {
223            let i = self.get_max_height();
224            for p in prev.iter_mut().take(height).skip(i) {
225                *p = self.head.as_ptr();
226            }
227            self.max_height.store(height, Ordering::Relaxed);
228        }
229
230        let new_node = unsafe {
231            let layout = std::alloc::Layout::new::<Node<K>>();
232            let ptr = self.arena.allocate(layout.size()) as *mut Node<K>;
233            ptr::write(ptr, Node::new(key, height));
234            &mut *ptr
235        };
236        for (i, p) in prev.iter().enumerate().take(height) {
237            unsafe {
238                new_node.no_barrier_set_next(i, p.as_ref().unwrap().no_barrier_next(i));
239                p.as_ref().unwrap().set_next(i, new_node);
240            }
241        }
242    }
243}
244
245struct SkipList<K: Ord + Debug + Default> {
246    skip_list: Arc<UnsafeCell<SkipListImpl<K>>>,
247    write_lock: Mutex<()>,
248}
249
250unsafe impl<K: Ord + Debug + Default + Send + Sync> Send for SkipList<K> {}
251unsafe impl<K: Ord + Debug + Default + Send + Sync> Sync for SkipList<K> {}
252
253impl<K: Ord + Debug + Default> SkipList<K> {
254    pub fn new(arena: Arena) -> Self {
255        SkipList {
256            skip_list: Arc::new(UnsafeCell::new(SkipListImpl::new(arena))),
257            write_lock: Mutex::new(()),
258        }
259    }
260
261    pub fn insert(&self, key: K) {
262        let _guard = self.write_lock.lock().unwrap();
263        unsafe {
264            (*self.skip_list.get()).insert(key);
265        }
266    }
267
268    pub fn contains(&self, key: &K) -> bool {
269        unsafe {
270            (*self.skip_list.get()).contains(key)
271        }
272    }
273
274    pub fn iter(&self) -> SkipListIterator<K> {
275        unsafe {
276            SkipListIterator::new(&*self.skip_list.get())
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use std::sync::{Arc, Condvar, Mutex};
284    use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
285    use std::thread;
286    use std::time::Duration;
287    use rand::{random, Rng, SeedableRng};
288    use crate::arena::Arena;
289    use super::{SkipListImpl, SkipListIterator, SkipList};
290    #[test]
291    fn test_empty() {
292        let arena = Arena::new();
293        let list = super::SkipListImpl::new(arena);
294        assert_eq!(list.contains(&10), false);
295
296        let mut iter = SkipListIterator::new(&list);
297        assert_eq!(iter.valid(), false);
298        iter.seek_to_first();
299        assert_eq!(iter.valid(), false);
300        iter.seek(&100);
301        assert_eq!(iter.valid(), false);
302        iter.seek_to_last();
303        assert_eq!(iter.valid(), false);
304    }
305
306    #[test]
307    fn insert_and_lookup() {
308        let n = 2000;
309        let r = 5000;
310        let mut rnd = rand::thread_rng();
311        let mut keys = std::collections::btree_set::BTreeSet::new();
312        let arena = Arena::new();
313        let mut list = SkipList::new(arena);
314
315        for _ in 0..r {
316            let key = rnd.gen_range(0..r);
317            if keys.insert(key) {
318                list.insert(key);
319                continue;
320            }
321        }
322
323        for i in 0..n {
324            if list.contains(&i) {
325                assert!(keys.contains(&i));
326            } else {
327                assert!(!keys.contains(&i));
328            }
329        }
330
331        {
332            let mut iter = list.iter();
333            iter.seek_to_first();
334            for i in 0..r {
335                if keys.contains(&i) {
336                    assert_eq!(iter.valid(), true);
337                    assert_eq!(iter.key(), &i);
338                    iter.next();
339                }
340            }
341            assert_eq!(iter.valid(), false);
342        }
343
344        {
345            let mut iter = list.iter();
346            assert!(!iter.valid());
347
348            iter.seek(&0);
349            assert!(iter.valid());
350            assert_eq!(keys.iter().next().unwrap(), iter.key());
351
352            iter.seek_to_first();
353            assert!(iter.valid());
354            assert_eq!(keys.iter().next().unwrap(), iter.key());
355
356            iter.seek_to_last();
357            assert!(iter.valid());
358            assert_eq!(keys.iter().rev().next().unwrap(), iter.key());
359        }
360
361        // Forward iteration test
362        for i in 0..r {
363            let mut iter = list.iter();
364            iter.seek(&i);
365            let mut model_iter = keys.range(i..);
366
367            for _ in 0..3 {
368                let v = model_iter.next();
369                if v.is_none() {
370                    assert!(!iter.valid());
371                    break;
372                } else {
373                    assert!(iter.valid());
374                    assert_eq!(v.unwrap(), iter.key());
375                    iter.next();
376                }
377            }
378        }
379
380        // Backward iteration test
381        {
382            let mut iter = list.iter();
383            iter.seek_to_last();
384
385            for k in keys.iter().rev() {
386                assert!(iter.valid());
387                assert_eq!(k, iter.key());
388                iter.prev();
389            }
390
391            assert!(!iter.valid());
392        }
393    }
394
395    const K: u64 = 4;
396
397    type Key = u64;
398
399    fn key(key: Key) -> u64 { key >> 40 }
400    fn gen(key: Key) -> u64 { (key >> 8) & 0xffffffff }
401    fn hash(key: Key) -> u64 { key & 0xff }
402
403    fn hash_numbers(k: u64, g: u64) -> u64 {
404        use std::collections::hash_map::DefaultHasher;
405        use std::hash::{Hash, Hasher};
406        let mut hasher = DefaultHasher::new();
407        k.hash(&mut hasher);
408        g.hash(&mut hasher);
409        hasher.finish()
410    }
411
412    fn make_key(k: u64, g: u64) -> Key {
413        assert!(k <= K);
414        assert!(g <= 0xffffffff);
415        (k << 40) | (g << 8) | (hash_numbers(k, g) & 0xff)
416    }
417
418    fn is_valid_key(k: Key) -> bool {
419        hash(k) == (hash_numbers(key(k), gen(k)) & 0xff)
420    }
421
422    fn random_target(rng: &mut impl Rng) -> Key {
423        match rng.gen_range(0..10) {
424            0 => make_key(0, 0),
425            1 => make_key(K, 0),
426            _ => make_key(rng.gen_range(0..K), 0),
427        }
428    }
429
430    struct State {
431        generation: Vec<AtomicU64>,
432    }
433
434    impl State {
435        fn new() -> Self {
436            let generation = (0..K).map(|_| AtomicU64::new(0)).collect();
437            State { generation }
438        }
439
440        fn set(&self, k: usize, v: u64) {
441            self.generation[k].store(v, Ordering::Release);
442        }
443
444        fn get(&self, k: usize) -> u64 {
445            self.generation[k].load(Ordering::Acquire)
446        }
447    }
448
449    struct ConcurrentTest {
450        current: State,
451        list: SkipListImpl<Key>,
452    }
453
454    impl ConcurrentTest {
455        fn new() -> Self {
456            let arena = Arena::new();
457            ConcurrentTest {
458                current: State::new(),
459                list: SkipListImpl::new(arena),
460            }
461        }
462
463        fn write_step(&mut self, rng: &mut impl Rng) {
464            let k = rng.gen_range(0..K) as usize;
465            let g = self.current.get(k) + 1;
466            let key = make_key(k as u64, g);
467            self.list.insert(key);
468            self.current.set(k, g);
469        }
470
471        fn read_step(&self, rng: &mut impl Rng) {
472            let initial_state = State::new();
473            for k in 0..K as usize {
474                initial_state.set(k, self.current.get(k));
475            }
476
477            let mut pos = random_target(rng);
478            let mut iter = SkipListIterator::new(&self.list);
479            iter.seek(&pos);
480
481            loop {
482                let current = if iter.valid() {
483                    *iter.key()
484                } else {
485                    make_key(K, 0)
486                };
487
488                assert!(is_valid_key(current));
489                assert!(pos <= current, "should not go backwards");
490
491                while pos < current {
492                    assert!(key(pos) < K);
493
494                    if gen(pos) != 0 {
495                        assert!(gen(pos) > initial_state.get(key(pos) as usize) as u64);
496                    }
497
498                    if key(pos) < key(current) {
499                        pos = make_key(key(pos) + 1, 0);
500                    } else {
501                        pos = make_key(key(pos), gen(pos) + 1);
502                    }
503                }
504
505                if !iter.valid() {
506                    break;
507                }
508
509                if rng.gen_bool(0.5) {
510                    iter.next();
511                    pos = make_key(key(pos), gen(pos) + 1);
512                } else {
513                    let new_target = random_target(rng);
514                    if new_target > pos {
515                        pos = new_target;
516                        iter.seek(&new_target);
517                    }
518                }
519            }
520        }
521    }
522
523    #[test]
524    fn concurrent_without_threads() {
525        let mut test = ConcurrentTest::new();
526        let mut rng = rand::thread_rng();
527        for _ in 0..10000 {
528            test.read_step(&mut rng);
529            test.write_step(&mut rng);
530        }
531    }
532
533    struct TestState {
534        t: Mutex<ConcurrentTest>,
535        seed: u64,
536        quit_flag: AtomicBool,
537        state: Mutex<ReaderState>,
538        state_cv: Condvar,
539    }
540
541    #[derive(PartialEq, Eq)]
542    enum ReaderState {
543        Starting,
544        Running,
545        Done,
546    }
547
548    impl TestState {
549        fn new(seed: u64) -> Self {
550            TestState {
551                t: Mutex::new(ConcurrentTest::new()),
552                seed,
553                quit_flag: AtomicBool::new(false),
554                state: Mutex::new(ReaderState::Starting),
555                state_cv: Condvar::new(),
556            }
557        }
558
559        fn wait(&self, s: ReaderState) {
560            let mut state = self.state.lock().unwrap();
561            while *state != s {
562                state = self.state_cv.wait(state).unwrap();
563            }
564        }
565
566        fn change(&self, s: ReaderState) {
567            let mut state = self.state.lock().unwrap();
568            *state = s;
569            self.state_cv.notify_all();
570        }
571    }
572
573    fn concurrent_reader(state: Arc<TestState>) {
574        let mut rng = rand::rngs::StdRng::seed_from_u64(state.seed);
575        state.change(ReaderState::Running);
576        while !state.quit_flag.load(Ordering::Acquire) {
577            state.t.lock().unwrap().read_step(&mut rng);
578        }
579        state.change(ReaderState::Done);
580    }
581
582    fn run_concurrent(run: u64) {
583        let seed = random::<u64>() + (run * 100);
584        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
585        let n = 1000;
586        let k_size = 1000;
587
588        for i in 0..n {
589            if i % 100 == 0 {
590                println!("Run {} of {}", i, n);
591            }
592            let state = Arc::new(TestState::new(seed + 1));
593            let state_clone = state.clone();
594            thread::spawn(move || concurrent_reader(state_clone));
595
596            state.wait(ReaderState::Running);
597            for _ in 0..k_size {
598                state.t.lock().unwrap().write_step(&mut rng);
599            }
600            state.quit_flag.store(true, Ordering::Release);
601            state.wait(ReaderState::Done);
602        }
603    }
604
605    #[test]
606    fn concurrent_1() { run_concurrent(1); }
607    #[test]
608    fn concurrent_2() { run_concurrent(2); }
609    #[test]
610    fn concurrent_3() { run_concurrent(3); }
611    #[test]
612    fn concurrent_4() { run_concurrent(4); }
613    #[test]
614    fn concurrent_5() { run_concurrent(5); }
615
616    #[test]
617    fn test_concurrent_write() {
618        let arena = Arena::new();
619        let skiplist = Arc::new(SkipList::new(arena));
620        let mut write_handles = vec![];
621        for i in 0..5 {
622            let skiplist_clone = Arc::clone(&skiplist);
623            let handle = thread::spawn(move || {
624                let start = i * 100;
625                let end = start + 100;
626                for k in start..end {
627                    skiplist_clone.insert(k);
628                    println!("Thread {} inserted: {}", i, k);
629                }
630            });
631            write_handles.push(handle);
632        }
633
634        let mut read_handles = vec![];
635        for i in 0..3 {
636            let skiplist_clone = Arc::clone(&skiplist);
637            let handle = thread::spawn(move || {
638                let mut rng = rand::thread_rng();
639                let start = i * 100;
640                let end = start + 100;
641                for _ in start..end {
642                    let key = rng.gen_range(0..1000);
643                    let contains =  skiplist_clone.contains(&key);
644                    println!("Thread {} queried: {}, result: {}", i, key, contains);
645                    thread::sleep(Duration::from_millis(1));
646                }
647            });
648            read_handles.push(handle);
649        }
650
651        for handle in write_handles {
652            handle.join().unwrap();
653        }
654
655        for handle in read_handles {
656            handle.join().unwrap();
657        }
658
659        let mut iter = skiplist.iter();
660        iter.seek_to_first();
661        println!("Final SkipList contents:");
662        while iter.valid() {
663            println!("{:?}", iter.key());
664            iter.next();
665        }
666    }
667}