1mod pinyin_dict;
22
23use jieba_rs::Jieba;
24use pinyin_dict::{lookup_numbers, numbers_to_marks};
25use std::sync::OnceLock;
26
27use wasm_minimal_protocol::*;
30initiate_protocol!();
31
32static JIEBA: OnceLock<Jieba> = OnceLock::new();
35
36fn get_jieba() -> &'static Jieba {
37 JIEBA.get_or_init(|| {
38 use ruzstd::streaming_decoder::StreamingDecoder;
39 use std::io::Read;
40 static DICT_ZSTD: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/dict.dat"));
41 let mut buf = Vec::new();
42 StreamingDecoder::new(DICT_ZSTD)
43 .expect("invalid zstd stream in dict.dat")
44 .read_to_end(&mut buf)
45 .expect("failed to decompress dict.dat");
46 Jieba::with_dict(&mut buf.as_slice())
47 .expect("failed to load jieba dictionary")
48 })
49}
50
51fn is_cjk(ch: char) -> bool {
55 matches!(ch as u32,
56 0x3400..=0x4DBF | 0x4E00..=0x9FFF | 0xF900..=0xFAFF | 0x20000..=0x3FFFF )
61}
62
63fn apply_style(numbers: &str, style: &str) -> String {
68 match style {
69 "numbers" | "pinyin_numbers" => numbers.to_string(),
70 _ => numbers_to_marks(numbers),
71 }
72}
73
74fn render_word(word: &str, style: &str) -> Option<Vec<String>> {
87 if !word.chars().any(is_cjk) {
88 return None;
89 }
90
91 let char_count = word.chars().count();
92
93 if let Some(numbers) = lookup_numbers(word) {
95 let syllables: Vec<String> = numbers
96 .split_whitespace()
97 .map(|s| apply_style(s, style))
98 .collect();
99 if syllables.len() == char_count {
100 return Some(syllables);
101 }
102 }
103
104 Some(
106 word.chars()
107 .map(|ch| {
108 if is_cjk(ch) {
109 let s = ch.to_string();
110 lookup_numbers(&s)
111 .map(|n| apply_style(n.split_whitespace().next().unwrap_or(""), style))
112 .unwrap_or(s)
113 } else {
114 ch.to_string()
115 }
116 })
117 .collect(),
118 )
119}
120
121pub fn to_pinyin_flat(text: &str, style: &str) -> String {
127 get_jieba()
128 .cut(text, false)
129 .iter()
130 .filter_map(|w| render_word(w, style))
131 .flatten()
132 .collect::<Vec<_>>()
133 .join(" ")
134}
135
136#[derive(serde::Serialize, Debug, PartialEq)]
139pub struct Segment {
140 pub word: String,
141 pub pinyin: Option<Vec<String>>,
142}
143
144pub fn to_pinyin_segmented(text: &str, style: &str) -> Vec<Segment> {
146 get_jieba()
147 .cut(text, false)
148 .iter()
149 .map(|w| Segment {
150 word: w.to_string(),
151 pinyin: render_word(w, style),
152 })
153 .collect()
154}
155
156#[wasm_func]
160pub fn pinyin_flat(text: &[u8], style: &[u8]) -> Vec<u8> {
161 let text = std::str::from_utf8(text).unwrap_or("");
162 let style = std::str::from_utf8(style).unwrap_or("marks");
163 to_pinyin_flat(text, style).into_bytes()
164}
165
166#[wasm_func]
168pub fn pinyin_segmented(text: &[u8], style: &[u8]) -> Vec<u8> {
169 let text = std::str::from_utf8(text).unwrap_or("");
170 let style = std::str::from_utf8(style).unwrap_or("marks");
171 serde_json::to_vec(&to_pinyin_segmented(text, style))
172 .unwrap_or_else(|_| b"[]".to_vec())
173}
174
175#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn flat_marks_basic() {
183 assert_eq!(to_pinyin_flat("你好", "marks"), "nǐ hǎo");
184 }
185
186 #[test]
187 fn flat_numbers_basic() {
188 assert_eq!(to_pinyin_flat("你好", "numbers"), "ni3 hao3");
189 }
190
191 #[test]
192 fn flat_marks_beijing() {
193 assert_eq!(to_pinyin_flat("北京", "marks"), "běi jīng");
194 }
195
196 #[test]
197 fn flat_numbers_beijing() {
198 assert_eq!(to_pinyin_flat("北京", "numbers"), "bei3 jing1");
199 }
200
201 #[test]
202 fn heteronym_zhong_in_zhongguo() {
203 assert_eq!(to_pinyin_flat("中國", "marks"), "Zhōng guó");
204 }
205
206 #[test]
207 fn heteronym_le_in_kuaile() {
208 assert_eq!(to_pinyin_flat("快樂", "marks"), "kuài lè");
209 }
210
211 #[test]
212 fn heteronym_yue_in_yinyue() {
213 assert_eq!(to_pinyin_flat("音樂", "marks"), "yīn yuè");
214 }
215
216 #[test]
217 fn segmented_ziran_yuyan() {
218 assert_eq!(
219 to_pinyin_segmented("自然語言", "marks"),
220 vec![Segment {
221 word: "自然語言".to_string(),
222 pinyin: Some(vec![
223 "zì".to_string(), "rán".to_string(),
224 "yǔ".to_string(), "yán".to_string(),
225 ]),
226 }]
227 );
228 }
229
230 #[test]
231 fn segmented_empty() {
232 assert!(to_pinyin_segmented("", "marks").is_empty());
233 }
234
235 #[test]
236 fn latin_word_pinyin_is_null() {
237 let segs = to_pinyin_segmented("world", "marks");
238 assert_eq!(segs.len(), 1);
239 assert_eq!(segs[0].pinyin, None);
240 }
241
242 #[test]
243 fn punctuation_pinyin_is_null() {
244 for token in ["!", "?", ",", ",", " ", "\n"] {
245 let segs = to_pinyin_segmented(token, "marks");
246 for seg in &segs {
247 assert_eq!(seg.pinyin, None,
248 "expected null pinyin for {:?}, got {:?}", token, seg.pinyin);
249 }
250 }
251 }
252
253 #[test]
254 fn flat_skips_non_chinese() {
255 assert_eq!(to_pinyin_flat("world!", "marks"), "");
256 assert_eq!(to_pinyin_flat("北京!world", "marks"), "běi jīng");
257 }
258
259 #[test]
260 fn unknown_style_falls_back_to_marks() {
261 assert_eq!(
262 to_pinyin_flat("好", "marks"),
263 to_pinyin_flat("好", "whatever")
264 );
265 }
266}