rustfst/algorithms/lazy/
state_table.rs1use std::fmt;
2use std::hash::Hash;
3use std::sync::Mutex;
4
5use crate::StateId;
6use std::collections::hash_map::Entry;
7use std::collections::hash_map::RandomState;
8use std::collections::HashMap;
9use std::hash::BuildHasher;
10
11use crate::parsers::nom_utils::NomCustomError;
12use crate::parsers::{parse_bin_u64, write_bin_u64, SerializeBinary};
13use nom::multi::{count, fold_many_m_n};
14use nom::IResult;
15use std::io::Write;
16
17use anyhow::{anyhow, Result};
18
19#[derive(Clone, Debug, Default)]
20pub(crate) struct BiHashMap<T: Hash + Eq + Clone, H: BuildHasher = RandomState> {
21 tuple_to_id: HashMap<T, StateId, H>,
22 id_to_tuple: Vec<T>,
23}
24
25impl<T: Hash + Eq + Clone, H: BuildHasher> PartialEq for BiHashMap<T, H> {
26 fn eq(&self, other: &Self) -> bool {
27 self.tuple_to_id.eq(&other.tuple_to_id) && self.id_to_tuple.eq(&other.id_to_tuple)
28 }
29}
30
31impl<T: Hash + Eq + Clone> BiHashMap<T> {
32 pub fn new() -> Self {
33 Self {
34 tuple_to_id: HashMap::new(),
35 id_to_tuple: Vec::new(),
36 }
37 }
38}
39
40impl<T: Hash + Eq + Clone, H: BuildHasher> BiHashMap<T, H> {
41 #[allow(unused)]
42 pub fn with_hasher(hash_builder: H) -> Self {
43 Self {
44 tuple_to_id: HashMap::with_hasher(hash_builder),
45 id_to_tuple: Vec::new(),
46 }
47 }
48
49 pub fn get_id_or_insert(&mut self, tuple: T) -> StateId {
50 match self.tuple_to_id.entry(tuple) {
51 Entry::Occupied(e) => *e.get(),
52 Entry::Vacant(e) => {
53 let n = self.id_to_tuple.len() as StateId;
54 self.id_to_tuple.push(e.key().clone());
55 e.insert(n);
56 n
57 }
58 }
59 }
60
61 pub fn get_tuple_unchecked(&self, id: StateId) -> T {
62 self.id_to_tuple[id as usize].clone()
63 }
64}
65
66pub struct StateTable<T: Hash + Eq + Clone> {
67 pub(crate) table: Mutex<BiHashMap<T>>,
68}
69
70impl<T: Hash + Eq + Clone> Clone for StateTable<T> {
71 fn clone(&self) -> Self {
72 Self {
73 table: Mutex::new(self.table.lock().unwrap().clone()),
74 }
75 }
76}
77
78impl<T: Hash + Eq + Clone> Default for StateTable<T> {
79 fn default() -> Self {
80 Self {
81 table: Mutex::new(BiHashMap::new()),
82 }
83 }
84}
85
86impl<T: Hash + Eq + Clone + fmt::Debug> fmt::Debug for StateTable<T> {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 write!(
89 f,
90 "StateTable {{ table : {:?} }}",
91 self.table.lock().unwrap()
92 )
93 }
94}
95
96impl<T: Hash + Eq + Clone + PartialEq> PartialEq for StateTable<T> {
97 fn eq(&self, other: &Self) -> bool {
98 self.table.lock().unwrap().eq(&*other.table.lock().unwrap())
99 }
100}
101
102impl<T: Hash + Eq + Clone> StateTable<T> {
103 pub fn new() -> Self {
104 Self {
105 table: Mutex::new(BiHashMap::new()),
106 }
107 }
108
109 pub fn find_id_from_ref(&self, tuple: &T) -> StateId {
111 let mut table = self.table.lock().unwrap();
112 table.get_id_or_insert(tuple.clone())
113 }
114
115 pub fn find_id(&self, tuple: T) -> StateId {
116 let mut table = self.table.lock().unwrap();
117 table.get_id_or_insert(tuple)
118 }
119
120 pub fn find_tuple(&self, tuple_id: StateId) -> T {
122 let table = self.table.lock().unwrap();
123 table.get_tuple_unchecked(tuple_id)
124 }
125}
126
127impl<T: SerializeBinary + Hash + Eq + Clone> SerializeBinary for StateTable<T> {
128 fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
130 let (i, tuple_to_id_len) = parse_bin_u64(i)?;
131 let (i, tuple_to_id) = fold_many_m_n(
132 tuple_to_id_len as usize,
133 tuple_to_id_len as usize,
134 parse_tuple_to_id,
135 HashMap::<T, StateId>::new,
136 |mut acc, item| {
137 acc.insert(item.0, item.1);
138 acc
139 },
140 )(i)?;
141
142 let (i, id_to_tuple_len) = parse_bin_u64(i)?;
143 let (i, id_to_tuple) = count(T::parse_binary, id_to_tuple_len as usize)(i)?;
144 Ok((
145 i,
146 StateTable {
147 table: Mutex::new(BiHashMap {
148 tuple_to_id,
149 id_to_tuple,
150 }),
151 },
152 ))
153 }
154 fn write_binary<WB: Write>(&self, writer: &mut WB) -> Result<()> {
156 let table = self.table.lock().map_err(|err| anyhow!("{}", err))?;
157 write_bin_u64(writer, table.tuple_to_id.len() as u64)?;
158
159 for (tuple, state) in table.tuple_to_id.iter() {
161 (*tuple).write_binary(writer)?;
162 write_bin_u64(writer, *state as u64)?;
163 }
164
165 write_bin_u64(writer, table.id_to_tuple.len() as u64)?;
166 for tuple in table.id_to_tuple.iter() {
167 (*tuple).write_binary(writer)?;
168 }
169
170 Ok(())
171 }
172}
173
174fn parse_tuple_to_id<T: SerializeBinary>(
175 i: &[u8],
176) -> IResult<&[u8], (T, StateId), NomCustomError<&[u8]>> {
177 let (i, tuple) = T::parse_binary(i)?;
178 let (i, state) = parse_bin_u64(i)?;
179
180 Ok((i, (tuple, state as StateId)))
181}
182
183#[cfg(test)]
184mod test {
185 use super::*;
186 use crate::algorithms::compose::filter_states::{FilterState, IntegerFilterState};
187 use crate::algorithms::compose::ComposeStateTuple;
188 use crate::StateId;
189 use anyhow::Result;
190
191 #[test]
192 fn test_read_write_state_table_empty() -> Result<()> {
193 let state_table = StateTable::<ComposeStateTuple<IntegerFilterState>>::new();
194
195 let mut buffer = Vec::new();
196 state_table.write_binary(&mut buffer)?;
197 let (_, parsed_state_table) =
198 StateTable::<ComposeStateTuple<IntegerFilterState>>::parse_binary(&buffer)
199 .map_err(|err| anyhow!("{}", err))?;
200
201 assert_eq!(state_table, parsed_state_table);
202 Ok(())
203 }
204
205 #[test]
206 fn test_read_write_state_table() -> Result<()> {
207 let fs1 = IntegerFilterState::new(1);
208 let fs2 = IntegerFilterState::new(2);
209 let tuple_1 = ComposeStateTuple {
210 fs: fs1,
211 s1: 1 as StateId,
212 s2: 2 as StateId,
213 };
214 let tuple_2 = ComposeStateTuple {
215 fs: fs2,
216 s1: 1 as StateId,
217 s2: 2 as StateId,
218 };
219 let state_table = StateTable::new();
220 state_table.find_id(tuple_1);
221 state_table.find_id(tuple_2);
222
223 let mut buffer = Vec::new();
224 state_table.write_binary(&mut buffer)?;
225 let (_, parsed_state_table) =
226 StateTable::<ComposeStateTuple<IntegerFilterState>>::parse_binary(&buffer)
227 .map_err(|err| anyhow!("{}", err))?;
228
229 assert_eq!(state_table, parsed_state_table);
230 Ok(())
231 }
232}