1pub mod arena;
2
3use std::cell::UnsafeCell;
4use std::fmt::Debug;
5use std::iter::Iterator;
6use std::ptr;
7use std::ptr::{null_mut, NonNull};
8use std::sync::atomic::{AtomicPtr, Ordering};
9use std::sync::{Arc, Mutex};
10use rand::{Rng, SeedableRng};
11use rand::rngs::StdRng;
12use crate::arena::Arena;
13
14const MAX_HEIGHT: usize = 12;
15const K_BRANCHING: usize = 4;
16
17pub struct Node<K> {
18 key: K,
19 next: Vec<AtomicPtr<Node<K>>>,
20}
21
22impl<K> Node<K> {
23 fn new(key: K, height: usize) -> Self {
24 let mut next = Vec::with_capacity(height);
25 for _ in 0..height {
26 next.push(AtomicPtr::new(ptr::null_mut()));
27 }
28 Node { key, next }
29 }
30
31 fn next(&self, level: usize) -> *mut Node<K> {
32 self.next[level].load(Ordering::Acquire)
33 }
34
35 fn set_next(&self, level: usize, node: *mut Node<K>) {
36 self.next[level].store(node, Ordering::Release);
37 }
38
39 fn no_barrier_next(&self, level: usize) -> *mut Node<K> {
40 self.next[level].load(Ordering::Relaxed)
41 }
42
43 fn no_barrier_set_next(&self, level: usize, node: *mut Node<K>) {
44 self.next[level].store(node, Ordering::Relaxed);
45 }
46}
47
48
49pub struct SkipListIterator<'a, K: Ord + Debug + Default> {
50 node: *mut Node<K>,
51 list: &'a SkipListImpl<K>,
52}
53
54impl<'a, K: Ord + Debug + Default> SkipListIterator<'a, K> {
55 pub fn new(list: &'a SkipListImpl<K>) -> Self {
56 SkipListIterator { node: null_mut(), list }
57 }
58
59 pub fn valid(&self) -> bool {
60 !self.node.is_null()
61 }
62
63 pub fn key(&self) -> &K {
64 assert!(self.valid());
65 unsafe { &self.node.as_ref().unwrap().key }
66 }
67
68 pub fn next(&mut self) {
69 assert!(self.valid());
70 self.node = unsafe { self.node.as_ref().unwrap().next(0) };
71 }
72
73 pub fn prev(&mut self) {
74 assert!(self.valid());
75 self.node = self.list.find_less_than(self.key()).as_ptr();
76 if self.node == self.list.head.as_ptr() {
77 self.node = null_mut();
78 }
79 }
80
81 pub fn seek(&mut self, target: &K) {
82 self.node = self.list.find_greater_or_equal(target, &mut None);
83 }
84
85 pub fn seek_to_first(&mut self) {
86 self.node = unsafe { self.list.head.as_ref().next(0) };
87 }
88
89 pub fn seek_to_last(&mut self) {
90 self.node = self.list.find_last().as_ptr();
91 if self.node == self.list.head.as_ptr() {
92 self.node = null_mut();
93 }
94 }
95}
96
97pub struct SkipListImpl<K: Ord + Debug + Default> {
98 head: NonNull<Node<K>>,
99 max_height: std::sync::atomic::AtomicUsize,
100 rnd: StdRng,
101 arena: Arena,
102}
103
104unsafe impl<K: Ord + Debug + Default + Send> Send for SkipListImpl<K> {}
105unsafe impl<K: Ord + Debug + Default + Sync> Sync for SkipListImpl<K> {}
106
107impl<K: Ord + Debug + Default> SkipListImpl<K> {
108 pub fn new(mut arena: Arena) -> SkipListImpl<K> {
109 let head = unsafe {
110 let layout = std::alloc::Layout::new::<Node<K>>();
111 let ptr = arena.allocate(layout.size()) as *mut Node<K>;
112 ptr::write(ptr, Node::new(K::default(), MAX_HEIGHT));
113 NonNull::new_unchecked(ptr)
114 };
115 let mut s = SkipListImpl {
116 head,
117 max_height: std::sync::atomic::AtomicUsize::new(1),
118 rnd: StdRng::seed_from_u64(0xdeadbeef),
119 arena,
120 };
121
122 for i in 0..MAX_HEIGHT {
123 unsafe {
124 s.head.as_mut().set_next(i, ptr::null_mut());
125 }
126 }
127 s
128 }
129
130 pub unsafe fn key_is_after_node(&self, key: &K, node: *const Node<K>) -> bool {
134 unsafe {
135 node.as_ref().map(|n| &n.key)
136 .map_or(false, |node_key| node_key < key)
137 }
138 }
139
140 pub fn find_greater_or_equal(&self, key: &K, prev: &mut Option<&mut Vec<*mut Node<K>>>) -> *mut Node<K> {
141 let mut x = self.head.as_ptr();
142 let mut level = self.get_max_height() - 1;
143 loop {
144 let next = unsafe { x.as_ref().unwrap().next(level) };
145 if unsafe { self.key_is_after_node(key, next) } {
146 x = next;
147 } else {
148 if let Some(prev_node) = prev {
149 prev_node[level] = x;
150 }
151 if level == 0 {
152 return next;
153 } else {
154 level -= 1;
155 }
156 }
157 }
158 }
159
160 pub fn find_less_than(&self, key: &K) -> NonNull<Node<K>> {
161 let mut x = self.head;
162 let mut level = self.get_max_height() - 1;
163 loop {
164 let next = unsafe { x.as_ref().next(level) };
165 if next.is_null() || unsafe { next.as_ref().unwrap().key >= *key } {
166 if level == 0 {
167 return x;
168 } else {
169 level -= 1;
170 }
171 } else {
172 x = unsafe { NonNull::new_unchecked(next) };
173 }
174 }
175 }
176
177 pub fn find_last(&self) -> NonNull<Node<K>> {
178 let mut x = self.head;
179 let mut level = self.get_max_height() - 1;
180 loop {
181 let next = unsafe { x.as_ref().next(level) };
182 if next.is_null() {
183 if level == 0 {
184 return x;
185 } else {
186 level -= 1;
187 }
188 } else {
189 x = unsafe { NonNull::new_unchecked(next) };
190 }
191 }
192 }
193
194 pub fn contains(&self, key: &K) -> bool {
195 let x = self.find_greater_or_equal(key, &mut None);
196 let x_ref = unsafe { x.as_ref() };
197 match x_ref {
198 None => false,
199 Some(x_ref) => x_ref.key == *key,
200 }
201 }
202
203 pub fn random_height(&mut self) -> usize {
204 let mut height = 1;
205 while height < MAX_HEIGHT && self.rnd.gen_range(0..K_BRANCHING) == 0 {
206 height += 1;
207 }
208 height
209 }
210
211 #[inline]
212 fn get_max_height(&self) -> usize {
213 self.max_height.load(Ordering::Relaxed)
214 }
215
216 pub fn insert(&mut self, key: K) {
217 let mut prev = vec![ptr::null_mut(); MAX_HEIGHT];
218 let x = self.find_greater_or_equal(&key, &mut Some(&mut prev));
219 assert!(x.is_null() || unsafe { x.as_ref().unwrap().key != key });
220
221 let height = self.random_height();
222 if height > self.get_max_height() {
223 let i = self.get_max_height();
224 for p in prev.iter_mut().take(height).skip(i) {
225 *p = self.head.as_ptr();
226 }
227 self.max_height.store(height, Ordering::Relaxed);
228 }
229
230 let new_node = unsafe {
231 let layout = std::alloc::Layout::new::<Node<K>>();
232 let ptr = self.arena.allocate(layout.size()) as *mut Node<K>;
233 ptr::write(ptr, Node::new(key, height));
234 &mut *ptr
235 };
236 for (i, p) in prev.iter().enumerate().take(height) {
237 unsafe {
238 new_node.no_barrier_set_next(i, p.as_ref().unwrap().no_barrier_next(i));
239 p.as_ref().unwrap().set_next(i, new_node);
240 }
241 }
242 }
243}
244
245struct SkipList<K: Ord + Debug + Default> {
246 skip_list: Arc<UnsafeCell<SkipListImpl<K>>>,
247 write_lock: Mutex<()>,
248}
249
250unsafe impl<K: Ord + Debug + Default + Send + Sync> Send for SkipList<K> {}
251unsafe impl<K: Ord + Debug + Default + Send + Sync> Sync for SkipList<K> {}
252
253impl<K: Ord + Debug + Default> SkipList<K> {
254 pub fn new(arena: Arena) -> Self {
255 SkipList {
256 skip_list: Arc::new(UnsafeCell::new(SkipListImpl::new(arena))),
257 write_lock: Mutex::new(()),
258 }
259 }
260
261 pub fn insert(&self, key: K) {
262 let _guard = self.write_lock.lock().unwrap();
263 unsafe {
264 (*self.skip_list.get()).insert(key);
265 }
266 }
267
268 pub fn contains(&self, key: &K) -> bool {
269 unsafe {
270 (*self.skip_list.get()).contains(key)
271 }
272 }
273
274 pub fn iter(&self) -> SkipListIterator<K> {
275 unsafe {
276 SkipListIterator::new(&*self.skip_list.get())
277 }
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use std::sync::{Arc, Condvar, Mutex};
284 use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
285 use std::thread;
286 use std::time::Duration;
287 use rand::{random, Rng, SeedableRng};
288 use crate::arena::Arena;
289 use super::{SkipListImpl, SkipListIterator, SkipList};
290 #[test]
291 fn test_empty() {
292 let arena = Arena::new();
293 let list = super::SkipListImpl::new(arena);
294 assert_eq!(list.contains(&10), false);
295
296 let mut iter = SkipListIterator::new(&list);
297 assert_eq!(iter.valid(), false);
298 iter.seek_to_first();
299 assert_eq!(iter.valid(), false);
300 iter.seek(&100);
301 assert_eq!(iter.valid(), false);
302 iter.seek_to_last();
303 assert_eq!(iter.valid(), false);
304 }
305
306 #[test]
307 fn insert_and_lookup() {
308 let n = 2000;
309 let r = 5000;
310 let mut rnd = rand::thread_rng();
311 let mut keys = std::collections::btree_set::BTreeSet::new();
312 let arena = Arena::new();
313 let mut list = SkipList::new(arena);
314
315 for _ in 0..r {
316 let key = rnd.gen_range(0..r);
317 if keys.insert(key) {
318 list.insert(key);
319 continue;
320 }
321 }
322
323 for i in 0..n {
324 if list.contains(&i) {
325 assert!(keys.contains(&i));
326 } else {
327 assert!(!keys.contains(&i));
328 }
329 }
330
331 {
332 let mut iter = list.iter();
333 iter.seek_to_first();
334 for i in 0..r {
335 if keys.contains(&i) {
336 assert_eq!(iter.valid(), true);
337 assert_eq!(iter.key(), &i);
338 iter.next();
339 }
340 }
341 assert_eq!(iter.valid(), false);
342 }
343
344 {
345 let mut iter = list.iter();
346 assert!(!iter.valid());
347
348 iter.seek(&0);
349 assert!(iter.valid());
350 assert_eq!(keys.iter().next().unwrap(), iter.key());
351
352 iter.seek_to_first();
353 assert!(iter.valid());
354 assert_eq!(keys.iter().next().unwrap(), iter.key());
355
356 iter.seek_to_last();
357 assert!(iter.valid());
358 assert_eq!(keys.iter().rev().next().unwrap(), iter.key());
359 }
360
361 for i in 0..r {
363 let mut iter = list.iter();
364 iter.seek(&i);
365 let mut model_iter = keys.range(i..);
366
367 for _ in 0..3 {
368 let v = model_iter.next();
369 if v.is_none() {
370 assert!(!iter.valid());
371 break;
372 } else {
373 assert!(iter.valid());
374 assert_eq!(v.unwrap(), iter.key());
375 iter.next();
376 }
377 }
378 }
379
380 {
382 let mut iter = list.iter();
383 iter.seek_to_last();
384
385 for k in keys.iter().rev() {
386 assert!(iter.valid());
387 assert_eq!(k, iter.key());
388 iter.prev();
389 }
390
391 assert!(!iter.valid());
392 }
393 }
394
395 const K: u64 = 4;
396
397 type Key = u64;
398
399 fn key(key: Key) -> u64 { key >> 40 }
400 fn gen(key: Key) -> u64 { (key >> 8) & 0xffffffff }
401 fn hash(key: Key) -> u64 { key & 0xff }
402
403 fn hash_numbers(k: u64, g: u64) -> u64 {
404 use std::collections::hash_map::DefaultHasher;
405 use std::hash::{Hash, Hasher};
406 let mut hasher = DefaultHasher::new();
407 k.hash(&mut hasher);
408 g.hash(&mut hasher);
409 hasher.finish()
410 }
411
412 fn make_key(k: u64, g: u64) -> Key {
413 assert!(k <= K);
414 assert!(g <= 0xffffffff);
415 (k << 40) | (g << 8) | (hash_numbers(k, g) & 0xff)
416 }
417
418 fn is_valid_key(k: Key) -> bool {
419 hash(k) == (hash_numbers(key(k), gen(k)) & 0xff)
420 }
421
422 fn random_target(rng: &mut impl Rng) -> Key {
423 match rng.gen_range(0..10) {
424 0 => make_key(0, 0),
425 1 => make_key(K, 0),
426 _ => make_key(rng.gen_range(0..K), 0),
427 }
428 }
429
430 struct State {
431 generation: Vec<AtomicU64>,
432 }
433
434 impl State {
435 fn new() -> Self {
436 let generation = (0..K).map(|_| AtomicU64::new(0)).collect();
437 State { generation }
438 }
439
440 fn set(&self, k: usize, v: u64) {
441 self.generation[k].store(v, Ordering::Release);
442 }
443
444 fn get(&self, k: usize) -> u64 {
445 self.generation[k].load(Ordering::Acquire)
446 }
447 }
448
449 struct ConcurrentTest {
450 current: State,
451 list: SkipListImpl<Key>,
452 }
453
454 impl ConcurrentTest {
455 fn new() -> Self {
456 let arena = Arena::new();
457 ConcurrentTest {
458 current: State::new(),
459 list: SkipListImpl::new(arena),
460 }
461 }
462
463 fn write_step(&mut self, rng: &mut impl Rng) {
464 let k = rng.gen_range(0..K) as usize;
465 let g = self.current.get(k) + 1;
466 let key = make_key(k as u64, g);
467 self.list.insert(key);
468 self.current.set(k, g);
469 }
470
471 fn read_step(&self, rng: &mut impl Rng) {
472 let initial_state = State::new();
473 for k in 0..K as usize {
474 initial_state.set(k, self.current.get(k));
475 }
476
477 let mut pos = random_target(rng);
478 let mut iter = SkipListIterator::new(&self.list);
479 iter.seek(&pos);
480
481 loop {
482 let current = if iter.valid() {
483 *iter.key()
484 } else {
485 make_key(K, 0)
486 };
487
488 assert!(is_valid_key(current));
489 assert!(pos <= current, "should not go backwards");
490
491 while pos < current {
492 assert!(key(pos) < K);
493
494 if gen(pos) != 0 {
495 assert!(gen(pos) > initial_state.get(key(pos) as usize) as u64);
496 }
497
498 if key(pos) < key(current) {
499 pos = make_key(key(pos) + 1, 0);
500 } else {
501 pos = make_key(key(pos), gen(pos) + 1);
502 }
503 }
504
505 if !iter.valid() {
506 break;
507 }
508
509 if rng.gen_bool(0.5) {
510 iter.next();
511 pos = make_key(key(pos), gen(pos) + 1);
512 } else {
513 let new_target = random_target(rng);
514 if new_target > pos {
515 pos = new_target;
516 iter.seek(&new_target);
517 }
518 }
519 }
520 }
521 }
522
523 #[test]
524 fn concurrent_without_threads() {
525 let mut test = ConcurrentTest::new();
526 let mut rng = rand::thread_rng();
527 for _ in 0..10000 {
528 test.read_step(&mut rng);
529 test.write_step(&mut rng);
530 }
531 }
532
533 struct TestState {
534 t: Mutex<ConcurrentTest>,
535 seed: u64,
536 quit_flag: AtomicBool,
537 state: Mutex<ReaderState>,
538 state_cv: Condvar,
539 }
540
541 #[derive(PartialEq, Eq)]
542 enum ReaderState {
543 Starting,
544 Running,
545 Done,
546 }
547
548 impl TestState {
549 fn new(seed: u64) -> Self {
550 TestState {
551 t: Mutex::new(ConcurrentTest::new()),
552 seed,
553 quit_flag: AtomicBool::new(false),
554 state: Mutex::new(ReaderState::Starting),
555 state_cv: Condvar::new(),
556 }
557 }
558
559 fn wait(&self, s: ReaderState) {
560 let mut state = self.state.lock().unwrap();
561 while *state != s {
562 state = self.state_cv.wait(state).unwrap();
563 }
564 }
565
566 fn change(&self, s: ReaderState) {
567 let mut state = self.state.lock().unwrap();
568 *state = s;
569 self.state_cv.notify_all();
570 }
571 }
572
573 fn concurrent_reader(state: Arc<TestState>) {
574 let mut rng = rand::rngs::StdRng::seed_from_u64(state.seed);
575 state.change(ReaderState::Running);
576 while !state.quit_flag.load(Ordering::Acquire) {
577 state.t.lock().unwrap().read_step(&mut rng);
578 }
579 state.change(ReaderState::Done);
580 }
581
582 fn run_concurrent(run: u64) {
583 let seed = random::<u64>() + (run * 100);
584 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
585 let n = 1000;
586 let k_size = 1000;
587
588 for i in 0..n {
589 if i % 100 == 0 {
590 println!("Run {} of {}", i, n);
591 }
592 let state = Arc::new(TestState::new(seed + 1));
593 let state_clone = state.clone();
594 thread::spawn(move || concurrent_reader(state_clone));
595
596 state.wait(ReaderState::Running);
597 for _ in 0..k_size {
598 state.t.lock().unwrap().write_step(&mut rng);
599 }
600 state.quit_flag.store(true, Ordering::Release);
601 state.wait(ReaderState::Done);
602 }
603 }
604
605 #[test]
606 fn concurrent_1() { run_concurrent(1); }
607 #[test]
608 fn concurrent_2() { run_concurrent(2); }
609 #[test]
610 fn concurrent_3() { run_concurrent(3); }
611 #[test]
612 fn concurrent_4() { run_concurrent(4); }
613 #[test]
614 fn concurrent_5() { run_concurrent(5); }
615
616 #[test]
617 fn test_concurrent_write() {
618 let arena = Arena::new();
619 let skiplist = Arc::new(SkipList::new(arena));
620 let mut write_handles = vec![];
621 for i in 0..5 {
622 let skiplist_clone = Arc::clone(&skiplist);
623 let handle = thread::spawn(move || {
624 let start = i * 100;
625 let end = start + 100;
626 for k in start..end {
627 skiplist_clone.insert(k);
628 println!("Thread {} inserted: {}", i, k);
629 }
630 });
631 write_handles.push(handle);
632 }
633
634 let mut read_handles = vec![];
635 for i in 0..3 {
636 let skiplist_clone = Arc::clone(&skiplist);
637 let handle = thread::spawn(move || {
638 let mut rng = rand::thread_rng();
639 let start = i * 100;
640 let end = start + 100;
641 for _ in start..end {
642 let key = rng.gen_range(0..1000);
643 let contains = skiplist_clone.contains(&key);
644 println!("Thread {} queried: {}, result: {}", i, key, contains);
645 thread::sleep(Duration::from_millis(1));
646 }
647 });
648 read_handles.push(handle);
649 }
650
651 for handle in write_handles {
652 handle.join().unwrap();
653 }
654
655 for handle in read_handles {
656 handle.join().unwrap();
657 }
658
659 let mut iter = skiplist.iter();
660 iter.seek_to_first();
661 println!("Final SkipList contents:");
662 while iter.valid() {
663 println!("{:?}", iter.key());
664 iter.next();
665 }
666 }
667}