1pub mod concat_fst;
2pub mod const_fst;
3pub mod utils;
4pub mod vector_fst;
5
6use crate::symbol_table::CSymbolTable;
7use crate::tr::CTr;
8use crate::trs::CTrs;
9use crate::{get, get_mut, wrap, CStateId, RUSTFST_FFI_RESULT};
10
11use anyhow::Result;
12use downcast_rs::Downcast;
13use ffi_convert::*;
14use rustfst::algorithms::concat::ConcatFst;
15use rustfst::fst_impls::{ConstFst, VectorFst};
16use rustfst::fst_traits::{Fst, MutableFst, SerializableFst};
17use rustfst::semirings::TropicalWeight;
18use rustfst::Semiring;
19use rustfst::{StateId, SymbolTable, Trs, TrsVec};
20use std::ffi::CStr;
21use std::sync::Arc;
22
23pub trait BindableFst: Downcast {
28 fn fst_start(&self) -> Option<StateId>;
29 fn fst_final_weight(&self, state: StateId) -> Result<Option<TropicalWeight>>;
30 fn fst_num_trs(&self, s: StateId) -> Result<usize>;
31
32 #[inline]
33 fn fst_is_final(&self, state_id: StateId) -> Result<bool> {
34 let w = self.fst_final_weight(state_id)?;
35 Ok(w.is_some())
36 }
37
38 #[inline]
39 fn fst_is_start(&self, state_id: StateId) -> bool {
40 Some(state_id) == self.fst_start()
41 }
42
43 fn fst_get_trs(&self, state_id: StateId) -> Result<TrsVec<TropicalWeight>>;
44 fn fst_input_symbols(&self) -> Option<Arc<SymbolTable>>;
45 fn fst_output_symbols(&self) -> Option<Arc<SymbolTable>>;
46 fn fst_set_input_symbols(&mut self, symt: Arc<SymbolTable>);
47 fn fst_set_output_symbols(&mut self, symt: Arc<SymbolTable>);
48 fn fst_take_input_symbols(&mut self) -> Option<Arc<SymbolTable>>;
49 fn fst_take_output_symbols(&mut self) -> Option<Arc<SymbolTable>>;
50}
51
52downcast_rs::impl_downcast!(BindableFst);
53
54impl<F: Fst<TropicalWeight> + 'static> BindableFst for F {
55 fn fst_start(&self) -> Option<StateId> {
56 self.start()
57 }
58 fn fst_final_weight(&self, state: StateId) -> Result<Option<TropicalWeight>> {
59 self.final_weight(state)
60 }
61 fn fst_num_trs(&self, s: StateId) -> Result<usize> {
62 self.num_trs(s)
63 }
64 fn fst_get_trs(&self, state_id: StateId) -> Result<TrsVec<TropicalWeight>> {
65 self.get_trs(state_id).map(|it| it.to_trs_vec())
66 }
67 fn fst_input_symbols(&self) -> Option<Arc<SymbolTable>> {
68 self.input_symbols().cloned()
69 }
70 fn fst_output_symbols(&self) -> Option<Arc<SymbolTable>> {
71 self.output_symbols().cloned()
72 }
73 fn fst_set_input_symbols(&mut self, symt: Arc<SymbolTable>) {
74 self.set_input_symbols(symt)
75 }
76 fn fst_set_output_symbols(&mut self, symt: Arc<SymbolTable>) {
77 self.set_output_symbols(symt)
78 }
79 fn fst_take_input_symbols(&mut self) -> Option<Arc<SymbolTable>> {
80 self.take_input_symbols()
81 }
82 fn fst_take_output_symbols(&mut self) -> Option<Arc<SymbolTable>> {
83 self.take_output_symbols()
84 }
85}
86
87#[derive(RawPointerConverter)]
88pub struct CFst(pub Box<dyn BindableFst>);
89
90#[derive(RawPointerConverter)]
91pub struct CVecFst(pub Box<VectorFst<TropicalWeight>>);
92
93#[derive(RawPointerConverter)]
94pub struct CConstFst(pub Box<ConstFst<TropicalWeight>>);
95
96#[derive(RawPointerConverter)]
97pub struct CConcatFst(pub Box<ConcatFst<TropicalWeight, VectorFst<TropicalWeight>>>);
98
99macro_rules! as_fst {
100 ($typ:ty,$fst:ident) => {{
101 $fst.downcast_ref::<$typ>()
102 .ok_or_else(|| anyhow!("Could not downcast to {} FST", stringify!($typ)))?
103 }};
104}
105
106macro_rules! as_mut_fst {
107 ($typ:ty,$fst:ident) => {{
108 $fst.downcast_mut::<$typ>()
109 .ok_or_else(|| anyhow!("Could not downcast to {} FST", stringify!($typ)))?
110 }};
111}
112
113pub(crate) use as_fst;
114pub(crate) use as_mut_fst;
115#[no_mangle]
128pub unsafe fn fst_start(fst: *const CFst, mut state: *mut CStateId) -> RUSTFST_FFI_RESULT {
129 wrap(|| {
130 let fst = get!(CFst, fst);
131 fst.fst_start()
132 .map(|it| unsafe { *state = it })
133 .unwrap_or_else(|| state = std::ptr::null_mut());
134 Ok(())
135 })
136}
137
138#[no_mangle]
143pub unsafe fn fst_final_weight(
144 fst: *const CFst,
145 state_id: CStateId,
146 mut final_weight: *mut libc::c_float,
147) -> RUSTFST_FFI_RESULT {
148 wrap(|| {
149 let fst = get!(CFst, fst);
150 fst.fst_final_weight(state_id)?
151 .map(|it| unsafe { *final_weight = *it.value() })
152 .unwrap_or_else(|| final_weight = std::ptr::null_mut());
153 Ok(())
154 })
155}
156
157#[no_mangle]
162pub unsafe fn fst_num_trs(
163 fst: *const CFst,
164 state: CStateId,
165 num_trs: *mut libc::size_t,
166) -> RUSTFST_FFI_RESULT {
167 wrap(|| {
168 let fst = get!(CFst, fst);
169 let res = fst.fst_num_trs(state)?;
170 unsafe { *num_trs = res };
171 Ok(())
172 })
173}
174
175#[no_mangle]
179pub unsafe fn fst_get_trs(
180 fst: *const CFst,
181 state: CStateId,
182 trs: *mut *const CTrs,
183) -> RUSTFST_FFI_RESULT {
184 wrap(|| {
185 let fst = get!(CFst, fst);
186 let res = fst.fst_get_trs(state)?;
187 let trs_vec = CTrs(res).into_raw_pointer();
188 unsafe { *trs = trs_vec }
189 Ok(())
190 })
191}
192
193#[no_mangle]
198pub unsafe fn fst_is_final(
199 fst: *const CFst,
200 state: CStateId,
201 is_final: *mut libc::size_t,
202) -> RUSTFST_FFI_RESULT {
203 wrap(|| {
204 let fst = get!(CFst, fst);
205 let res = fst.fst_is_final(state)?;
206 unsafe { *is_final = res as usize }
207 Ok(())
208 })
209}
210
211#[no_mangle]
216pub unsafe fn fst_is_start(
217 fst: *const CFst,
218 state: CStateId,
219 is_start: *mut libc::size_t,
220) -> RUSTFST_FFI_RESULT {
221 wrap(|| {
222 let fst = get!(CFst, fst);
223 let res = fst.fst_is_start(state);
224 unsafe { *is_start = res as usize }
225 Ok(())
226 })
227}
228
229#[no_mangle]
243pub unsafe fn fst_input_symbols(
244 fst: *const CFst,
245 mut input_symt: *mut *const CSymbolTable,
246) -> RUSTFST_FFI_RESULT {
247 wrap(|| {
248 let fst = get!(CFst, fst);
249 fst.fst_input_symbols()
250 .map(|it| {
251 let symt = CSymbolTable(it).into_raw_pointer();
252 unsafe { *input_symt = symt }
253 })
254 .unwrap_or_else(|| input_symt = std::ptr::null_mut());
255 Ok(())
256 })
257}
258
259#[no_mangle]
265pub unsafe fn fst_output_symbols(
266 fst: *const CFst,
267 mut output_symt: *mut *const CSymbolTable,
268) -> RUSTFST_FFI_RESULT {
269 wrap(|| {
270 let fst = get!(CFst, fst);
271 fst.fst_output_symbols()
272 .map(|it| {
273 let symt = CSymbolTable(it).into_raw_pointer();
274 unsafe { *output_symt = symt }
275 })
276 .unwrap_or_else(|| output_symt = std::ptr::null_mut());
277 Ok(())
278 })
279}
280
281#[no_mangle]
287pub unsafe fn fst_set_input_symbols(
288 fst: *mut CFst,
289 symt: *const CSymbolTable,
290) -> RUSTFST_FFI_RESULT {
291 wrap(|| {
292 let fst = get_mut!(CFst, fst);
293 let symt = get!(CSymbolTable, symt);
294 fst.fst_set_input_symbols(symt.clone());
295 Ok(())
296 })
297}
298
299#[no_mangle]
305pub unsafe fn fst_set_output_symbols(
306 fst: *mut CFst,
307 symt: *const CSymbolTable,
308) -> RUSTFST_FFI_RESULT {
309 wrap(|| {
310 let fst = get_mut!(CFst, fst);
311 let symt = get!(CSymbolTable, symt);
312 fst.fst_set_output_symbols(symt.clone());
313 Ok(())
314 })
315}
316
317#[no_mangle]
322pub unsafe fn fst_unset_input_symbols(fst: *mut CFst) -> RUSTFST_FFI_RESULT {
323 wrap(|| {
324 let fst = get_mut!(CFst, fst);
325 fst.fst_take_input_symbols();
326 Ok(())
327 })
328}
329
330#[no_mangle]
335pub unsafe fn fst_unset_output_symbols(fst: *mut CFst) -> RUSTFST_FFI_RESULT {
336 wrap(|| {
337 let fst = get_mut!(CFst, fst);
338 fst.fst_take_output_symbols();
339 Ok(())
340 })
341}
342
343#[no_mangle]
351pub unsafe extern "C" fn fst_weight_one(weight_one: *mut libc::c_float) -> RUSTFST_FFI_RESULT {
352 wrap(|| {
353 let weight = TropicalWeight::one();
354 unsafe { *weight_one = *weight.value() };
355 Ok(())
356 })
357}
358
359#[no_mangle]
363pub unsafe extern "C" fn fst_weight_zero(weight_zero: *mut libc::c_float) -> RUSTFST_FFI_RESULT {
364 wrap(|| {
365 let weight = TropicalWeight::zero();
366 unsafe { *weight_zero = *weight.value() };
367 Ok(())
368 })
369}
370
371#[no_mangle]
376pub unsafe fn fst_destroy(fst_ptr: *mut CFst) -> RUSTFST_FFI_RESULT {
377 wrap(|| {
378 if fst_ptr.is_null() {
379 return Ok(());
380 }
381
382 drop(unsafe { Box::from_raw(fst_ptr) });
383 Ok(())
384 })
385}