striped_lock/
std.rs

1// Copyright (c) 2024 Mek101
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at https://mozilla.org/MPL/2.0/.
6
7use std::{
8    hash::{BuildHasher, BuildHasherDefault, DefaultHasher, Hash},
9    marker::PhantomData,
10    num::NonZeroUsize,
11    sync::{Mutex, MutexGuard},
12};
13
14use crate::batch::{KeyBatch, MAX_BATCH_KEYS};
15
16/// The inner mutex is poisoned.
17pub struct StripedPoisonError;
18
19pub struct StripedLockGuard<'l> {
20    _guard: MutexGuard<'l, ()>,
21}
22
23pub struct StripedBatchLockGuard<'l> {
24    _guards: [Option<MutexGuard<'l, ()>>; MAX_BATCH_KEYS],
25}
26
27pub struct StripedLock<K, H = BuildHasherDefault<DefaultHasher>>
28where
29    K: Hash,
30    H: BuildHasher,
31{
32    hasher_builder: H,
33    locks: Box<[Mutex<()>]>,
34    phantom: PhantomData<K>,
35}
36
37impl<K> StripedLock<K, BuildHasherDefault<DefaultHasher>>
38where
39    K: Hash,
40{
41    /// Create a new [`StripedLock`] instance with rust's default hasher.
42    ///
43    /// # Arguments
44    ///
45    /// * `locks` - The number of inner locks. Increase to reduce collisions.
46    pub fn new(locks: NonZeroUsize) -> Self {
47        Self::with_hasher(BuildHasherDefault::default(), locks)
48    }
49}
50
51impl<K, H> StripedLock<K, H>
52where
53    K: Hash,
54    H: BuildHasher,
55{
56    /// Create a new [`StripedLock`] instance.
57    ///
58    /// # Arguments
59    ///
60    /// * `locks` - The number of inner locks. Increase to reduce collisions.
61    /// * `hasher_builder` - The factory of hashers.
62    pub fn with_hasher(hasher_builder: H, locks: NonZeroUsize) -> Self {
63        let locks = (0..locks.get())
64            .map(|_| Mutex::new(()))
65            .collect::<Vec<_>>()
66            .into_boxed_slice();
67
68        Self {
69            hasher_builder,
70            locks,
71            phantom: PhantomData::default(),
72        }
73    }
74
75    /// Lock on the key.
76    /// Use `lock_batch` if you want to lock on multiple keys.
77    ///
78    /// # Arguments
79    ///
80    /// * `key` - The key to lock on.
81    pub fn lock(&self, key: K) -> Result<StripedLockGuard, StripedPoisonError> {
82        fn inner(locks: &[Mutex<()>], key: u64) -> Result<StripedLockGuard, StripedPoisonError> {
83            let idx = (key % locks.len() as u64) as usize;
84            let lock = &locks[idx];
85
86            match lock.lock() {
87                Ok(guard) => Ok(StripedLockGuard { _guard: guard }),
88                Err(_) => Err(StripedPoisonError),
89            }
90        }
91
92        let hash = self.hasher_builder.hash_one(key);
93        inner(&self.locks, hash)
94    }
95
96    /// Lock on the key.
97    /// Use `lock_batch` if you want to lock on multiple keys.
98    ///
99    /// # Arguments
100    ///
101    /// * `batch` - The batch of keys to lock on. May be up to 4.
102    ///
103    /// # Example
104    ///
105    /// ```
106    /// # use std::hash::{BuildHasherDefault, DefaultHasher};
107    /// # use std::num::NonZeroUsize;
108    /// # use striped_lock::std::StripedLock;
109    /// let sl: StripedLock<char> = StripedLock::new(NonZeroUsize::new(4).unwrap());
110    /// sl.lock_batch(('a', 'b', 'c', 'd'));
111    /// ```
112    pub fn lock_batch<B>(&self, batch: B) -> Result<StripedBatchLockGuard, StripedPoisonError>
113    where
114        B: KeyBatch<K, H>,
115    {
116        fn inner<'l>(
117            locks: &'l [Mutex<()>],
118            batch: &mut [u64],
119        ) -> Result<StripedBatchLockGuard<'l>, StripedPoisonError> {
120            const ARRAY_REPEAT_VALUE: Option<MutexGuard<()>> = None;
121
122            assert!(batch.len() > 0);
123            assert!(batch.len() <= MAX_BATCH_KEYS);
124
125            // "Normalize".
126            for key in batch.iter_mut() {
127                *key %= locks.len() as u64;
128            }
129
130            // Sort such that we always obtain the locks in the same order.
131            batch.sort_unstable();
132
133            let mut guards = [ARRAY_REPEAT_VALUE; MAX_BATCH_KEYS];
134
135            guards[0] = Some(
136                locks[batch[0] as usize]
137                    .lock()
138                    .map_err(|_| StripedPoisonError)?,
139            );
140
141            for i in 1..batch.len() {
142                // Skip duplicates since locks are not re-entrant.
143                if batch[i] != batch[i - 1] {
144                    guards[i] = Some(
145                        locks[batch[i] as usize]
146                            .lock()
147                            .map_err(|_| StripedPoisonError)?,
148                    );
149                }
150            }
151
152            Ok(StripedBatchLockGuard { _guards: guards })
153        }
154
155        let (mut arr, filled) = batch.into_hash_array(&self.hasher_builder);
156        let batch = &mut arr[..filled];
157        inner(&self.locks, batch)
158    }
159
160    /// Check if the mutex at the given key is poisoned.
161    pub fn is_poisoned(&self, key: K) -> bool {
162        fn inner(locks: &[Mutex<()>], key: u64) -> bool {
163            let idx = (key % locks.len() as u64) as usize;
164            let lock = &locks[idx];
165            lock.is_poisoned()
166        }
167
168        let key = self.hasher_builder.hash_one(key);
169        inner(&self.locks, key)
170    }
171
172    // Remove the poisoned status from the mutex at the given key.
173    pub fn clear_poison(&self, key: K) {
174        fn inner(locks: &[Mutex<()>], key: u64) {
175            let idx = (key % locks.len() as u64) as usize;
176            let lock = &locks[idx];
177            lock.clear_poison();
178        }
179
180        let key = self.hasher_builder.hash_one(key);
181        inner(&self.locks, key);
182    }
183}