1use std::ffi::{CStr, CString};
25use std::fs::File;
26use std::io::BufWriter;
27use std::path::Path;
28
29const MODEL_BASE_URL: &str =
31 "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131";
32
33#[derive(Debug, Clone)]
35pub struct UdpipeError {
36 pub message: String,
38}
39
40impl UdpipeError {
41 pub fn new(message: impl Into<String>) -> Self {
43 Self {
44 message: message.into(),
45 }
46 }
47}
48
49impl std::fmt::Display for UdpipeError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "UDPipe error: {}", self.message)
52 }
53}
54
55impl std::error::Error for UdpipeError {}
56
57impl From<std::io::Error> for UdpipeError {
58 fn from(err: std::io::Error) -> Self {
59 Self {
60 message: err.to_string(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67pub struct Word {
68 pub form: String,
70 pub lemma: String,
72 pub upostag: String,
74 pub xpostag: String,
76 pub feats: String,
78 pub deprel: String,
80 pub misc: String,
82 pub id: i32,
84 pub head: i32,
86 pub sentence_id: i32,
88}
89
90impl Word {
91 #[must_use]
111 pub fn has_feature(&self, key: &str, value: &str) -> bool {
112 self.get_feature(key) == Some(value)
113 }
114
115 #[must_use]
135 pub fn get_feature(&self, key: &str) -> Option<&str> {
136 self.feats
137 .split('|')
138 .find_map(|f| f.strip_prefix(key)?.strip_prefix('='))
139 }
140
141 #[must_use]
143 pub fn is_verb(&self) -> bool {
144 self.upostag == "VERB" || self.upostag == "AUX"
145 }
146
147 #[must_use]
149 pub fn is_noun(&self) -> bool {
150 self.upostag == "NOUN" || self.upostag == "PROPN"
151 }
152
153 #[must_use]
155 pub fn is_adjective(&self) -> bool {
156 self.upostag == "ADJ"
157 }
158
159 #[must_use]
161 pub fn is_punct(&self) -> bool {
162 self.upostag == "PUNCT"
163 }
164
165 #[must_use]
167 pub fn is_root(&self) -> bool {
168 self.deprel == "root"
169 }
170
171 #[must_use]
177 pub fn has_space_after(&self) -> bool {
178 !self.misc.contains("SpaceAfter=No")
179 }
180}
181
182mod ffi {
184 use std::os::raw::c_char;
185
186 #[repr(C)]
188 pub struct UdpipeModel {
189 _private: [u8; 0],
191 }
192
193 #[repr(C)]
195 pub struct UdpipeParseResult {
196 _private: [u8; 0],
198 }
199
200 #[repr(C)]
202 pub struct UdpipeWord {
203 pub form: *const c_char,
205 pub lemma: *const c_char,
207 pub upostag: *const c_char,
209 pub xpostag: *const c_char,
211 pub feats: *const c_char,
213 pub deprel: *const c_char,
215 pub misc: *const c_char,
217 pub id: i32,
219 pub head: i32,
221 pub sentence_id: i32,
223 }
224
225 unsafe extern "C" {
226 pub fn udpipe_model_load(model_path: *const c_char) -> *mut UdpipeModel;
228 pub fn udpipe_model_load_from_memory(data: *const u8, len: usize) -> *mut UdpipeModel;
230 pub fn udpipe_model_free(model: *mut UdpipeModel);
232 pub fn udpipe_parse(model: *mut UdpipeModel, text: *const c_char)
234 -> *mut UdpipeParseResult;
235 pub fn udpipe_result_free(result: *mut UdpipeParseResult);
237 pub fn udpipe_get_error() -> *const c_char;
239 pub fn udpipe_result_word_count(result: *mut UdpipeParseResult) -> i32;
241 pub fn udpipe_result_get_word(result: *mut UdpipeParseResult, index: i32) -> UdpipeWord;
243 }
244}
245
246fn get_ffi_error() -> String {
248 let err_ptr = unsafe { ffi::udpipe_get_error() };
250 assert!(!err_ptr.is_null(), "UDPipe returned null error pointer");
251 unsafe { CStr::from_ptr(err_ptr) }
253 .to_string_lossy()
254 .into_owned()
255}
256
257pub struct Model {
269 inner: *mut ffi::UdpipeModel,
271}
272
273impl std::fmt::Debug for Model {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 f.debug_struct("Model")
276 .field("inner", &(!self.inner.is_null()))
277 .finish()
278 }
279}
280
281unsafe impl Send for Model {}
291
292impl Model {
304 pub fn load(path: impl AsRef<Path>) -> Result<Self, UdpipeError> {
317 let path_str = path.as_ref().to_string_lossy();
318 let c_path = CString::new(path_str.as_bytes()).map_err(|_| UdpipeError {
319 message: "Invalid path (contains null byte)".to_owned(),
320 })?;
321
322 let model = unsafe { ffi::udpipe_model_load(c_path.as_ptr()) };
324
325 if model.is_null() {
326 return Err(UdpipeError {
327 message: get_ffi_error(),
328 });
329 }
330
331 Ok(Self { inner: model })
332 }
333
334 pub fn load_from_memory(data: &[u8]) -> Result<Self, UdpipeError> {
350 let model = unsafe { ffi::udpipe_model_load_from_memory(data.as_ptr(), data.len()) };
352
353 if model.is_null() {
354 return Err(UdpipeError {
355 message: get_ffi_error(),
356 });
357 }
358
359 Ok(Self { inner: model })
360 }
361
362 pub fn parse(&self, text: &str) -> Result<Vec<Word>, UdpipeError> {
382 let c_text = CString::new(text).map_err(|_| UdpipeError {
383 message: "Invalid text (contains null byte)".to_owned(),
384 })?;
385
386 let result = unsafe { ffi::udpipe_parse(self.inner, c_text.as_ptr()) };
389 if result.is_null() {
390 return Err(UdpipeError {
391 message: get_ffi_error(),
392 });
393 }
394
395 let word_count = unsafe { ffi::udpipe_result_word_count(result) };
397 let capacity = usize::try_from(word_count).unwrap_or(0);
398 let mut words = Vec::with_capacity(capacity);
399
400 for i in 0..word_count {
401 let word = unsafe { ffi::udpipe_result_get_word(result, i) };
403 words.push(Word {
404 form: ptr_to_string(word.form),
405 lemma: ptr_to_string(word.lemma),
406 upostag: ptr_to_string(word.upostag),
407 xpostag: ptr_to_string(word.xpostag),
408 feats: ptr_to_string(word.feats),
409 deprel: ptr_to_string(word.deprel),
410 misc: ptr_to_string(word.misc),
411 id: word.id,
412 head: word.head,
413 sentence_id: word.sentence_id,
414 });
415 }
416
417 unsafe { ffi::udpipe_result_free(result) };
419
420 Ok(words)
421 }
422}
423
424fn ptr_to_string(ptr: *const std::os::raw::c_char) -> String {
429 unsafe { CStr::from_ptr(ptr) }
431 .to_string_lossy()
432 .into_owned()
433}
434
435impl Drop for Model {
436 fn drop(&mut self) {
437 if !self.inner.is_null() {
438 unsafe { ffi::udpipe_model_free(self.inner) };
440 }
441 }
442}
443
444pub const AVAILABLE_MODELS: &[&str] = &[
449 "afrikaans-afribooms",
450 "ancient_greek-perseus",
451 "ancient_greek-proiel",
452 "arabic-padt",
453 "armenian-armtdp",
454 "basque-bdt",
455 "belarusian-hse",
456 "bulgarian-btb",
457 "buryat-bdt",
458 "catalan-ancora",
459 "chinese-gsd",
460 "chinese-gsdsimp",
461 "classical_chinese-kyoto",
462 "coptic-scriptorium",
463 "croatian-set",
464 "czech-cac",
465 "czech-cltt",
466 "czech-fictree",
467 "czech-pdt",
468 "danish-ddt",
469 "dutch-alpino",
470 "dutch-lassysmall",
471 "english-ewt",
472 "english-gum",
473 "english-lines",
474 "english-partut",
475 "estonian-edt",
476 "estonian-ewt",
477 "finnish-ftb",
478 "finnish-tdt",
479 "french-gsd",
480 "french-partut",
481 "french-sequoia",
482 "french-spoken",
483 "galician-ctg",
484 "galician-treegal",
485 "german-gsd",
486 "german-hdt",
487 "gothic-proiel",
488 "greek-gdt",
489 "hebrew-htb",
490 "hindi-hdtb",
491 "hungarian-szeged",
492 "indonesian-gsd",
493 "irish-idt",
494 "italian-isdt",
495 "italian-partut",
496 "italian-postwita",
497 "italian-twittiro",
498 "italian-vit",
499 "japanese-gsd",
500 "kazakh-ktb",
501 "korean-gsd",
502 "korean-kaist",
503 "kurmanji-mg",
504 "latin-ittb",
505 "latin-perseus",
506 "latin-proiel",
507 "latvian-lvtb",
508 "lithuanian-alksnis",
509 "lithuanian-hse",
510 "maltese-mudt",
511 "marathi-ufal",
512 "north_sami-giella",
513 "norwegian-bokmaal",
514 "norwegian-nynorsk",
515 "norwegian-nynorsklia",
516 "old_church_slavonic-proiel",
517 "old_french-srcmf",
518 "old_russian-torot",
519 "persian-seraji",
520 "polish-lfg",
521 "polish-pdb",
522 "polish-sz",
523 "portuguese-bosque",
524 "portuguese-br",
525 "portuguese-gsd",
526 "romanian-nonstandard",
527 "romanian-rrt",
528 "russian-gsd",
529 "russian-syntagrus",
530 "russian-taiga",
531 "sanskrit-ufal",
532 "scottish_gaelic-arcosg",
533 "serbian-set",
534 "slovak-snk",
535 "slovenian-ssj",
536 "slovenian-sst",
537 "spanish-ancora",
538 "spanish-gsd",
539 "swedish-lines",
540 "swedish-talbanken",
541 "tamil-ttb",
542 "telugu-mtg",
543 "turkish-imst",
544 "ukrainian-iu",
545 "upper_sorbian-ufal",
546 "urdu-udtb",
547 "uyghur-udt",
548 "vietnamese-vtb",
549 "wolof-wtb",
550];
551
552pub fn download_model(language: &str, dest_dir: impl AsRef<Path>) -> Result<String, UdpipeError> {
582 let dest_dir = dest_dir.as_ref();
583
584 if !AVAILABLE_MODELS.contains(&language) {
586 return Err(UdpipeError {
587 message: format!(
588 "Unknown language '{}'. Use one of: {}",
589 language,
590 AVAILABLE_MODELS[..5].join(", ") + ", ..."
591 ),
592 });
593 }
594
595 let filename = model_filename(language);
597 let dest_path = dest_dir.join(&filename);
598 let url = format!("{MODEL_BASE_URL}/{filename}");
599
600 download_model_from_url(&url, &dest_path)?;
602
603 Ok(dest_path.to_string_lossy().into_owned())
604}
605
606pub fn download_model_from_url(url: &str, path: impl AsRef<Path>) -> Result<(), UdpipeError> {
628 let path = path.as_ref();
629
630 let response = ureq::get(url).call().map_err(|e| UdpipeError {
632 message: format!("Failed to download: {e}"),
633 })?;
634
635 let file = File::create(path)?;
637 let mut writer = BufWriter::new(file);
638 let bytes_written = std::io::copy(&mut response.into_body().into_reader(), &mut writer)?;
639
640 if bytes_written == 0 {
641 return Err(UdpipeError {
642 message: "Downloaded file is empty".to_owned(),
643 });
644 }
645
646 Ok(())
647}
648
649#[must_use]
660pub fn model_filename(language: &str) -> String {
661 format!("{language}-ud-2.5-191206.udpipe")
662}
663
664#[cfg(test)]
665mod tests {
666 use super::*;
667
668 fn make_word(feats: &str) -> Word {
669 Word {
670 form: "test".to_owned(),
671 lemma: "test".to_owned(),
672 upostag: "NOUN".to_owned(),
673 xpostag: String::new(),
674 feats: feats.to_owned(),
675 deprel: "root".to_owned(),
676 misc: String::new(),
677 id: 1,
678 head: 0,
679 sentence_id: 0,
680 }
681 }
682
683 #[test]
684 fn test_word_has_feature() {
685 let word = make_word("Mood=Imp|VerbForm=Fin");
686
687 assert!(word.has_feature("Mood", "Imp"));
688 assert!(word.has_feature("VerbForm", "Fin"));
689 assert!(!word.has_feature("Mood", "Ind"));
690 assert!(!word.has_feature("Tense", "Past"));
691 }
692
693 #[test]
694 fn test_word_has_feature_empty() {
695 let word = make_word("");
696 assert!(!word.has_feature("Mood", "Imp"));
697 }
698
699 #[test]
700 fn test_word_has_feature_single() {
701 let word = make_word("Mood=Imp");
702 assert!(word.has_feature("Mood", "Imp"));
703 assert!(!word.has_feature("VerbForm", "Fin"));
704 }
705
706 #[test]
707 fn test_word_get_feature() {
708 let word = make_word("Tense=Pres|VerbForm=Part");
709
710 assert_eq!(word.get_feature("Tense"), Some("Pres"));
711 assert_eq!(word.get_feature("VerbForm"), Some("Part"));
712 assert_eq!(word.get_feature("Mood"), None);
713 }
714
715 #[test]
716 fn test_word_get_feature_empty() {
717 let word = make_word("");
718 assert_eq!(word.get_feature("Mood"), None);
719 }
720
721 #[test]
722 fn test_word_get_feature_single() {
723 let word = make_word("Mood=Imp");
724 assert_eq!(word.get_feature("Mood"), Some("Imp"));
725 assert_eq!(word.get_feature("VerbForm"), None);
726 }
727
728 #[test]
729 fn test_word_is_verb() {
730 let mut word = make_word("");
731 word.upostag = "VERB".to_owned();
732 assert!(word.is_verb());
733
734 word.upostag = "AUX".to_owned();
735 assert!(word.is_verb());
736
737 word.upostag = "NOUN".to_owned();
738 assert!(!word.is_verb());
739 }
740
741 #[test]
742 fn test_word_is_noun() {
743 let mut word = make_word("");
744 word.upostag = "NOUN".to_owned();
745 assert!(word.is_noun());
746
747 word.upostag = "PROPN".to_owned();
748 assert!(word.is_noun());
749
750 word.upostag = "VERB".to_owned();
751 assert!(!word.is_noun());
752 }
753
754 #[test]
755 fn test_word_is_root() {
756 let mut word = make_word("");
757 word.deprel = "root".to_owned();
758 assert!(word.is_root());
759
760 word.deprel = "nsubj".to_owned();
761 assert!(!word.is_root());
762 }
763
764 #[test]
765 fn test_word_is_adjective() {
766 let mut word = make_word("");
767 word.upostag = "ADJ".to_owned();
768 assert!(word.is_adjective());
769
770 word.upostag = "NOUN".to_owned();
771 assert!(!word.is_adjective());
772 }
773
774 #[test]
775 fn test_word_is_punct() {
776 let mut word = make_word("");
777 word.upostag = "PUNCT".to_owned();
778 assert!(word.is_punct());
779
780 word.upostag = "NOUN".to_owned();
781 assert!(!word.is_punct());
782 }
783
784 #[test]
785 fn test_word_hash() {
786 use std::collections::HashSet;
787
788 let word1 = make_word("Mood=Imp");
789 let word2 = make_word("Mood=Imp");
790 let mut set = HashSet::new();
791 set.insert(word1);
792 assert!(set.contains(&word2));
793 }
794
795 #[test]
796 fn test_model_filename() {
797 assert_eq!(
798 model_filename("english-ewt"),
799 "english-ewt-ud-2.5-191206.udpipe"
800 );
801 assert_eq!(
802 model_filename("dutch-alpino"),
803 "dutch-alpino-ud-2.5-191206.udpipe"
804 );
805 }
806
807 #[test]
808 fn test_available_models_contains_common_languages() {
809 assert!(AVAILABLE_MODELS.contains(&"english-ewt"));
810 assert!(AVAILABLE_MODELS.contains(&"german-gsd"));
811 assert!(AVAILABLE_MODELS.contains(&"french-gsd"));
812 assert!(AVAILABLE_MODELS.contains(&"spanish-ancora"));
813 }
814
815 #[test]
816 fn test_available_models_sorted() {
817 let mut sorted = AVAILABLE_MODELS.to_vec();
819 sorted.sort_unstable();
820 assert_eq!(AVAILABLE_MODELS, sorted.as_slice());
821 }
822
823 #[test]
824 fn test_download_model_invalid_language() {
825 let result = download_model("invalid-language-xyz", ".");
826 assert!(result.is_err());
827 let err = result.unwrap_err();
828 assert!(err.message.contains("Unknown language"));
829 }
830
831 #[test]
832 fn test_udpipe_error_display() {
833 let err = UdpipeError::new("test error");
834 assert_eq!(format!("{err}"), "UDPipe error: test error");
835 }
836
837 #[test]
838 fn test_udpipe_error_from_io() {
839 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
840 let err: UdpipeError = io_err.into();
841 assert!(err.message.contains("not found"));
842 }
843
844 #[test]
845 fn test_has_space_after() {
846 let mut word = make_word("");
847 word.misc = String::new();
848 assert!(word.has_space_after()); word.misc = "SpaceAfter=No".to_owned();
851 assert!(!word.has_space_after());
852
853 word.misc = "SpaceAfter=No|Other=Value".to_owned();
854 assert!(!word.has_space_after());
855 }
856
857 #[test]
858 fn test_model_load_nonexistent_file() {
859 let result = Model::load("/nonexistent/path/to/model.udpipe");
860 assert!(result.is_err());
861 }
862
863 #[test]
864 fn test_model_load_path_with_null_byte() {
865 let result = Model::load("path\0with\0nulls.udpipe");
866 let err = result.expect_err("expected error");
867 assert!(err.message.contains("null byte"));
868 }
869
870 #[test]
871 fn test_model_load_from_memory_empty() {
872 let result = Model::load_from_memory(&[]);
873 assert!(result.is_err());
874 }
875
876 #[test]
877 fn test_model_load_from_memory_invalid() {
878 let garbage = b"this is not a valid udpipe model";
879 let result = Model::load_from_memory(garbage);
880 assert!(result.is_err());
881 }
882
883 #[test]
884 fn test_parse_with_null_model() {
885 let model = Model {
887 inner: std::ptr::null_mut(),
888 };
889 let result = model.parse("test");
890 let err = result.unwrap_err();
891 assert!(err.message.contains("Invalid arguments"));
892 }
893
894 #[test]
895 fn test_model_debug() {
896 let model = Model {
897 inner: std::ptr::null_mut(),
898 };
899 let debug_str = format!("{model:?}");
900 assert!(debug_str.contains("Model"));
901 assert!(debug_str.contains("inner"));
902 }
903
904 #[test]
905 fn test_download_model_from_url_invalid_url() {
906 let temp_dir = tempfile::tempdir().unwrap();
907 let path = temp_dir.path().join("model.udpipe");
908 let result = download_model_from_url("http://invalid.invalid/no-such-model", &path);
909 assert!(result.is_err());
910 let err = result.unwrap_err();
911 assert!(err.message.contains("Failed to download"));
912 }
913
914 #[test]
915 fn test_download_model_from_url_nonexistent_dir() {
916 let temp_dir = tempfile::tempdir().unwrap();
917 let path = temp_dir.path().join("nonexistent/model.udpipe");
918 let url = "http://localhost:1/model.udpipe";
920
921 let result = download_model_from_url(url, &path);
922 assert!(result.is_err());
925 }
926
927 #[test]
928 fn test_download_model_from_url_empty_response() {
929 let temp_dir = tempfile::tempdir().unwrap();
930 let path = temp_dir.path().join("model.udpipe");
931
932 let mut server = mockito::Server::new();
933 let mock = server
934 .mock("GET", "/empty-model.udpipe")
935 .with_status(200)
936 .with_body("")
937 .create();
938 let full_url = format!("{}/empty-model.udpipe", server.url());
939
940 let result = download_model_from_url(&full_url, &path);
941 mock.assert();
942 drop(server);
943
944 assert!(result.is_err());
945 let err = result.unwrap_err();
946 assert!(err.message.contains("empty"));
947 }
948
949 #[test]
950 fn test_ffi_null_result_word_count() {
951 let count = unsafe { ffi::udpipe_result_word_count(std::ptr::null_mut()) };
953 assert_eq!(count, 0);
954 }
955
956 #[test]
957 fn test_ffi_null_result_get_word() {
958 let word = unsafe { ffi::udpipe_result_get_word(std::ptr::null_mut(), 0) };
960 assert!(word.form.is_null());
961 assert!(word.lemma.is_null());
962 assert!(word.upostag.is_null());
963 }
964
965 #[test]
966 fn test_ffi_invalid_index() {
967 let word = unsafe { ffi::udpipe_result_get_word(std::ptr::null_mut(), -1) };
969 assert!(word.form.is_null());
970 }
971}