rustfst/
trs.rs

1use crate::semirings::Semiring;
2use crate::Tr;
3use std::fmt::Debug;
4use std::sync::Arc;
5
6pub trait Trs<W: Semiring>: std::ops::Deref<Target = [Tr<W>]> + Debug {
7    fn trs(&self) -> &[Tr<W>];
8    fn to_trs_vec(&self) -> TrsVec<W>;
9    fn shallow_clone(&self) -> Self;
10}
11
12#[derive(Debug, PartialOrd, PartialEq, Eq)]
13pub struct TrsVec<W: Semiring>(pub Arc<Vec<Tr<W>>>);
14
15impl<W: Semiring> Trs<W> for TrsVec<W> {
16    fn trs(&self) -> &[Tr<W>] {
17        self.0.as_slice()
18    }
19
20    fn to_trs_vec(&self) -> TrsVec<W> {
21        self.shallow_clone()
22    }
23
24    fn shallow_clone(&self) -> Self {
25        Self(Arc::clone(&self.0))
26    }
27}
28
29impl<W: Semiring> TrsVec<W> {
30    pub fn remove(&mut self, index: usize) -> Tr<W> {
31        Arc::make_mut(&mut self.0).remove(index)
32    }
33    pub fn push(&mut self, tr: Tr<W>) {
34        Arc::make_mut(&mut self.0).push(tr)
35    }
36
37    pub fn clear(&mut self) {
38        Arc::make_mut(&mut self.0).clear()
39    }
40}
41
42impl<W: Semiring> Clone for TrsVec<W> {
43    fn clone(&self) -> Self {
44        Self(Arc::new((*self.0).clone()))
45    }
46}
47
48impl<W: Semiring> std::ops::Deref for TrsVec<W> {
49    type Target = [Tr<W>];
50    fn deref(&self) -> &Self::Target {
51        self.trs()
52    }
53}
54
55impl<W: Semiring> Default for TrsVec<W> {
56    fn default() -> Self {
57        Self(Arc::new(vec![]))
58    }
59}
60
61impl<W: Semiring> From<Vec<Tr<W>>> for TrsVec<W> {
62    fn from(v: Vec<Tr<W>>) -> Self {
63        Self(Arc::new(v))
64    }
65}
66
67#[derive(Debug, PartialOrd, PartialEq)]
68pub struct TrsConst<W: Semiring> {
69    pub(crate) trs: Arc<Vec<Tr<W>>>,
70    pub(crate) pos: usize,
71    pub(crate) n: usize,
72}
73
74impl<W: Semiring> Trs<W> for TrsConst<W> {
75    fn trs(&self) -> &[Tr<W>] {
76        &self.trs[self.pos..self.pos + self.n]
77    }
78
79    fn to_trs_vec(&self) -> TrsVec<W> {
80        TrsVec(Arc::new(self.trs().to_vec()))
81    }
82
83    // Doesn't clone the data, only the Arc
84    fn shallow_clone(&self) -> Self {
85        Self {
86            trs: Arc::clone(&self.trs),
87            pos: self.pos,
88            n: self.n,
89        }
90    }
91}
92
93impl<W: Semiring> Clone for TrsConst<W> {
94    fn clone(&self) -> Self {
95        Self {
96            trs: Arc::new((*self.trs).clone()),
97            n: self.n,
98            pos: self.pos,
99        }
100    }
101}
102
103impl<W: Semiring> std::ops::Deref for TrsConst<W> {
104    type Target = [Tr<W>];
105    fn deref(&self) -> &Self::Target {
106        self.trs()
107    }
108}
109
110impl<W: Semiring> Default for TrsConst<W> {
111    fn default() -> Self {
112        Self {
113            trs: Arc::new(vec![]),
114            pos: 0,
115            n: 0,
116        }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    mod test_trs_const {
125        use super::*;
126        use crate::prelude::TropicalWeight;
127        use anyhow::Result;
128
129        #[test]
130        fn test_to_trs_vec() -> Result<()> {
131            let trs = TrsConst {
132                trs: Arc::new(vec![
133                    Tr::<TropicalWeight>::new(1, 1, TropicalWeight::one(), 0),
134                    Tr::<TropicalWeight>::new(1, 1, TropicalWeight::one(), 0),
135                ]),
136                pos: 1,
137                n: 1,
138            };
139
140            let tr_vec = trs.to_trs_vec();
141            assert_eq!(tr_vec.len(), 1);
142
143            Ok(())
144        }
145    }
146}