1#![deny(missing_docs)]
25
26use std::ffi::{CStr, CString};
27use std::io::Read;
28use std::path::Path;
29
30const MODEL_BASE_URL: &str =
32 "https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-3131";
33
34#[derive(Debug, Clone)]
36pub struct UdpipeError {
37 pub message: String,
39}
40
41impl UdpipeError {
42 pub fn new(message: impl Into<String>) -> Self {
44 Self {
45 message: message.into(),
46 }
47 }
48}
49
50impl std::fmt::Display for UdpipeError {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 write!(f, "UDPipe error: {}", self.message)
53 }
54}
55
56impl std::error::Error for UdpipeError {}
57
58impl From<std::io::Error> for UdpipeError {
59 fn from(err: std::io::Error) -> Self {
60 Self {
61 message: err.to_string(),
62 }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub struct Word {
69 pub form: String,
71 pub lemma: String,
73 pub upostag: String,
75 pub xpostag: String,
77 pub feats: String,
79 pub deprel: String,
81 pub misc: String,
83 pub id: i32,
85 pub head: i32,
87 pub sentence_id: i32,
89}
90
91impl Word {
92 #[must_use]
112 pub fn has_feature(&self, key: &str, value: &str) -> bool {
113 self.get_feature(key) == Some(value)
114 }
115
116 #[must_use]
136 pub fn get_feature(&self, key: &str) -> Option<&str> {
137 self.feats
138 .split('|')
139 .find_map(|f| f.strip_prefix(key)?.strip_prefix('='))
140 }
141
142 #[must_use]
144 pub fn is_verb(&self) -> bool {
145 self.upostag == "VERB" || self.upostag == "AUX"
146 }
147
148 #[must_use]
150 pub fn is_noun(&self) -> bool {
151 self.upostag == "NOUN" || self.upostag == "PROPN"
152 }
153
154 #[must_use]
156 pub fn is_adjective(&self) -> bool {
157 self.upostag == "ADJ"
158 }
159
160 #[must_use]
162 pub fn is_punct(&self) -> bool {
163 self.upostag == "PUNCT"
164 }
165
166 #[must_use]
168 pub fn is_root(&self) -> bool {
169 self.deprel == "root"
170 }
171
172 #[must_use]
177 pub fn space_after(&self) -> bool {
178 !self.misc.contains("SpaceAfter=No")
179 }
180}
181
182mod ffi {
184 use std::os::raw::c_char;
185
186 #[repr(C)]
187 pub struct UdpipeModel {
188 _private: [u8; 0],
189 }
190
191 #[repr(C)]
192 pub struct UdpipeParseResult {
193 _private: [u8; 0],
194 }
195
196 #[repr(C)]
197 pub struct UdpipeWord {
198 pub form: *const c_char,
199 pub lemma: *const c_char,
200 pub upostag: *const c_char,
201 pub xpostag: *const c_char,
202 pub feats: *const c_char,
203 pub deprel: *const c_char,
204 pub misc: *const c_char,
205 pub id: i32,
206 pub head: i32,
207 pub sentence_id: i32,
208 }
209
210 unsafe extern "C" {
211 pub fn udpipe_model_load(model_path: *const c_char) -> *mut UdpipeModel;
212 pub fn udpipe_model_load_from_memory(data: *const u8, len: usize) -> *mut UdpipeModel;
213 pub fn udpipe_model_free(model: *mut UdpipeModel);
214 pub fn udpipe_parse(model: *mut UdpipeModel, text: *const c_char)
215 -> *mut UdpipeParseResult;
216 pub fn udpipe_result_free(result: *mut UdpipeParseResult);
217 pub fn udpipe_get_error() -> *const c_char;
218 pub fn udpipe_result_word_count(result: *mut UdpipeParseResult) -> i32;
219 pub fn udpipe_result_get_word(result: *mut UdpipeParseResult, index: i32) -> UdpipeWord;
220 }
221}
222
223fn get_ffi_error(default: &str) -> String {
228 unsafe {
229 let err_ptr = ffi::udpipe_get_error();
230 if err_ptr.is_null() {
231 default.to_string()
232 } else {
233 CStr::from_ptr(err_ptr).to_string_lossy().into_owned()
234 }
235 }
236}
237
238fn check_ffi_result<T>(ptr: *mut T, error_msg: &str) -> Result<*mut T, UdpipeError> {
243 if ptr.is_null() {
244 Err(UdpipeError {
245 message: get_ffi_error(error_msg),
246 })
247 } else {
248 Ok(ptr)
249 }
250}
251
252pub struct Model {
257 inner: *mut ffi::UdpipeModel,
258}
259
260unsafe impl Send for Model {}
262unsafe impl Sync for Model {}
263
264impl Model {
265 pub fn load(path: impl AsRef<Path>) -> Result<Self, UdpipeError> {
273 let path_str = path.as_ref().to_string_lossy();
274 let c_path = CString::new(path_str.as_bytes()).map_err(|_| UdpipeError {
275 message: "Invalid path (contains null byte)".to_string(),
276 })?;
277
278 let model = unsafe { ffi::udpipe_model_load(c_path.as_ptr()) };
279
280 if model.is_null() {
281 return Err(UdpipeError {
282 message: get_ffi_error("Failed to load model"),
283 });
284 }
285
286 Ok(Model { inner: model })
287 }
288
289 pub fn load_from_memory(data: &[u8]) -> Result<Self, UdpipeError> {
300 let model = unsafe { ffi::udpipe_model_load_from_memory(data.as_ptr(), data.len()) };
301
302 if model.is_null() {
303 return Err(UdpipeError {
304 message: get_ffi_error("Failed to load model from memory"),
305 });
306 }
307
308 Ok(Model { inner: model })
309 }
310
311 pub fn parse(&self, text: &str) -> Result<Vec<Word>, UdpipeError> {
325 let c_text = CString::new(text).map_err(|_| UdpipeError {
326 message: "Invalid text (contains null byte)".to_string(),
327 })?;
328
329 let result = unsafe { ffi::udpipe_parse(self.inner, c_text.as_ptr()) };
330 let result = check_ffi_result(result, "Failed to parse text")?;
331
332 let word_count = unsafe { ffi::udpipe_result_word_count(result) };
333 let mut words = Vec::with_capacity(word_count as usize);
334
335 for i in 0..word_count {
336 let word = unsafe { ffi::udpipe_result_get_word(result, i) };
337 words.push(Word {
338 form: unsafe { CStr::from_ptr(word.form).to_string_lossy().into_owned() },
339 lemma: unsafe { CStr::from_ptr(word.lemma).to_string_lossy().into_owned() },
340 upostag: unsafe { CStr::from_ptr(word.upostag).to_string_lossy().into_owned() },
341 xpostag: unsafe { CStr::from_ptr(word.xpostag).to_string_lossy().into_owned() },
342 feats: unsafe { CStr::from_ptr(word.feats).to_string_lossy().into_owned() },
343 deprel: unsafe { CStr::from_ptr(word.deprel).to_string_lossy().into_owned() },
344 misc: unsafe { CStr::from_ptr(word.misc).to_string_lossy().into_owned() },
345 id: word.id,
346 head: word.head,
347 sentence_id: word.sentence_id,
348 });
349 }
350
351 unsafe { ffi::udpipe_result_free(result) };
352
353 Ok(words)
354 }
355}
356
357impl Drop for Model {
358 fn drop(&mut self) {
359 if !self.inner.is_null() {
360 unsafe { ffi::udpipe_model_free(self.inner) };
361 }
362 }
363}
364
365pub const AVAILABLE_MODELS: &[&str] = &[
370 "afrikaans-afribooms",
371 "ancient_greek-perseus",
372 "ancient_greek-proiel",
373 "arabic-padt",
374 "armenian-armtdp",
375 "basque-bdt",
376 "belarusian-hse",
377 "bulgarian-btb",
378 "buryat-bdt",
379 "catalan-ancora",
380 "chinese-gsd",
381 "chinese-gsdsimp",
382 "classical_chinese-kyoto",
383 "coptic-scriptorium",
384 "croatian-set",
385 "czech-cac",
386 "czech-cltt",
387 "czech-fictree",
388 "czech-pdt",
389 "danish-ddt",
390 "dutch-alpino",
391 "dutch-lassysmall",
392 "english-ewt",
393 "english-gum",
394 "english-lines",
395 "english-partut",
396 "estonian-edt",
397 "estonian-ewt",
398 "finnish-ftb",
399 "finnish-tdt",
400 "french-gsd",
401 "french-partut",
402 "french-sequoia",
403 "french-spoken",
404 "galician-ctg",
405 "galician-treegal",
406 "german-gsd",
407 "german-hdt",
408 "gothic-proiel",
409 "greek-gdt",
410 "hebrew-htb",
411 "hindi-hdtb",
412 "hungarian-szeged",
413 "indonesian-gsd",
414 "irish-idt",
415 "italian-isdt",
416 "italian-partut",
417 "italian-postwita",
418 "italian-twittiro",
419 "italian-vit",
420 "japanese-gsd",
421 "kazakh-ktb",
422 "korean-gsd",
423 "korean-kaist",
424 "kurmanji-mg",
425 "latin-ittb",
426 "latin-perseus",
427 "latin-proiel",
428 "latvian-lvtb",
429 "lithuanian-alksnis",
430 "lithuanian-hse",
431 "maltese-mudt",
432 "marathi-ufal",
433 "north_sami-giella",
434 "norwegian-bokmaal",
435 "norwegian-nynorsk",
436 "norwegian-nynorsklia",
437 "old_church_slavonic-proiel",
438 "old_french-srcmf",
439 "old_russian-torot",
440 "persian-seraji",
441 "polish-lfg",
442 "polish-pdb",
443 "polish-sz",
444 "portuguese-bosque",
445 "portuguese-br",
446 "portuguese-gsd",
447 "romanian-nonstandard",
448 "romanian-rrt",
449 "russian-gsd",
450 "russian-syntagrus",
451 "russian-taiga",
452 "sanskrit-ufal",
453 "scottish_gaelic-arcosg",
454 "serbian-set",
455 "slovak-snk",
456 "slovenian-ssj",
457 "slovenian-sst",
458 "spanish-ancora",
459 "spanish-gsd",
460 "swedish-lines",
461 "swedish-talbanken",
462 "tamil-ttb",
463 "telugu-mtg",
464 "turkish-imst",
465 "ukrainian-iu",
466 "upper_sorbian-ufal",
467 "urdu-udtb",
468 "uyghur-udt",
469 "vietnamese-vtb",
470 "wolof-wtb",
471];
472
473pub fn download_model(language: &str, dest_dir: impl AsRef<Path>) -> Result<String, UdpipeError> {
497 let dest_dir = dest_dir.as_ref();
498
499 if !AVAILABLE_MODELS.contains(&language) {
501 return Err(UdpipeError {
502 message: format!(
503 "Unknown language '{}'. Use one of: {}",
504 language,
505 AVAILABLE_MODELS[..5].join(", ") + ", ..."
506 ),
507 });
508 }
509
510 let filename = model_filename(language);
512 let dest_path = dest_dir.join(&filename);
513 let url = format!("{}/{}", MODEL_BASE_URL, filename);
514
515 download_model_from_url(&url, &dest_path)?;
517
518 Ok(dest_path.to_string_lossy().into_owned())
519}
520
521pub fn download_model_from_url(url: &str, path: impl AsRef<Path>) -> Result<(), UdpipeError> {
537 let path = path.as_ref();
538
539 if let Some(parent) = path.parent() {
541 if !parent.as_os_str().is_empty() {
542 std::fs::create_dir_all(parent)?;
543 }
544 }
545
546 let response = ureq::get(url).call().map_err(|e| UdpipeError {
548 message: format!("Failed to download: {}", e),
549 })?;
550
551 let mut data = Vec::new();
553 response
554 .into_body()
555 .into_reader()
556 .read_to_end(&mut data)
557 .map_err(|e| UdpipeError {
558 message: format!("Failed to read response: {}", e),
559 })?;
560
561 if data.is_empty() {
562 return Err(UdpipeError {
563 message: "Downloaded file is empty".to_string(),
564 });
565 }
566
567 std::fs::write(path, &data)?;
569
570 Ok(())
571}
572
573pub fn model_filename(language: &str) -> String {
581 format!("{}-ud-2.5-191206.udpipe", language)
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587
588 fn make_word(feats: &str) -> Word {
589 Word {
590 form: "test".to_string(),
591 lemma: "test".to_string(),
592 upostag: "NOUN".to_string(),
593 xpostag: String::new(),
594 feats: feats.to_string(),
595 deprel: "root".to_string(),
596 misc: String::new(),
597 id: 1,
598 head: 0,
599 sentence_id: 0,
600 }
601 }
602
603 #[test]
604 fn test_word_has_feature() {
605 let word = make_word("Mood=Imp|VerbForm=Fin");
606
607 assert!(word.has_feature("Mood", "Imp"));
608 assert!(word.has_feature("VerbForm", "Fin"));
609 assert!(!word.has_feature("Mood", "Ind"));
610 assert!(!word.has_feature("Tense", "Past"));
611 }
612
613 #[test]
614 fn test_word_has_feature_empty() {
615 let word = make_word("");
616 assert!(!word.has_feature("Mood", "Imp"));
617 }
618
619 #[test]
620 fn test_word_has_feature_single() {
621 let word = make_word("Mood=Imp");
622 assert!(word.has_feature("Mood", "Imp"));
623 assert!(!word.has_feature("VerbForm", "Fin"));
624 }
625
626 #[test]
627 fn test_word_get_feature() {
628 let word = make_word("Tense=Pres|VerbForm=Part");
629
630 assert_eq!(word.get_feature("Tense"), Some("Pres"));
631 assert_eq!(word.get_feature("VerbForm"), Some("Part"));
632 assert_eq!(word.get_feature("Mood"), None);
633 }
634
635 #[test]
636 fn test_word_get_feature_empty() {
637 let word = make_word("");
638 assert_eq!(word.get_feature("Mood"), None);
639 }
640
641 #[test]
642 fn test_word_get_feature_single() {
643 let word = make_word("Mood=Imp");
644 assert_eq!(word.get_feature("Mood"), Some("Imp"));
645 assert_eq!(word.get_feature("VerbForm"), None);
646 }
647
648 #[test]
649 fn test_word_is_verb() {
650 let mut word = make_word("");
651 word.upostag = "VERB".to_string();
652 assert!(word.is_verb());
653
654 word.upostag = "AUX".to_string();
655 assert!(word.is_verb());
656
657 word.upostag = "NOUN".to_string();
658 assert!(!word.is_verb());
659 }
660
661 #[test]
662 fn test_word_is_noun() {
663 let mut word = make_word("");
664 word.upostag = "NOUN".to_string();
665 assert!(word.is_noun());
666
667 word.upostag = "PROPN".to_string();
668 assert!(word.is_noun());
669
670 word.upostag = "VERB".to_string();
671 assert!(!word.is_noun());
672 }
673
674 #[test]
675 fn test_word_is_root() {
676 let mut word = make_word("");
677 word.deprel = "root".to_string();
678 assert!(word.is_root());
679
680 word.deprel = "nsubj".to_string();
681 assert!(!word.is_root());
682 }
683
684 #[test]
685 fn test_word_is_adjective() {
686 let mut word = make_word("");
687 word.upostag = "ADJ".to_string();
688 assert!(word.is_adjective());
689
690 word.upostag = "NOUN".to_string();
691 assert!(!word.is_adjective());
692 }
693
694 #[test]
695 fn test_word_is_punct() {
696 let mut word = make_word("");
697 word.upostag = "PUNCT".to_string();
698 assert!(word.is_punct());
699
700 word.upostag = "NOUN".to_string();
701 assert!(!word.is_punct());
702 }
703
704 #[test]
705 fn test_word_hash() {
706 use std::collections::HashSet;
707
708 let word1 = make_word("Mood=Imp");
709 let word2 = make_word("Mood=Imp");
710 let mut set = HashSet::new();
711 set.insert(word1.clone());
712 assert!(set.contains(&word2));
713 }
714
715 #[test]
716 fn test_model_filename() {
717 assert_eq!(
718 model_filename("english-ewt"),
719 "english-ewt-ud-2.5-191206.udpipe"
720 );
721 assert_eq!(
722 model_filename("dutch-alpino"),
723 "dutch-alpino-ud-2.5-191206.udpipe"
724 );
725 }
726
727 #[test]
728 fn test_available_models_contains_common_languages() {
729 assert!(AVAILABLE_MODELS.contains(&"english-ewt"));
730 assert!(AVAILABLE_MODELS.contains(&"german-gsd"));
731 assert!(AVAILABLE_MODELS.contains(&"french-gsd"));
732 assert!(AVAILABLE_MODELS.contains(&"spanish-ancora"));
733 }
734
735 #[test]
736 fn test_available_models_sorted() {
737 let mut sorted = AVAILABLE_MODELS.to_vec();
739 sorted.sort();
740 assert_eq!(AVAILABLE_MODELS, sorted.as_slice());
741 }
742
743 #[test]
744 fn test_download_model_invalid_language() {
745 let result = download_model("invalid-language-xyz", ".");
746 assert!(result.is_err());
747 let err = result.unwrap_err();
748 assert!(err.message.contains("Unknown language"));
749 }
750
751 #[test]
752 fn test_udpipe_error_display() {
753 let err = UdpipeError::new("test error");
754 assert_eq!(format!("{}", err), "UDPipe error: test error");
755 }
756
757 #[test]
758 fn test_udpipe_error_from_io() {
759 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
760 let err: UdpipeError = io_err.into();
761 assert!(err.message.contains("not found"));
762 }
763
764 #[test]
765 fn test_space_after() {
766 let mut word = make_word("");
767 word.misc = String::new();
768 assert!(word.space_after()); word.misc = "SpaceAfter=No".to_string();
771 assert!(!word.space_after());
772
773 word.misc = "SpaceAfter=No|Other=Value".to_string();
774 assert!(!word.space_after());
775 }
776
777 #[test]
778 fn test_model_load_nonexistent_file() {
779 let result = Model::load("/nonexistent/path/to/model.udpipe");
780 assert!(result.is_err());
781 }
782
783 #[test]
784 fn test_model_load_path_with_null_byte() {
785 let result = Model::load("path\0with\0nulls.udpipe");
786 let err = result.err().expect("expected error");
787 assert!(err.message.contains("null byte"));
788 }
789
790 #[test]
791 fn test_model_load_from_memory_empty() {
792 let result = Model::load_from_memory(&[]);
793 assert!(result.is_err());
794 }
795
796 #[test]
797 fn test_model_load_from_memory_invalid() {
798 let garbage = b"this is not a valid udpipe model";
799 let result = Model::load_from_memory(garbage);
800 assert!(result.is_err());
801 }
802
803 #[test]
804 fn test_download_model_from_url_invalid_url() {
805 let temp_dir = tempfile::tempdir().unwrap();
806 let path = temp_dir.path().join("model.udpipe");
807 let result = download_model_from_url("http://invalid.invalid/no-such-model", &path);
808 assert!(result.is_err());
809 let err = result.unwrap_err();
810 assert!(err.message.contains("Failed to download"));
811 }
812
813 #[test]
814 fn test_download_model_from_url_empty_response() {
815 let mut server = mockito::Server::new();
816 let mock = server
817 .mock("GET", "/empty-model.udpipe")
818 .with_status(200)
819 .with_body("")
820 .create();
821
822 let temp_dir = tempfile::tempdir().unwrap();
823 let path = temp_dir.path().join("model.udpipe");
824 let url = format!("{}/empty-model.udpipe", server.url());
825
826 let result = download_model_from_url(&url, &path);
827 assert!(result.is_err());
828 let err = result.unwrap_err();
829 assert!(err.message.contains("empty"));
830
831 mock.assert();
832 }
833
834 #[test]
835 fn test_download_model_from_url_creates_parent_dirs() {
836 let mut server = mockito::Server::new();
837 let mock = server
838 .mock("GET", "/model.udpipe")
839 .with_status(200)
840 .with_body("fake model data")
841 .create();
842
843 let temp_dir = tempfile::tempdir().unwrap();
844 let path = temp_dir.path().join("nested/dir/model.udpipe");
845 let url = format!("{}/model.udpipe", server.url());
846
847 let result = download_model_from_url(&url, &path);
848 assert!(result.is_ok());
849 assert!(path.exists());
850
851 mock.assert();
852 }
853
854 #[test]
855 #[cfg(unix)]
856 fn test_download_model_from_url_readonly_dir() {
857 use std::os::unix::fs::PermissionsExt;
858
859 let mut server = mockito::Server::new();
860 let _mock = server
861 .mock("GET", "/model.udpipe")
862 .with_status(200)
863 .with_body("fake model data")
864 .create();
865
866 let temp_dir = tempfile::tempdir().unwrap();
867 let readonly_dir = temp_dir.path().join("readonly");
868 std::fs::create_dir(&readonly_dir).unwrap();
869
870 let mut perms = std::fs::metadata(&readonly_dir).unwrap().permissions();
872 perms.set_mode(0o444);
873 std::fs::set_permissions(&readonly_dir, perms.clone()).unwrap();
874
875 let path = readonly_dir.join("nested/model.udpipe");
876 let url = format!("{}/model.udpipe", server.url());
877
878 let result = download_model_from_url(&url, &path);
879
880 perms.set_mode(0o755);
882 std::fs::set_permissions(&readonly_dir, perms).unwrap();
883
884 assert!(result.is_err());
885 }
886}