rustfst_ffi/algorithms/
compose.rs

1use anyhow::{anyhow, Result};
2
3use super::EnumConversionError;
4use crate::fst::CFst;
5use crate::{get, wrap, CLabel, RUSTFST_FFI_RESULT};
6
7use ffi_convert::*;
8use rustfst::algorithms::compose::matchers::MatcherRewriteMode;
9use rustfst::algorithms::compose::{
10    compose, compose_with_config, ComposeConfig, ComposeFilterEnum, MatcherConfig,
11    SigmaMatcherConfig,
12};
13use rustfst::fst_impls::VectorFst;
14use rustfst::semirings::TropicalWeight;
15use rustfst::Label;
16
17#[derive(RawPointerConverter, Debug)]
18pub struct CComposeFilterEnum(pub(crate) usize);
19
20impl AsRust<ComposeFilterEnum> for CComposeFilterEnum {
21    fn as_rust(&self) -> Result<ComposeFilterEnum, AsRustError> {
22        match self.0 {
23            0 => Ok(ComposeFilterEnum::AutoFilter),
24            1 => Ok(ComposeFilterEnum::NullFilter),
25            2 => Ok(ComposeFilterEnum::TrivialFilter),
26            3 => Ok(ComposeFilterEnum::SequenceFilter),
27            4 => Ok(ComposeFilterEnum::AltSequenceFilter),
28            5 => Ok(ComposeFilterEnum::MatchFilter),
29            6 => Ok(ComposeFilterEnum::NoMatchFilter),
30            _ => Err(AsRustError::Other(Box::new(EnumConversionError {}))),
31        }
32    }
33}
34
35impl CDrop for CComposeFilterEnum {
36    fn do_drop(&mut self) -> Result<(), CDropError> {
37        Ok(())
38    }
39}
40
41impl CReprOf<ComposeFilterEnum> for CComposeFilterEnum {
42    fn c_repr_of(value: ComposeFilterEnum) -> Result<CComposeFilterEnum, CReprOfError> {
43        let variant = match value {
44            ComposeFilterEnum::AutoFilter => 0,
45            ComposeFilterEnum::NullFilter => 1,
46            ComposeFilterEnum::TrivialFilter => 2,
47            ComposeFilterEnum::SequenceFilter => 3,
48            ComposeFilterEnum::AltSequenceFilter => 4,
49            ComposeFilterEnum::MatchFilter => 5,
50            ComposeFilterEnum::NoMatchFilter => 6,
51        };
52        Ok(CComposeFilterEnum(variant))
53    }
54}
55
56#[derive(RawPointerConverter, Debug, Clone)]
57pub struct CMatcherRewriteMode(pub(crate) usize);
58
59impl AsRust<MatcherRewriteMode> for CMatcherRewriteMode {
60    fn as_rust(&self) -> Result<MatcherRewriteMode, AsRustError> {
61        match self.0 {
62            0 => Ok(MatcherRewriteMode::MatcherRewriteAuto),
63            1 => Ok(MatcherRewriteMode::MatcherRewriteAlways),
64            2 => Ok(MatcherRewriteMode::MatcherRewriteNever),
65            _ => Err(AsRustError::Other(Box::new(EnumConversionError {}))),
66        }
67    }
68}
69
70impl CDrop for CMatcherRewriteMode {
71    fn do_drop(&mut self) -> Result<(), CDropError> {
72        Ok(())
73    }
74}
75
76impl CReprOf<MatcherRewriteMode> for CMatcherRewriteMode {
77    fn c_repr_of(value: MatcherRewriteMode) -> Result<CMatcherRewriteMode, CReprOfError> {
78        let variant = match value {
79            MatcherRewriteMode::MatcherRewriteAuto => 0,
80            MatcherRewriteMode::MatcherRewriteAlways => 1,
81            MatcherRewriteMode::MatcherRewriteNever => 2,
82        };
83        Ok(CMatcherRewriteMode(variant))
84    }
85}
86
87#[derive(RawPointerConverter, Debug, Clone)]
88pub struct CSigmaMatcherConfig {
89    pub sigma_label: CLabel,
90    pub rewrite_mode: CMatcherRewriteMode,
91    pub sigma_allowed_matches: Option<Vec<CLabel>>,
92}
93
94impl AsRust<SigmaMatcherConfig> for CSigmaMatcherConfig {
95    fn as_rust(&self) -> Result<SigmaMatcherConfig, AsRustError> {
96        Ok(SigmaMatcherConfig {
97            sigma_label: self.sigma_label.as_rust()?,
98            rewrite_mode: self.rewrite_mode.as_rust()?,
99            sigma_allowed_matches: self.sigma_allowed_matches.clone(),
100        })
101    }
102}
103
104impl CDrop for CSigmaMatcherConfig {
105    fn do_drop(&mut self) -> Result<(), CDropError> {
106        Ok(())
107    }
108}
109
110impl CReprOf<SigmaMatcherConfig> for CSigmaMatcherConfig {
111    fn c_repr_of(input: SigmaMatcherConfig) -> Result<Self, CReprOfError> {
112        Ok(CSigmaMatcherConfig {
113            sigma_label: <Label as CReprOf<_>>::c_repr_of(input.sigma_label)?,
114            rewrite_mode: CMatcherRewriteMode::c_repr_of(input.rewrite_mode)?,
115            sigma_allowed_matches: input.sigma_allowed_matches,
116        })
117    }
118}
119
120#[derive(RawPointerConverter, Debug, Clone, Default)]
121pub struct CMatcherConfig {
122    pub sigma_matcher_config: Option<CSigmaMatcherConfig>,
123}
124
125impl AsRust<MatcherConfig> for CMatcherConfig {
126    fn as_rust(&self) -> Result<MatcherConfig, AsRustError> {
127        if let Some(v) = &self.sigma_matcher_config {
128            Ok(MatcherConfig {
129                sigma_matcher_config: Some(v.as_rust()?),
130            })
131        } else {
132            Ok(MatcherConfig {
133                sigma_matcher_config: None,
134            })
135        }
136    }
137}
138
139impl CDrop for CMatcherConfig {
140    fn do_drop(&mut self) -> Result<(), CDropError> {
141        self.sigma_matcher_config
142            .as_mut()
143            .map(|v| v.do_drop())
144            .transpose()?;
145        Ok(())
146    }
147}
148
149impl CReprOf<MatcherConfig> for CMatcherConfig {
150    fn c_repr_of(input: MatcherConfig) -> Result<Self, CReprOfError> {
151        if let Some(v) = input.sigma_matcher_config {
152            Ok(Self {
153                sigma_matcher_config: Some(CReprOf::c_repr_of(v)?),
154            })
155        } else {
156            Ok(Self {
157                sigma_matcher_config: None,
158            })
159        }
160    }
161}
162
163#[derive(AsRust, CReprOf, CDrop, RawPointerConverter, Debug)]
164#[target_type(ComposeConfig)]
165pub struct CComposeConfig {
166    pub compose_filter: CComposeFilterEnum,
167    pub connect: bool,
168    pub matcher1_config: CMatcherConfig,
169    pub matcher2_config: CMatcherConfig,
170}
171
172#[derive(Debug)]
173#[repr(C)]
174pub struct CIntArray {
175    pub data: *const u32,
176    pub size: usize,
177}
178
179impl<'a> From<&'a [u32]> for CIntArray {
180    fn from(array: &[u32]) -> Self {
181        Self {
182            size: array.len(),
183            data: array.as_ptr(),
184        }
185    }
186}
187
188/// # Safety
189///
190/// The pointers should be valid.
191#[no_mangle]
192pub unsafe extern "C" fn fst_matcher_config_new(
193    sigma_label: libc::size_t,
194    rewrite_mode: libc::size_t,
195    sigma_allowed_matches: CIntArray,
196    config: *mut *const CMatcherConfig,
197) -> RUSTFST_FFI_RESULT {
198    wrap(|| {
199        let sigma_allowed_matches = unsafe {
200            std::slice::from_raw_parts(sigma_allowed_matches.data, sigma_allowed_matches.size)
201                .to_vec()
202        };
203        let sigma_allowed_matches = sigma_allowed_matches
204            .iter()
205            .map(|v| *v as CLabel)
206            .collect::<Vec<_>>();
207        let sigma_allowed_matches = if sigma_allowed_matches.is_empty() {
208            None
209        } else {
210            Some(sigma_allowed_matches)
211        };
212        let matcher_config = CMatcherConfig {
213            sigma_matcher_config: Some(CSigmaMatcherConfig {
214                sigma_label: sigma_label as CLabel,
215                rewrite_mode: CMatcherRewriteMode(rewrite_mode),
216                sigma_allowed_matches,
217            }),
218        };
219
220        unsafe { *config = matcher_config.into_raw_pointer() };
221        Ok(())
222    })
223}
224
225/// # Safety
226///
227/// The pointers should be valid.
228#[no_mangle]
229pub unsafe extern "C" fn fst_compose_config_new(
230    compose_filter: libc::size_t,
231    connect: bool,
232    matcher1_config: *const CMatcherConfig,
233    matcher2_config: *const CMatcherConfig,
234    config: *mut *const CComposeConfig,
235) -> RUSTFST_FFI_RESULT {
236    wrap(|| {
237        let matcher1_config = if matcher1_config.is_null() {
238            CMatcherConfig::default()
239        } else {
240            unsafe {
241                <CMatcherConfig as ffi_convert::RawBorrow<CMatcherConfig>>::raw_borrow(
242                    matcher1_config,
243                )?
244            }
245            .clone()
246        };
247
248        let matcher2_config = if matcher2_config.is_null() {
249            CMatcherConfig::default()
250        } else {
251            unsafe {
252                <CMatcherConfig as ffi_convert::RawBorrow<CMatcherConfig>>::raw_borrow(
253                    matcher2_config,
254                )?
255            }
256            .clone()
257        };
258
259        let compose_config = CComposeConfig {
260            matcher1_config,
261            matcher2_config,
262            compose_filter: CComposeFilterEnum(compose_filter),
263            connect,
264        };
265        unsafe { *config = compose_config.into_raw_pointer() };
266        Ok(())
267    })
268}
269
270/// # Safety
271///
272/// The pointers should be valid.
273#[no_mangle]
274pub unsafe extern "C" fn fst_matcher_config_destroy(
275    ptr: *mut CMatcherConfig,
276) -> RUSTFST_FFI_RESULT {
277    wrap(|| {
278        if ptr.is_null() {
279            return Ok(());
280        }
281
282        drop(unsafe { Box::from_raw(ptr) });
283        Ok(())
284    })
285}
286
287/// # Safety
288///
289/// The pointers should be valid.
290#[no_mangle]
291pub unsafe extern "C" fn fst_compose_config_destroy(
292    ptr: *mut CComposeConfig,
293) -> RUSTFST_FFI_RESULT {
294    wrap(|| {
295        if ptr.is_null() {
296            return Ok(());
297        }
298
299        drop(unsafe { Box::from_raw(ptr) });
300        Ok(())
301    })
302}
303
304/// # Safety
305///
306/// The pointers should be valid.
307#[no_mangle]
308pub unsafe extern "C" fn fst_compose(
309    fst_1: *const CFst,
310    fst_2: *const CFst,
311    composition_ptr: *mut *const CFst,
312) -> RUSTFST_FFI_RESULT {
313    wrap(|| {
314        let fst_1 = get!(CFst, fst_1);
315        let vec_fst1: &VectorFst<TropicalWeight> = fst_1
316            .downcast_ref()
317            .ok_or_else(|| anyhow!("Could not downcast to vector FST"))?;
318        let fst_2 = get!(CFst, fst_2);
319        let vec_fst2: &VectorFst<TropicalWeight> = fst_2
320            .downcast_ref()
321            .ok_or_else(|| anyhow!("Could not downcast to vector FST"))?;
322        let fst: VectorFst<TropicalWeight> = compose::<
323            TropicalWeight,
324            VectorFst<TropicalWeight>,
325            VectorFst<TropicalWeight>,
326            _,
327            _,
328            _,
329        >(vec_fst1, vec_fst2)?;
330        let fst_ptr = CFst(Box::new(fst)).into_raw_pointer();
331        unsafe { *composition_ptr = fst_ptr };
332        Ok(())
333    })
334}
335
336/// # Safety
337///
338/// The pointers should be valid.
339#[no_mangle]
340pub unsafe extern "C" fn fst_compose_with_config(
341    fst_1: *const CFst,
342    fst_2: *const CFst,
343    config: *const CComposeConfig,
344    composition_ptr: *mut *const CFst,
345) -> RUSTFST_FFI_RESULT {
346    wrap(|| {
347        let fst_1 = get!(CFst, fst_1);
348        let vec_fst1: &VectorFst<TropicalWeight> = fst_1
349            .downcast_ref()
350            .ok_or_else(|| anyhow!("Could not downcast to vector FST"))?;
351        let fst_2 = get!(CFst, fst_2);
352        let vec_fst2: &VectorFst<TropicalWeight> = fst_2
353            .downcast_ref()
354            .ok_or_else(|| anyhow!("Could not downcast to vector FST"))?;
355
356        let compose_config = unsafe {
357            <CComposeConfig as ffi_convert::RawBorrow<CComposeConfig>>::raw_borrow(config)?
358        };
359        let fst: VectorFst<TropicalWeight> =
360            compose_with_config::<
361                TropicalWeight,
362                VectorFst<TropicalWeight>,
363                VectorFst<TropicalWeight>,
364                _,
365                _,
366                _,
367            >(vec_fst1, vec_fst2, compose_config.as_rust()?)?;
368        let fst_ptr = CFst(Box::new(fst)).into_raw_pointer();
369        unsafe { *composition_ptr = fst_ptr };
370        Ok(())
371    })
372}