standing_relations/core/op/
join.rs

1use std::{collections::HashMap, hash::Hash};
2
3use crate::core::{
4    mborrowed::OrOwnedDefault, relation::RelationInner, CountMap, Op, Op_, Relation,
5};
6
7pub struct Join<K, V1, V2, C1: Op<D = (K, V1)>, C2: Op<D = (K, V2)>> {
8    left: RelationInner<C1>,
9    left_map: HashMap<K, HashMap<V1, isize>>,
10    right: RelationInner<C2>,
11    right_map: HashMap<K, HashMap<V2, isize>>,
12}
13
14impl<
15        K: Eq + Hash + Clone,
16        V1: Eq + Hash + Clone,
17        V2: Eq + Hash + Clone,
18        C1: Op<D = (K, V1)>,
19        C2: Op<D = (K, V2)>,
20    > Op_ for Join<K, V1, V2, C1, C2>
21{
22    type T = ((K, V1, V2), isize);
23
24    fn foreach<'a>(&'a mut self, mut continuation: impl FnMut(Self::T) + 'a) {
25        let Join {
26            left,
27            left_map,
28            right,
29            right_map,
30        } = self;
31        left.foreach(|((k, x), x_count)| {
32            for (y, y_count) in &*right_map.get(&k).or_owned_default() {
33                continuation(((k.clone(), x.clone(), y.clone()), x_count * y_count));
34            }
35            left_map.add((k, x), x_count);
36        });
37        right.foreach(|((k, y), y_count)| {
38            for (x, x_count) in &*left_map.get(&k).or_owned_default() {
39                continuation(((k.clone(), x.clone(), y.clone()), x_count * y_count));
40            }
41            right_map.add((k, y), y_count);
42        });
43    }
44
45    fn get_type_name() -> &'static str {
46        "join"
47    }
48}
49
50pub struct AntiJoin<K, V, C1: Op<D = (K, V)>, C2: Op<D = K>> {
51    left: RelationInner<C1>,
52    left_map: HashMap<K, HashMap<V, isize>>,
53    right: RelationInner<C2>,
54    right_map: HashMap<K, isize>,
55}
56
57impl<K: Eq + Hash + Clone, V: Eq + Hash + Clone, C1: Op<D = (K, V)>, C2: Op<D = K>> Op_
58    for AntiJoin<K, V, C1, C2>
59{
60    type T = ((K, V), isize);
61
62    fn foreach<'a>(&'a mut self, mut continuation: impl FnMut(Self::T) + 'a) {
63        let AntiJoin {
64            left,
65            left_map,
66            right,
67            right_map,
68        } = self;
69        left.foreach(|((k, x), x_count)| {
70            if !right_map.contains_key(&k) {
71                continuation(((k.clone(), x.clone()), x_count));
72            }
73            left_map.add((k, x), x_count);
74        });
75        right.foreach(|(k, y_count)| {
76            if y_count != 0 {
77                let old_count = right_map.get(&k).map(Clone::clone).unwrap_or(0);
78                if old_count == -y_count {
79                    for (x, &x_count) in &*left_map.get(&k).or_owned_default() {
80                        continuation(((k.clone(), x.clone()), x_count));
81                    }
82                } else if old_count == 0 {
83                    for (x, &x_count) in &*left_map.get(&k).or_owned_default() {
84                        continuation(((k.clone(), x.clone()), -x_count));
85                    }
86                }
87                right_map.add(k, y_count);
88            }
89        });
90    }
91
92    fn get_type_name() -> &'static str {
93        "antijoin"
94    }
95}
96
97impl<K: Clone + Eq + Hash, V1: Clone + Eq + Hash, C1: Op<D = (K, V1)>> Relation<C1> {
98    pub fn join<V2: Clone + Eq + Hash, C2: Op<D = (K, V2)>>(
99        self,
100        other: Relation<C2>,
101    ) -> Relation<Join<K, V1, V2, C1, C2>> {
102        assert_eq!(
103            self.context_tracker, other.context_tracker,
104            "Context mismatch"
105        );
106        self.context_tracker.add_relation(
107            self.dirty.or(other.dirty),
108            Join {
109                left: self.inner,
110                left_map: HashMap::new(),
111                right: other.inner,
112                right_map: HashMap::new(),
113            },
114            vec![self.tracking_index, other.tracking_index],
115        )
116    }
117
118    /// Retains only those keys which have count 0 in the argument relation.
119    pub fn antijoin<C2: Op<D = K>>(self, other: Relation<C2>) -> Relation<AntiJoin<K, V1, C1, C2>> {
120        assert_eq!(
121            self.context_tracker, other.context_tracker,
122            "Context mismatch"
123        );
124        self.context_tracker.add_relation(
125            self.dirty.or(other.dirty),
126            AntiJoin {
127                left: self.inner,
128                left_map: HashMap::new(),
129                right: other.inner,
130                right_map: HashMap::new(),
131            },
132            vec![self.tracking_index, other.tracking_index],
133        )
134    }
135}