1use std::{
2 collections::HashMap,
3 sync::{Arc, Mutex, RwLock, RwLockReadGuard, RwLockWriteGuard},
4};
5
6use crate::{error::SmallError, utils::HandyRwLock};
7
8pub type Pod<T> = Arc<RwLock<T>>;
10
11pub type ResultPod<T> = Result<Pod<T>, SmallError>;
16pub type SmallResult = Result<(), SmallError>;
17
18pub struct ConcurrentHashMap<K, V> {
19 map: Arc<RwLock<HashMap<K, V>>>,
20}
21
22impl<K, V> ConcurrentHashMap<K, V> {
23 pub fn new() -> Self {
24 Self {
25 map: Arc::new(RwLock::new(HashMap::new())),
26 }
27 }
28
29 pub fn get_inner(&self) -> Arc<RwLock<HashMap<K, V>>> {
30 self.map.clone()
31 }
32
33 pub fn get_inner_rl(&self) -> RwLockReadGuard<HashMap<K, V>> {
34 self.map.rl()
35 }
36
37 pub fn get_inner_wl(&self) -> RwLockWriteGuard<HashMap<K, V>> {
38 self.map.wl()
39 }
40
41 pub fn get_or_insert(
42 &self,
43 key: &K,
44 value_gen_fn: impl Fn(&K) -> Result<V, SmallError>,
45 ) -> Result<V, SmallError>
46 where
47 K: std::cmp::Eq + std::hash::Hash + Clone,
48 V: Clone,
49 {
50 let mut buffer = self.map.wl();
51 match buffer.get(&key) {
52 Some(v) => Ok(v.clone()),
53 None => {
54 let v = value_gen_fn(key)?;
55 buffer.insert(key.clone(), v.clone());
56 Ok(v)
57 }
58 }
59 }
60
61 pub fn alter_value(
62 &self,
63 key: &K,
64 alter_fn: impl Fn(&mut V) -> Result<(), SmallError>,
65 ) -> Result<(), SmallError>
66 where
67 K: std::cmp::Eq + std::hash::Hash + Clone,
68 V: Clone + std::default::Default,
69 {
70 let mut map = self.map.wl();
71
72 if let Some(v) = map.get_mut(key) {
73 alter_fn(v)
74 } else {
75 let mut new_v = Default::default();
76 alter_fn(&mut new_v)?;
77 map.insert(key.clone(), new_v);
78 Ok(())
79 }
80 }
81
82 pub fn exact_or_empty(&self, k: &K, v: &V) -> bool
86 where
87 K: std::cmp::Eq + std::hash::Hash,
88 V: std::cmp::Eq,
89 {
90 let map = self.map.rl();
91 map.get(k).map_or(true, |v2| v == v2)
92 }
93
94 pub fn clear(&self) {
95 self.map.wl().clear();
96 }
97
98 pub fn remove(&self, key: &K) -> Option<V>
99 where
100 K: std::cmp::Eq + std::hash::Hash,
101 {
102 self.map.wl().remove(key)
103 }
104
105 pub fn insert(&self, key: K, value: V) -> Option<V>
106 where
107 K: std::cmp::Eq + std::hash::Hash,
108 {
109 self.map.wl().insert(key, value)
110 }
111
112 pub fn keys(&self) -> Vec<K>
113 where
114 K: std::cmp::Eq + std::hash::Hash + Clone,
115 {
116 self.map.rl().keys().cloned().collect()
117 }
118}
119
120pub struct SmallLock {
121 name: String,
122 lock: Arc<Mutex<()>>,
123}
124
125impl SmallLock {
126 pub fn new(name: &str) -> Self {
127 Self {
128 name: name.to_string(),
129 lock: Arc::new(Mutex::new(())),
130 }
131 }
132
133 pub fn lock(&self) -> std::sync::MutexGuard<()> {
134 self.lock.lock().unwrap()
135 }
136}
137
138impl Drop for SmallLock {
139 fn drop(&mut self) {
140 println!("> Dropping {}", self.name);
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use std::thread::{self, sleep};
147
148 use log::debug;
149
150 use crate::utils::init_log;
151
152 #[test]
153 fn test_small_lock() {
154 init_log();
155 {
156 let lock = super::SmallLock::new("test");
157 let _guard = lock.lock();
158 debug!("Locking");
159 }
160 debug!("Dropped");
161
162 let global_lock = super::SmallLock::new("global");
163 thread::scope(|s| {
164 let mut threads = vec![];
165 for _ in 0..5 {
166 let handle = s.spawn(|| {
167 let thread_name = format!(
168 "thread-{:?}",
169 thread::current().id()
170 );
171 debug!("{}: start", thread_name);
172 {
173 let _guard = global_lock.lock();
178 sleep(std::time::Duration::from_millis(10));
179 debug!("{}: lock acquired", thread_name);
180 sleep(std::time::Duration::from_millis(1000));
181 }
182 debug!("{}: end", thread_name);
183 });
184 threads.push(handle);
185 }
186
187 for handle in threads {
188 handle.join().unwrap();
189 }
190 });
191 }
192}