standing_relations/core/op/
reduce.rs

1use std::{
2    cell::{Ref, RefCell},
3    collections::{HashMap, HashSet},
4    hash::Hash,
5};
6
7use crate::core::{
8    relation::RelationInner, ContextTracker, CountMap, CreationContext, ExecutionContext,
9    Observable, Op, Op_, Relation, Save, Saved,
10};
11
12use self::map::{InsertResult, OutputMap};
13
14mod map;
15
16pub struct Reduce<
17    K,
18    X,
19    C: Op<D = (K, X)>,
20    M: CountMap<X> + Observable,
21    Y,
22    OM: OutputMap<K, Y>,
23    F: Fn(&K, &M) -> Y,
24> {
25    inner: RelationInner<C>,
26    in_map: HashMap<K, M>,
27    out_map: OM,
28    f: F,
29}
30
31impl<
32        K: Clone + Eq + Hash,
33        X,
34        C: Op<D = (K, X)>,
35        M: CountMap<X> + Observable,
36        Y: Clone + Eq,
37        OM: OutputMap<K, Y>,
38        F: Fn(&K, &M) -> Y,
39    > Op_ for Reduce<K, X, C, M, Y, OM, F>
40{
41    type T = ((K, Y), isize);
42
43    fn foreach<'a>(&'a mut self, mut continuation: impl FnMut(Self::T) + 'a) {
44        let Reduce {
45            inner,
46            in_map,
47            out_map,
48            f,
49        } = self;
50        let mut changed_keys = HashSet::new();
51        inner.foreach(|((k, v), count)| {
52            in_map.add((k.clone(), v), count);
53            changed_keys.insert(k);
54        });
55        'keys: for k in changed_keys {
56            match in_map.get(&k) {
57                None => {
58                    if let Some(old_val) = out_map.remove(&k) {
59                        continuation(((k, old_val), -1))
60                    }
61                }
62                Some(m) => {
63                    let new_val = f(&k, m);
64                    match out_map.insert_if_different(k.clone(), new_val.clone()) {
65                        InsertResult::NoOldValue => (),
66                        InsertResult::OldValue(old_val) => continuation(((k.clone(), old_val), -1)),
67                        InsertResult::Unchanged => continue 'keys,
68                    };
69                    continuation(((k, new_val), 1));
70                }
71            }
72        }
73    }
74
75    fn get_type_name() -> &'static str {
76        "reduce"
77    }
78}
79
80impl<C: Op<D = (K, X)>, K: Clone + Eq + Hash, X> Relation<C> {
81    pub fn reduce_with_output_<
82        M: CountMap<X> + Observable,
83        OM: OutputMap<K, Y> + Default,
84        Y: Clone + Eq,
85        F: Fn(&K, &M) -> Y,
86    >(
87        self,
88        f: F,
89    ) -> Relation<Reduce<K, X, C, M, Y, OM, F>> {
90        self.context_tracker.add_relation(
91            self.dirty,
92            Reduce {
93                inner: self.inner,
94                in_map: HashMap::new(),
95                out_map: Default::default(),
96                f,
97            },
98            vec![self.tracking_index],
99        )
100    }
101}
102
103pub trait IsReduce: Op_ {
104    type OM;
105
106    fn get_map(&self) -> &Self::OM;
107}
108
109impl<
110        K: Clone + Eq + Hash,
111        X,
112        C: Op<D = (K, X)>,
113        M: CountMap<X> + Observable,
114        Y: Clone + Eq,
115        OM: OutputMap<K, Y>,
116        F: Fn(&K, &M) -> Y,
117    > IsReduce for Reduce<K, X, C, M, Y, OM, F>
118{
119    type OM = OM;
120
121    fn get_map(&self) -> &OM {
122        &self.out_map
123    }
124}
125
126impl<C: IsReduce> Relation<C> {
127    pub fn probe(self, context: &CreationContext) -> ReduceProbe<C> {
128        assert_eq!(&self.context_tracker, context.tracker(), "Context mismatch");
129        ReduceProbe {
130            context_tracker: self.context_tracker.clone(),
131            inner: RefCell::new(Saved::new(self)),
132        }
133    }
134}
135
136pub struct ReduceProbe<C: IsReduce> {
137    context_tracker: ContextTracker,
138    inner: RefCell<Saved<C>>,
139}
140
141impl<C: IsReduce> ReduceProbe<C> {
142    pub fn get_relation(&self) -> Relation<Save<C>>
143    where
144        C::T: Clone,
145    {
146        self.inner.borrow().clone().get()
147    }
148    pub fn inspect<'a>(&'a self, context: &'a ExecutionContext<'_>) -> ProbeRef<'a, C> {
149        assert_eq!(&self.context_tracker, context.tracker(), "Context mismatch");
150        self.inner.borrow_mut().propagate();
151        ProbeRef(self.inner.borrow())
152    }
153}
154
155pub struct ProbeRef<'a, C: IsReduce>(Ref<'a, Saved<C>>);
156
157impl<'a, C: IsReduce> ProbeRef<'a, C> {
158    pub fn get(&self) -> Ref<'_, C::OM> {
159        Ref::map(self.0.borrow(), |x| x.inner.get_map())
160    }
161}
162
163impl<C: IsReduce> Clone for ReduceProbe<C> {
164    fn clone(&self) -> Self {
165        ReduceProbe {
166            context_tracker: self.context_tracker.clone(),
167            inner: self.inner.clone(),
168        }
169    }
170}