1use crate::fst::{CFst, CVecFst};
2use crate::tr::CTr;
3use crate::{get, get_mut, wrap, CStateId, RUSTFST_FFI_RESULT};
4use anyhow::Result;
5use ffi_convert::*;
6use rustfst::fst_traits::MutableFst;
7use rustfst::prelude::{StateIterator, Tr, TropicalWeight, TrsVec};
8use rustfst::trs_iter_mut::TrsIterMut;
9use std::iter::Peekable;
10use std::ops::Range;
11
12#[derive(Debug)]
13pub struct TrsIterator {
14 trs: TrsVec<TropicalWeight>,
15 index: usize,
16}
17
18impl TrsIterator {
19 fn done(&self) -> bool {
20 self.trs.len() == self.index
21 }
22
23 fn reset(&mut self) {
24 self.index = 0
25 }
26}
27
28impl Iterator for TrsIterator {
29 type Item = Tr<TropicalWeight>;
30 fn next(&mut self) -> Option<Self::Item> {
31 let item = self.trs.get(self.index).cloned();
32 self.index += 1;
33 item
34 }
35}
36
37#[derive(RawPointerConverter)]
38pub struct CTrsIterator(pub(crate) TrsIterator);
39
40#[no_mangle]
44pub unsafe extern "C" fn trs_iterator_new(
45 fst_ptr: *mut CFst,
46 state_id: CStateId,
47 mut iter_ptr: *mut *const CTrsIterator,
48) -> RUSTFST_FFI_RESULT {
49 wrap(|| {
50 let fst = get!(CFst, fst_ptr);
51 fst.fst_get_trs(state_id)
52 .map(|trs| {
53 let raw_ptr = {
54 let trs_iterator = TrsIterator { trs, index: 0 };
55 CTrsIterator(trs_iterator).into_raw_pointer()
56 };
57
58 unsafe { *iter_ptr = raw_ptr };
59 })
60 .unwrap_or_else(|_| iter_ptr = std::ptr::null_mut());
61
62 Ok(())
63 })
64}
65
66#[no_mangle]
70pub unsafe extern "C" fn trs_iterator_next(
71 iter_ptr: *mut CTrsIterator,
72 mut tr_ptr: *mut *const CTr,
73) -> RUSTFST_FFI_RESULT {
74 wrap(|| {
75 let trs_iter = get_mut!(CTrsIterator, iter_ptr);
76 trs_iter
77 .next()
78 .map(|tr| {
79 let ctr = Box::into_raw(Box::new(CTr::c_repr_of(tr)?));
80 unsafe { *tr_ptr = ctr };
81 Ok(())
82 })
83 .unwrap_or_else(|| -> Result<()> {
84 tr_ptr = std::ptr::null_mut();
85 Ok(())
86 })?;
87
88 Ok(())
89 })
90}
91
92#[no_mangle]
96pub unsafe extern "C" fn trs_iterator_done(
97 iter_ptr: *const CTrsIterator,
98 done: *mut libc::size_t,
99) -> RUSTFST_FFI_RESULT {
100 wrap(|| {
101 let trs_iter = get!(CTrsIterator, iter_ptr);
102 let res = trs_iter.done();
103 unsafe { *done = res as libc::size_t };
104 Ok(())
105 })
106}
107
108#[no_mangle]
112pub unsafe extern "C" fn trs_iterator_reset(iter_ptr: *mut CTrsIterator) -> RUSTFST_FFI_RESULT {
113 wrap(|| {
114 let trs_iter = get_mut!(CTrsIterator, iter_ptr);
115 trs_iter.reset();
116 Ok(())
117 })
118}
119
120#[no_mangle]
124pub unsafe extern "C" fn trs_iterator_destroy(iter_ptr: *mut CTrsIterator) -> RUSTFST_FFI_RESULT {
125 wrap(|| {
126 if iter_ptr.is_null() {
127 return Ok(());
128 }
129
130 drop(unsafe { Box::from_raw(iter_ptr) });
131 Ok(())
132 })
133}
134
135pub struct MutTrsIterator<'a> {
136 trs: TrsIterMut<'a, TropicalWeight>,
137 index: usize,
138}
139
140impl MutTrsIterator<'_> {
141 pub fn done(&self) -> bool {
142 self.trs.len() == self.index
143 }
144
145 pub fn next(&mut self) {
146 self.index += 1
147 }
148
149 pub fn value(&self) -> Option<Tr<TropicalWeight>> {
150 self.trs.get(self.index).cloned()
151 }
152
153 pub fn set_value(&mut self, tr: Tr<TropicalWeight>) -> Result<()> {
154 self.trs.set_tr(self.index, tr)
155 }
156
157 pub fn reset(&mut self) {
158 self.index = 0
159 }
160}
161
162pub struct CMutTrsIterator<'a>(pub(crate) MutTrsIterator<'a>);
163
164impl<'a> RawPointerConverter<CMutTrsIterator<'a>> for CMutTrsIterator<'a> {
165 fn into_raw_pointer(self) -> *const CMutTrsIterator<'a> {
166 Box::into_raw(Box::new(self)) as _
167 }
168 fn into_raw_pointer_mut(self) -> *mut CMutTrsIterator<'a> {
169 Box::into_raw(Box::new(self))
170 }
171
172 unsafe fn from_raw_pointer(
173 input: *const CMutTrsIterator<'a>,
174 ) -> Result<Self, UnexpectedNullPointerError> {
175 if input.is_null() {
176 Err(UnexpectedNullPointerError)
177 } else {
178 Ok(*Box::from_raw(input as _))
179 }
180 }
181
182 unsafe fn from_raw_pointer_mut(
183 input: *mut CMutTrsIterator<'a>,
184 ) -> Result<Self, UnexpectedNullPointerError> {
185 if input.is_null() {
186 Err(UnexpectedNullPointerError)
187 } else {
188 Ok(*Box::from_raw(input))
189 }
190 }
191}
192
193#[no_mangle]
197pub unsafe extern "C" fn mut_trs_iterator_new(
198 fst_ptr: *mut CVecFst,
199 state_id: CStateId,
200 mut iter_ptr: *mut *const CMutTrsIterator,
201) -> RUSTFST_FFI_RESULT {
202 wrap(|| {
203 let fst = get_mut!(CVecFst, fst_ptr);
204 fst.tr_iter_mut(state_id)
205 .map(|trs| {
206 let raw_ptr = {
207 let trs_iterator = MutTrsIterator { trs, index: 0 };
208 CMutTrsIterator(trs_iterator).into_raw_pointer()
209 };
210
211 unsafe { *iter_ptr = raw_ptr };
212 })
213 .unwrap_or_else(|_| iter_ptr = std::ptr::null_mut());
214
215 Ok(())
216 })
217}
218
219#[no_mangle]
223pub unsafe extern "C" fn mut_trs_iterator_next(
224 iter_ptr: *mut CMutTrsIterator,
225) -> RUSTFST_FFI_RESULT {
226 wrap(|| {
227 let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
228 trs_iter.next();
229 Ok(())
230 })
231}
232
233#[no_mangle]
237pub unsafe extern "C" fn mut_trs_iterator_value(
238 iter_ptr: *mut CMutTrsIterator,
239 mut tr_ptr: *mut *const CTr,
240) -> RUSTFST_FFI_RESULT {
241 wrap(|| {
242 let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
243 trs_iter
244 .value()
245 .map(|tr| {
246 let ctr = Box::into_raw(Box::new(CTr::c_repr_of(tr)?));
247 unsafe { *tr_ptr = ctr };
248 Ok(())
249 })
250 .unwrap_or_else(|| -> Result<()> {
251 tr_ptr = std::ptr::null_mut();
252 Ok(())
253 })?;
254 Ok(())
255 })
256}
257
258#[no_mangle]
262pub unsafe extern "C" fn mut_trs_iterator_set_value(
263 iter_ptr: *mut CMutTrsIterator,
264 tr_ptr: *const CTr,
265) -> RUSTFST_FFI_RESULT {
266 wrap(|| {
267 let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
268 let tr = unsafe { <CTr as ffi_convert::RawBorrow<CTr>>::raw_borrow(tr_ptr)? }.as_rust()?;
269 trs_iter.set_value(tr)?;
270 Ok(())
271 })
272}
273
274#[no_mangle]
278pub unsafe extern "C" fn mut_trs_iterator_done(
279 iter_ptr: *const CMutTrsIterator,
280 done: *mut libc::size_t,
281) -> RUSTFST_FFI_RESULT {
282 wrap(|| {
283 let trs_iter = get!(CMutTrsIterator, iter_ptr);
284 let res = trs_iter.done();
285 unsafe { *done = res as libc::size_t };
286 Ok(())
287 })
288}
289
290#[no_mangle]
294pub unsafe extern "C" fn mut_trs_iterator_reset(
295 iter_ptr: *mut CMutTrsIterator,
296) -> RUSTFST_FFI_RESULT {
297 wrap(|| {
298 let trs_iter = get_mut!(CMutTrsIterator, iter_ptr);
299 trs_iter.reset();
300 Ok(())
301 })
302}
303
304#[no_mangle]
308pub unsafe extern "C" fn mut_trs_iterator_destroy(
309 iter_ptr: *mut CMutTrsIterator,
310) -> RUSTFST_FFI_RESULT {
311 wrap(|| {
312 if iter_ptr.is_null() {
313 return Ok(());
314 }
315
316 drop(unsafe { Box::from_raw(iter_ptr) });
317 Ok(())
318 })
319}
320
321#[derive(RawPointerConverter)]
322pub struct CStateIterator(pub(crate) Peekable<Range<CStateId>>);
323
324#[no_mangle]
328pub unsafe extern "C" fn state_iterator_new(
329 fst_ptr: *mut CVecFst,
330 iter_ptr: *mut *const CStateIterator,
331) -> RUSTFST_FFI_RESULT {
332 wrap(|| {
333 let fst = get!(CVecFst, fst_ptr);
334 let state_iter = fst.states_iter().peekable();
335 let raw_ptr = CStateIterator(state_iter).into_raw_pointer();
336 unsafe { *iter_ptr = raw_ptr };
337 Ok(())
338 })
339}
340
341#[no_mangle]
345pub unsafe extern "C" fn state_iterator_next(
346 iter_ptr: *mut CStateIterator,
347 mut state: *mut CStateId,
348) -> RUSTFST_FFI_RESULT {
349 wrap(|| {
350 let state_iter = get_mut!(CStateIterator, iter_ptr);
351 state_iter
352 .next()
353 .map(|it| unsafe { *state = it })
354 .unwrap_or_else(|| state = std::ptr::null_mut());
355 Ok(())
356 })
357}
358
359#[no_mangle]
363pub unsafe extern "C" fn state_iterator_done(
364 iter_ptr: *mut CStateIterator,
365 done: *mut libc::size_t,
366) -> RUSTFST_FFI_RESULT {
367 wrap(|| {
368 let trs_iter = get_mut!(CStateIterator, iter_ptr);
369 let res = trs_iter.peek().is_none();
370 unsafe { *done = res as libc::size_t };
371 Ok(())
372 })
373}
374
375#[no_mangle]
379pub unsafe extern "C" fn state_iterator_destroy(
380 iter_ptr: *mut CStateIterator,
381) -> RUSTFST_FFI_RESULT {
382 wrap(|| {
383 if iter_ptr.is_null() {
384 return Ok(());
385 }
386
387 drop(unsafe { Box::from_raw(iter_ptr) });
388 Ok(())
389 })
390}