rust_queries_core/
lock_query.rs

1//! Full SQL-like query support for locked data structures.
2//!
3//! This module provides a complete Query API for collections of locked values,
4//! enabling WHERE, SELECT, ORDER BY, GROUP BY, aggregations, and JOIN operations
5//! without copying data unnecessarily.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use rust_queries_core::{LockQuery};
11//! use std::sync::{Arc, RwLock};
12//! use std::collections::HashMap;
13//!
14//! let products: HashMap<String, Arc<RwLock<Product>>> = /* ... */;
15//!
16//! // Full SQL-like syntax on locked data!
17//! let expensive = LockQuery::new(&products)
18//!     .where_(Product::category(), |cat| cat == "Electronics")
19//!     .where_(Product::price(), |&p| p > 500.0)
20//!     .order_by_float(Product::rating())
21//!     .limit(10);
22//! ```
23
24use crate::locks::LockValue;
25use key_paths_core::KeyPaths;
26use std::collections::HashMap;
27use std::sync::{Arc, RwLock, Mutex};
28
29/// A query builder for locked data structures.
30///
31/// Provides full SQL-like query operations (WHERE, SELECT, ORDER BY, GROUP BY)
32/// on collections of locked values without unnecessary copying.
33pub struct LockQuery<'a, T: 'static, L>
34where
35    L: LockValue<T> + 'a,
36{
37    locks: Vec<&'a L>,
38    filters: Vec<Box<dyn Fn(&T) -> bool + 'a>>,
39    _phantom: std::marker::PhantomData<T>,
40}
41
42impl<'a, T: 'static, L> LockQuery<'a, T, L>
43where
44    L: LockValue<T> + 'a,
45{
46    /// Create a new lock query from a collection of locks.
47    ///
48    /// # Example
49    ///
50    /// ```ignore
51    /// let query = LockQuery::from_locks(product_map.values().collect());
52    /// ```
53    pub fn from_locks(locks: Vec<&'a L>) -> Self {
54        Self {
55            locks,
56            filters: Vec::new(),
57            _phantom: std::marker::PhantomData,
58        }
59    }
60
61    /// Add a WHERE clause using a key-path.
62    ///
63    /// # Example
64    ///
65    /// ```ignore
66    /// let query = LockQuery::new(&products)
67    ///     .where_(Product::category(), |cat| cat == "Electronics");
68    /// ```
69    pub fn where_<F>(mut self, path: KeyPaths<T, F>, predicate: impl Fn(&F) -> bool + 'a) -> Self
70    where
71        F: 'static,
72    {
73        self.filters.push(Box::new(move |item| {
74            path.get(item).map_or(false, |val| predicate(val))
75        }));
76        self
77    }
78
79    /// Get all matching items (collects by cloning).
80    ///
81    /// # Example
82    ///
83    /// ```ignore
84    /// let results: Vec<Product> = query.all();
85    /// ```
86    pub fn all(&self) -> Vec<T>
87    where
88        T: Clone,
89    {
90        self.locks
91            .iter()
92            .filter_map(|lock| {
93                lock.with_value(|item| {
94                    if self.filters.iter().all(|f| f(item)) {
95                        Some(item.clone())
96                    } else {
97                        None
98                    }
99                })
100                .flatten()
101            })
102            .collect()
103    }
104
105    /// Get the first matching item.
106    ///
107    /// # Example
108    ///
109    /// ```ignore
110    /// let first = query.first();
111    /// ```
112    pub fn first(&self) -> Option<T>
113    where
114        T: Clone,
115    {
116        self.locks
117            .iter()
118            .find_map(|lock| {
119                lock.with_value(|item| {
120                    if self.filters.iter().all(|f| f(item)) {
121                        Some(item.clone())
122                    } else {
123                        None
124                    }
125                })
126                .flatten()
127            })
128    }
129
130    /// Count matching items.
131    ///
132    /// # Example
133    ///
134    /// ```ignore
135    /// let count = query.count();
136    /// ```
137    pub fn count(&self) -> usize {
138        self.locks
139            .iter()
140            .filter(|lock| {
141                lock.with_value(|item| self.filters.iter().all(|f| f(item)))
142                    .unwrap_or(false)
143            })
144            .count()
145    }
146
147    /// Check if any items match.
148    ///
149    /// # Example
150    ///
151    /// ```ignore
152    /// let exists = query.exists();
153    /// ```
154    pub fn exists(&self) -> bool {
155        self.locks
156            .iter()
157            .any(|lock| {
158                lock.with_value(|item| self.filters.iter().all(|f| f(item)))
159                    .unwrap_or(false)
160            })
161    }
162
163    /// Limit results to first N items.
164    ///
165    /// # Example
166    ///
167    /// ```ignore
168    /// let first_10 = query.limit(10);
169    /// ```
170    pub fn limit(&self, n: usize) -> Vec<T>
171    where
172        T: Clone,
173    {
174        self.locks
175            .iter()
176            .filter_map(|lock| {
177                lock.with_value(|item| {
178                    if self.filters.iter().all(|f| f(item)) {
179                        Some(item.clone())
180                    } else {
181                        None
182                    }
183                })
184                .flatten()
185            })
186            .take(n)
187            .collect()
188    }
189
190    /// Select/project a field.
191    ///
192    /// # Example
193    ///
194    /// ```ignore
195    /// let names: Vec<String> = query.select(Product::name());
196    /// ```
197    pub fn select<F>(&self, path: KeyPaths<T, F>) -> Vec<F>
198    where
199        F: Clone + 'static,
200    {
201        self.locks
202            .iter()
203            .filter_map(|lock| {
204                lock.with_value(|item| {
205                    if self.filters.iter().all(|f| f(item)) {
206                        path.get(item).cloned()
207                    } else {
208                        None
209                    }
210                })
211                .flatten()
212            })
213            .collect()
214    }
215
216    /// Sum a numeric field.
217    ///
218    /// # Example
219    ///
220    /// ```ignore
221    /// let total = query.sum(Product::price());
222    /// ```
223    pub fn sum<F>(&self, path: KeyPaths<T, F>) -> F
224    where
225        F: Clone + std::ops::Add<Output = F> + Default + 'static,
226    {
227        self.locks
228            .iter()
229            .filter_map(|lock| {
230                lock.with_value(|item| {
231                    if self.filters.iter().all(|f| f(item)) {
232                        path.get(item).cloned()
233                    } else {
234                        None
235                    }
236                })
237                .flatten()
238            })
239            .fold(F::default(), |acc, val| acc + val)
240    }
241
242    /// Calculate average of f64 field.
243    ///
244    /// # Example
245    ///
246    /// ```ignore
247    /// let avg = query.avg(Product::price());
248    /// ```
249    pub fn avg(&self, path: KeyPaths<T, f64>) -> Option<f64> {
250        let values: Vec<f64> = self.select(path);
251        if values.is_empty() {
252            None
253        } else {
254            Some(values.iter().sum::<f64>() / values.len() as f64)
255        }
256    }
257
258    /// Find minimum value.
259    ///
260    /// # Example
261    ///
262    /// ```ignore
263    /// let min = query.min(Product::stock());
264    /// ```
265    pub fn min<F>(&self, path: KeyPaths<T, F>) -> Option<F>
266    where
267        F: Ord + Clone + 'static,
268    {
269        self.select(path).into_iter().min()
270    }
271
272    /// Find maximum value.
273    ///
274    /// # Example
275    ///
276    /// ```ignore
277    /// let max = query.max(Product::stock());
278    /// ```
279    pub fn max<F>(&self, path: KeyPaths<T, F>) -> Option<F>
280    where
281        F: Ord + Clone + 'static,
282    {
283        self.select(path).into_iter().max()
284    }
285
286    /// Find minimum float value.
287    pub fn min_float(&self, path: KeyPaths<T, f64>) -> Option<f64> {
288        self.select(path)
289            .into_iter()
290            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
291    }
292
293    /// Find maximum float value.
294    pub fn max_float(&self, path: KeyPaths<T, f64>) -> Option<f64> {
295        self.select(path)
296            .into_iter()
297            .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
298    }
299
300    /// Order by a field (requires collecting data).
301    ///
302    /// # Example
303    ///
304    /// ```ignore
305    /// let sorted = query.order_by(Product::name());
306    /// ```
307    pub fn order_by<F>(&self, path: KeyPaths<T, F>) -> Vec<T>
308    where
309        F: Ord + Clone + 'static,
310        T: Clone,
311    {
312        let mut results = self.all();
313        results.sort_by_key(|item| path.get(item).cloned());
314        results
315    }
316
317    /// Order by a field descending.
318    pub fn order_by_desc<F>(&self, path: KeyPaths<T, F>) -> Vec<T>
319    where
320        F: Ord + Clone + 'static,
321        T: Clone,
322    {
323        let mut results = self.all();
324        results.sort_by(|a, b| {
325            let a_val = path.get(a).cloned();
326            let b_val = path.get(b).cloned();
327            b_val.cmp(&a_val)
328        });
329        results
330    }
331
332    /// Order by float field.
333    pub fn order_by_float(&self, path: KeyPaths<T, f64>) -> Vec<T>
334    where
335        T: Clone,
336    {
337        let mut results = self.all();
338        results.sort_by(|a, b| {
339            let a_val = path.get(a).cloned().unwrap_or(0.0);
340            let b_val = path.get(b).cloned().unwrap_or(0.0);
341            a_val.partial_cmp(&b_val).unwrap_or(std::cmp::Ordering::Equal)
342        });
343        results
344    }
345
346    /// Order by float field descending.
347    pub fn order_by_float_desc(&self, path: KeyPaths<T, f64>) -> Vec<T>
348    where
349        T: Clone,
350    {
351        let mut results = self.all();
352        results.sort_by(|a, b| {
353            let a_val = path.get(a).cloned().unwrap_or(0.0);
354            let b_val = path.get(b).cloned().unwrap_or(0.0);
355            b_val.partial_cmp(&a_val).unwrap_or(std::cmp::Ordering::Equal)
356        });
357        results
358    }
359
360    /// Group by a field.
361    ///
362    /// # Example
363    ///
364    /// ```ignore
365    /// let groups = query.group_by(Product::category());
366    /// ```
367    pub fn group_by<F>(&self, path: KeyPaths<T, F>) -> HashMap<F, Vec<T>>
368    where
369        F: Eq + std::hash::Hash + Clone + 'static,
370        T: Clone,
371    {
372        let mut groups: HashMap<F, Vec<T>> = HashMap::new();
373
374        for lock in &self.locks {
375            if let Some(item) = lock.with_value(|item| {
376                if self.filters.iter().all(|f| f(item)) {
377                    Some(item.clone())
378                } else {
379                    None
380                }
381            })
382            .flatten()
383            {
384                if let Some(key) = path.get(&item).cloned() {
385                    groups.entry(key).or_insert_with(Vec::new).push(item);
386                }
387            }
388        }
389
390        groups
391    }
392}
393
394/// Helper to create LockQuery from HashMap.
395pub trait LockQueryable<T, L>
396where
397    L: LockValue<T>,
398{
399    /// Create a LockQuery for SQL-like operations.
400    fn lock_query(&self) -> LockQuery<'_, T, L>;
401}
402
403// Implementation for HashMap<K, Arc<RwLock<V>>>
404impl<K, V> LockQueryable<V, Arc<RwLock<V>>> for HashMap<K, Arc<RwLock<V>>>
405where
406    K: Eq + std::hash::Hash,
407{
408    fn lock_query(&self) -> LockQuery<'_, V, Arc<RwLock<V>>> {
409        LockQuery::from_locks(self.values().collect())
410    }
411}
412
413// Implementation for HashMap<K, Arc<Mutex<V>>>
414impl<K, V> LockQueryable<V, Arc<Mutex<V>>> for HashMap<K, Arc<Mutex<V>>>
415where
416    K: Eq + std::hash::Hash,
417{
418    fn lock_query(&self) -> LockQuery<'_, V, Arc<Mutex<V>>> {
419        LockQuery::from_locks(self.values().collect())
420    }
421}
422
423// Implementation for Vec<Arc<RwLock<T>>>
424impl<T> LockQueryable<T, Arc<RwLock<T>>> for Vec<Arc<RwLock<T>>> {
425    fn lock_query(&self) -> LockQuery<'_, T, Arc<RwLock<T>>> {
426        LockQuery::from_locks(self.iter().collect())
427    }
428}
429
430// Implementation for Vec<Arc<Mutex<T>>>
431impl<T> LockQueryable<T, Arc<Mutex<T>>> for Vec<Arc<Mutex<T>>> {
432    fn lock_query(&self) -> LockQuery<'_, T, Arc<Mutex<T>>> {
433        LockQuery::from_locks(self.iter().collect())
434    }
435}
436
437// Extension trait for creating lazy lock queries
438use crate::lock_lazy::LockLazyQuery;
439
440/// Extension trait for creating lazy lock queries.
441pub trait LockLazyQueryable<T, L>
442where
443    L: LockValue<T>,
444{
445    /// Create a lazy lock query.
446    fn lock_lazy_query(&self) -> LockLazyQuery<'_, T, L, impl Iterator<Item = &L>>;
447}
448
449// Implementation for HashMap<K, Arc<RwLock<V>>>
450impl<K, V> LockLazyQueryable<V, Arc<RwLock<V>>> for HashMap<K, Arc<RwLock<V>>>
451where
452    K: Eq + std::hash::Hash,
453{
454    fn lock_lazy_query(&self) -> LockLazyQuery<'_, V, Arc<RwLock<V>>, impl Iterator<Item = &Arc<RwLock<V>>>> {
455        LockLazyQuery::new(self.values())
456    }
457}
458
459// Implementation for HashMap<K, Arc<Mutex<V>>>
460impl<K, V> LockLazyQueryable<V, Arc<Mutex<V>>> for HashMap<K, Arc<Mutex<V>>>
461where
462    K: Eq + std::hash::Hash,
463{
464    fn lock_lazy_query(&self) -> LockLazyQuery<'_, V, Arc<Mutex<V>>, impl Iterator<Item = &Arc<Mutex<V>>>> {
465        LockLazyQuery::new(self.values())
466    }
467}
468
469// Implementation for Vec<Arc<RwLock<T>>>
470impl<T> LockLazyQueryable<T, Arc<RwLock<T>>> for Vec<Arc<RwLock<T>>> {
471    fn lock_lazy_query(&self) -> LockLazyQuery<'_, T, Arc<RwLock<T>>, impl Iterator<Item = &Arc<RwLock<T>>>> {
472        LockLazyQuery::new(self.iter())
473    }
474}
475
476// Implementation for Vec<Arc<Mutex<T>>>
477impl<T> LockLazyQueryable<T, Arc<Mutex<T>>> for Vec<Arc<Mutex<T>>> {
478    fn lock_lazy_query(&self) -> LockLazyQuery<'_, T, Arc<Mutex<T>>, impl Iterator<Item = &Arc<Mutex<T>>>> {
479        LockLazyQuery::new(self.iter())
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use std::sync::{Arc, RwLock};
487    use key_paths_derive::Keypath;
488
489    #[derive(Clone, Keypath)]
490    struct Product {
491        id: u32,
492        name: String,
493        price: f64,
494        category: String,
495    }
496
497    fn create_test_map() -> HashMap<String, Arc<RwLock<Product>>> {
498        let mut map = HashMap::new();
499        map.insert(
500            "p1".to_string(),
501            Arc::new(RwLock::new(Product {
502                id: 1,
503                name: "Laptop".to_string(),
504                price: 999.99,
505                category: "Electronics".to_string(),
506            })),
507        );
508        map.insert(
509            "p2".to_string(),
510            Arc::new(RwLock::new(Product {
511                id: 2,
512                name: "Chair".to_string(),
513                price: 299.99,
514                category: "Furniture".to_string(),
515            })),
516        );
517        map.insert(
518            "p3".to_string(),
519            Arc::new(RwLock::new(Product {
520                id: 3,
521                name: "Mouse".to_string(),
522                price: 29.99,
523                category: "Electronics".to_string(),
524            })),
525        );
526        map
527    }
528
529    #[test]
530    fn test_lock_query_where() {
531        let map = create_test_map();
532        let query = map.lock_query();
533        let count = query
534            .where_(Product::category(), |cat| cat == "Electronics")
535            .count();
536        assert_eq!(count, 2);
537    }
538
539    #[test]
540    fn test_lock_query_select() {
541        let map = create_test_map();
542        let names = map
543            .lock_query()
544            .select(Product::name());
545        assert_eq!(names.len(), 3);
546    }
547
548    #[test]
549    fn test_lock_query_sum() {
550        let map = create_test_map();
551        let total = map
552            .lock_query()
553            .sum(Product::price());
554        assert!((total - 1329.97).abs() < 0.01);
555    }
556
557    #[test]
558    fn test_lock_query_group_by() {
559        let map = create_test_map();
560        let groups = map
561            .lock_query()
562            .group_by(Product::category());
563        assert_eq!(groups.len(), 2);
564        assert_eq!(groups.get("Electronics").unwrap().len(), 2);
565    }
566
567    #[test]
568    fn test_lock_query_order_by() {
569        let map = create_test_map();
570        let sorted = map
571            .lock_query()
572            .order_by_float(Product::price());
573        assert_eq!(sorted[0].price, 29.99);
574        assert_eq!(sorted[2].price, 999.99);
575    }
576}
577