rustfst_ffi/algorithms/
replace.rs

1use anyhow::{anyhow, Result};
2
3use crate::fst::CFst;
4use crate::CLabel;
5use crate::{get, wrap, RUSTFST_FFI_RESULT};
6
7use ffi_convert::RawPointerConverter;
8use rustfst::algorithms::replace::replace;
9use rustfst::prelude::{Label, TropicalWeight, VectorFst};
10
11#[repr(C)]
12#[derive(Debug)]
13pub struct CLabelFstPair {
14    pub label: CLabel,
15    pub fst: *const CFst,
16}
17
18/// # Safety
19///
20/// The pointers should be valid.
21#[no_mangle]
22pub unsafe extern "C" fn fst_replace(
23    root: CLabel,
24    fst_list_ptr: *mut CLabelFstPair,
25    fst_list_ptr_len: libc::size_t,
26    epsilon_on_replace: bool,
27    replaced_fst: *mut *const CFst,
28) -> RUSTFST_FFI_RESULT {
29    wrap(|| {
30        let label_fst_pairs =
31            unsafe { std::slice::from_raw_parts_mut(fst_list_ptr, fst_list_ptr_len) };
32        let fst_list = label_fst_pairs
33            .iter_mut()
34            .map(|pair| -> Result<(CLabel, &VectorFst<TropicalWeight>)> {
35                let fst_ptr = pair.fst;
36                let fst = get!(CFst, fst_ptr);
37                let vec_fst: &VectorFst<TropicalWeight> = fst
38                    .downcast_ref()
39                    .ok_or_else(|| anyhow!("Could not downcast to vector FST"))?;
40                Ok((pair.label as Label, vec_fst))
41            })
42            .collect::<Result<Vec<(CLabel, &VectorFst<TropicalWeight>)>>>()?;
43        let res_fst: VectorFst<TropicalWeight> = replace::<
44            TropicalWeight,
45            VectorFst<TropicalWeight>,
46            _,
47            _,
48        >(fst_list, root, epsilon_on_replace)?;
49        unsafe { *replaced_fst = CFst(Box::new(res_fst)).into_raw_pointer() };
50        Ok(())
51    })
52}