1use crate::locks::LockValue;
23use key_paths_core::KeyPaths;
24
25pub struct LockJoinQuery<'a, L, R, LL, LR>
29where
30 LL: LockValue<L> + 'a,
31 LR: LockValue<R> + 'a,
32{
33 left: Vec<&'a LL>,
34 right: Vec<&'a LR>,
35 _phantom: std::marker::PhantomData<(L, R)>,
36}
37
38impl<'a, L: 'static, R: 'static, LL, LR> LockJoinQuery<'a, L, R, LL, LR>
39where
40 LL: LockValue<L> + 'a,
41 LR: LockValue<R> + 'a,
42{
43 pub fn new(left: Vec<&'a LL>, right: Vec<&'a LR>) -> Self {
45 Self {
46 left,
47 right,
48 _phantom: std::marker::PhantomData,
49 }
50 }
51
52 pub fn inner_join<LK, RK, M, Out>(&self, left_key: KeyPaths<L, LK>, right_key: KeyPaths<R, RK>, mapper: M) -> Vec<Out>
67 where
68 LK: Eq + Clone + 'static,
69 RK: Eq + Clone + 'static,
70 LK: PartialEq<RK>,
71 M: Fn(&L, &R) -> Out,
72 L: Clone,
73 R: Clone,
74 {
75 let mut results = Vec::new();
76
77 for left_lock in &self.left {
78 let left_data = left_lock.with_value(|l| (left_key.get(l).cloned(), l.clone()));
79 if let Some((Some(left_k), left_item)) = left_data {
80 for right_lock in &self.right {
81 let right_data = right_lock.with_value(|r| (right_key.get(r).cloned(), r.clone()));
82 if let Some((Some(right_k), right_item)) = right_data {
83 if left_k == right_k {
84 results.push(mapper(&left_item, &right_item));
85 }
86 }
87 }
88 }
89 }
90
91 results
92 }
93
94 pub fn left_join<LK, RK, M, Out>(&self, left_key: KeyPaths<L, LK>, right_key: KeyPaths<R, RK>, mapper: M) -> Vec<Out>
112 where
113 LK: Eq + Clone + 'static,
114 RK: Eq + Clone + 'static,
115 LK: PartialEq<RK>,
116 M: Fn(&L, Option<&R>) -> Out,
117 L: Clone,
118 R: Clone,
119 {
120 let mut results = Vec::new();
121
122 for left_lock in &self.left {
123 let left_data = left_lock.with_value(|l| (left_key.get(l).cloned(), l.clone()));
124 if let Some((Some(left_key_val), left_item)) = left_data {
125 let mut found_match = false;
126
127 for right_lock in &self.right {
128 let right_data = right_lock.with_value(|r| (right_key.get(r).cloned(), r.clone()));
129 if let Some((Some(right_key_val), right_item)) = right_data {
130 if left_key_val == right_key_val {
131 results.push(mapper(&left_item, Some(&right_item)));
132 found_match = true;
133 }
134 }
135 }
136
137 if !found_match {
138 results.push(mapper(&left_item, None));
139 }
140 }
141 }
142
143 results
144 }
145
146 pub fn right_join<LK, RK, M, Out>(&self, left_key: KeyPaths<L, LK>, right_key: KeyPaths<R, RK>, mapper: M) -> Vec<Out>
150 where
151 LK: Eq + Clone + 'static,
152 RK: Eq + Clone + 'static,
153 LK: PartialEq<RK>,
154 M: Fn(Option<&L>, &R) -> Out,
155 L: Clone,
156 R: Clone,
157 {
158 let mut results = Vec::new();
159
160 for right_lock in &self.right {
161 let right_data = right_lock.with_value(|r| (right_key.get(r).cloned(), r.clone()));
162 if let Some((Some(right_key_val), right_item)) = right_data {
163 let mut found_match = false;
164
165 for left_lock in &self.left {
166 let left_data = left_lock.with_value(|l| (left_key.get(l).cloned(), l.clone()));
167 if let Some((Some(left_key_val), left_item)) = left_data {
168 if left_key_val == right_key_val {
169 results.push(mapper(Some(&left_item), &right_item));
170 found_match = true;
171 }
172 }
173 }
174
175 if !found_match {
176 results.push(mapper(None, &right_item));
177 }
178 }
179 }
180
181 results
182 }
183
184 pub fn cross_join<M, Out>(&self, mapper: M) -> Vec<Out>
188 where
189 M: Fn(&L, &R) -> Out,
190 L: Clone,
191 R: Clone,
192 {
193 let mut results = Vec::new();
194
195 for left_lock in &self.left {
196 if let Some(left_item) = left_lock.with_value(|l| l.clone()) {
197 for right_lock in &self.right {
198 if let Some(right_item) = right_lock.with_value(|r| r.clone()) {
199 results.push(mapper(&left_item, &right_item));
200 }
201 }
202 }
203 }
204
205 results
206 }
207}
208
209pub trait LockJoinable<T, L>
211where
212 L: LockValue<T>,
213{
214 fn lock_join<'a, R, LR>(&'a self, right: &'a impl LockJoinableCollection<R, LR>) -> LockJoinQuery<'a, T, R, L, LR>
216 where
217 LR: LockValue<R> + 'a;
218}
219
220pub trait LockJoinableCollection<T, L>
222where
223 L: LockValue<T>,
224{
225 fn get_locks(&self) -> Vec<&L>;
227}
228
229use std::collections::HashMap;
231use std::sync::{Arc, RwLock, Mutex};
232
233impl<K, V> LockJoinableCollection<V, Arc<RwLock<V>>> for HashMap<K, Arc<RwLock<V>>>
234where
235 K: Eq + std::hash::Hash,
236{
237 fn get_locks(&self) -> Vec<&Arc<RwLock<V>>> {
238 self.values().collect()
239 }
240}
241
242impl<K, V> LockJoinableCollection<V, Arc<Mutex<V>>> for HashMap<K, Arc<Mutex<V>>>
243where
244 K: Eq + std::hash::Hash,
245{
246 fn get_locks(&self) -> Vec<&Arc<Mutex<V>>> {
247 self.values().collect()
248 }
249}
250
251impl<T> LockJoinableCollection<T, Arc<RwLock<T>>> for Vec<Arc<RwLock<T>>> {
253 fn get_locks(&self) -> Vec<&Arc<RwLock<T>>> {
254 self.iter().collect()
255 }
256}
257
258impl<T> LockJoinableCollection<T, Arc<Mutex<T>>> for Vec<Arc<Mutex<T>>> {
259 fn get_locks(&self) -> Vec<&Arc<Mutex<T>>> {
260 self.iter().collect()
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use std::sync::{Arc, RwLock};
268 use std::collections::HashMap;
269 use key_paths_derive::Keypath;
270
271 #[derive(Clone, Keypath)]
272 struct User {
273 id: u32,
274 name: String,
275 }
276
277 #[derive(Clone, Keypath)]
278 struct Order {
279 id: u32,
280 user_id: u32,
281 total: f64,
282 }
283
284 fn create_test_data() -> (HashMap<String, Arc<RwLock<User>>>, HashMap<String, Arc<RwLock<Order>>>) {
285 let mut users = HashMap::new();
286 users.insert("u1".to_string(), Arc::new(RwLock::new(User { id: 1, name: "Alice".to_string() })));
287 users.insert("u2".to_string(), Arc::new(RwLock::new(User { id: 2, name: "Bob".to_string() })));
288
289 let mut orders = HashMap::new();
290 orders.insert("o1".to_string(), Arc::new(RwLock::new(Order { id: 101, user_id: 1, total: 99.99 })));
291 orders.insert("o2".to_string(), Arc::new(RwLock::new(Order { id: 102, user_id: 1, total: 149.99 })));
292 orders.insert("o3".to_string(), Arc::new(RwLock::new(Order { id: 103, user_id: 3, total: 199.99 })));
293
294 (users, orders)
295 }
296
297 #[test]
298 fn test_inner_join() {
299 let (users, orders) = create_test_data();
300
301 let user_locks: Vec<_> = users.values().collect();
302 let order_locks: Vec<_> = orders.values().collect();
303
304 let results = LockJoinQuery::new(user_locks, order_locks)
305 .inner_join(
306 User::id(),
307 Order::user_id(),
308 |user, order| (user.name.clone(), order.total)
309 );
310
311 assert_eq!(results.len(), 2); }
313
314 #[test]
315 fn test_left_join() {
316 let (users, orders) = create_test_data();
317
318 let user_locks: Vec<_> = users.values().collect();
319 let order_locks: Vec<_> = orders.values().collect();
320
321 let results = LockJoinQuery::new(user_locks, order_locks)
322 .left_join(
323 User::id(),
324 Order::user_id(),
325 |user, order_opt| match order_opt {
326 Some(_) => format!("{} has order", user.name),
327 None => format!("{} no orders", user.name),
328 }
329 );
330
331 assert_eq!(results.len(), 3); }
333}
334