1use std::ffi::{CStr, CString};
25use std::fs::File;
26use std::io::BufWriter;
27use std::path::Path;
28
29const MODEL_BASE_URL: &str =
31 "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131";
32
33#[derive(Debug, Clone)]
35pub struct UdpipeError {
36 pub message: String,
38}
39
40impl UdpipeError {
41 pub fn new(message: impl Into<String>) -> Self {
43 Self {
44 message: message.into(),
45 }
46 }
47}
48
49impl std::fmt::Display for UdpipeError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "UDPipe error: {}", self.message)
52 }
53}
54
55impl std::error::Error for UdpipeError {}
56
57impl From<std::io::Error> for UdpipeError {
58 fn from(err: std::io::Error) -> Self {
59 Self {
60 message: err.to_string(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67pub struct Word {
68 pub form: String,
70 pub lemma: String,
72 pub upostag: String,
74 pub xpostag: String,
76 pub feats: String,
78 pub deprel: String,
80 pub misc: String,
82 pub id: i32,
84 pub head: i32,
86 pub sentence_id: i32,
88}
89
90impl Word {
91 #[must_use]
111 pub fn has_feature(&self, key: &str, value: &str) -> bool {
112 self.get_feature(key) == Some(value)
113 }
114
115 #[must_use]
135 pub fn get_feature(&self, key: &str) -> Option<&str> {
136 self.feats
137 .split('|')
138 .find_map(|f| f.strip_prefix(key)?.strip_prefix('='))
139 }
140
141 #[must_use]
143 pub fn is_verb(&self) -> bool {
144 self.upostag == "VERB" || self.upostag == "AUX"
145 }
146
147 #[must_use]
149 pub fn is_noun(&self) -> bool {
150 self.upostag == "NOUN" || self.upostag == "PROPN"
151 }
152
153 #[must_use]
155 pub fn is_adjective(&self) -> bool {
156 self.upostag == "ADJ"
157 }
158
159 #[must_use]
161 pub fn is_punct(&self) -> bool {
162 self.upostag == "PUNCT"
163 }
164
165 #[must_use]
167 pub fn is_root(&self) -> bool {
168 self.deprel == "root"
169 }
170
171 #[must_use]
176 pub fn space_after(&self) -> bool {
177 !self.misc.contains("SpaceAfter=No")
178 }
179}
180
181mod ffi {
183 use std::os::raw::c_char;
184
185 #[repr(C)]
186 pub struct UdpipeModel {
187 _private: [u8; 0],
188 }
189
190 #[repr(C)]
191 pub struct UdpipeParseResult {
192 _private: [u8; 0],
193 }
194
195 #[repr(C)]
196 pub struct UdpipeWord {
197 pub form: *const c_char,
198 pub lemma: *const c_char,
199 pub upostag: *const c_char,
200 pub xpostag: *const c_char,
201 pub feats: *const c_char,
202 pub deprel: *const c_char,
203 pub misc: *const c_char,
204 pub id: i32,
205 pub head: i32,
206 pub sentence_id: i32,
207 }
208
209 unsafe extern "C" {
210 pub fn udpipe_model_load(model_path: *const c_char) -> *mut UdpipeModel;
211 pub fn udpipe_model_load_from_memory(data: *const u8, len: usize) -> *mut UdpipeModel;
212 pub fn udpipe_model_free(model: *mut UdpipeModel);
213 pub fn udpipe_parse(model: *mut UdpipeModel, text: *const c_char)
214 -> *mut UdpipeParseResult;
215 pub fn udpipe_result_free(result: *mut UdpipeParseResult);
216 pub fn udpipe_get_error() -> *const c_char;
217 pub fn udpipe_result_word_count(result: *mut UdpipeParseResult) -> i32;
218 pub fn udpipe_result_get_word(result: *mut UdpipeParseResult, index: i32) -> UdpipeWord;
219 }
220}
221
222fn get_ffi_error() -> String {
224 unsafe {
225 let err_ptr = ffi::udpipe_get_error();
226 assert!(!err_ptr.is_null(), "UDPipe returned null error pointer");
227 CStr::from_ptr(err_ptr).to_string_lossy().into_owned()
228 }
229}
230
231pub struct Model {
236 inner: *mut ffi::UdpipeModel,
237}
238
239impl std::fmt::Debug for Model {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 f.debug_struct("Model")
242 .field("inner", &(!self.inner.is_null()))
243 .finish()
244 }
245}
246
247unsafe impl Send for Model {}
249unsafe impl Sync for Model {}
250
251impl Model {
252 pub fn load(path: impl AsRef<Path>) -> Result<Self, UdpipeError> {
264 let path_str = path.as_ref().to_string_lossy();
265 let c_path = CString::new(path_str.as_bytes()).map_err(|_| UdpipeError {
266 message: "Invalid path (contains null byte)".to_string(),
267 })?;
268
269 let model = unsafe { ffi::udpipe_model_load(c_path.as_ptr()) };
270
271 if model.is_null() {
272 return Err(UdpipeError {
273 message: get_ffi_error(),
274 });
275 }
276
277 Ok(Self { inner: model })
278 }
279
280 pub fn load_from_memory(data: &[u8]) -> Result<Self, UdpipeError> {
295 let model = unsafe { ffi::udpipe_model_load_from_memory(data.as_ptr(), data.len()) };
296
297 if model.is_null() {
298 return Err(UdpipeError {
299 message: get_ffi_error(),
300 });
301 }
302
303 Ok(Self { inner: model })
304 }
305
306 pub fn parse(&self, text: &str) -> Result<Vec<Word>, UdpipeError> {
324 let c_text = CString::new(text).map_err(|_| UdpipeError {
325 message: "Invalid text (contains null byte)".to_string(),
326 })?;
327
328 let result = unsafe { ffi::udpipe_parse(self.inner, c_text.as_ptr()) };
329 if result.is_null() {
330 return Err(UdpipeError {
331 message: get_ffi_error(),
332 });
333 }
334
335 let word_count = unsafe { ffi::udpipe_result_word_count(result) };
336 let capacity = usize::try_from(word_count).unwrap_or(0);
337 let mut words = Vec::with_capacity(capacity);
338
339 for i in 0..word_count {
340 let word = unsafe { ffi::udpipe_result_get_word(result, i) };
341 words.push(Word {
342 form: unsafe { CStr::from_ptr(word.form).to_string_lossy().into_owned() },
343 lemma: unsafe { CStr::from_ptr(word.lemma).to_string_lossy().into_owned() },
344 upostag: unsafe { CStr::from_ptr(word.upostag).to_string_lossy().into_owned() },
345 xpostag: unsafe { CStr::from_ptr(word.xpostag).to_string_lossy().into_owned() },
346 feats: unsafe { CStr::from_ptr(word.feats).to_string_lossy().into_owned() },
347 deprel: unsafe { CStr::from_ptr(word.deprel).to_string_lossy().into_owned() },
348 misc: unsafe { CStr::from_ptr(word.misc).to_string_lossy().into_owned() },
349 id: word.id,
350 head: word.head,
351 sentence_id: word.sentence_id,
352 });
353 }
354
355 unsafe { ffi::udpipe_result_free(result) };
356
357 Ok(words)
358 }
359}
360
361impl Drop for Model {
362 fn drop(&mut self) {
363 if !self.inner.is_null() {
364 unsafe { ffi::udpipe_model_free(self.inner) };
365 }
366 }
367}
368
369pub const AVAILABLE_MODELS: &[&str] = &[
374 "afrikaans-afribooms",
375 "ancient_greek-perseus",
376 "ancient_greek-proiel",
377 "arabic-padt",
378 "armenian-armtdp",
379 "basque-bdt",
380 "belarusian-hse",
381 "bulgarian-btb",
382 "buryat-bdt",
383 "catalan-ancora",
384 "chinese-gsd",
385 "chinese-gsdsimp",
386 "classical_chinese-kyoto",
387 "coptic-scriptorium",
388 "croatian-set",
389 "czech-cac",
390 "czech-cltt",
391 "czech-fictree",
392 "czech-pdt",
393 "danish-ddt",
394 "dutch-alpino",
395 "dutch-lassysmall",
396 "english-ewt",
397 "english-gum",
398 "english-lines",
399 "english-partut",
400 "estonian-edt",
401 "estonian-ewt",
402 "finnish-ftb",
403 "finnish-tdt",
404 "french-gsd",
405 "french-partut",
406 "french-sequoia",
407 "french-spoken",
408 "galician-ctg",
409 "galician-treegal",
410 "german-gsd",
411 "german-hdt",
412 "gothic-proiel",
413 "greek-gdt",
414 "hebrew-htb",
415 "hindi-hdtb",
416 "hungarian-szeged",
417 "indonesian-gsd",
418 "irish-idt",
419 "italian-isdt",
420 "italian-partut",
421 "italian-postwita",
422 "italian-twittiro",
423 "italian-vit",
424 "japanese-gsd",
425 "kazakh-ktb",
426 "korean-gsd",
427 "korean-kaist",
428 "kurmanji-mg",
429 "latin-ittb",
430 "latin-perseus",
431 "latin-proiel",
432 "latvian-lvtb",
433 "lithuanian-alksnis",
434 "lithuanian-hse",
435 "maltese-mudt",
436 "marathi-ufal",
437 "north_sami-giella",
438 "norwegian-bokmaal",
439 "norwegian-nynorsk",
440 "norwegian-nynorsklia",
441 "old_church_slavonic-proiel",
442 "old_french-srcmf",
443 "old_russian-torot",
444 "persian-seraji",
445 "polish-lfg",
446 "polish-pdb",
447 "polish-sz",
448 "portuguese-bosque",
449 "portuguese-br",
450 "portuguese-gsd",
451 "romanian-nonstandard",
452 "romanian-rrt",
453 "russian-gsd",
454 "russian-syntagrus",
455 "russian-taiga",
456 "sanskrit-ufal",
457 "scottish_gaelic-arcosg",
458 "serbian-set",
459 "slovak-snk",
460 "slovenian-ssj",
461 "slovenian-sst",
462 "spanish-ancora",
463 "spanish-gsd",
464 "swedish-lines",
465 "swedish-talbanken",
466 "tamil-ttb",
467 "telugu-mtg",
468 "turkish-imst",
469 "ukrainian-iu",
470 "upper_sorbian-ufal",
471 "urdu-udtb",
472 "uyghur-udt",
473 "vietnamese-vtb",
474 "wolof-wtb",
475];
476
477pub fn download_model(language: &str, dest_dir: impl AsRef<Path>) -> Result<String, UdpipeError> {
505 let dest_dir = dest_dir.as_ref();
506
507 if !AVAILABLE_MODELS.contains(&language) {
509 return Err(UdpipeError {
510 message: format!(
511 "Unknown language '{}'. Use one of: {}",
512 language,
513 AVAILABLE_MODELS[..5].join(", ") + ", ..."
514 ),
515 });
516 }
517
518 let filename = model_filename(language);
520 let dest_path = dest_dir.join(&filename);
521 let url = format!("{MODEL_BASE_URL}/{filename}");
522
523 download_model_from_url(&url, &dest_path)?;
525
526 Ok(dest_path.to_string_lossy().into_owned())
527}
528
529pub fn download_model_from_url(url: &str, path: impl AsRef<Path>) -> Result<(), UdpipeError> {
549 let path = path.as_ref();
550
551 let response = ureq::get(url).call().map_err(|e| UdpipeError {
553 message: format!("Failed to download: {e}"),
554 })?;
555
556 let file = File::create(path)?;
558 let mut writer = BufWriter::new(file);
559 let bytes_written = std::io::copy(&mut response.into_body().into_reader(), &mut writer)?;
560
561 if bytes_written == 0 {
562 return Err(UdpipeError {
563 message: "Downloaded file is empty".to_string(),
564 });
565 }
566
567 Ok(())
568}
569
570#[must_use]
578pub fn model_filename(language: &str) -> String {
579 format!("{language}-ud-2.5-191206.udpipe")
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585
586 fn make_word(feats: &str) -> Word {
587 Word {
588 form: "test".to_string(),
589 lemma: "test".to_string(),
590 upostag: "NOUN".to_string(),
591 xpostag: String::new(),
592 feats: feats.to_string(),
593 deprel: "root".to_string(),
594 misc: String::new(),
595 id: 1,
596 head: 0,
597 sentence_id: 0,
598 }
599 }
600
601 #[test]
602 fn test_word_has_feature() {
603 let word = make_word("Mood=Imp|VerbForm=Fin");
604
605 assert!(word.has_feature("Mood", "Imp"));
606 assert!(word.has_feature("VerbForm", "Fin"));
607 assert!(!word.has_feature("Mood", "Ind"));
608 assert!(!word.has_feature("Tense", "Past"));
609 }
610
611 #[test]
612 fn test_word_has_feature_empty() {
613 let word = make_word("");
614 assert!(!word.has_feature("Mood", "Imp"));
615 }
616
617 #[test]
618 fn test_word_has_feature_single() {
619 let word = make_word("Mood=Imp");
620 assert!(word.has_feature("Mood", "Imp"));
621 assert!(!word.has_feature("VerbForm", "Fin"));
622 }
623
624 #[test]
625 fn test_word_get_feature() {
626 let word = make_word("Tense=Pres|VerbForm=Part");
627
628 assert_eq!(word.get_feature("Tense"), Some("Pres"));
629 assert_eq!(word.get_feature("VerbForm"), Some("Part"));
630 assert_eq!(word.get_feature("Mood"), None);
631 }
632
633 #[test]
634 fn test_word_get_feature_empty() {
635 let word = make_word("");
636 assert_eq!(word.get_feature("Mood"), None);
637 }
638
639 #[test]
640 fn test_word_get_feature_single() {
641 let word = make_word("Mood=Imp");
642 assert_eq!(word.get_feature("Mood"), Some("Imp"));
643 assert_eq!(word.get_feature("VerbForm"), None);
644 }
645
646 #[test]
647 fn test_word_is_verb() {
648 let mut word = make_word("");
649 word.upostag = "VERB".to_string();
650 assert!(word.is_verb());
651
652 word.upostag = "AUX".to_string();
653 assert!(word.is_verb());
654
655 word.upostag = "NOUN".to_string();
656 assert!(!word.is_verb());
657 }
658
659 #[test]
660 fn test_word_is_noun() {
661 let mut word = make_word("");
662 word.upostag = "NOUN".to_string();
663 assert!(word.is_noun());
664
665 word.upostag = "PROPN".to_string();
666 assert!(word.is_noun());
667
668 word.upostag = "VERB".to_string();
669 assert!(!word.is_noun());
670 }
671
672 #[test]
673 fn test_word_is_root() {
674 let mut word = make_word("");
675 word.deprel = "root".to_string();
676 assert!(word.is_root());
677
678 word.deprel = "nsubj".to_string();
679 assert!(!word.is_root());
680 }
681
682 #[test]
683 fn test_word_is_adjective() {
684 let mut word = make_word("");
685 word.upostag = "ADJ".to_string();
686 assert!(word.is_adjective());
687
688 word.upostag = "NOUN".to_string();
689 assert!(!word.is_adjective());
690 }
691
692 #[test]
693 fn test_word_is_punct() {
694 let mut word = make_word("");
695 word.upostag = "PUNCT".to_string();
696 assert!(word.is_punct());
697
698 word.upostag = "NOUN".to_string();
699 assert!(!word.is_punct());
700 }
701
702 #[test]
703 fn test_word_hash() {
704 use std::collections::HashSet;
705
706 let word1 = make_word("Mood=Imp");
707 let word2 = make_word("Mood=Imp");
708 let mut set = HashSet::new();
709 set.insert(word1);
710 assert!(set.contains(&word2));
711 }
712
713 #[test]
714 fn test_model_filename() {
715 assert_eq!(
716 model_filename("english-ewt"),
717 "english-ewt-ud-2.5-191206.udpipe"
718 );
719 assert_eq!(
720 model_filename("dutch-alpino"),
721 "dutch-alpino-ud-2.5-191206.udpipe"
722 );
723 }
724
725 #[test]
726 fn test_available_models_contains_common_languages() {
727 assert!(AVAILABLE_MODELS.contains(&"english-ewt"));
728 assert!(AVAILABLE_MODELS.contains(&"german-gsd"));
729 assert!(AVAILABLE_MODELS.contains(&"french-gsd"));
730 assert!(AVAILABLE_MODELS.contains(&"spanish-ancora"));
731 }
732
733 #[test]
734 fn test_available_models_sorted() {
735 let mut sorted = AVAILABLE_MODELS.to_vec();
737 sorted.sort_unstable();
738 assert_eq!(AVAILABLE_MODELS, sorted.as_slice());
739 }
740
741 #[test]
742 fn test_download_model_invalid_language() {
743 let result = download_model("invalid-language-xyz", ".");
744 assert!(result.is_err());
745 let err = result.unwrap_err();
746 assert!(err.message.contains("Unknown language"));
747 }
748
749 #[test]
750 fn test_udpipe_error_display() {
751 let err = UdpipeError::new("test error");
752 assert_eq!(format!("{err}"), "UDPipe error: test error");
753 }
754
755 #[test]
756 fn test_udpipe_error_from_io() {
757 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
758 let err: UdpipeError = io_err.into();
759 assert!(err.message.contains("not found"));
760 }
761
762 #[test]
763 fn test_space_after() {
764 let mut word = make_word("");
765 word.misc = String::new();
766 assert!(word.space_after()); word.misc = "SpaceAfter=No".to_string();
769 assert!(!word.space_after());
770
771 word.misc = "SpaceAfter=No|Other=Value".to_string();
772 assert!(!word.space_after());
773 }
774
775 #[test]
776 fn test_model_load_nonexistent_file() {
777 let result = Model::load("/nonexistent/path/to/model.udpipe");
778 assert!(result.is_err());
779 }
780
781 #[test]
782 fn test_model_load_path_with_null_byte() {
783 let result = Model::load("path\0with\0nulls.udpipe");
784 let err = result.expect_err("expected error");
785 assert!(err.message.contains("null byte"));
786 }
787
788 #[test]
789 fn test_model_load_from_memory_empty() {
790 let result = Model::load_from_memory(&[]);
791 assert!(result.is_err());
792 }
793
794 #[test]
795 fn test_model_load_from_memory_invalid() {
796 let garbage = b"this is not a valid udpipe model";
797 let result = Model::load_from_memory(garbage);
798 assert!(result.is_err());
799 }
800
801 #[test]
802 fn test_parse_with_null_model() {
803 let model = Model {
805 inner: std::ptr::null_mut(),
806 };
807 let result = model.parse("test");
808 let err = result.unwrap_err();
809 assert!(err.message.contains("Invalid arguments"));
810 }
811
812 #[test]
813 fn test_model_debug() {
814 let model = Model {
815 inner: std::ptr::null_mut(),
816 };
817 let debug_str = format!("{model:?}");
818 assert!(debug_str.contains("Model"));
819 assert!(debug_str.contains("inner"));
820 }
821
822 #[test]
823 fn test_download_model_from_url_invalid_url() {
824 let temp_dir = tempfile::tempdir().unwrap();
825 let path = temp_dir.path().join("model.udpipe");
826 let result = download_model_from_url("http://invalid.invalid/no-such-model", &path);
827 assert!(result.is_err());
828 let err = result.unwrap_err();
829 assert!(err.message.contains("Failed to download"));
830 }
831
832 #[test]
833 fn test_download_model_from_url_nonexistent_dir() {
834 let temp_dir = tempfile::tempdir().unwrap();
835 let path = temp_dir.path().join("nonexistent/model.udpipe");
836 let url = "http://localhost:1/model.udpipe";
838
839 let result = download_model_from_url(url, &path);
840 assert!(result.is_err());
842 }
843
844 #[test]
845 fn test_download_model_from_url_empty_response() {
846 let temp_dir = tempfile::tempdir().unwrap();
847 let path = temp_dir.path().join("model.udpipe");
848
849 let mut server = mockito::Server::new();
850 let mock = server
851 .mock("GET", "/empty-model.udpipe")
852 .with_status(200)
853 .with_body("")
854 .create();
855 let full_url = format!("{}/empty-model.udpipe", server.url());
856
857 let result = download_model_from_url(&full_url, &path);
858 mock.assert();
859 drop(server);
860
861 assert!(result.is_err());
862 let err = result.unwrap_err();
863 assert!(err.message.contains("empty"));
864 }
865}