rust_queries_core/
join.rs

1//! Join query implementation for combining multiple collections.
2//!
3//! This module provides the `JoinQuery` struct which enables SQL-like JOIN operations
4//! between collections using type-safe key-paths.
5
6use key_paths_core::KeyPaths;
7use std::collections::HashMap;
8
9/// A query builder for joining two collections.
10///
11/// Supports inner joins, left joins, and filtered joins using key-paths for type-safe
12/// join conditions.
13///
14/// # Type Parameters
15///
16/// * `'a` - The lifetime of the data being joined
17/// * `L` - The type of items in the left collection
18/// * `R` - The type of items in the right collection
19///
20/// # Example
21///
22/// ```ignore
23/// let user_orders = JoinQuery::new(&users, &orders)
24///     .inner_join(
25///         User::id(),
26///         Order::user_id(),
27///         |user, order| (user.name.clone(), order.total)
28///     );
29/// ```
30pub struct JoinQuery<'a, L: 'static, R: 'static> {
31    left: &'a [L],
32    right: &'a [R],
33}
34
35impl<'a, L: 'static, R: 'static> JoinQuery<'a, L, R> {
36    /// Creates a new join query from two collections.
37    ///
38    /// **Note**: No `Clone` required on `L` or `R`. The mapper function 
39    /// handles any cloning needed for the result type.
40    ///
41    /// # Arguments
42    ///
43    /// * `left` - The left collection to join
44    /// * `right` - The right collection to join
45    ///
46    /// # Example
47    ///
48    /// ```ignore
49    /// let join = JoinQuery::new(&users, &orders);
50    /// ```
51    pub fn new(left: &'a [L], right: &'a [R]) -> Self {
52        Self { left, right }
53    }
54
55    /// Performs an inner join between two collections.
56    ///
57    /// Returns only the pairs where the join keys match. Uses a hash-based
58    /// algorithm for O(n + m) performance.
59    ///
60    /// # Arguments
61    ///
62    /// * `left_key` - Key-path to the join field in the left collection
63    /// * `right_key` - Key-path to the join field in the right collection
64    /// * `mapper` - Function to transform matching pairs into the result type
65    ///
66    /// # Example
67    ///
68    /// ```ignore
69    /// let results = JoinQuery::new(&users, &orders)
70    ///     .inner_join(
71    ///         User::id(),
72    ///         Order::user_id(),
73    ///         |user, order| UserOrder {
74    ///             user_name: user.name.clone(),
75    ///             order_total: order.total,
76    ///         }
77    ///     );
78    /// ```
79    pub fn inner_join<K, O, F>(&self, left_key: KeyPaths<L, K>, right_key: KeyPaths<R, K>, mapper: F) -> Vec<O>
80    where
81        K: Eq + std::hash::Hash + Clone + 'static,
82        F: Fn(&L, &R) -> O,
83    {
84        // Build index for right side for O(n) lookup
85        let mut right_index: HashMap<K, Vec<&R>> = HashMap::new();
86        for item in self.right.iter() {
87            if let Some(key) = right_key.get(item).cloned() {
88                right_index.entry(key).or_insert_with(Vec::new).push(item);
89            }
90        }
91
92        // Join left with indexed right
93        let mut results = Vec::new();
94        for left_item in self.left.iter() {
95            if let Some(key) = left_key.get(left_item).cloned() {
96                if let Some(right_items) = right_index.get(&key) {
97                    for right_item in right_items {
98                        results.push(mapper(left_item, right_item));
99                    }
100                }
101            }
102        }
103
104        results
105    }
106
107    /// Performs a left join between two collections.
108    ///
109    /// Returns all items from the left collection with optional matching items
110    /// from the right collection. If no match is found, the right item is `None`.
111    ///
112    /// # Arguments
113    ///
114    /// * `left_key` - Key-path to the join field in the left collection
115    /// * `right_key` - Key-path to the join field in the right collection
116    /// * `mapper` - Function to transform pairs into the result type (right item may be None)
117    ///
118    /// # Example
119    ///
120    /// ```ignore
121    /// let results = JoinQuery::new(&users, &orders)
122    ///     .left_join(
123    ///         User::id(),
124    ///         Order::user_id(),
125    ///         |user, order| match order {
126    ///             Some(o) => format!("{} has order {}", user.name, o.id),
127    ///             None => format!("{} has no orders", user.name),
128    ///         }
129    ///     );
130    /// ```
131    pub fn left_join<K, O, F>(&self, left_key: KeyPaths<L, K>, right_key: KeyPaths<R, K>, mapper: F) -> Vec<O>
132    where
133        K: Eq + std::hash::Hash + Clone + 'static,
134        F: Fn(&L, Option<&R>) -> O,
135    {
136        // Build index for right side
137        let mut right_index: HashMap<K, Vec<&R>> = HashMap::new();
138        for item in self.right.iter() {
139            if let Some(key) = right_key.get(item).cloned() {
140                right_index.entry(key).or_insert_with(Vec::new).push(item);
141            }
142        }
143
144        // Join left with indexed right
145        let mut results = Vec::new();
146        for left_item in self.left.iter() {
147            if let Some(key) = left_key.get(left_item).cloned() {
148                if let Some(right_items) = right_index.get(&key) {
149                    for right_item in right_items {
150                        results.push(mapper(left_item, Some(right_item)));
151                    }
152                } else {
153                    results.push(mapper(left_item, None));
154                }
155            } else {
156                results.push(mapper(left_item, None));
157            }
158        }
159
160        results
161    }
162
163    /// Performs an inner join with an additional filter predicate.
164    ///
165    /// Like `inner_join`, but only includes pairs that satisfy both the join
166    /// condition and the additional predicate.
167    ///
168    /// # Arguments
169    ///
170    /// * `left_key` - Key-path to the join field in the left collection
171    /// * `right_key` - Key-path to the join field in the right collection
172    /// * `predicate` - Additional condition that must be true for pairs to be included
173    /// * `mapper` - Function to transform matching pairs into the result type
174    ///
175    /// # Example
176    ///
177    /// ```ignore
178    /// // Join orders with products, but only high-value orders
179    /// let results = JoinQuery::new(&orders, &products)
180    ///     .inner_join_where(
181    ///         Order::product_id(),
182    ///         Product::id(),
183    ///         |order, _product| order.total > 100.0,
184    ///         |order, product| (product.name.clone(), order.total)
185    ///     );
186    /// ```
187    pub fn inner_join_where<K, O, F, P>(
188        &self,
189        left_key: KeyPaths<L, K>,
190        right_key: KeyPaths<R, K>,
191        predicate: P,
192        mapper: F,
193    ) -> Vec<O>
194    where
195        K: Eq + std::hash::Hash + Clone + 'static,
196        F: Fn(&L, &R) -> O,
197        P: Fn(&L, &R) -> bool,
198    {
199        // Build index for right side
200        let mut right_index: HashMap<K, Vec<&R>> = HashMap::new();
201        for item in self.right.iter() {
202            if let Some(key) = right_key.get(item).cloned() {
203                right_index.entry(key).or_insert_with(Vec::new).push(item);
204            }
205        }
206
207        // Join left with indexed right, applying predicate
208        let mut results = Vec::new();
209        for left_item in self.left.iter() {
210            if let Some(key) = left_key.get(left_item).cloned() {
211                if let Some(right_items) = right_index.get(&key) {
212                    for right_item in right_items {
213                        if predicate(left_item, right_item) {
214                            results.push(mapper(left_item, right_item));
215                        }
216                    }
217                }
218            }
219        }
220
221        results
222    }
223
224    /// Performs a right join between two collections.
225    ///
226    /// Returns all items from the right collection with optional matching items
227    /// from the left collection. If no match is found, the left item is `None`.
228    ///
229    /// # Arguments
230    ///
231    /// * `left_key` - Key-path to the join field in the left collection
232    /// * `right_key` - Key-path to the join field in the right collection
233    /// * `mapper` - Function to transform pairs into the result type (left item may be None)
234    ///
235    /// # Example
236    ///
237    /// ```ignore
238    /// let results = JoinQuery::new(&users, &orders)
239    ///     .right_join(
240    ///         User::id(),
241    ///         Order::user_id(),
242    ///         |user, order| match user {
243    ///             Some(u) => format!("Order {} by {}", order.id, u.name),
244    ///             None => format!("Order {} by unknown user", order.id),
245    ///         }
246    ///     );
247    /// ```
248    pub fn right_join<K, O, F>(&self, left_key: KeyPaths<L, K>, right_key: KeyPaths<R, K>, mapper: F) -> Vec<O>
249    where
250        K: Eq + std::hash::Hash + Clone + 'static,
251        F: Fn(Option<&L>, &R) -> O,
252    {
253        // Build index for left side
254        let mut left_index: HashMap<K, Vec<&L>> = HashMap::new();
255        for item in self.left.iter() {
256            if let Some(key) = left_key.get(item).cloned() {
257                left_index.entry(key).or_insert_with(Vec::new).push(item);
258            }
259        }
260
261        // Join right with indexed left
262        let mut results = Vec::new();
263        for right_item in self.right.iter() {
264            if let Some(key) = right_key.get(right_item).cloned() {
265                if let Some(left_items) = left_index.get(&key) {
266                    for left_item in left_items {
267                        results.push(mapper(Some(left_item), right_item));
268                    }
269                } else {
270                    results.push(mapper(None, right_item));
271                }
272            } else {
273                results.push(mapper(None, right_item));
274            }
275        }
276
277        results
278    }
279
280    /// Performs a cross join (Cartesian product) between two collections.
281    ///
282    /// Returns all possible pairs of items from both collections.
283    /// **Warning**: This can produce very large result sets (size = left.len() * right.len()).
284    ///
285    /// # Arguments
286    ///
287    /// * `mapper` - Function to transform pairs into the result type
288    ///
289    /// # Example
290    ///
291    /// ```ignore
292    /// let all_combinations = JoinQuery::new(&colors, &sizes)
293    ///     .cross_join(|color, size| ProductVariant {
294    ///         color: color.clone(),
295    ///         size: size.clone(),
296    ///     });
297    /// ```
298    pub fn cross_join<O, F>(&self, mapper: F) -> Vec<O>
299    where
300        F: Fn(&L, &R) -> O,
301    {
302        let mut results = Vec::new();
303        for left_item in self.left.iter() {
304            for right_item in self.right.iter() {
305                results.push(mapper(left_item, right_item));
306            }
307        }
308        results
309    }
310}
311
312#[cfg(feature = "parallel")]
313mod parallel_join {
314    use super::JoinQuery;
315    use key_paths_core::KeyPaths;
316    use rayon::prelude::*;
317    use std::collections::HashMap;
318    use std::sync::Arc;
319
320    /// Extension trait for parallel join operations.
321    ///
322    /// Provides parallel versions of join operations using rayon for better
323    /// performance on large datasets.
324    ///
325    /// # Example
326    ///
327    /// ```ignore
328    /// use rust_queries_core::join::ParallelJoinExt;
329    ///
330    /// let results: Vec<_> = JoinQuery::new(&users, &orders)
331    ///     .inner_join_parallel(
332    ///         User::id(),
333    ///         Order::user_id(),
334    ///         |user, order| (user.name.clone(), order.total)
335    ///     );
336    /// ```
337    pub trait ParallelJoinExt<'a, L: 'static + Send + Sync, R: 'static + Send + Sync> {
338        /// Performs a parallel inner join between two collections.
339        ///
340        /// Uses rayon to process the join in parallel across multiple CPU cores.
341        /// Best performance on large datasets (10,000+ items).
342        ///
343        /// # Arguments
344        ///
345        /// * `left_key` - Key-path to the join field in the left collection
346        /// * `right_key` - Key-path to the join field in the right collection
347        /// * `mapper` - Function to transform matching pairs into the result type (must be Send + Sync)
348        ///
349        /// # Example
350        ///
351        /// ```ignore
352        /// let results: Vec<_> = JoinQuery::new(&users, &orders)
353        ///     .inner_join_parallel(
354        ///         User::id(),
355        ///         Order::user_id(),
356        ///         |user, order| (user.name.clone(), order.total)
357        ///     );
358        /// ```
359        fn inner_join_parallel<K, O, F>(
360            &self,
361            left_key: KeyPaths<L, K>,
362            right_key: KeyPaths<R, K>,
363            mapper: F,
364        ) -> Vec<O>
365        where
366            K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
367            F: Fn(&L, &R) -> O + Send + Sync;
368
369        /// Performs a parallel left join between two collections.
370        ///
371        /// Uses rayon to process the join in parallel across multiple CPU cores.
372        ///
373        /// # Arguments
374        ///
375        /// * `left_key` - Key-path to the join field in the left collection
376        /// * `right_key` - Key-path to the join field in the right collection
377        /// * `mapper` - Function to transform pairs into the result type (right item may be None, must be Send + Sync)
378        ///
379        /// # Example
380        ///
381        /// ```ignore
382        /// let results: Vec<_> = JoinQuery::new(&users, &orders)
383        ///     .left_join_parallel(
384        ///         User::id(),
385        ///         Order::user_id(),
386        ///         |user, order| match order {
387        ///             Some(o) => format!("{} has order {}", user.name, o.id),
388        ///             None => format!("{} has no orders", user.name),
389        ///         }
390        ///     );
391        /// ```
392        fn left_join_parallel<K, O, F>(
393            &self,
394            left_key: KeyPaths<L, K>,
395            right_key: KeyPaths<R, K>,
396            mapper: F,
397        ) -> Vec<O>
398        where
399            K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
400            F: Fn(&L, Option<&R>) -> O + Send + Sync;
401
402        /// Performs a parallel inner join with an additional filter predicate.
403        ///
404        /// Like `inner_join_parallel`, but only includes pairs that satisfy both the join
405        /// condition and the additional predicate.
406        ///
407        /// # Arguments
408        ///
409        /// * `left_key` - Key-path to the join field in the left collection
410        /// * `right_key` - Key-path to the join field in the right collection
411        /// * `predicate` - Additional condition that must be true for pairs to be included (must be Send + Sync)
412        /// * `mapper` - Function to transform matching pairs into the result type (must be Send + Sync)
413        ///
414        /// # Example
415        ///
416        /// ```ignore
417        /// let results: Vec<_> = JoinQuery::new(&orders, &products)
418        ///     .inner_join_where_parallel(
419        ///         Order::product_id(),
420        ///         Product::id(),
421        ///         |order, _product| order.total > 100.0,
422        ///         |order, product| (product.name.clone(), order.total)
423        ///     );
424        /// ```
425        fn inner_join_where_parallel<K, O, F, P>(
426            &self,
427            left_key: KeyPaths<L, K>,
428            right_key: KeyPaths<R, K>,
429            predicate: P,
430            mapper: F,
431        ) -> Vec<O>
432        where
433            K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
434            F: Fn(&L, &R) -> O + Send + Sync,
435            P: Fn(&L, &R) -> bool + Send + Sync;
436    }
437
438    #[cfg(feature = "parallel")]
439    impl<'a, L: 'static + Send + Sync, R: 'static + Send + Sync> ParallelJoinExt<'a, L, R> for JoinQuery<'a, L, R> {
440        fn inner_join_parallel<K, O, F>(
441            &self,
442            left_key: KeyPaths<L, K>,
443            right_key: KeyPaths<R, K>,
444            mapper: F,
445        ) -> Vec<O>
446        where
447            K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
448            F: Fn(&L, &R) -> O + Send + Sync,
449        {
450            // Extract keys first to avoid thread safety issues with keypath (which uses Rc internally)
451            let left_with_keys: Vec<(usize, K)> = self.left
452                .iter()
453                .enumerate()
454                .filter_map(|(idx, item)| {
455                    left_key.get(item).cloned().map(|key| (idx, key))
456                })
457                .collect();
458
459            let right_with_keys: Vec<(usize, K)> = self.right
460                .iter()
461                .enumerate()
462                .filter_map(|(idx, item)| {
463                    right_key.get(item).cloned().map(|key| (idx, key))
464                })
465                .collect();
466
467            // Build index for right side
468            let mut right_index: HashMap<K, Vec<usize>> = HashMap::new();
469            for (idx, key) in right_with_keys {
470                right_index.entry(key).or_insert_with(Vec::new).push(idx);
471            }
472
473            // Join left with indexed right in parallel
474            // Share the index and slices across threads using Arc
475            let right_index_arc = Arc::new(right_index);
476            let left_slice = self.left;
477            let right_slice = self.right;
478            
479            // Process in parallel and collect results
480            let results: Vec<Vec<O>> = left_with_keys
481                .into_par_iter()
482                .map(|(left_idx, key)| {
483                    let left_item = &left_slice[left_idx];
484                    let index = right_index_arc.clone();
485                    index.get(&key)
486                        .map(|right_indices| {
487                            right_indices.iter().map(|right_idx| {
488                                mapper(left_item, &right_slice[*right_idx])
489                            }).collect::<Vec<_>>()
490                        })
491                        .unwrap_or_default()
492                })
493                .collect();
494            
495            results.into_iter().flatten().collect()
496        }
497
498        fn left_join_parallel<K, O, F>(
499            &self,
500            left_key: KeyPaths<L, K>,
501            right_key: KeyPaths<R, K>,
502            mapper: F,
503        ) -> Vec<O>
504        where
505            K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
506            F: Fn(&L, Option<&R>) -> O + Send + Sync,
507        {
508            // Extract keys first to avoid thread safety issues with keypath (which uses Rc internally)
509            let left_with_keys: Vec<(usize, Option<K>)> = self.left
510                .iter()
511                .enumerate()
512                .map(|(idx, item)| (idx, left_key.get(item).cloned()))
513                .collect();
514
515            let right_with_keys: Vec<(usize, K)> = self.right
516                .iter()
517                .enumerate()
518                .filter_map(|(idx, item)| {
519                    right_key.get(item).cloned().map(|key| (idx, key))
520                })
521                .collect();
522
523            // Build index for right side
524            let mut right_index: HashMap<K, Vec<usize>> = HashMap::new();
525            for (idx, key) in right_with_keys {
526                right_index.entry(key).or_insert_with(Vec::new).push(idx);
527            }
528
529            // Join left with indexed right in parallel
530            // Share the index and slices across threads using Arc
531            let right_index_arc = Arc::new(right_index);
532            let left_slice = self.left;
533            let right_slice = self.right;
534            
535            // Process in parallel and collect results
536            let results: Vec<Vec<O>> = left_with_keys
537                .into_par_iter()
538                .map(|(left_idx, key_opt)| {
539                    let left_item = &left_slice[left_idx];
540                    let index = right_index_arc.clone();
541                    if let Some(key) = key_opt {
542                        if let Some(right_indices) = index.get(&key) {
543                            // Has matches - yield all matches
544                            right_indices.iter().map(|right_idx| {
545                                mapper(left_item, Some(&right_slice[*right_idx]))
546                            }).collect::<Vec<_>>()
547                        } else {
548                            // No matches - yield None
549                            vec![mapper(left_item, None)]
550                        }
551                    } else {
552                        // No key - yield None
553                        vec![mapper(left_item, None)]
554                    }
555                })
556                .collect();
557            
558            results.into_iter().flatten().collect()
559        }
560
561        fn inner_join_where_parallel<K, O, F, P>(
562            &self,
563            left_key: KeyPaths<L, K>,
564            right_key: KeyPaths<R, K>,
565            predicate: P,
566            mapper: F,
567        ) -> Vec<O>
568        where
569            K: Eq + std::hash::Hash + Clone + Send + Sync + 'static,
570            F: Fn(&L, &R) -> O + Send + Sync,
571            P: Fn(&L, &R) -> bool + Send + Sync,
572        {
573            // Extract keys first to avoid thread safety issues with keypath (which uses Rc internally)
574            let left_with_keys: Vec<(usize, K)> = self.left
575                .iter()
576                .enumerate()
577                .filter_map(|(idx, item)| {
578                    left_key.get(item).cloned().map(|key| (idx, key))
579                })
580                .collect();
581
582            let right_with_keys: Vec<(usize, K)> = self.right
583                .iter()
584                .enumerate()
585                .filter_map(|(idx, item)| {
586                    right_key.get(item).cloned().map(|key| (idx, key))
587                })
588                .collect();
589
590            // Build index for right side
591            let mut right_index: HashMap<K, Vec<usize>> = HashMap::new();
592            for (idx, key) in right_with_keys {
593                right_index.entry(key).or_insert_with(Vec::new).push(idx);
594            }
595
596            // Join left with indexed right in parallel, applying predicate
597            // Share the index and slices across threads using Arc
598            let right_index_arc = Arc::new(right_index);
599            let left_slice = self.left;
600            let right_slice = self.right;
601            
602            // Process in parallel and collect results
603            let results: Vec<Vec<O>> = left_with_keys
604                .into_par_iter()
605                .map(|(left_idx, key)| {
606                    let left_item = &left_slice[left_idx];
607                    let index = right_index_arc.clone();
608                    index.get(&key)
609                        .map(|right_indices| {
610                            right_indices.iter()
611                                .filter_map(|right_idx| {
612                                    let right_item = &right_slice[*right_idx];
613                                    if predicate(left_item, right_item) {
614                                        Some(mapper(left_item, right_item))
615                                    } else {
616                                        None
617                                    }
618                                })
619                                .collect::<Vec<_>>()
620                        })
621                        .unwrap_or_default()
622                })
623                .collect();
624            
625            results.into_iter().flatten().collect()
626        }
627    }
628}
629
630#[cfg(feature = "parallel")]
631pub use parallel_join::ParallelJoinExt;
632