small_db/
types.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard},
4};
5
6use crate::{error::SmallError, utils::HandyRwLock};
7
8// Type alias, not a new type, cannot define methods on it
9pub type Pod<T> = Arc<RwLock<T>>;
10
11// Define a new type, can define methods on it, but different with the
12// underlying type, so the original methods cannot be used
13// pub struct Pod<T>(Arc<RwLock<T>>);
14
15pub type ResultPod<T> = Result<Pod<T>, SmallError>;
16pub type SmallResult = Result<(), SmallError>;
17
18pub struct ConcurrentHashMap<K, V> {
19    map: Arc<RwLock<HashMap<K, V>>>,
20}
21
22impl<K, V> ConcurrentHashMap<K, V> {
23    pub fn new() -> Self {
24        Self {
25            map: Arc::new(RwLock::new(HashMap::new())),
26        }
27    }
28
29    pub fn get_inner(&self) -> Arc<RwLock<HashMap<K, V>>> {
30        self.map.clone()
31    }
32
33    pub fn get_inner_rl(&self) -> RwLockReadGuard<HashMap<K, V>> {
34        self.map.rl()
35    }
36
37    pub fn get_inner_wl(&self) -> RwLockWriteGuard<HashMap<K, V>> {
38        self.map.wl()
39    }
40
41    pub fn get_or_insert(
42        &self,
43        key: &K,
44        value_gen_fn: impl Fn(&K) -> Result<V, SmallError>,
45    ) -> Result<V, SmallError>
46    where
47        K: std::cmp::Eq + std::hash::Hash + Clone,
48        V: Clone,
49    {
50        let mut buffer = self.map.wl();
51        match buffer.get(&key) {
52            Some(v) => Ok(v.clone()),
53            None => {
54                let v = value_gen_fn(key)?;
55                buffer.insert(key.clone(), v.clone());
56                Ok(v)
57            }
58        }
59    }
60
61    pub fn alter_value(
62        &self,
63        key: &K,
64        alter_fn: impl Fn(&mut V) -> Result<(), SmallError>,
65    ) -> Result<(), SmallError>
66    where
67        K: std::cmp::Eq + std::hash::Hash + Clone,
68        V: Clone + std::default::Default,
69    {
70        let mut map = self.map.wl();
71
72        if let Some(v) = map.get_mut(key) {
73            alter_fn(v)
74        } else {
75            let mut new_v = Default::default();
76            alter_fn(&mut new_v)?;
77            map.insert(key.clone(), new_v);
78            Ok(())
79        }
80    }
81
82    /// Return true if `map[&k] == v`, or `map[&k]` is not exist.
83    ///
84    /// Return false if `map[&k] != v`.
85    pub fn exact_or_empty(&self, k: &K, v: &V) -> bool
86    where
87        K: std::cmp::Eq + std::hash::Hash,
88        V: std::cmp::Eq,
89    {
90        let map = self.map.rl();
91        map.get(k).map_or(true, |v2| v == v2)
92    }
93
94    pub fn clear(&self) {
95        self.map.wl().clear();
96    }
97
98    pub fn remove(&self, key: &K) -> Option<V>
99    where
100        K: std::cmp::Eq + std::hash::Hash,
101    {
102        self.map.wl().remove(key)
103    }
104
105    pub fn insert(&self, key: K, value: V) -> Option<V>
106    where
107        K: std::cmp::Eq + std::hash::Hash,
108    {
109        self.map.wl().insert(key, value)
110    }
111
112    pub fn keys(&self) -> Vec<K>
113    where
114        K: std::cmp::Eq + std::hash::Hash + Clone,
115    {
116        self.map.rl().keys().cloned().collect()
117    }
118}
119
120pub struct SmallLock {
121    name: String,
122    lock: Arc<Mutex<()>>,
123}
124
125impl SmallLock {
126    pub fn new(name: &str) -> Self {
127        Self {
128            name: name.to_string(),
129            lock: Arc::new(Mutex::new(())),
130        }
131    }
132
133    pub fn lock(&self) -> std::sync::MutexGuard<()> {
134        self.lock.lock().unwrap()
135    }
136}
137
138impl Drop for SmallLock {
139    fn drop(&mut self) {
140        println!("> Dropping {}", self.name);
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use std::thread::{self, sleep};
147
148    use log::debug;
149
150    use crate::utils::init_log;
151
152    #[test]
153    fn test_small_lock() {
154        init_log();
155        {
156            let lock = super::SmallLock::new("test");
157            let _guard = lock.lock();
158            debug!("Locking");
159        }
160        debug!("Dropped");
161
162        let global_lock = super::SmallLock::new("global");
163        thread::scope(|s| {
164            let mut threads = vec![];
165            for _ in 0..5 {
166                let handle = s.spawn(|| {
167                    let thread_name = format!(
168                        "thread-{:?}",
169                        thread::current().id()
170                    );
171                    debug!("{}: start", thread_name);
172                    {
173                        // We have to give the guard a name, otherwise
174                        // it will be dropped
175                        // immediately. (i.e, this block of code will
176                        // be protected by the lock)
177                        let _guard = global_lock.lock();
178                        sleep(std::time::Duration::from_millis(10));
179                        debug!("{}: lock acquired", thread_name);
180                        sleep(std::time::Duration::from_millis(1000));
181                    }
182                    debug!("{}: end", thread_name);
183                });
184                threads.push(handle);
185            }
186
187            for handle in threads {
188                handle.join().unwrap();
189            }
190        });
191    }
192}