yoin_core/dic/fst/
mod.rs

1use std::borrow::Borrow;
2use std::iter::IntoIterator;
3
4mod mast;
5mod op;
6
7#[derive(Debug, Clone)]
8pub struct Fst<T>
9    where T: Borrow<[u8]>
10{
11    bytecode: T,
12}
13
14impl<'a> Fst<&'a [u8]> {
15    pub unsafe fn from_bytes(bytes: &'a [u8]) -> Self {
16        Fst { bytecode: bytes }
17    }
18}
19
20impl<T: Borrow<[u8]>> Fst<T> {
21    pub fn run_iter<'a>(&'a self, input: &'a [u8]) -> Iter<'a> {
22        Iter::new(self.bytecode.borrow(), input)
23    }
24
25    pub fn run<'a>(&'a self, input: &'a [u8]) -> Vec<Accept> {
26        self.run_iter(input).collect()
27    }
28
29    pub fn bytecode<'a>(&'a self) -> &'a [u8] {
30        self.bytecode.borrow()
31    }
32}
33
34impl Fst<Vec<u8>> {
35    pub fn build<'a, I: IntoIterator<Item = (&'a [u8], u32)>>(inputs: I) -> Self {
36        let m = mast::Mast::build(inputs);
37        Fst { bytecode: op::build(m) }
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct Iter<'a> {
43    pc: usize,
44    iseq: &'a [u8],
45    input: &'a [u8],
46    len: usize,
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
50pub struct Accept(pub u32);
51
52impl<'a> Iter<'a> {
53    pub fn new(iseq: &'a [u8], input: &'a [u8]) -> Self {
54        Iter {
55            pc: 0,
56            iseq: iseq,
57            input: input,
58            len: 0,
59        }
60    }
61
62    fn read_u16(&mut self) -> u16 {
63        let from = self.iseq[self.pc..].as_ptr() as *const u16;
64        self.pc += 2; // skip 16 bits
65        unsafe { *from }
66    }
67
68    fn read_u32(&mut self) -> u32 {
69        let from = self.iseq[self.pc..].as_ptr() as *const u32;
70        self.pc += 4; // skip 32 bits
71        unsafe { *from }
72    }
73
74    fn get_jump_offset(&mut self, jump_size: u8) -> usize {
75        if jump_size == op::JUMP_SIZE_16 {
76            self.read_u16() as usize
77        } else {
78            debug_assert!(jump_size == op::JUMP_SIZE_32, "invalid bytecode");
79            self.read_u32() as usize
80        }
81    }
82
83    fn run_jump(&mut self) {
84        let op = op::Op(self.iseq[self.pc]);
85        self.pc += 1;
86        let cmp = self.iseq[self.pc];
87        self.pc += 1;
88
89        let jump = self.get_jump_offset(op.jump_bytes());
90        if cmp != self.input[self.len] {
91            return;
92        }
93        self.len += 1;
94        self.pc += jump;
95    }
96
97    fn run_outjump(&mut self) -> Option<u32> {
98        let op = op::Op(self.iseq[self.pc]);
99        self.pc += 1;
100        let cmp = self.iseq[self.pc];
101        self.pc += 1;
102        let jump = self.get_jump_offset(op.jump_bytes());
103        if cmp != self.input[self.len] {
104            self.pc += 4 as usize; // skip unused data bytes.
105            return None;
106        }
107        self.len += 1;
108        let n = self.read_u32();
109        self.pc += jump - 4 as usize;
110        Some(n)
111    }
112}
113
114impl<'a> Iterator for Iter<'a> {
115    type Item = Accept;
116
117    fn next(&mut self) -> Option<Self::Item> {
118        loop {
119            let op = op::Op(self.iseq[self.pc]);
120            match op.code() {
121                op::OPCODE_BREAK => return None,
122                op::OPCODE_JUMP => {
123                    if self.len >= self.input.len() {
124                        return None;
125                    }
126                    self.run_jump();
127                }
128                op::OPCODE_OUTJUMP => {
129                    if self.len >= self.input.len() {
130                        return None;
131                    }
132                    match self.run_outjump() {
133                        None => (),
134                        Some(n) => return Some(Accept(n)),
135                    }
136                }
137                op::OPCODE_ACCEPT_WITH => {
138                    self.pc += 1; // skip op::OPCODE_ACCEPT_WITH
139                    let n = self.read_u32();
140                    let accept = Accept(n);
141                    return Some(accept);
142                }
143                op => unreachable!("unknown operator in bytecode: {:?}", op),
144            }
145        }
146    }
147}
148
149#[test]
150fn test_run() {
151    use std::collections::HashSet;
152
153    let samples: Vec<(&[u8], u32)> = vec![(b"ab", 0xFF), (b"abc", 0), (b"abc", !0), (b"abd", 1)];
154    let iseq = Fst::build(samples);
155    let accs: HashSet<_> = iseq.run(b"abc").into_iter().collect();
156    let expects: HashSet<_> = vec![Accept(0xFF), Accept(0), Accept(!0)]
157        .into_iter()
158        .collect();
159    assert_eq!(accs, expects);
160}
161
162#[test]
163fn test_op() {
164    use std::collections::HashSet;
165    let samples: Vec<(&[u8], u32)> = vec![(b"apr", 0),
166                                          (b"aug", 1),
167                                          (b"dec", 2),
168                                          (b"feb", 3),
169                                          (b"feb", 4),
170                                          (b"feb'", 8),
171                                          (b"jan", 5),
172                                          (b"jul", 6),
173                                          (b"jun", 7)];
174    let iseq = Fst::build(samples);
175    let expected =
176        vec![Accept(3), Accept(4), Accept(8)]
177            .into_iter()
178            .collect();
179    assert_eq!(iseq.run_iter(b"feb'").collect::<HashSet<_>>(), expected);
180}