Skip to main content

rust_canto/
lib.rs

1mod trie;
2mod token;
3
4use trie::Trie;
5use once_cell::sync::Lazy;
6use wasm_minimal_protocol::*;
7
8const CHAR_DATA: &str = include_str!("../data/chars.tsv");
9const WORD_DATA: &str = include_str!("../data/words.tsv");
10const FREQ_DATA: &str = include_str!("../data/freq.txt");
11
12initiate_protocol!();
13
14static TRIE: Lazy<Trie> = Lazy::new(|| build_trie());
15
16fn build_trie() -> Trie {
17    let mut trie = Trie::new();
18
19    for line in CHAR_DATA.lines() {
20        let parts: Vec<&str> = line.split('\t').collect();
21        if parts.len() >= 2 {
22            if let Some(ch) = parts[0].chars().next() {
23                trie.insert_char(ch, parts[1]);
24            }
25        }
26    }
27
28    for line in WORD_DATA.lines() {
29        let parts: Vec<&str> = line.split('\t').collect();
30        if parts.len() >= 2 {
31            trie.insert_word(parts[0], parts[1]);
32        }
33    }
34
35    for line in FREQ_DATA.lines() {
36        let parts: Vec<&str> = line.split('\t').collect();
37        if parts.len() >= 2 {
38            if let Ok(freq) = parts[1].parse::<i64>() {
39                trie.insert_freq(parts[0], freq);
40            }
41        }
42    }
43
44    trie
45}
46
47#[wasm_func]
48pub fn annotate(input: &[u8]) -> Vec<u8> {
49    let text = std::str::from_utf8(input).unwrap_or("");
50    let tokens = TRIE.segment(text);
51    serde_json::to_string(&tokens)
52        .unwrap_or_else(|_| "[]".to_string())
53        .into_bytes()
54}
55
56#[cfg(test)]
57mod tests {
58    use super::*;
59
60    #[test]
61    fn test_segmentation() {
62        let trie = build_trie();
63
64        let cases = vec![
65            (
66                "都會大學",
67                vec![("都會大學", Some("dou1 wui6 daai6 hok6"))],
68            ),
69            (
70                "好學生",
71                vec![
72                    ("好", Some("hou2")),
73                    ("學生", Some("hok6 saang1")),
74                ],
75            ),
76            (
77                "我係好學生",
78                vec![
79                    ("我", Some("ngo5")),
80                    ("係", Some("hai6")),
81                    ("好", Some("hou2")),
82                    ("學生", Some("hok6 saang1")),
83                ],
84            ),
85        ];
86
87        for (input, expected) in cases {
88            println!("Testing: {}", input);
89            let result = trie.segment(input);
90            assert_eq!(result.len(), expected.len(),
91                "token count mismatch for '{}': got {:?}", input,
92                result.iter().map(|t| &t.word).collect::<Vec<_>>()
93            );
94            for (i, token) in result.iter().enumerate() {
95                assert_eq!(token.word, expected[i].0,
96                    "word mismatch at index {} for '{}'", i, input);
97                assert_eq!(token.reading.as_deref(), expected[i].1,
98                    "reading mismatch at index {} for '{}'", i, input);
99            }
100        }
101    }
102}