rustfst/algorithms/lazy/
state_table.rs

1use std::fmt;
2use std::hash::Hash;
3use std::sync::Mutex;
4
5use crate::StateId;
6use std::collections::hash_map::Entry;
7use std::collections::hash_map::RandomState;
8use std::collections::HashMap;
9use std::hash::BuildHasher;
10
11use crate::parsers::nom_utils::NomCustomError;
12use crate::parsers::{parse_bin_u64, write_bin_u64, SerializeBinary};
13use nom::multi::{count, fold_many_m_n};
14use nom::IResult;
15use std::io::Write;
16
17use anyhow::{anyhow, Result};
18
19#[derive(Clone, Debug, Default)]
20pub(crate) struct BiHashMap<T: Hash + Eq + Clone, H: BuildHasher = RandomState> {
21    tuple_to_id: HashMap<T, StateId, H>,
22    id_to_tuple: Vec<T>,
23}
24
25impl<T: Hash + Eq + Clone, H: BuildHasher> PartialEq for BiHashMap<T, H> {
26    fn eq(&self, other: &Self) -> bool {
27        self.tuple_to_id.eq(&other.tuple_to_id) && self.id_to_tuple.eq(&other.id_to_tuple)
28    }
29}
30
31impl<T: Hash + Eq + Clone> BiHashMap<T> {
32    pub fn new() -> Self {
33        Self {
34            tuple_to_id: HashMap::new(),
35            id_to_tuple: Vec::new(),
36        }
37    }
38}
39
40impl<T: Hash + Eq + Clone, H: BuildHasher> BiHashMap<T, H> {
41    #[allow(unused)]
42    pub fn with_hasher(hash_builder: H) -> Self {
43        Self {
44            tuple_to_id: HashMap::with_hasher(hash_builder),
45            id_to_tuple: Vec::new(),
46        }
47    }
48
49    pub fn get_id_or_insert(&mut self, tuple: T) -> StateId {
50        match self.tuple_to_id.entry(tuple) {
51            Entry::Occupied(e) => *e.get(),
52            Entry::Vacant(e) => {
53                let n = self.id_to_tuple.len() as StateId;
54                self.id_to_tuple.push(e.key().clone());
55                e.insert(n);
56                n
57            }
58        }
59    }
60
61    pub fn get_tuple_unchecked(&self, id: StateId) -> T {
62        self.id_to_tuple[id as usize].clone()
63    }
64}
65
66pub struct StateTable<T: Hash + Eq + Clone> {
67    pub(crate) table: Mutex<BiHashMap<T>>,
68}
69
70impl<T: Hash + Eq + Clone> Clone for StateTable<T> {
71    fn clone(&self) -> Self {
72        Self {
73            table: Mutex::new(self.table.lock().unwrap().clone()),
74        }
75    }
76}
77
78impl<T: Hash + Eq + Clone> Default for StateTable<T> {
79    fn default() -> Self {
80        Self {
81            table: Mutex::new(BiHashMap::new()),
82        }
83    }
84}
85
86impl<T: Hash + Eq + Clone + fmt::Debug> fmt::Debug for StateTable<T> {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        write!(
89            f,
90            "StateTable {{ table : {:?} }}",
91            self.table.lock().unwrap()
92        )
93    }
94}
95
96impl<T: Hash + Eq + Clone + PartialEq> PartialEq for StateTable<T> {
97    fn eq(&self, other: &Self) -> bool {
98        self.table.lock().unwrap().eq(&*other.table.lock().unwrap())
99    }
100}
101
102impl<T: Hash + Eq + Clone> StateTable<T> {
103    pub fn new() -> Self {
104        Self {
105            table: Mutex::new(BiHashMap::new()),
106        }
107    }
108
109    /// Looks up integer ID from entry. If it doesn't exist and insert
110    pub fn find_id_from_ref(&self, tuple: &T) -> StateId {
111        let mut table = self.table.lock().unwrap();
112        table.get_id_or_insert(tuple.clone())
113    }
114
115    pub fn find_id(&self, tuple: T) -> StateId {
116        let mut table = self.table.lock().unwrap();
117        table.get_id_or_insert(tuple)
118    }
119
120    /// Looks up tuple from integer ID.
121    pub fn find_tuple(&self, tuple_id: StateId) -> T {
122        let table = self.table.lock().unwrap();
123        table.get_tuple_unchecked(tuple_id)
124    }
125}
126
127impl<T: SerializeBinary + Hash + Eq + Clone> SerializeBinary for StateTable<T> {
128    /// Parse a struct of type Self from a binary buffer.
129    fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
130        let (i, tuple_to_id_len) = parse_bin_u64(i)?;
131        let (i, tuple_to_id) = fold_many_m_n(
132            tuple_to_id_len as usize,
133            tuple_to_id_len as usize,
134            parse_tuple_to_id,
135            HashMap::<T, StateId>::new,
136            |mut acc, item| {
137                acc.insert(item.0, item.1);
138                acc
139            },
140        )(i)?;
141
142        let (i, id_to_tuple_len) = parse_bin_u64(i)?;
143        let (i, id_to_tuple) = count(T::parse_binary, id_to_tuple_len as usize)(i)?;
144        Ok((
145            i,
146            StateTable {
147                table: Mutex::new(BiHashMap {
148                    tuple_to_id,
149                    id_to_tuple,
150                }),
151            },
152        ))
153    }
154    /// Writes a struct to a writable buffer.
155    fn write_binary<WB: Write>(&self, writer: &mut WB) -> Result<()> {
156        let table = self.table.lock().map_err(|err| anyhow!("{}", err))?;
157        write_bin_u64(writer, table.tuple_to_id.len() as u64)?;
158
159        // Final weights serialization
160        for (tuple, state) in table.tuple_to_id.iter() {
161            (*tuple).write_binary(writer)?;
162            write_bin_u64(writer, *state as u64)?;
163        }
164
165        write_bin_u64(writer, table.id_to_tuple.len() as u64)?;
166        for tuple in table.id_to_tuple.iter() {
167            (*tuple).write_binary(writer)?;
168        }
169
170        Ok(())
171    }
172}
173
174fn parse_tuple_to_id<T: SerializeBinary>(
175    i: &[u8],
176) -> IResult<&[u8], (T, StateId), NomCustomError<&[u8]>> {
177    let (i, tuple) = T::parse_binary(i)?;
178    let (i, state) = parse_bin_u64(i)?;
179
180    Ok((i, (tuple, state as StateId)))
181}
182
183#[cfg(test)]
184mod test {
185    use super::*;
186    use crate::algorithms::compose::filter_states::{FilterState, IntegerFilterState};
187    use crate::algorithms::compose::ComposeStateTuple;
188    use crate::StateId;
189    use anyhow::Result;
190
191    #[test]
192    fn test_read_write_state_table_empty() -> Result<()> {
193        let state_table = StateTable::<ComposeStateTuple<IntegerFilterState>>::new();
194
195        let mut buffer = Vec::new();
196        state_table.write_binary(&mut buffer)?;
197        let (_, parsed_state_table) =
198            StateTable::<ComposeStateTuple<IntegerFilterState>>::parse_binary(&buffer)
199                .map_err(|err| anyhow!("{}", err))?;
200
201        assert_eq!(state_table, parsed_state_table);
202        Ok(())
203    }
204
205    #[test]
206    fn test_read_write_state_table() -> Result<()> {
207        let fs1 = IntegerFilterState::new(1);
208        let fs2 = IntegerFilterState::new(2);
209        let tuple_1 = ComposeStateTuple {
210            fs: fs1,
211            s1: 1 as StateId,
212            s2: 2 as StateId,
213        };
214        let tuple_2 = ComposeStateTuple {
215            fs: fs2,
216            s1: 1 as StateId,
217            s2: 2 as StateId,
218        };
219        let state_table = StateTable::new();
220        state_table.find_id(tuple_1);
221        state_table.find_id(tuple_2);
222
223        let mut buffer = Vec::new();
224        state_table.write_binary(&mut buffer)?;
225        let (_, parsed_state_table) =
226            StateTable::<ComposeStateTuple<IntegerFilterState>>::parse_binary(&buffer)
227                .map_err(|err| anyhow!("{}", err))?;
228
229        assert_eq!(state_table, parsed_state_table);
230        Ok(())
231    }
232}