1use crate::locks::LockValue;
23use key_paths_core::KeyPaths;
24
25pub struct LockJoinQuery<'a, L, R, LL, LR>
29where
30 LL: LockValue<L> + 'a,
31 LR: LockValue<R> + 'a,
32{
33 left: Vec<&'a LL>,
34 right: Vec<&'a LR>,
35 _phantom: std::marker::PhantomData<(L, R)>,
36}
37
38impl<'a, L: 'static, R: 'static, LL, LR> LockJoinQuery<'a, L, R, LL, LR>
39where
40 LL: LockValue<L> + 'a,
41 LR: LockValue<R> + 'a,
42{
43 pub fn new(left: Vec<&'a LL>, right: Vec<&'a LR>) -> Self {
45 Self {
46 left,
47 right,
48 _phantom: std::marker::PhantomData,
49 }
50 }
51
52 pub fn inner_join<LK, RK, M, Out>(&self, left_key: KeyPaths<L, LK>, right_key: KeyPaths<R, RK>, mapper: M) -> Vec<Out>
67 where
68 LK: Eq + Clone + 'static,
69 RK: Eq + Clone + 'static,
70 LK: PartialEq<RK>,
71 M: Fn(&L, &R) -> Out,
72 L: Clone,
73 R: Clone,
74 {
75 let mut results = Vec::new();
76
77 for left_lock in &self.left {
78 let left_data = left_lock.with_value(|l| (left_key.get(l).cloned(), l.clone()));
79 if let Some((Some(left_k), left_item)) = left_data {
80 for right_lock in &self.right {
81 let right_data = right_lock.with_value(|r| (right_key.get(r).cloned(), r.clone()));
82 if let Some((Some(right_k), right_item)) = right_data {
83 if left_k == right_k {
84 results.push(mapper(&left_item, &right_item));
85 }
86 }
87 }
88 }
89 }
90
91 results
92 }
93
94 pub fn left_join<LK, RK, M, Out>(&self, left_key: KeyPaths<L, LK>, right_key: KeyPaths<R, RK>, mapper: M) -> Vec<Out>
112 where
113 LK: Eq + Clone + 'static,
114 RK: Eq + Clone + 'static,
115 LK: PartialEq<RK>,
116 M: Fn(&L, Option<&R>) -> Out,
117 L: Clone,
118 R: Clone,
119 {
120 let mut results = Vec::new();
121
122 for left_lock in &self.left {
123 let left_data = left_lock.with_value(|l| (left_key.get(l).cloned(), l.clone()));
124 if let Some((Some(left_key_val), left_item)) = left_data {
125 let mut found_match = false;
126
127 for right_lock in &self.right {
128 let right_data = right_lock.with_value(|r| (right_key.get(r).cloned(), r.clone()));
129 if let Some((Some(right_key_val), right_item)) = right_data {
130 if left_key_val == right_key_val {
131 results.push(mapper(&left_item, Some(&right_item)));
132 found_match = true;
133 }
134 }
135 }
136
137 if !found_match {
138 results.push(mapper(&left_item, None));
139 }
140 }
141 }
142
143 results
144 }
145
146 pub fn right_join<LK, RK, M, Out>(&self, left_key: KeyPaths<L, LK>, right_key: KeyPaths<R, RK>, mapper: M) -> Vec<Out>
150 where
151 LK: Eq + Clone + 'static,
152 RK: Eq + Clone + 'static,
153 LK: PartialEq<RK>,
154 M: Fn(Option<&L>, &R) -> Out,
155 L: Clone,
156 R: Clone,
157 {
158 let mut results = Vec::new();
159
160 for right_lock in &self.right {
161 let right_data = right_lock.with_value(|r| (right_key.get(r).cloned(), r.clone()));
162 if let Some((Some(right_key_val), right_item)) = right_data {
163 let mut found_match = false;
164
165 for left_lock in &self.left {
166 let left_data = left_lock.with_value(|l| (left_key.get(l).cloned(), l.clone()));
167 if let Some((Some(left_key_val), left_item)) = left_data {
168 if left_key_val == right_key_val {
169 results.push(mapper(Some(&left_item), &right_item));
170 found_match = true;
171 }
172 }
173 }
174
175 if !found_match {
176 results.push(mapper(None, &right_item));
177 }
178 }
179 }
180
181 results
182 }
183
184 pub fn cross_join<M, Out>(&self, mapper: M) -> Vec<Out>
188 where
189 M: Fn(&L, &R) -> Out,
190 L: Clone,
191 R: Clone,
192 {
193 let mut results = Vec::new();
194
195 for left_lock in &self.left {
196 if let Some(left_item) = left_lock.with_value(|l| l.clone()) {
197 for right_lock in &self.right {
198 if let Some(right_item) = right_lock.with_value(|r| r.clone()) {
199 results.push(mapper(&left_item, &right_item));
200 }
201 }
202 }
203 }
204
205 results
206 }
207
208}
209
210pub trait LockJoinable<T, L>
212where
213 L: LockValue<T>,
214{
215 fn lock_join<'a, R, LR>(&'a self, right: &'a impl LockJoinableCollection<R, LR>) -> LockJoinQuery<'a, T, R, L, LR>
217 where
218 LR: LockValue<R> + 'a;
219}
220
221pub trait LockJoinableCollection<T, L>
223where
224 L: LockValue<T>,
225{
226 fn get_locks(&self) -> Vec<&L>;
228}
229
230use std::collections::HashMap;
232use std::sync::{Arc, RwLock, Mutex};
233
234impl<K, V> LockJoinableCollection<V, Arc<RwLock<V>>> for HashMap<K, Arc<RwLock<V>>>
235where
236 K: Eq + std::hash::Hash,
237{
238 fn get_locks(&self) -> Vec<&Arc<RwLock<V>>> {
239 self.values().collect()
240 }
241}
242
243impl<K, V> LockJoinableCollection<V, Arc<Mutex<V>>> for HashMap<K, Arc<Mutex<V>>>
244where
245 K: Eq + std::hash::Hash,
246{
247 fn get_locks(&self) -> Vec<&Arc<Mutex<V>>> {
248 self.values().collect()
249 }
250}
251
252impl<T> LockJoinableCollection<T, Arc<RwLock<T>>> for Vec<Arc<RwLock<T>>> {
254 fn get_locks(&self) -> Vec<&Arc<RwLock<T>>> {
255 self.iter().collect()
256 }
257}
258
259impl<T> LockJoinableCollection<T, Arc<Mutex<T>>> for Vec<Arc<Mutex<T>>> {
260 fn get_locks(&self) -> Vec<&Arc<Mutex<T>>> {
261 self.iter().collect()
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use std::sync::{Arc, RwLock};
269 use std::collections::HashMap;
270 use key_paths_derive::Keypath;
271
272 #[derive(Clone, Keypath)]
273 struct User {
274 id: u32,
275 name: String,
276 }
277
278 #[derive(Clone, Keypath)]
279 struct Order {
280 id: u32,
281 user_id: u32,
282 total: f64,
283 }
284
285 fn create_test_data() -> (HashMap<String, Arc<RwLock<User>>>, HashMap<String, Arc<RwLock<Order>>>) {
286 let mut users = HashMap::new();
287 users.insert("u1".to_string(), Arc::new(RwLock::new(User { id: 1, name: "Alice".to_string() })));
288 users.insert("u2".to_string(), Arc::new(RwLock::new(User { id: 2, name: "Bob".to_string() })));
289
290 let mut orders = HashMap::new();
291 orders.insert("o1".to_string(), Arc::new(RwLock::new(Order { id: 101, user_id: 1, total: 99.99 })));
292 orders.insert("o2".to_string(), Arc::new(RwLock::new(Order { id: 102, user_id: 1, total: 149.99 })));
293 orders.insert("o3".to_string(), Arc::new(RwLock::new(Order { id: 103, user_id: 3, total: 199.99 })));
294
295 (users, orders)
296 }
297
298 #[test]
299 fn test_inner_join() {
300 let (users, orders) = create_test_data();
301
302 let user_locks: Vec<_> = users.values().collect();
303 let order_locks: Vec<_> = orders.values().collect();
304
305 let results = LockJoinQuery::new(user_locks, order_locks)
306 .inner_join(
307 User::id(),
308 Order::user_id(),
309 |user, order| (user.name.clone(), order.total)
310 );
311
312 assert_eq!(results.len(), 2); }
314
315 #[test]
316 fn test_left_join() {
317 let (users, orders) = create_test_data();
318
319 let user_locks: Vec<_> = users.values().collect();
320 let order_locks: Vec<_> = orders.values().collect();
321
322 let results = LockJoinQuery::new(user_locks, order_locks)
323 .left_join(
324 User::id(),
325 Order::user_id(),
326 |user, order_opt| match order_opt {
327 Some(_) => format!("{} has order", user.name),
328 None => format!("{} no orders", user.name),
329 }
330 );
331
332 assert_eq!(results.len(), 3); }
334}
335