1use std::collections::VecDeque;
2use std::fmt::Debug;
3use std::iter::{repeat, Map, Repeat, Zip};
4use std::marker::PhantomData;
5use std::ops::Deref;
6use std::path::Path;
7use std::sync::Arc;
8
9use anyhow::Result;
10use itertools::izip;
11
12use crate::algorithms::lazy::cache::CacheStatus;
13use crate::algorithms::lazy::fst_op::{AccessibleOpState, FstOp, SerializableOpState};
14use crate::algorithms::lazy::{FstCache, SerializableCache};
15use crate::fst_properties::FstProperties;
16use crate::fst_traits::{
17 AllocableFst, CoreFst, Fst, FstIterData, FstIterator, MutableFst, StateIterator,
18};
19use crate::semirings::{Semiring, SerializableSemiring};
20use crate::{StateId, SymbolTable, Trs, TrsVec};
21
22#[derive(Debug, Clone)]
23pub struct LazyFst<W: Semiring, Op: FstOp<W>, Cache> {
24 cache: Cache,
25 pub(crate) op: Op,
26 w: PhantomData<W>,
27 isymt: Option<Arc<SymbolTable>>,
28 osymt: Option<Arc<SymbolTable>>,
29}
30
31impl<W: Semiring, Op: FstOp<W>, Cache: FstCache<W>> CoreFst<W> for LazyFst<W, Op, Cache> {
32 type TRS = TrsVec<W>;
33
34 fn start(&self) -> Option<StateId> {
35 match self.cache.get_start() {
36 CacheStatus::Computed(start) => start,
37 CacheStatus::NotComputed => {
38 let start = self.op.compute_start().unwrap();
40 self.cache.insert_start(start);
41 start
42 }
43 }
44 }
45
46 fn final_weight(&self, state_id: StateId) -> Result<Option<W>> {
47 match self.cache.get_final_weight(state_id) {
48 CacheStatus::Computed(final_weight) => Ok(final_weight),
49 CacheStatus::NotComputed => {
50 let final_weight = self.op.compute_final_weight(state_id)?;
51 self.cache
52 .insert_final_weight(state_id, final_weight.clone());
53 Ok(final_weight)
54 }
55 }
56 }
57
58 unsafe fn final_weight_unchecked(&self, state_id: StateId) -> Option<W> {
59 self.final_weight(state_id).unwrap_unchecked()
60 }
61
62 fn num_trs(&self, s: StateId) -> Result<usize> {
63 self.cache
64 .num_trs(s)
65 .ok_or_else(|| format_err!("State {:?} doesn't exist", s))
66 }
67
68 unsafe fn num_trs_unchecked(&self, s: StateId) -> usize {
69 self.cache.num_trs(s).unwrap_unchecked()
70 }
71
72 fn get_trs(&self, state_id: StateId) -> Result<Self::TRS> {
73 match self.cache.get_trs(state_id) {
74 CacheStatus::Computed(trs) => Ok(trs),
75 CacheStatus::NotComputed => {
76 let trs = self.op.compute_trs(state_id)?;
77 self.cache.insert_trs(state_id, trs.shallow_clone());
78 Ok(trs)
79 }
80 }
81 }
82
83 unsafe fn get_trs_unchecked(&self, state_id: StateId) -> Self::TRS {
84 self.get_trs(state_id).unwrap_unchecked()
85 }
86
87 fn properties(&self) -> FstProperties {
88 self.op.properties()
89 }
90
91 fn num_input_epsilons(&self, state: StateId) -> Result<usize> {
92 self.cache
93 .num_input_epsilons(state)
94 .ok_or_else(|| format_err!("State {:?} doesn't exist", state))
95 }
96
97 fn num_output_epsilons(&self, state: StateId) -> Result<usize> {
98 self.cache
99 .num_output_epsilons(state)
100 .ok_or_else(|| format_err!("State {:?} doesn't exist", state))
101 }
102}
103
104impl<'a, W, Op, Cache> StateIterator<'a> for LazyFst<W, Op, Cache>
105where
106 W: Semiring,
107 Op: FstOp<W> + 'a,
108 Cache: FstCache<W> + 'a,
109{
110 type Iter = StatesIteratorLazyFst<'a, Self>;
111
112 fn states_iter(&'a self) -> Self::Iter {
113 self.start();
114 StatesIteratorLazyFst { fst: self, s: 0 }
115 }
116}
117
118#[derive(Clone)]
119pub struct StatesIteratorLazyFst<'a, T> {
120 pub(crate) fst: &'a T,
121 pub(crate) s: StateId,
122}
123
124impl<W, Op, Cache> Iterator for StatesIteratorLazyFst<'_, LazyFst<W, Op, Cache>>
125where
126 W: Semiring,
127 Op: FstOp<W>,
128 Cache: FstCache<W>,
129{
130 type Item = StateId;
131
132 fn next(&mut self) -> Option<Self::Item> {
133 let num_known_states = self.fst.cache.num_known_states();
134 if (self.s as usize) < num_known_states {
135 let s_cur = self.s;
136 self.fst.get_trs(self.s).unwrap();
138 self.s += 1;
139 Some(s_cur)
140 } else {
141 None
142 }
143 }
144}
145
146type ZipIter<'a, W, Op, Cache, SELF> =
147 Zip<<LazyFst<W, Op, Cache> as StateIterator<'a>>::Iter, Repeat<&'a SELF>>;
148type MapFunction<'a, W, SELF, TRS> = Box<dyn FnMut((StateId, &'a SELF)) -> FstIterData<W, TRS>>;
149type MapIter<'a, W, Op, Cache, SELF, TRS> =
150 Map<ZipIter<'a, W, Op, Cache, SELF>, MapFunction<'a, W, SELF, TRS>>;
151
152impl<'a, W, Op, Cache> FstIterator<'a, W> for LazyFst<W, Op, Cache>
153where
154 W: Semiring,
155 Op: FstOp<W> + 'a,
156 Cache: FstCache<W> + 'a,
157{
158 type FstIter = MapIter<'a, W, Op, Cache, Self, Self::TRS>;
159
160 fn fst_iter(&'a self) -> Self::FstIter {
161 let it = repeat(self);
162 izip!(self.states_iter(), it).map(Box::new(|(state_id, p): (StateId, &'a Self)| {
163 FstIterData {
164 state_id,
165 trs: unsafe { p.get_trs_unchecked(state_id) },
166 final_weight: unsafe { p.final_weight_unchecked(state_id) },
167 num_trs: unsafe { p.num_trs_unchecked(state_id) },
168 }
169 }))
170 }
171}
172
173impl<W, Op, Cache> Fst<W> for LazyFst<W, Op, Cache>
174where
175 W: Semiring,
176 Op: FstOp<W> + 'static,
177 Cache: FstCache<W> + 'static,
178{
179 fn input_symbols(&self) -> Option<&Arc<SymbolTable>> {
180 self.isymt.as_ref()
181 }
182
183 fn output_symbols(&self) -> Option<&Arc<SymbolTable>> {
184 self.osymt.as_ref()
185 }
186
187 fn set_input_symbols(&mut self, symt: Arc<SymbolTable>) {
188 self.isymt = Some(symt);
189 }
190
191 fn set_output_symbols(&mut self, symt: Arc<SymbolTable>) {
192 self.osymt = Some(symt);
193 }
194
195 fn take_input_symbols(&mut self) -> Option<Arc<SymbolTable>> {
196 self.isymt.take()
197 }
198
199 fn take_output_symbols(&mut self) -> Option<Arc<SymbolTable>> {
200 self.osymt.take()
201 }
202}
203
204impl<W, Op, Cache> LazyFst<W, Op, Cache>
205where
206 W: Semiring,
207 Op: FstOp<W>,
208 Cache: FstCache<W>,
209{
210 pub fn from_op_and_cache(
211 op: Op,
212 cache: Cache,
213 isymt: Option<Arc<SymbolTable>>,
214 osymt: Option<Arc<SymbolTable>>,
215 ) -> Self {
216 Self {
217 op,
218 cache,
219 isymt,
220 osymt,
221 w: PhantomData,
222 }
223 }
224
225 pub fn compute<F2: MutableFst<W> + AllocableFst<W>>(&self) -> Result<F2> {
227 let start_state = self.start();
228 let mut fst_out = F2::new();
229 let start_state = match start_state {
230 Some(s) => s,
231 None => return Ok(fst_out),
232 };
233 fst_out.add_states(start_state as usize + 1);
234 fst_out.set_start(start_state)?;
235 let mut queue = VecDeque::new();
236 let mut visited_states = vec![];
237 visited_states.resize(start_state as usize + 1, false);
238 visited_states[start_state as usize] = true;
239 queue.push_back(start_state);
240 while let Some(s) = queue.pop_front() {
241 let trs_owner = self.get_trs(s)?;
242 for tr in trs_owner.trs() {
243 if (tr.nextstate as usize) >= visited_states.len() {
244 visited_states.resize(tr.nextstate as usize + 1, false);
245 }
246 if !visited_states[tr.nextstate as usize] {
247 queue.push_back(tr.nextstate);
248 visited_states[tr.nextstate as usize] = true;
249 }
250 let n = fst_out.num_states();
251 if (tr.nextstate as usize) >= n {
252 fst_out.add_states(tr.nextstate as usize - n + 1)
253 }
254 }
255 unsafe { fst_out.set_trs_unchecked(s, trs_owner.trs().to_vec()) };
256 if let Some(f_w) = self.final_weight(s)? {
257 fst_out.set_final(s, f_w)?;
258 }
259 }
260 fst_out.set_properties(self.properties());
261
262 if let Some(isymt) = &self.isymt {
263 fst_out.set_input_symbols(Arc::clone(isymt));
264 }
265 if let Some(osymt) = &self.osymt {
266 fst_out.set_output_symbols(Arc::clone(osymt));
267 }
268 Ok(fst_out)
269 }
270}
271
272impl<W, Op, Cache> SerializableLazyFst for LazyFst<W, Op, Cache>
273where
274 W: SerializableSemiring,
275 Op: FstOp<W> + AccessibleOpState,
276 Op::FstOpState: SerializableOpState,
277 Cache: FstCache<W> + SerializableCache,
278{
279 fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()> {
281 self.cache.write(cache_dir)?;
282 self.op.get_op_state().write(op_state_dir)?;
283 Ok(())
284 }
285}
286
287pub trait SerializableLazyFst {
288 fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()>;
290}
291
292impl<C: SerializableLazyFst, CP: Deref<Target = C> + Debug> SerializableLazyFst for CP {
293 fn write<P: AsRef<Path>>(&self, cache_dir: P, op_state_dir: P) -> Result<()> {
294 self.deref().write(cache_dir, op_state_dir)
295 }
296}