Skip to main content

rustfst/algorithms/lazy/
lazy_fst.rs

1use std::collections::VecDeque;
2use std::fmt::Debug;
3use std::iter::{repeat, Map, Repeat, Zip};
4use std::marker::PhantomData;
5use std::ops::Deref;
6use std::path::Path;
7use std::sync::Arc;
8
9use anyhow::Result;
10use itertools::izip;
11
12use crate::algorithms::lazy::cache::CacheStatus;
13use crate::algorithms::lazy::fst_op::{AccessibleOpState, FstOp, SerializableOpState};
14use crate::algorithms::lazy::{FstCache, SerializableCache};
15use crate::fst_properties::FstProperties;
16use crate::fst_traits::{
17    AllocableFst, CoreFst, Fst, FstIterData, FstIterator, MutableFst, StateIterator,
18};
19use crate::semirings::{Semiring, SerializableSemiring};
20use crate::{StateId, SymbolTable, Trs, TrsVec};
21
22#[derive(Debug, Clone)]
23pub struct LazyFst<W: Semiring, Op: FstOp<W>, Cache> {
24    cache: Cache,
25    pub(crate) op: Op,
26    w: PhantomData<W>,
27    isymt: Option<Arc<SymbolTable>>,
28    osymt: Option<Arc<SymbolTable>>,
29}
30
31impl<W: Semiring, Op: FstOp<W>, Cache: FstCache<W>> CoreFst<W> for LazyFst<W, Op, Cache> {
32    type TRS = TrsVec<W>;
33
34    fn start(&self) -> Option<StateId> {
35        match self.cache.get_start() {
36            CacheStatus::Computed(start) => start,
37            CacheStatus::NotComputed => {
38                // TODO: Need to return a Result
39                let start = self.op.compute_start().unwrap();
40                self.cache.insert_start(start);
41                start
42            }
43        }
44    }
45
46    fn final_weight(&self, state_id: StateId) -> Result<Option<W>> {
47        match self.cache.get_final_weight(state_id) {
48            CacheStatus::Computed(final_weight) => Ok(final_weight),
49            CacheStatus::NotComputed => {
50                let final_weight = self.op.compute_final_weight(state_id)?;
51                self.cache
52                    .insert_final_weight(state_id, final_weight.clone());
53                Ok(final_weight)
54            }
55        }
56    }
57
58    unsafe fn final_weight_unchecked(&self, state_id: StateId) -> Option<W> {
59        self.final_weight(state_id).unwrap_unchecked()
60    }
61
62    fn num_trs(&self, s: StateId) -> Result<usize> {
63        self.cache
64            .num_trs(s)
65            .ok_or_else(|| format_err!("State {:?} doesn't exist", s))
66    }
67
68    unsafe fn num_trs_unchecked(&self, s: StateId) -> usize {
69        self.cache.num_trs(s).unwrap_unchecked()
70    }
71
72    fn get_trs(&self, state_id: StateId) -> Result<Self::TRS> {
73        match self.cache.get_trs(state_id) {
74            CacheStatus::Computed(trs) => Ok(trs),
75            CacheStatus::NotComputed => {
76                let trs = self.op.compute_trs(state_id)?;
77                self.cache.insert_trs(state_id, trs.shallow_clone());
78                Ok(trs)
79            }
80        }
81    }
82
83    unsafe fn get_trs_unchecked(&self, state_id: StateId) -> Self::TRS {
84        self.get_trs(state_id).unwrap_unchecked()
85    }
86
87    fn properties(&self) -> FstProperties {
88        self.op.properties()
89    }
90
91    fn num_input_epsilons(&self, state: StateId) -> Result<usize> {
92        self.cache
93            .num_input_epsilons(state)
94            .ok_or_else(|| format_err!("State {:?} doesn't exist", state))
95    }
96
97    fn num_output_epsilons(&self, state: StateId) -> Result<usize> {
98        self.cache
99            .num_output_epsilons(state)
100            .ok_or_else(|| format_err!("State {:?} doesn't exist", state))
101    }
102}
103
104impl<'a, W, Op, Cache> StateIterator<'a> for LazyFst<W, Op, Cache>
105where
106    W: Semiring,
107    Op: FstOp<W> + 'a,
108    Cache: FstCache<W> + 'a,
109{
110    type Iter = StatesIteratorLazyFst<'a, Self>;
111
112    fn states_iter(&'a self) -> Self::Iter {
113        self.start();
114        StatesIteratorLazyFst { fst: self, s: 0 }
115    }
116}
117
118#[derive(Clone)]
119pub struct StatesIteratorLazyFst<'a, T> {
120    pub(crate) fst: &'a T,
121    pub(crate) s: StateId,
122}
123
124impl<W, Op, Cache> Iterator for StatesIteratorLazyFst<'_, LazyFst<W, Op, Cache>>
125where
126    W: Semiring,
127    Op: FstOp<W>,
128    Cache: FstCache<W>,
129{
130    type Item = StateId;
131
132    fn next(&mut self) -> Option<Self::Item> {
133        let num_known_states = self.fst.cache.num_known_states();
134        if (self.s as usize) < num_known_states {
135            let s_cur = self.s;
136            // Force expansion of the state
137            self.fst.get_trs(self.s).unwrap();
138            self.s += 1;
139            Some(s_cur)
140        } else {
141            None
142        }
143    }
144}
145
146type ZipIter<'a, W, Op, Cache, SELF> =
147    Zip<<LazyFst<W, Op, Cache> as StateIterator<'a>>::Iter, Repeat<&'a SELF>>;
148type MapFunction<'a, W, SELF, TRS> = Box<dyn FnMut((StateId, &'a SELF)) -> FstIterData<W, TRS>>;
149type MapIter<'a, W, Op, Cache, SELF, TRS> =
150    Map<ZipIter<'a, W, Op, Cache, SELF>, MapFunction<'a, W, SELF, TRS>>;
151
152impl<'a, W, Op, Cache> FstIterator<'a, W> for LazyFst<W, Op, Cache>
153where
154    W: Semiring,
155    Op: FstOp<W> + 'a,
156    Cache: FstCache<W> + 'a,
157{
158    type FstIter = MapIter<'a, W, Op, Cache, Self, Self::TRS>;
159
160    fn fst_iter(&'a self) -> Self::FstIter {
161        let it = repeat(self);
162        izip!(self.states_iter(), it).map(Box::new(|(state_id, p): (StateId, &'a Self)| {
163            FstIterData {
164                state_id,
165                trs: unsafe { p.get_trs_unchecked(state_id) },
166                final_weight: unsafe { p.final_weight_unchecked(state_id) },
167                num_trs: unsafe { p.num_trs_unchecked(state_id) },
168            }
169        }))
170    }
171}
172
173impl<W, Op, Cache> Fst<W> for LazyFst<W, Op, Cache>
174where
175    W: Semiring,
176    Op: FstOp<W> + 'static,
177    Cache: FstCache<W> + 'static,
178{
179    fn input_symbols(&self) -> Option<&Arc<SymbolTable>> {
180        self.isymt.as_ref()
181    }
182
183    fn output_symbols(&self) -> Option<&Arc<SymbolTable>> {
184        self.osymt.as_ref()
185    }
186
187    fn set_input_symbols(&mut self, symt: Arc<SymbolTable>) {
188        self.isymt = Some(symt);
189    }
190
191    fn set_output_symbols(&mut self, symt: Arc<SymbolTable>) {
192        self.osymt = Some(symt);
193    }
194
195    fn take_input_symbols(&mut self) -> Option<Arc<SymbolTable>> {
196        self.isymt.take()
197    }
198
199    fn take_output_symbols(&mut self) -> Option<Arc<SymbolTable>> {
200        self.osymt.take()
201    }
202}
203
204impl<W, Op, Cache> LazyFst<W, Op, Cache>
205where
206    W: Semiring,
207    Op: FstOp<W>,
208    Cache: FstCache<W>,
209{
210    pub fn from_op_and_cache(
211        op: Op,
212        cache: Cache,
213        isymt: Option<Arc<SymbolTable>>,
214        osymt: Option<Arc<SymbolTable>>,
215    ) -> Self {
216        Self {
217            op,
218            cache,
219            isymt,
220            osymt,
221            w: PhantomData,
222        }
223    }
224
225    /// Turns the Lazy FST into a static one.
226    pub fn compute<F2: MutableFst<W> + AllocableFst<W>>(&self) -> Result<F2> {
227        let start_state = self.start();
228        let mut fst_out = F2::new();
229        let start_state = match start_state {
230            Some(s) => s,
231            None => return Ok(fst_out),
232        };
233        fst_out.add_states(start_state as usize + 1);
234        fst_out.set_start(start_state)?;
235        let mut queue = VecDeque::new();
236        let mut visited_states = vec![];
237        visited_states.resize(start_state as usize + 1, false);
238        visited_states[start_state as usize] = true;
239        queue.push_back(start_state);
240        while let Some(s) = queue.pop_front() {
241            let trs_owner = self.get_trs(s)?;
242            for tr in trs_owner.trs() {
243                if (tr.nextstate as usize) >= visited_states.len() {
244                    visited_states.resize(tr.nextstate as usize + 1, false);
245                }
246                if !visited_states[tr.nextstate as usize] {
247                    queue.push_back(tr.nextstate);
248                    visited_states[tr.nextstate as usize] = true;
249                }
250                let n = fst_out.num_states();
251                if (tr.nextstate as usize) >= n {
252                    fst_out.add_states(tr.nextstate as usize - n + 1)
253                }
254            }
255            unsafe { fst_out.set_trs_unchecked(s, trs_owner.trs().to_vec()) };
256            if let Some(f_w) = self.final_weight(s)? {
257                fst_out.set_final(s, f_w)?;
258            }
259        }
260        fst_out.set_properties(self.properties());
261
262        if let Some(isymt) = &self.isymt {
263            fst_out.set_input_symbols(Arc::clone(isymt));
264        }
265        if let Some(osymt) = &self.osymt {
266            fst_out.set_output_symbols(Arc::clone(osymt));
267        }
268        Ok(fst_out)
269    }
270}
271
272impl<W, Op, Cache> SerializableLazyFst for LazyFst<W, Op, Cache>
273where
274    W: SerializableSemiring,
275    Op: FstOp<W> + AccessibleOpState,
276    Op::FstOpState: SerializableOpState,
277    Cache: FstCache<W> + SerializableCache,
278{
279    /// Writes LazyFst interal states to a directory of files in binary format.
280    fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()> {
281        self.cache.write(cache_dir)?;
282        self.op.get_op_state().write(op_state_dir)?;
283        Ok(())
284    }
285}
286
287pub trait SerializableLazyFst {
288    /// Writes LazyFst interal states to a directory of files in binary format.
289    fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()>;
290}
291
292impl<C: SerializableLazyFst, CP: Deref<Target = C> + Debug> SerializableLazyFst for CP {
293    fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()> {
294        self.deref().write(cache_dir, op_state_dir)
295    }
296}