1use crate::locks::LockValue;
25use key_paths_core::KeyPaths;
26use std::collections::HashMap;
27use std::sync::{Arc, RwLock, Mutex};
28
29pub struct LockQuery<'a, T: 'static, L>
34where
35 L: LockValue<T> + 'a,
36{
37 locks: Vec<&'a L>,
38 filters: Vec<Box<dyn Fn(&T) -> bool + 'a>>,
39 _phantom: std::marker::PhantomData<T>,
40}
41
42impl<'a, T: 'static, L> LockQuery<'a, T, L>
43where
44 L: LockValue<T> + 'a,
45{
46 pub fn from_locks(locks: Vec<&'a L>) -> Self {
54 Self {
55 locks,
56 filters: Vec::new(),
57 _phantom: std::marker::PhantomData,
58 }
59 }
60
61 pub fn where_<F>(mut self, path: KeyPaths<T, F>, predicate: impl Fn(&F) -> bool + 'a) -> Self
70 where
71 F: 'static,
72 {
73 self.filters.push(Box::new(move |item| {
74 path.get(item).map_or(false, |val| predicate(val))
75 }));
76 self
77 }
78
79 pub fn all(&self) -> Vec<T>
87 where
88 T: Clone,
89 {
90 self.locks
91 .iter()
92 .filter_map(|lock| {
93 lock.with_value(|item| {
94 if self.filters.iter().all(|f| f(item)) {
95 Some(item.clone())
96 } else {
97 None
98 }
99 })
100 .flatten()
101 })
102 .collect()
103 }
104
105 pub fn first(&self) -> Option<T>
113 where
114 T: Clone,
115 {
116 self.locks
117 .iter()
118 .find_map(|lock| {
119 lock.with_value(|item| {
120 if self.filters.iter().all(|f| f(item)) {
121 Some(item.clone())
122 } else {
123 None
124 }
125 })
126 .flatten()
127 })
128 }
129
130 pub fn count(&self) -> usize {
138 self.locks
139 .iter()
140 .filter(|lock| {
141 lock.with_value(|item| self.filters.iter().all(|f| f(item)))
142 .unwrap_or(false)
143 })
144 .count()
145 }
146
147 pub fn exists(&self) -> bool {
155 self.locks
156 .iter()
157 .any(|lock| {
158 lock.with_value(|item| self.filters.iter().all(|f| f(item)))
159 .unwrap_or(false)
160 })
161 }
162
163 pub fn limit(&self, n: usize) -> Vec<T>
171 where
172 T: Clone,
173 {
174 self.locks
175 .iter()
176 .filter_map(|lock| {
177 lock.with_value(|item| {
178 if self.filters.iter().all(|f| f(item)) {
179 Some(item.clone())
180 } else {
181 None
182 }
183 })
184 .flatten()
185 })
186 .take(n)
187 .collect()
188 }
189
190 pub fn select<F>(&self, path: KeyPaths<T, F>) -> Vec<F>
198 where
199 F: Clone + 'static,
200 {
201 self.locks
202 .iter()
203 .filter_map(|lock| {
204 lock.with_value(|item| {
205 if self.filters.iter().all(|f| f(item)) {
206 path.get(item).cloned()
207 } else {
208 None
209 }
210 })
211 .flatten()
212 })
213 .collect()
214 }
215
216 pub fn sum<F>(&self, path: KeyPaths<T, F>) -> F
224 where
225 F: Clone + std::ops::Add<Output = F> + Default + 'static,
226 {
227 self.locks
228 .iter()
229 .filter_map(|lock| {
230 lock.with_value(|item| {
231 if self.filters.iter().all(|f| f(item)) {
232 path.get(item).cloned()
233 } else {
234 None
235 }
236 })
237 .flatten()
238 })
239 .fold(F::default(), |acc, val| acc + val)
240 }
241
242 pub fn avg(&self, path: KeyPaths<T, f64>) -> Option<f64> {
250 let values: Vec<f64> = self.select(path);
251 if values.is_empty() {
252 None
253 } else {
254 Some(values.iter().sum::<f64>() / values.len() as f64)
255 }
256 }
257
258 pub fn min<F>(&self, path: KeyPaths<T, F>) -> Option<F>
266 where
267 F: Ord + Clone + 'static,
268 {
269 self.select(path).into_iter().min()
270 }
271
272 pub fn max<F>(&self, path: KeyPaths<T, F>) -> Option<F>
280 where
281 F: Ord + Clone + 'static,
282 {
283 self.select(path).into_iter().max()
284 }
285
286 pub fn min_float(&self, path: KeyPaths<T, f64>) -> Option<f64> {
288 self.select(path)
289 .into_iter()
290 .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
291 }
292
293 pub fn max_float(&self, path: KeyPaths<T, f64>) -> Option<f64> {
295 self.select(path)
296 .into_iter()
297 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
298 }
299
300 pub fn order_by<F>(&self, path: KeyPaths<T, F>) -> Vec<T>
308 where
309 F: Ord + Clone + 'static,
310 T: Clone,
311 {
312 let mut results = self.all();
313 results.sort_by_key(|item| path.get(item).cloned());
314 results
315 }
316
317 pub fn order_by_desc<F>(&self, path: KeyPaths<T, F>) -> Vec<T>
319 where
320 F: Ord + Clone + 'static,
321 T: Clone,
322 {
323 let mut results = self.all();
324 results.sort_by(|a, b| {
325 let a_val = path.get(a).cloned();
326 let b_val = path.get(b).cloned();
327 b_val.cmp(&a_val)
328 });
329 results
330 }
331
332 pub fn order_by_float(&self, path: KeyPaths<T, f64>) -> Vec<T>
334 where
335 T: Clone,
336 {
337 let mut results = self.all();
338 results.sort_by(|a, b| {
339 let a_val = path.get(a).cloned().unwrap_or(0.0);
340 let b_val = path.get(b).cloned().unwrap_or(0.0);
341 a_val.partial_cmp(&b_val).unwrap_or(std::cmp::Ordering::Equal)
342 });
343 results
344 }
345
346 pub fn order_by_float_desc(&self, path: KeyPaths<T, f64>) -> Vec<T>
348 where
349 T: Clone,
350 {
351 let mut results = self.all();
352 results.sort_by(|a, b| {
353 let a_val = path.get(a).cloned().unwrap_or(0.0);
354 let b_val = path.get(b).cloned().unwrap_or(0.0);
355 b_val.partial_cmp(&a_val).unwrap_or(std::cmp::Ordering::Equal)
356 });
357 results
358 }
359
360 pub fn group_by<F>(&self, path: KeyPaths<T, F>) -> HashMap<F, Vec<T>>
368 where
369 F: Eq + std::hash::Hash + Clone + 'static,
370 T: Clone,
371 {
372 let mut groups: HashMap<F, Vec<T>> = HashMap::new();
373
374 for lock in &self.locks {
375 if let Some(item) = lock.with_value(|item| {
376 if self.filters.iter().all(|f| f(item)) {
377 Some(item.clone())
378 } else {
379 None
380 }
381 })
382 .flatten()
383 {
384 if let Some(key) = path.get(&item).cloned() {
385 groups.entry(key).or_insert_with(Vec::new).push(item);
386 }
387 }
388 }
389
390 groups
391 }
392}
393
394pub trait LockQueryable<T, L>
396where
397 L: LockValue<T>,
398{
399 fn lock_query(&self) -> LockQuery<'_, T, L>;
401}
402
403impl<K, V> LockQueryable<V, Arc<RwLock<V>>> for HashMap<K, Arc<RwLock<V>>>
405where
406 K: Eq + std::hash::Hash,
407{
408 fn lock_query(&self) -> LockQuery<'_, V, Arc<RwLock<V>>> {
409 LockQuery::from_locks(self.values().collect())
410 }
411}
412
413impl<K, V> LockQueryable<V, Arc<Mutex<V>>> for HashMap<K, Arc<Mutex<V>>>
415where
416 K: Eq + std::hash::Hash,
417{
418 fn lock_query(&self) -> LockQuery<'_, V, Arc<Mutex<V>>> {
419 LockQuery::from_locks(self.values().collect())
420 }
421}
422
423impl<T> LockQueryable<T, Arc<RwLock<T>>> for Vec<Arc<RwLock<T>>> {
425 fn lock_query(&self) -> LockQuery<'_, T, Arc<RwLock<T>>> {
426 LockQuery::from_locks(self.iter().collect())
427 }
428}
429
430impl<T> LockQueryable<T, Arc<Mutex<T>>> for Vec<Arc<Mutex<T>>> {
432 fn lock_query(&self) -> LockQuery<'_, T, Arc<Mutex<T>>> {
433 LockQuery::from_locks(self.iter().collect())
434 }
435}
436
437use crate::lock_lazy::LockLazyQuery;
439
440pub trait LockLazyQueryable<T, L>
442where
443 L: LockValue<T>,
444{
445 fn lock_lazy_query(&self) -> LockLazyQuery<'_, T, L, impl Iterator<Item = &L>>;
447}
448
449impl<K, V> LockLazyQueryable<V, Arc<RwLock<V>>> for HashMap<K, Arc<RwLock<V>>>
451where
452 K: Eq + std::hash::Hash,
453{
454 fn lock_lazy_query(&self) -> LockLazyQuery<'_, V, Arc<RwLock<V>>, impl Iterator<Item = &Arc<RwLock<V>>>> {
455 LockLazyQuery::new(self.values())
456 }
457}
458
459impl<K, V> LockLazyQueryable<V, Arc<Mutex<V>>> for HashMap<K, Arc<Mutex<V>>>
461where
462 K: Eq + std::hash::Hash,
463{
464 fn lock_lazy_query(&self) -> LockLazyQuery<'_, V, Arc<Mutex<V>>, impl Iterator<Item = &Arc<Mutex<V>>>> {
465 LockLazyQuery::new(self.values())
466 }
467}
468
469impl<T> LockLazyQueryable<T, Arc<RwLock<T>>> for Vec<Arc<RwLock<T>>> {
471 fn lock_lazy_query(&self) -> LockLazyQuery<'_, T, Arc<RwLock<T>>, impl Iterator<Item = &Arc<RwLock<T>>>> {
472 LockLazyQuery::new(self.iter())
473 }
474}
475
476impl<T> LockLazyQueryable<T, Arc<Mutex<T>>> for Vec<Arc<Mutex<T>>> {
478 fn lock_lazy_query(&self) -> LockLazyQuery<'_, T, Arc<Mutex<T>>, impl Iterator<Item = &Arc<Mutex<T>>>> {
479 LockLazyQuery::new(self.iter())
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use std::sync::{Arc, RwLock};
487 use key_paths_derive::Keypath;
488
489 #[derive(Clone, Keypath)]
490 struct Product {
491 id: u32,
492 name: String,
493 price: f64,
494 category: String,
495 }
496
497 fn create_test_map() -> HashMap<String, Arc<RwLock<Product>>> {
498 let mut map = HashMap::new();
499 map.insert(
500 "p1".to_string(),
501 Arc::new(RwLock::new(Product {
502 id: 1,
503 name: "Laptop".to_string(),
504 price: 999.99,
505 category: "Electronics".to_string(),
506 })),
507 );
508 map.insert(
509 "p2".to_string(),
510 Arc::new(RwLock::new(Product {
511 id: 2,
512 name: "Chair".to_string(),
513 price: 299.99,
514 category: "Furniture".to_string(),
515 })),
516 );
517 map.insert(
518 "p3".to_string(),
519 Arc::new(RwLock::new(Product {
520 id: 3,
521 name: "Mouse".to_string(),
522 price: 29.99,
523 category: "Electronics".to_string(),
524 })),
525 );
526 map
527 }
528
529 #[test]
530 fn test_lock_query_where() {
531 let map = create_test_map();
532 let query = map.lock_query();
533 let count = query
534 .where_(Product::category(), |cat| cat == "Electronics")
535 .count();
536 assert_eq!(count, 2);
537 }
538
539 #[test]
540 fn test_lock_query_select() {
541 let map = create_test_map();
542 let names = map
543 .lock_query()
544 .select(Product::name());
545 assert_eq!(names.len(), 3);
546 }
547
548 #[test]
549 fn test_lock_query_sum() {
550 let map = create_test_map();
551 let total = map
552 .lock_query()
553 .sum(Product::price());
554 assert!((total - 1329.97).abs() < 0.01);
555 }
556
557 #[test]
558 fn test_lock_query_group_by() {
559 let map = create_test_map();
560 let groups = map
561 .lock_query()
562 .group_by(Product::category());
563 assert_eq!(groups.len(), 2);
564 assert_eq!(groups.get("Electronics").unwrap().len(), 2);
565 }
566
567 #[test]
568 fn test_lock_query_order_by() {
569 let map = create_test_map();
570 let sorted = map
571 .lock_query()
572 .order_by_float(Product::price());
573 assert_eq!(sorted[0].price, 29.99);
574 assert_eq!(sorted[2].price, 999.99);
575 }
576}
577