Skip to main content

rust_canto/
lib.rs

1mod trie;
2mod token;
3mod yale;
4use yale::jyutping_to_yale;
5use yale::jyutping_to_yale_vec;
6
7use trie::Trie;
8use token::Token;
9use once_cell::sync::Lazy;
10use wasm_minimal_protocol::*;
11
12const CHAR_DATA: &str = include_str!("../data/chars.tsv");
13const WORD_DATA: &str = include_str!("../data/words.tsv");
14const FREQ_DATA: &str = include_str!("../data/freq.txt");
15
16initiate_protocol!();
17
18static TRIE: Lazy<Trie> = Lazy::new(|| build_trie());
19
20fn build_trie() -> Trie {
21    let mut trie = Trie::new();
22
23    for line in CHAR_DATA.lines() {
24        let parts: Vec<&str> = line.split('\t').collect();
25        if parts.len() >= 2 {
26            if let Some(ch) = parts[0].chars().next() {
27                // parse "5%" → 5, missing → 100 (highest priority)
28                let weight = parts.get(2)
29                    .map(|s| s.replace('%', "").trim().parse::<u32>().unwrap_or(0))
30                    .unwrap_or(100);
31                trie.insert_char(ch, parts[1], weight);
32            }
33        }
34    }
35
36    for line in WORD_DATA.lines() {
37        let parts: Vec<&str> = line.split('\t').collect();
38        if parts.len() >= 2 {
39            trie.insert_word(parts[0], parts[1]);
40        }
41    }
42
43    for line in FREQ_DATA.lines() {
44        let parts: Vec<&str> = line.split('\t').collect();
45        if parts.len() >= 2 {
46            if let Ok(freq) = parts[1].parse::<i64>() {
47                trie.insert_freq(parts[0], freq);
48            }
49        }
50    }
51
52    trie
53}
54
55#[wasm_func]
56pub fn annotate(input: &[u8]) -> Vec<u8> {
57    let text = std::str::from_utf8(input).unwrap_or("");
58    let tokens = TRIE.segment(text);
59
60    let output: Vec<Token> = tokens
61        .into_iter()
62        .map(|t| Token {
63            word: t.word,
64            yale: t.reading.as_deref().and_then(jyutping_to_yale_vec),
65            reading: t.reading,
66        })
67        .collect();
68
69    serde_json::to_string(&output)
70        .unwrap_or_else(|_| "[]".to_string())
71        .into_bytes()
72}
73
74/// Input: jyutping bytes, e.g. b"gwong2 dung1 waa2"
75/// Output: Yale with tone numbers, e.g. b"gwong2 dung1 waa2"
76#[wasm_func]
77pub fn to_yale_numeric(input: &[u8]) -> Vec<u8> {
78    let jp = std::str::from_utf8(input).unwrap_or("");
79    jyutping_to_yale(jp, false)
80        .unwrap_or_default()
81        .into_bytes()
82}
83
84/// Input: jyutping bytes
85/// Output: Yale with diacritics, e.g. b"gwóngdūngwá"
86#[wasm_func]
87pub fn to_yale_diacritics(input: &[u8]) -> Vec<u8> {
88    let jp = std::str::from_utf8(input).unwrap_or("");
89    jyutping_to_yale(jp, true)
90        .unwrap_or_default()
91        .into_bytes()
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_segmentation() {
100        let trie = build_trie();
101
102        let cases = vec![
103            (
104                "都會大學",
105                vec![("都會大學", Some("dou1 wui6 daai6 hok6"))],
106            ),
107            (
108                "我會番教會",
109                vec![
110                    ("我", Some("ngo5")),
111                    ("會", Some("wui5")),
112                    ("番", Some("faan1")),
113                    ("教會", Some("gaau3 wui2")),
114                ],
115            ),
116            (
117                "佢係好學生",
118                vec![
119                    ("佢", Some("keoi5")),
120                    ("係", Some("hai6")),
121                    ("好", Some("hou2")),
122                    ("學生", Some("hok6 saang1")),
123                ],
124            ),
125        ];
126
127        for (input, expected) in cases {
128            println!("Testing: {}", input);
129            let result = trie.segment(input);
130            assert_eq!(result.len(), expected.len(),
131                "token count mismatch for '{}': got {:?}", input,
132                result.iter().map(|t| &t.word).collect::<Vec<_>>()
133            );
134            for (i, token) in result.iter().enumerate() {
135                assert_eq!(token.word, expected[i].0,
136                    "word mismatch at index {} for '{}'", i, input);
137                assert_eq!(token.reading.as_deref(), expected[i].1,
138                    "reading mismatch at index {} for '{}'", i, input);
139            }
140        }
141    }
142}