rustfst_ffi/algorithms/
replace.rs1use anyhow::{anyhow, Result};
2
3use crate::fst::CFst;
4use crate::CLabel;
5use crate::{get, wrap, RUSTFST_FFI_RESULT};
6
7use ffi_convert::RawPointerConverter;
8use rustfst::algorithms::replace::replace;
9use rustfst::prelude::{Label, TropicalWeight, VectorFst};
10
11#[repr(C)]
12#[derive(Debug)]
13pub struct CLabelFstPair {
14 pub label: CLabel,
15 pub fst: *const CFst,
16}
17
18#[no_mangle]
22pub unsafe extern "C" fn fst_replace(
23 root: CLabel,
24 fst_list_ptr: *mut CLabelFstPair,
25 fst_list_ptr_len: libc::size_t,
26 epsilon_on_replace: bool,
27 replaced_fst: *mut *const CFst,
28) -> RUSTFST_FFI_RESULT {
29 wrap(|| {
30 let label_fst_pairs =
31 unsafe { std::slice::from_raw_parts_mut(fst_list_ptr, fst_list_ptr_len) };
32 let fst_list = label_fst_pairs
33 .iter_mut()
34 .map(|pair| -> Result<(CLabel, &VectorFst<TropicalWeight>)> {
35 let fst_ptr = pair.fst;
36 let fst = get!(CFst, fst_ptr);
37 let vec_fst: &VectorFst<TropicalWeight> = fst
38 .downcast_ref()
39 .ok_or_else(|| anyhow!("Could not downcast to vector FST"))?;
40 Ok((pair.label as Label, vec_fst))
41 })
42 .collect::<Result<Vec<(CLabel, &VectorFst<TropicalWeight>)>>>()?;
43 let res_fst: VectorFst<TropicalWeight> = replace::<
44 TropicalWeight,
45 VectorFst<TropicalWeight>,
46 _,
47 _,
48 >(fst_list, root, epsilon_on_replace)?;
49 unsafe { *replaced_fst = CFst(Box::new(res_fst)).into_raw_pointer() };
50 Ok(())
51 })
52}