1use std::ffi::{CStr, CString};
25use std::fs::File;
26use std::io::BufWriter;
27use std::path::Path;
28
29const MODEL_BASE_URL: &str =
31 "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131";
32
33#[derive(Debug, Clone)]
35pub struct UdpipeError {
36 pub message: String,
38}
39
40impl UdpipeError {
41 pub fn new(message: impl Into<String>) -> Self {
43 Self {
44 message: message.into(),
45 }
46 }
47}
48
49impl std::fmt::Display for UdpipeError {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(f, "UDPipe error: {}", self.message)
52 }
53}
54
55impl std::error::Error for UdpipeError {}
56
57impl From<std::io::Error> for UdpipeError {
58 fn from(err: std::io::Error) -> Self {
59 Self {
60 message: err.to_string(),
61 }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
67pub struct Word {
68 pub form: String,
70 pub lemma: String,
72 pub upostag: String,
74 pub xpostag: String,
76 pub feats: String,
78 pub deprel: String,
80 pub misc: String,
82 pub id: i32,
84 pub head: i32,
86 pub sentence_id: i32,
88}
89
90impl Word {
91 #[must_use]
111 pub fn has_feature(&self, key: &str, value: &str) -> bool {
112 self.get_feature(key) == Some(value)
113 }
114
115 #[must_use]
135 pub fn get_feature(&self, key: &str) -> Option<&str> {
136 self.feats
137 .split('|')
138 .find_map(|f| f.strip_prefix(key)?.strip_prefix('='))
139 }
140
141 #[must_use]
143 pub fn is_verb(&self) -> bool {
144 self.upostag == "VERB" || self.upostag == "AUX"
145 }
146
147 #[must_use]
149 pub fn is_noun(&self) -> bool {
150 self.upostag == "NOUN" || self.upostag == "PROPN"
151 }
152
153 #[must_use]
155 pub fn is_adjective(&self) -> bool {
156 self.upostag == "ADJ"
157 }
158
159 #[must_use]
161 pub fn is_punct(&self) -> bool {
162 self.upostag == "PUNCT"
163 }
164
165 #[must_use]
167 pub fn is_root(&self) -> bool {
168 self.deprel == "root"
169 }
170
171 #[must_use]
177 pub fn space_after(&self) -> bool {
178 !self.misc.contains("SpaceAfter=No")
179 }
180}
181
182mod ffi {
184 use std::os::raw::c_char;
185
186 #[repr(C)]
188 pub struct UdpipeModel {
189 _private: [u8; 0],
191 }
192
193 #[repr(C)]
195 pub struct UdpipeParseResult {
196 _private: [u8; 0],
198 }
199
200 #[repr(C)]
202 pub struct UdpipeWord {
203 pub form: *const c_char,
205 pub lemma: *const c_char,
207 pub upostag: *const c_char,
209 pub xpostag: *const c_char,
211 pub feats: *const c_char,
213 pub deprel: *const c_char,
215 pub misc: *const c_char,
217 pub id: i32,
219 pub head: i32,
221 pub sentence_id: i32,
223 }
224
225 unsafe extern "C" {
226 pub fn udpipe_model_load(model_path: *const c_char) -> *mut UdpipeModel;
228 pub fn udpipe_model_load_from_memory(data: *const u8, len: usize) -> *mut UdpipeModel;
230 pub fn udpipe_model_free(model: *mut UdpipeModel);
232 pub fn udpipe_parse(model: *mut UdpipeModel, text: *const c_char)
234 -> *mut UdpipeParseResult;
235 pub fn udpipe_result_free(result: *mut UdpipeParseResult);
237 pub fn udpipe_get_error() -> *const c_char;
239 pub fn udpipe_result_word_count(result: *mut UdpipeParseResult) -> i32;
241 pub fn udpipe_result_get_word(result: *mut UdpipeParseResult, index: i32) -> UdpipeWord;
243 }
244}
245
246fn get_ffi_error() -> String {
248 let err_ptr = unsafe { ffi::udpipe_get_error() };
250 assert!(!err_ptr.is_null(), "UDPipe returned null error pointer");
251 unsafe { CStr::from_ptr(err_ptr) }
253 .to_string_lossy()
254 .into_owned()
255}
256
257pub struct Model {
262 inner: *mut ffi::UdpipeModel,
264}
265
266impl std::fmt::Debug for Model {
267 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268 f.debug_struct("Model")
269 .field("inner", &(!self.inner.is_null()))
270 .finish()
271 }
272}
273
274unsafe impl Send for Model {}
276unsafe impl Sync for Model {}
279
280impl Model {
281 pub fn load(path: impl AsRef<Path>) -> Result<Self, UdpipeError> {
294 let path_str = path.as_ref().to_string_lossy();
295 let c_path = CString::new(path_str.as_bytes()).map_err(|_| UdpipeError {
296 message: "Invalid path (contains null byte)".to_owned(),
297 })?;
298
299 let model = unsafe { ffi::udpipe_model_load(c_path.as_ptr()) };
301
302 if model.is_null() {
303 return Err(UdpipeError {
304 message: get_ffi_error(),
305 });
306 }
307
308 Ok(Self { inner: model })
309 }
310
311 pub fn load_from_memory(data: &[u8]) -> Result<Self, UdpipeError> {
327 let model = unsafe { ffi::udpipe_model_load_from_memory(data.as_ptr(), data.len()) };
329
330 if model.is_null() {
331 return Err(UdpipeError {
332 message: get_ffi_error(),
333 });
334 }
335
336 Ok(Self { inner: model })
337 }
338
339 pub fn parse(&self, text: &str) -> Result<Vec<Word>, UdpipeError> {
359 let c_text = CString::new(text).map_err(|_| UdpipeError {
360 message: "Invalid text (contains null byte)".to_owned(),
361 })?;
362
363 let result = unsafe { ffi::udpipe_parse(self.inner, c_text.as_ptr()) };
366 if result.is_null() {
367 return Err(UdpipeError {
368 message: get_ffi_error(),
369 });
370 }
371
372 let word_count = unsafe { ffi::udpipe_result_word_count(result) };
374 let capacity = usize::try_from(word_count).unwrap_or(0);
375 let mut words = Vec::with_capacity(capacity);
376
377 for i in 0..word_count {
378 let word = unsafe { ffi::udpipe_result_get_word(result, i) };
380 words.push(Word {
381 form: ptr_to_string(word.form),
382 lemma: ptr_to_string(word.lemma),
383 upostag: ptr_to_string(word.upostag),
384 xpostag: ptr_to_string(word.xpostag),
385 feats: ptr_to_string(word.feats),
386 deprel: ptr_to_string(word.deprel),
387 misc: ptr_to_string(word.misc),
388 id: word.id,
389 head: word.head,
390 sentence_id: word.sentence_id,
391 });
392 }
393
394 unsafe { ffi::udpipe_result_free(result) };
396
397 Ok(words)
398 }
399}
400
401fn ptr_to_string(ptr: *const std::os::raw::c_char) -> String {
406 unsafe { CStr::from_ptr(ptr) }
408 .to_string_lossy()
409 .into_owned()
410}
411
412impl Drop for Model {
413 fn drop(&mut self) {
414 if !self.inner.is_null() {
415 unsafe { ffi::udpipe_model_free(self.inner) };
417 }
418 }
419}
420
421pub const AVAILABLE_MODELS: &[&str] = &[
426 "afrikaans-afribooms",
427 "ancient_greek-perseus",
428 "ancient_greek-proiel",
429 "arabic-padt",
430 "armenian-armtdp",
431 "basque-bdt",
432 "belarusian-hse",
433 "bulgarian-btb",
434 "buryat-bdt",
435 "catalan-ancora",
436 "chinese-gsd",
437 "chinese-gsdsimp",
438 "classical_chinese-kyoto",
439 "coptic-scriptorium",
440 "croatian-set",
441 "czech-cac",
442 "czech-cltt",
443 "czech-fictree",
444 "czech-pdt",
445 "danish-ddt",
446 "dutch-alpino",
447 "dutch-lassysmall",
448 "english-ewt",
449 "english-gum",
450 "english-lines",
451 "english-partut",
452 "estonian-edt",
453 "estonian-ewt",
454 "finnish-ftb",
455 "finnish-tdt",
456 "french-gsd",
457 "french-partut",
458 "french-sequoia",
459 "french-spoken",
460 "galician-ctg",
461 "galician-treegal",
462 "german-gsd",
463 "german-hdt",
464 "gothic-proiel",
465 "greek-gdt",
466 "hebrew-htb",
467 "hindi-hdtb",
468 "hungarian-szeged",
469 "indonesian-gsd",
470 "irish-idt",
471 "italian-isdt",
472 "italian-partut",
473 "italian-postwita",
474 "italian-twittiro",
475 "italian-vit",
476 "japanese-gsd",
477 "kazakh-ktb",
478 "korean-gsd",
479 "korean-kaist",
480 "kurmanji-mg",
481 "latin-ittb",
482 "latin-perseus",
483 "latin-proiel",
484 "latvian-lvtb",
485 "lithuanian-alksnis",
486 "lithuanian-hse",
487 "maltese-mudt",
488 "marathi-ufal",
489 "north_sami-giella",
490 "norwegian-bokmaal",
491 "norwegian-nynorsk",
492 "norwegian-nynorsklia",
493 "old_church_slavonic-proiel",
494 "old_french-srcmf",
495 "old_russian-torot",
496 "persian-seraji",
497 "polish-lfg",
498 "polish-pdb",
499 "polish-sz",
500 "portuguese-bosque",
501 "portuguese-br",
502 "portuguese-gsd",
503 "romanian-nonstandard",
504 "romanian-rrt",
505 "russian-gsd",
506 "russian-syntagrus",
507 "russian-taiga",
508 "sanskrit-ufal",
509 "scottish_gaelic-arcosg",
510 "serbian-set",
511 "slovak-snk",
512 "slovenian-ssj",
513 "slovenian-sst",
514 "spanish-ancora",
515 "spanish-gsd",
516 "swedish-lines",
517 "swedish-talbanken",
518 "tamil-ttb",
519 "telugu-mtg",
520 "turkish-imst",
521 "ukrainian-iu",
522 "upper_sorbian-ufal",
523 "urdu-udtb",
524 "uyghur-udt",
525 "vietnamese-vtb",
526 "wolof-wtb",
527];
528
529pub fn download_model(language: &str, dest_dir: impl AsRef<Path>) -> Result<String, UdpipeError> {
559 let dest_dir = dest_dir.as_ref();
560
561 if !AVAILABLE_MODELS.contains(&language) {
563 return Err(UdpipeError {
564 message: format!(
565 "Unknown language '{}'. Use one of: {}",
566 language,
567 AVAILABLE_MODELS[..5].join(", ") + ", ..."
568 ),
569 });
570 }
571
572 let filename = model_filename(language);
574 let dest_path = dest_dir.join(&filename);
575 let url = format!("{MODEL_BASE_URL}/{filename}");
576
577 download_model_from_url(&url, &dest_path)?;
579
580 Ok(dest_path.to_string_lossy().into_owned())
581}
582
583pub fn download_model_from_url(url: &str, path: impl AsRef<Path>) -> Result<(), UdpipeError> {
605 let path = path.as_ref();
606
607 let response = ureq::get(url).call().map_err(|e| UdpipeError {
609 message: format!("Failed to download: {e}"),
610 })?;
611
612 let file = File::create(path)?;
614 let mut writer = BufWriter::new(file);
615 let bytes_written = std::io::copy(&mut response.into_body().into_reader(), &mut writer)?;
616
617 if bytes_written == 0 {
618 return Err(UdpipeError {
619 message: "Downloaded file is empty".to_owned(),
620 });
621 }
622
623 Ok(())
624}
625
626#[must_use]
637pub fn model_filename(language: &str) -> String {
638 format!("{language}-ud-2.5-191206.udpipe")
639}
640
641#[cfg(test)]
642mod tests {
643 use super::*;
644
645 fn make_word(feats: &str) -> Word {
646 Word {
647 form: "test".to_owned(),
648 lemma: "test".to_owned(),
649 upostag: "NOUN".to_owned(),
650 xpostag: String::new(),
651 feats: feats.to_owned(),
652 deprel: "root".to_owned(),
653 misc: String::new(),
654 id: 1,
655 head: 0,
656 sentence_id: 0,
657 }
658 }
659
660 #[test]
661 fn test_word_has_feature() {
662 let word = make_word("Mood=Imp|VerbForm=Fin");
663
664 assert!(word.has_feature("Mood", "Imp"));
665 assert!(word.has_feature("VerbForm", "Fin"));
666 assert!(!word.has_feature("Mood", "Ind"));
667 assert!(!word.has_feature("Tense", "Past"));
668 }
669
670 #[test]
671 fn test_word_has_feature_empty() {
672 let word = make_word("");
673 assert!(!word.has_feature("Mood", "Imp"));
674 }
675
676 #[test]
677 fn test_word_has_feature_single() {
678 let word = make_word("Mood=Imp");
679 assert!(word.has_feature("Mood", "Imp"));
680 assert!(!word.has_feature("VerbForm", "Fin"));
681 }
682
683 #[test]
684 fn test_word_get_feature() {
685 let word = make_word("Tense=Pres|VerbForm=Part");
686
687 assert_eq!(word.get_feature("Tense"), Some("Pres"));
688 assert_eq!(word.get_feature("VerbForm"), Some("Part"));
689 assert_eq!(word.get_feature("Mood"), None);
690 }
691
692 #[test]
693 fn test_word_get_feature_empty() {
694 let word = make_word("");
695 assert_eq!(word.get_feature("Mood"), None);
696 }
697
698 #[test]
699 fn test_word_get_feature_single() {
700 let word = make_word("Mood=Imp");
701 assert_eq!(word.get_feature("Mood"), Some("Imp"));
702 assert_eq!(word.get_feature("VerbForm"), None);
703 }
704
705 #[test]
706 fn test_word_is_verb() {
707 let mut word = make_word("");
708 word.upostag = "VERB".to_owned();
709 assert!(word.is_verb());
710
711 word.upostag = "AUX".to_owned();
712 assert!(word.is_verb());
713
714 word.upostag = "NOUN".to_owned();
715 assert!(!word.is_verb());
716 }
717
718 #[test]
719 fn test_word_is_noun() {
720 let mut word = make_word("");
721 word.upostag = "NOUN".to_owned();
722 assert!(word.is_noun());
723
724 word.upostag = "PROPN".to_owned();
725 assert!(word.is_noun());
726
727 word.upostag = "VERB".to_owned();
728 assert!(!word.is_noun());
729 }
730
731 #[test]
732 fn test_word_is_root() {
733 let mut word = make_word("");
734 word.deprel = "root".to_owned();
735 assert!(word.is_root());
736
737 word.deprel = "nsubj".to_owned();
738 assert!(!word.is_root());
739 }
740
741 #[test]
742 fn test_word_is_adjective() {
743 let mut word = make_word("");
744 word.upostag = "ADJ".to_owned();
745 assert!(word.is_adjective());
746
747 word.upostag = "NOUN".to_owned();
748 assert!(!word.is_adjective());
749 }
750
751 #[test]
752 fn test_word_is_punct() {
753 let mut word = make_word("");
754 word.upostag = "PUNCT".to_owned();
755 assert!(word.is_punct());
756
757 word.upostag = "NOUN".to_owned();
758 assert!(!word.is_punct());
759 }
760
761 #[test]
762 fn test_word_hash() {
763 use std::collections::HashSet;
764
765 let word1 = make_word("Mood=Imp");
766 let word2 = make_word("Mood=Imp");
767 let mut set = HashSet::new();
768 set.insert(word1);
769 assert!(set.contains(&word2));
770 }
771
772 #[test]
773 fn test_model_filename() {
774 assert_eq!(
775 model_filename("english-ewt"),
776 "english-ewt-ud-2.5-191206.udpipe"
777 );
778 assert_eq!(
779 model_filename("dutch-alpino"),
780 "dutch-alpino-ud-2.5-191206.udpipe"
781 );
782 }
783
784 #[test]
785 fn test_available_models_contains_common_languages() {
786 assert!(AVAILABLE_MODELS.contains(&"english-ewt"));
787 assert!(AVAILABLE_MODELS.contains(&"german-gsd"));
788 assert!(AVAILABLE_MODELS.contains(&"french-gsd"));
789 assert!(AVAILABLE_MODELS.contains(&"spanish-ancora"));
790 }
791
792 #[test]
793 fn test_available_models_sorted() {
794 let mut sorted = AVAILABLE_MODELS.to_vec();
796 sorted.sort_unstable();
797 assert_eq!(AVAILABLE_MODELS, sorted.as_slice());
798 }
799
800 #[test]
801 fn test_download_model_invalid_language() {
802 let result = download_model("invalid-language-xyz", ".");
803 assert!(result.is_err());
804 let err = result.unwrap_err();
805 assert!(err.message.contains("Unknown language"));
806 }
807
808 #[test]
809 fn test_udpipe_error_display() {
810 let err = UdpipeError::new("test error");
811 assert_eq!(format!("{err}"), "UDPipe error: test error");
812 }
813
814 #[test]
815 fn test_udpipe_error_from_io() {
816 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
817 let err: UdpipeError = io_err.into();
818 assert!(err.message.contains("not found"));
819 }
820
821 #[test]
822 fn test_space_after() {
823 let mut word = make_word("");
824 word.misc = String::new();
825 assert!(word.space_after()); word.misc = "SpaceAfter=No".to_owned();
828 assert!(!word.space_after());
829
830 word.misc = "SpaceAfter=No|Other=Value".to_owned();
831 assert!(!word.space_after());
832 }
833
834 #[test]
835 fn test_model_load_nonexistent_file() {
836 let result = Model::load("/nonexistent/path/to/model.udpipe");
837 assert!(result.is_err());
838 }
839
840 #[test]
841 fn test_model_load_path_with_null_byte() {
842 let result = Model::load("path\0with\0nulls.udpipe");
843 let err = result.expect_err("expected error");
844 assert!(err.message.contains("null byte"));
845 }
846
847 #[test]
848 fn test_model_load_from_memory_empty() {
849 let result = Model::load_from_memory(&[]);
850 assert!(result.is_err());
851 }
852
853 #[test]
854 fn test_model_load_from_memory_invalid() {
855 let garbage = b"this is not a valid udpipe model";
856 let result = Model::load_from_memory(garbage);
857 assert!(result.is_err());
858 }
859
860 #[test]
861 fn test_parse_with_null_model() {
862 let model = Model {
864 inner: std::ptr::null_mut(),
865 };
866 let result = model.parse("test");
867 let err = result.unwrap_err();
868 assert!(err.message.contains("Invalid arguments"));
869 }
870
871 #[test]
872 fn test_model_debug() {
873 let model = Model {
874 inner: std::ptr::null_mut(),
875 };
876 let debug_str = format!("{model:?}");
877 assert!(debug_str.contains("Model"));
878 assert!(debug_str.contains("inner"));
879 }
880
881 #[test]
882 fn test_download_model_from_url_invalid_url() {
883 let temp_dir = tempfile::tempdir().unwrap();
884 let path = temp_dir.path().join("model.udpipe");
885 let result = download_model_from_url("http://invalid.invalid/no-such-model", &path);
886 assert!(result.is_err());
887 let err = result.unwrap_err();
888 assert!(err.message.contains("Failed to download"));
889 }
890
891 #[test]
892 fn test_download_model_from_url_nonexistent_dir() {
893 let temp_dir = tempfile::tempdir().unwrap();
894 let path = temp_dir.path().join("nonexistent/model.udpipe");
895 let url = "http://localhost:1/model.udpipe";
897
898 let result = download_model_from_url(url, &path);
899 assert!(result.is_err());
902 }
903
904 #[test]
905 fn test_download_model_from_url_empty_response() {
906 let temp_dir = tempfile::tempdir().unwrap();
907 let path = temp_dir.path().join("model.udpipe");
908
909 let mut server = mockito::Server::new();
910 let mock = server
911 .mock("GET", "/empty-model.udpipe")
912 .with_status(200)
913 .with_body("")
914 .create();
915 let full_url = format!("{}/empty-model.udpipe", server.url());
916
917 let result = download_model_from_url(&full_url, &path);
918 mock.assert();
919 drop(server);
920
921 assert!(result.is_err());
922 let err = result.unwrap_err();
923 assert!(err.message.contains("empty"));
924 }
925}