Skip to main content

rustfst_ffi/
iterators.rs

1use crate::fst::{CFst, CVecFst};
2use crate::tr::CTr;
3use crate::{get, get_mut, wrap, CStateId, RUSTFST_FFI_RESULT};
4use anyhow::Result;
5use ffi_convert::*;
6use rustfst::fst_traits::MutableFst;
7use rustfst::prelude::{StateIterator, Tr, TropicalWeight, TrsVec};
8use rustfst::trs_iter_mut::TrsIterMut;
9use std::iter::Peekable;
10use std::ops::Range;
11
12#[derive(Debug)]
13pub struct TrsIterator {
14    trs: TrsVec<TropicalWeight>,
15    index: usize,
16}
17
18impl TrsIterator {
19    fn done(&self) -> bool {
20        self.trs.len() == self.index
21    }
22
23    fn reset(&mut self) {
24        self.index = 0
25    }
26}
27
28impl Iterator for TrsIterator {
29    type Item = Tr<TropicalWeight>;
30    fn next(&mut self) -> Option<Self::Item> {
31        let item = self.trs.get(self.index).cloned();
32        self.index += 1;
33        item
34    }
35}
36
37#[derive(RawPointerConverter)]
38pub struct CTrsIterator(pub(crate) TrsIterator);
39
40/// # Safety
41///
42/// The pointers should be valid.
43#[no_mangle]
44pub unsafe extern "C" fn trs_iterator_new(
45    fst_ptr: *mut CFst,
46    state_id: CStateId,
47    mut iter_ptr: *mut *const CTrsIterator,
48) -> RUSTFST_FFI_RESULT {
49    wrap(|| {
50        let fst = get!(CFst, fst_ptr);
51        fst.fst_get_trs(state_id)
52            .map(|trs| {
53                let raw_ptr = {
54                    let trs_iterator = TrsIterator { trs, index: 0 };
55                    CTrsIterator(trs_iterator).into_raw_pointer()
56                };
57
58                unsafe { *iter_ptr = raw_ptr };
59            })
60            .unwrap_or_else(|_| iter_ptr = std::ptr::null_mut());
61
62        Ok(())
63    })
64}
65
66/// # Safety
67///
68/// The pointers should be valid.
69#[no_mangle]
70pub unsafe extern "C" fn trs_iterator_next(
71    iter_ptr: *mut CTrsIterator,
72    mut tr_ptr: *mut *const CTr,
73) -> RUSTFST_FFI_RESULT {
74    wrap(|| {
75        let trs_iter = get_mut!(CTrsIterator, iter_ptr);
76        trs_iter
77            .next()
78            .map(|tr| {
79                let ctr = Box::into_raw(Box::new(CTr::c_repr_of(tr)?));
80                unsafe { *tr_ptr = ctr };
81                Ok(())
82            })
83            .unwrap_or_else(|| -> Result<()> {
84                tr_ptr = std::ptr::null_mut();
85                Ok(())
86            })?;
87
88        Ok(())
89    })
90}
91
92/// # Safety
93///
94/// The pointers should be valid.
95#[no_mangle]
96pub unsafe extern "C" fn trs_iterator_done(
97    iter_ptr: *const CTrsIterator,
98    done: *mut libc::size_t,
99) -> RUSTFST_FFI_RESULT {
100    wrap(|| {
101        let trs_iter = get!(CTrsIterator, iter_ptr);
102        let res = trs_iter.done();
103        unsafe { *done = res as libc::size_t };
104        Ok(())
105    })
106}
107
108/// # Safety
109///
110/// The pointers should be valid.
111#[no_mangle]
112pub unsafe extern "C" fn trs_iterator_reset(iter_ptr: *mut CTrsIterator) -> RUSTFST_FFI_RESULT {
113    wrap(|| {
114        let trs_iter = get_mut!(CTrsIterator, iter_ptr);
115        trs_iter.reset();
116        Ok(())
117    })
118}
119
120/// # Safety
121///
122/// The pointers should be valid.
123#[no_mangle]
124pub unsafe extern "C" fn trs_iterator_destroy(iter_ptr: *mut CTrsIterator) -> RUSTFST_FFI_RESULT {
125    wrap(|| {
126        if iter_ptr.is_null() {
127            return Ok(());
128        }
129
130        drop(unsafe { Box::from_raw(iter_ptr) });
131        Ok(())
132    })
133}
134
135pub struct MutTrsIterator<'a> {
136    trs: TrsIterMut<'a, TropicalWeight>,
137    index: usize,
138}
139
140impl MutTrsIterator<'_> {
141    pub fn done(&self) -> bool {
142        self.trs.len() == self.index
143    }
144
145    pub fn next(&mut self) {
146        self.index += 1
147    }
148
149    pub fn value(&self) -> Option<Tr<TropicalWeight>> {
150        self.trs.get(self.index).cloned()
151    }
152
153    pub fn set_value(&mut self, tr: Tr<TropicalWeight>) -> Result<()> {
154        self.trs.set_tr(self.index, tr)
155    }
156
157    pub fn reset(&mut self) {
158        self.index = 0
159    }
160}
161
162pub struct CMutTrsIterator<'a>(pub(crate) MutTrsIterator<'a>);
163
164impl<'a> RawPointerConverter<CMutTrsIterator<'a>> for CMutTrsIterator<'a> {
165    fn into_raw_pointer(self) -> *const CMutTrsIterator<'a> {
166        Box::into_raw(Box::new(self)) as _
167    }
168    fn into_raw_pointer_mut(self) -> *mut CMutTrsIterator<'a> {
169        Box::into_raw(Box::new(self))
170    }
171
172    unsafe fn from_raw_pointer(
173        input: *const CMutTrsIterator<'a>,
174    ) -> Result<Self, UnexpectedNullPointerError> {
175        if input.is_null() {
176            Err(UnexpectedNullPointerError)
177        } else {
178            Ok(*Box::from_raw(input as _))
179        }
180    }
181
182    unsafe fn from_raw_pointer_mut(
183        input: *mut CMutTrsIterator<'a>,
184    ) -> Result<Self, UnexpectedNullPointerError> {
185        if input.is_null() {
186            Err(UnexpectedNullPointerError)
187        } else {
188            Ok(*Box::from_raw(input))
189        }
190    }
191}
192
193/// # Safety
194///
195/// The pointers should be valid.
196#[no_mangle]
197pub unsafe extern "C" fn mut_trs_iterator_new(
198    fst_ptr: *mut CVecFst,
199    state_id: CStateId,
200    mut iter_ptr: *mut *const CMutTrsIterator,
201) -> RUSTFST_FFI_RESULT {
202    wrap(|| {
203        let fst = get_mut!(CVecFst, fst_ptr);
204        fst.tr_iter_mut(state_id)
205            .map(|trs| {
206                let raw_ptr = {
207                    let trs_iterator = MutTrsIterator { trs, index: 0 };
208                    CMutTrsIterator(trs_iterator).into_raw_pointer()
209                };
210
211                unsafe { *iter_ptr = raw_ptr };
212            })
213            .unwrap_or_else(|_| iter_ptr = std::ptr::null_mut());
214
215        Ok(())
216    })
217}
218
219/// # Safety
220///
221/// The pointers should be valid.
222#[no_mangle]
223pub unsafe extern "C" fn mut_trs_iterator_next(
224    iter_ptr: *mut CMutTrsIterator,
225) -> RUSTFST_FFI_RESULT {
226    wrap(|| {
227        let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
228        trs_iter.next();
229        Ok(())
230    })
231}
232
233/// # Safety
234///
235/// The pointers should be valid.
236#[no_mangle]
237pub unsafe extern "C" fn mut_trs_iterator_value(
238    iter_ptr: *mut CMutTrsIterator,
239    mut tr_ptr: *mut *const CTr,
240) -> RUSTFST_FFI_RESULT {
241    wrap(|| {
242        let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
243        trs_iter
244            .value()
245            .map(|tr| {
246                let ctr = Box::into_raw(Box::new(CTr::c_repr_of(tr)?));
247                unsafe { *tr_ptr = ctr };
248                Ok(())
249            })
250            .unwrap_or_else(|| -> Result<()> {
251                tr_ptr = std::ptr::null_mut();
252                Ok(())
253            })?;
254        Ok(())
255    })
256}
257
258/// # Safety
259///
260/// The pointers should be valid.
261#[no_mangle]
262pub unsafe extern "C" fn mut_trs_iterator_set_value(
263    iter_ptr: *mut CMutTrsIterator,
264    tr_ptr: *const CTr,
265) -> RUSTFST_FFI_RESULT {
266    wrap(|| {
267        let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
268        let tr = unsafe { <CTr as ffi_convert::RawBorrow<CTr>>::raw_borrow(tr_ptr)? }.as_rust()?;
269        trs_iter.set_value(tr)?;
270        Ok(())
271    })
272}
273
274/// # Safety
275///
276/// The pointers should be valid.
277#[no_mangle]
278pub unsafe extern "C" fn mut_trs_iterator_done(
279    iter_ptr: *const CMutTrsIterator,
280    done: *mut libc::size_t,
281) -> RUSTFST_FFI_RESULT {
282    wrap(|| {
283        let trs_iter = get!(CMutTrsIterator, iter_ptr);
284        let res = trs_iter.done();
285        unsafe { *done = res as libc::size_t };
286        Ok(())
287    })
288}
289
290/// # Safety
291///
292/// The pointers should be valid.
293#[no_mangle]
294pub unsafe extern "C" fn mut_trs_iterator_reset(
295    iter_ptr: *mut CMutTrsIterator,
296) -> RUSTFST_FFI_RESULT {
297    wrap(|| {
298        let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
299        trs_iter.reset();
300        Ok(())
301    })
302}
303
304/// # Safety
305///
306/// The pointers should be valid.
307#[no_mangle]
308pub unsafe extern "C" fn mut_trs_iterator_destroy(
309    iter_ptr: *mut CMutTrsIterator,
310) -> RUSTFST_FFI_RESULT {
311    wrap(|| {
312        if iter_ptr.is_null() {
313            return Ok(());
314        }
315
316        drop(unsafe { Box::from_raw(iter_ptr) });
317        Ok(())
318    })
319}
320
321#[derive(RawPointerConverter)]
322pub struct CStateIterator(pub(crate) Peekable<Range<CStateId>>);
323
324/// # Safety
325///
326/// The pointers should be valid.
327#[no_mangle]
328pub unsafe extern "C" fn state_iterator_new(
329    fst_ptr: *mut CVecFst,
330    iter_ptr: *mut *const CStateIterator,
331) -> RUSTFST_FFI_RESULT {
332    wrap(|| {
333        let fst = get!(CVecFst, fst_ptr);
334        let state_iter = fst.states_iter().peekable();
335        let raw_ptr = CStateIterator(state_iter).into_raw_pointer();
336        unsafe { *iter_ptr = raw_ptr };
337        Ok(())
338    })
339}
340
341/// # Safety
342///
343/// The pointers should be valid.
344#[no_mangle]
345pub unsafe extern "C" fn state_iterator_next(
346    iter_ptr: *mut CStateIterator,
347    mut state: *mut CStateId,
348) -> RUSTFST_FFI_RESULT {
349    wrap(|| {
350        let state_iter = get_mut!(CStateIterator, iter_ptr);
351        state_iter
352            .next()
353            .map(|it| unsafe { *state = it })
354            .unwrap_or_else(|| state = std::ptr::null_mut());
355        Ok(())
356    })
357}
358
359/// # Safety
360///
361/// The pointers should be valid.
362#[no_mangle]
363pub unsafe extern "C" fn state_iterator_done(
364    iter_ptr: *mut CStateIterator,
365    done: *mut libc::size_t,
366) -> RUSTFST_FFI_RESULT {
367    wrap(|| {
368        let trs_iter = get_mut!(CStateIterator, iter_ptr);
369        let res = trs_iter.peek().is_none();
370        unsafe { *done = res as libc::size_t };
371        Ok(())
372    })
373}
374
375/// # Safety
376///
377/// The pointers should be valid.
378#[no_mangle]
379pub unsafe extern "C" fn state_iterator_destroy(
380    iter_ptr: *mut CStateIterator,
381) -> RUSTFST_FFI_RESULT {
382    wrap(|| {
383        if iter_ptr.is_null() {
384            return Ok(());
385        }
386
387        drop(unsafe { Box::from_raw(iter_ptr) });
388        Ok(())
389    })
390}