1use std::sync::{Arc, RwLock, Mutex};
32use std::collections::HashMap;
33
34pub trait LockValue<T> {
38 fn with_value<F, R>(&self, f: F) -> Option<R>
40 where
41 F: FnOnce(&T) -> R;
42}
43
44impl<T> LockValue<T> for Arc<RwLock<T>> {
46 fn with_value<F, R>(&self, f: F) -> Option<R>
47 where
48 F: FnOnce(&T) -> R,
49 {
50 self.read().ok().map(|guard| f(&*guard))
51 }
52}
53
54impl<T> LockValue<T> for Arc<Mutex<T>> {
56 fn with_value<F, R>(&self, f: F) -> Option<R>
57 where
58 F: FnOnce(&T) -> R,
59 {
60 self.lock().ok().map(|guard| f(&*guard))
61 }
62}
63
64impl<T> LockValue<T> for RwLock<T> {
66 fn with_value<F, R>(&self, f: F) -> Option<R>
67 where
68 F: FnOnce(&T) -> R,
69 {
70 self.read().ok().map(|guard| f(&*guard))
71 }
72}
73
74impl<T> LockValue<T> for Mutex<T> {
76 fn with_value<F, R>(&self, f: F) -> Option<R>
77 where
78 F: FnOnce(&T) -> R,
79 {
80 self.lock().ok().map(|guard| f(&*guard))
81 }
82}
83
84
85pub trait LockQueryExt<T, L>
90where
91 L: LockValue<T>,
92{
93 fn lock_iter(&self) -> Box<dyn Iterator<Item = LockedValueRef<'_, T, L>> + '_>;
102}
103
104pub struct LockedValueRef<'a, T, L>
109where
110 L: LockValue<T>,
111{
112 lock: &'a L,
113 _phantom: std::marker::PhantomData<T>,
114}
115
116impl<'a, T, L> LockedValueRef<'a, T, L>
117where
118 L: LockValue<T>,
119{
120 pub fn new(lock: &'a L) -> Self {
121 Self {
122 lock,
123 _phantom: std::marker::PhantomData,
124 }
125 }
126
127 pub fn with_value<F, R>(&self, f: F) -> Option<R>
129 where
130 F: FnOnce(&T) -> R,
131 {
132 self.lock.with_value(f)
133 }
134
135 pub fn map<F, R>(&self, f: F) -> Option<R>
137 where
138 F: FnOnce(&T) -> R,
139 {
140 self.lock.with_value(f)
141 }
142
143 pub fn matches<F>(&self, predicate: F) -> bool
145 where
146 F: FnOnce(&T) -> bool,
147 {
148 self.lock.with_value(predicate).unwrap_or(false)
149 }
150
151}
152
153impl<K, V> LockQueryExt<V, Arc<RwLock<V>>> for HashMap<K, Arc<RwLock<V>>>
155where
156 K: Eq + std::hash::Hash,
157{
158 fn lock_iter(&self) -> Box<dyn Iterator<Item = LockedValueRef<'_, V, Arc<RwLock<V>>>> + '_> {
159 Box::new(self.values().map(|lock| LockedValueRef::new(lock)))
160 }
161}
162
163impl<K, V> LockQueryExt<V, Arc<Mutex<V>>> for HashMap<K, Arc<Mutex<V>>>
165where
166 K: Eq + std::hash::Hash,
167{
168 fn lock_iter(&self) -> Box<dyn Iterator<Item = LockedValueRef<'_, V, Arc<Mutex<V>>>> + '_> {
169 Box::new(self.values().map(|lock| LockedValueRef::new(lock)))
170 }
171}
172
173impl<T> LockQueryExt<T, Arc<RwLock<T>>> for Vec<Arc<RwLock<T>>> {
175 fn lock_iter(&self) -> Box<dyn Iterator<Item = LockedValueRef<'_, T, Arc<RwLock<T>>>> + '_> {
176 Box::new(self.iter().map(|lock| LockedValueRef::new(lock)))
177 }
178}
179
180impl<T> LockQueryExt<T, Arc<Mutex<T>>> for Vec<Arc<Mutex<T>>> {
182 fn lock_iter(&self) -> Box<dyn Iterator<Item = LockedValueRef<'_, T, Arc<Mutex<T>>>> + '_> {
183 Box::new(self.iter().map(|lock| LockedValueRef::new(lock)))
184 }
185}
186
187impl<T> LockQueryExt<T, Arc<RwLock<T>>> for [Arc<RwLock<T>>] {
189 fn lock_iter(&self) -> Box<dyn Iterator<Item = LockedValueRef<'_, T, Arc<RwLock<T>>>> + '_> {
190 Box::new(self.iter().map(|lock| LockedValueRef::new(lock)))
191 }
192}
193
194impl<T> LockQueryExt<T, Arc<Mutex<T>>> for [Arc<Mutex<T>>] {
196 fn lock_iter(&self) -> Box<dyn Iterator<Item = LockedValueRef<'_, T, Arc<Mutex<T>>>> + '_> {
197 Box::new(self.iter().map(|lock| LockedValueRef::new(lock)))
198 }
199}
200
201pub struct LockFilterIter<'a, T, L, I, F>
203where
204 L: LockValue<T> + 'a,
205 I: Iterator<Item = LockedValueRef<'a, T, L>>,
206 F: Fn(&T) -> bool,
207{
208 iter: I,
209 predicate: F,
210 _phantom: std::marker::PhantomData<(&'a T, L)>,
211}
212
213impl<'a, T, L, I, F> Iterator for LockFilterIter<'a, T, L, I, F>
214where
215 L: LockValue<T> + 'a,
216 I: Iterator<Item = LockedValueRef<'a, T, L>>,
217 F: Fn(&T) -> bool,
218{
219 type Item = LockedValueRef<'a, T, L>;
220
221 fn next(&mut self) -> Option<Self::Item> {
222 self.iter.find(|locked_ref| locked_ref.matches(&self.predicate))
223 }
224}
225
226pub trait LockIterExt<'a, T: 'a, L>: Iterator<Item = LockedValueRef<'a, T, L>> + Sized
228where
229 L: LockValue<T> + 'a,
230{
231 fn filter_locked<F>(self, predicate: F) -> LockFilterIter<'a, T, L, Self, F>
233 where
234 F: Fn(&T) -> bool,
235 {
236 LockFilterIter {
237 iter: self,
238 predicate,
239 _phantom: std::marker::PhantomData,
240 }
241 }
242
243 fn map_locked<F, R>(self, f: F) -> impl Iterator<Item = R> + 'a
245 where
246 F: Fn(&T) -> R + 'a,
247 Self: 'a,
248 {
249 self.filter_map(move |locked_ref| locked_ref.map(&f))
250 }
251
252 fn count_locked<F>(self, predicate: F) -> usize
254 where
255 F: Fn(&T) -> bool,
256 Self: 'a,
257 {
258 self.filter(|locked_ref| locked_ref.matches(&predicate))
259 .count()
260 }
261
262 fn find_locked<F>(mut self, predicate: F) -> Option<LockedValueRef<'a, T, L>>
264 where
265 F: Fn(&T) -> bool,
266 {
267 self.find(|locked_ref| locked_ref.matches(&predicate))
268 }
269
270 fn any_locked<F>(mut self, predicate: F) -> bool
272 where
273 F: Fn(&T) -> bool,
274 {
275 self.any(|locked_ref| locked_ref.matches(&predicate))
276 }
277
278 fn collect_cloned(self) -> Vec<T>
280 where
281 T: Clone,
282 Self: 'a,
283 {
284 self.filter_map(|locked_ref| {
285 locked_ref.with_value(|v| v.clone())
286 })
287 .collect()
288 }
289
290}
291
292impl<'a, T: 'a, L, I> LockIterExt<'a, T, L> for I
294where
295 L: LockValue<T> + 'a,
296 I: Iterator<Item = LockedValueRef<'a, T, L>>,
297{
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use std::sync::{Arc, RwLock};
304
305 #[test]
306 fn test_rwlock_lock_value() {
307 let data = Arc::new(RwLock::new(42));
308 let result = data.with_value(|v| *v * 2);
309 assert_eq!(result, Some(84));
310 }
311
312 #[test]
313 fn test_mutex_lock_value() {
314 let data = Arc::new(Mutex::new("hello"));
315 let result = data.with_value(|v| v.len());
316 assert_eq!(result, Some(5));
317 }
318
319 #[test]
320 fn test_hashmap_lock_query() {
321 let mut map: HashMap<String, Arc<RwLock<i32>>> = HashMap::new();
322 map.insert("a".to_string(), Arc::new(RwLock::new(10)));
323 map.insert("b".to_string(), Arc::new(RwLock::new(20)));
324 map.insert("c".to_string(), Arc::new(RwLock::new(30)));
325
326 let sum: i32 = map
327 .lock_iter()
328 .map_locked(|v| *v)
329 .sum();
330
331 assert_eq!(sum, 60);
332 }
333
334 #[test]
335 fn test_lock_filter() {
336 let mut map: HashMap<String, Arc<RwLock<i32>>> = HashMap::new();
337 map.insert("a".to_string(), Arc::new(RwLock::new(10)));
338 map.insert("b".to_string(), Arc::new(RwLock::new(20)));
339 map.insert("c".to_string(), Arc::new(RwLock::new(30)));
340
341 let count = map
342 .lock_iter()
343 .count_locked(|v| *v > 15);
344
345 assert_eq!(count, 2);
346 }
347
348 #[test]
349 fn test_lock_any() {
350 let mut map: HashMap<String, Arc<RwLock<i32>>> = HashMap::new();
351 map.insert("a".to_string(), Arc::new(RwLock::new(10)));
352 map.insert("b".to_string(), Arc::new(RwLock::new(20)));
353
354 let has_large = map
355 .lock_iter()
356 .any_locked(|v| *v > 15);
357
358 assert!(has_large);
359 }
360}
361