1mod prompts;
33mod types;
34
35pub use prompts::build_extraction_prompt;
36pub use types::{
37 EntityType, ExtractedEntity, ExtractedRelation, ExtractionOptions, ExtractionResult,
38 RelationType,
39};
40
41use serde::Deserialize;
42
43use crate::constants::{
44 EXTRACTION_CONFIDENCE_DEFAULT, EXTRACTION_CONFIDENCE_MAX, EXTRACTION_CONFIDENCE_MIN,
45 EXTRACTION_ENTITIES_COUNT_MAX, EXTRACTION_RELATIONS_COUNT_MAX, EXTRACTION_TEXT_BYTES_MAX,
46};
47use crate::llm::{CompletionRequest, LLMProvider, ProviderError};
48
49#[derive(Debug, Clone, thiserror::Error)]
58pub enum ExtractionError {
59 #[error("Text is empty")]
61 EmptyText,
62
63 #[error("Text too long: {len} bytes (max {max})")]
65 TextTooLong {
66 len: usize,
68 max: usize,
70 },
71
72 #[error("Invalid confidence: {value} (must be {min}-{max})")]
74 InvalidConfidence {
75 value: f64,
77 min: f64,
79 max: f64,
81 },
82}
83
84#[derive(Debug, Deserialize, Default)]
90struct LLMExtractionResponse {
91 #[serde(default)]
92 entities: Vec<RawEntity>,
93 #[serde(default)]
94 relations: Vec<RawRelation>,
95}
96
97#[derive(Debug, Deserialize)]
98struct RawEntity {
99 name: Option<String>,
100 #[serde(rename = "type")]
101 entity_type: Option<String>,
102 content: Option<String>,
103 confidence: Option<f64>,
104}
105
106#[derive(Debug, Deserialize)]
107struct RawRelation {
108 source: Option<String>,
109 target: Option<String>,
110 #[serde(rename = "type")]
111 relation_type: Option<String>,
112 confidence: Option<f64>,
113}
114
115#[derive(Debug)]
145pub struct EntityExtractor<P: LLMProvider> {
146 provider: P,
147}
148
149impl<P: LLMProvider> EntityExtractor<P> {
150 #[must_use]
152 pub fn new(provider: P) -> Self {
153 Self { provider }
154 }
155
156 pub async fn extract(
173 &self,
174 text: &str,
175 options: ExtractionOptions,
176 ) -> Result<ExtractionResult, ExtractionError> {
177 if text.is_empty() {
179 return Err(ExtractionError::EmptyText);
180 }
181 if text.len() > EXTRACTION_TEXT_BYTES_MAX {
182 return Err(ExtractionError::TextTooLong {
183 len: text.len(),
184 max: EXTRACTION_TEXT_BYTES_MAX,
185 });
186 }
187 if !(EXTRACTION_CONFIDENCE_MIN..=EXTRACTION_CONFIDENCE_MAX)
188 .contains(&options.min_confidence)
189 {
190 return Err(ExtractionError::InvalidConfidence {
191 value: options.min_confidence,
192 min: EXTRACTION_CONFIDENCE_MIN,
193 max: EXTRACTION_CONFIDENCE_MAX,
194 });
195 }
196
197 let existing = if options.existing_entities.is_empty() {
199 None
200 } else {
201 Some(options.existing_entities.as_slice())
202 };
203 let prompt = build_extraction_prompt(text, existing);
204
205 let (entities, relations) = match self.call_llm(&prompt, text).await {
207 Ok((e, r)) => (e, r),
208 Err(_) => {
209 (self.create_fallback_entity(text), Vec::new())
211 }
212 };
213
214 let entities: Vec<_> = if options.min_confidence > 0.0 {
216 entities
217 .into_iter()
218 .filter(|e| e.confidence >= options.min_confidence)
219 .collect()
220 } else {
221 entities
222 };
223
224 let relations: Vec<_> = if options.min_confidence > 0.0 {
225 relations
226 .into_iter()
227 .filter(|r| r.confidence >= options.min_confidence)
228 .collect()
229 } else {
230 relations
231 };
232
233 let entities: Vec<_> = entities
235 .into_iter()
236 .take(EXTRACTION_ENTITIES_COUNT_MAX)
237 .collect();
238 let relations: Vec<_> = relations
239 .into_iter()
240 .take(EXTRACTION_RELATIONS_COUNT_MAX)
241 .collect();
242
243 let result = ExtractionResult::new(entities, relations, text);
244
245 debug_assert!(
247 result.entity_count() <= EXTRACTION_ENTITIES_COUNT_MAX,
248 "too many entities"
249 );
250 debug_assert!(
251 result.relation_count() <= EXTRACTION_RELATIONS_COUNT_MAX,
252 "too many relations"
253 );
254
255 Ok(result)
256 }
257
258 pub async fn extract_entities_only(
266 &self,
267 text: &str,
268 ) -> Result<Vec<ExtractedEntity>, ExtractionError> {
269 let result = self.extract(text, ExtractionOptions::default()).await?;
270 Ok(result.entities)
271 }
272
273 async fn call_llm(
275 &self,
276 prompt: &str,
277 original_text: &str,
278 ) -> Result<(Vec<ExtractedEntity>, Vec<ExtractedRelation>), ProviderError> {
279 let request = CompletionRequest::new(prompt).with_json_mode();
280 let response = self.provider.complete(&request).await?;
281
282 let parsed = self.parse_response(&response, original_text);
284 Ok(parsed)
285 }
286
287 fn parse_response(
289 &self,
290 response: &str,
291 original_text: &str,
292 ) -> (Vec<ExtractedEntity>, Vec<ExtractedRelation>) {
293 let json_str = Self::extract_json_from_response(response);
295
296 let data: LLMExtractionResponse = match serde_json::from_str(json_str) {
298 Ok(d) => d,
299 Err(_) => {
300 return (self.create_fallback_entity(original_text), Vec::new());
302 }
303 };
304
305 let entities = self.parse_entities(&data.entities, original_text);
306 let relations = self.parse_relations(&data.relations);
307
308 if entities.is_empty() {
310 return (self.create_fallback_entity(original_text), relations);
311 }
312
313 (entities, relations)
314 }
315
316 fn extract_json_from_response(response: &str) -> &str {
321 let trimmed = response.trim();
322
323 if trimmed.starts_with("```json") {
325 if let Some(start_idx) = trimmed.find('\n') {
326 if let Some(end_idx) = trimmed.rfind("```") {
327 return trimmed[start_idx + 1..end_idx].trim();
328 }
329 }
330 }
331
332 if trimmed.starts_with("```") {
334 if let Some(start_idx) = trimmed.find('\n') {
335 if let Some(end_idx) = trimmed.rfind("```") {
336 return trimmed[start_idx + 1..end_idx].trim();
337 }
338 }
339 }
340
341 trimmed
343 }
344
345 fn parse_entities(
347 &self,
348 raw_entities: &[RawEntity],
349 original_text: &str,
350 ) -> Vec<ExtractedEntity> {
351 let mut entities = Vec::new();
352
353 for raw in raw_entities {
354 let name = match &raw.name {
356 Some(n) if !n.trim().is_empty() => n.trim().to_string(),
357 _ => continue,
358 };
359
360 let name = if name.len() > crate::constants::EXTRACTION_ENTITY_NAME_BYTES_MAX {
362 name[..crate::constants::EXTRACTION_ENTITY_NAME_BYTES_MAX].to_string()
363 } else {
364 name
365 };
366
367 let entity_type = raw
369 .entity_type
370 .as_deref()
371 .map(EntityType::from_str_or_note)
372 .unwrap_or(EntityType::Note);
373
374 let content = raw
376 .content
377 .as_deref()
378 .unwrap_or(&original_text[..200.min(original_text.len())])
379 .to_string();
380
381 let content = if content.len() > crate::constants::EXTRACTION_ENTITY_CONTENT_BYTES_MAX {
383 content[..crate::constants::EXTRACTION_ENTITY_CONTENT_BYTES_MAX].to_string()
384 } else {
385 content
386 };
387
388 let confidence = raw
390 .confidence
391 .map(|c| c.clamp(EXTRACTION_CONFIDENCE_MIN, EXTRACTION_CONFIDENCE_MAX))
392 .unwrap_or(EXTRACTION_CONFIDENCE_DEFAULT);
393
394 entities.push(ExtractedEntity::new(name, entity_type, content, confidence));
395 }
396
397 entities
398 }
399
400 fn parse_relations(&self, raw_relations: &[RawRelation]) -> Vec<ExtractedRelation> {
402 let mut relations = Vec::new();
403
404 for raw in raw_relations {
405 let source = match &raw.source {
407 Some(s) if !s.trim().is_empty() => s.trim().to_string(),
408 _ => continue,
409 };
410
411 let target = match &raw.target {
412 Some(t) if !t.trim().is_empty() => t.trim().to_string(),
413 _ => continue,
414 };
415
416 let relation_type = raw
418 .relation_type
419 .as_deref()
420 .map(RelationType::from_str_or_relates_to)
421 .unwrap_or(RelationType::RelatesTo);
422
423 let confidence = raw
425 .confidence
426 .map(|c| c.clamp(EXTRACTION_CONFIDENCE_MIN, EXTRACTION_CONFIDENCE_MAX))
427 .unwrap_or(EXTRACTION_CONFIDENCE_DEFAULT);
428
429 relations.push(ExtractedRelation::new(
430 source,
431 target,
432 relation_type,
433 confidence,
434 ));
435 }
436
437 relations
438 }
439
440 fn create_fallback_entity(&self, text: &str) -> Vec<ExtractedEntity> {
442 let name = format!("Note: {}", &text[..50.min(text.len())]);
443 let content = text[..500.min(text.len())].to_string();
444
445 vec![ExtractedEntity::new(
446 name,
447 EntityType::Note,
448 content,
449 EXTRACTION_CONFIDENCE_DEFAULT,
450 )]
451 }
452
453 #[must_use]
455 pub fn provider(&self) -> &P {
456 &self.provider
457 }
458}
459
460#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::llm::SimLLMProvider;
468
469 fn create_test_extractor(seed: u64) -> EntityExtractor<SimLLMProvider> {
470 EntityExtractor::new(SimLLMProvider::with_seed(seed))
471 }
472
473 #[tokio::test]
474 async fn test_basic_extraction() {
475 let extractor = create_test_extractor(42);
476
477 let result = extractor
478 .extract("Alice works at Acme Corp", ExtractionOptions::default())
479 .await
480 .unwrap();
481
482 assert!(!result.is_empty());
484 assert_eq!(result.raw_text, "Alice works at Acme Corp");
485 }
486
487 #[tokio::test]
488 async fn test_extraction_with_existing_entities() {
489 let extractor = create_test_extractor(42);
490
491 let options = ExtractionOptions::new()
492 .with_existing_entities(vec!["Alice".to_string(), "Acme".to_string()]);
493
494 let result = extractor
495 .extract("She joined last month", options)
496 .await
497 .unwrap();
498
499 assert!(!result.is_empty());
500 }
501
502 #[tokio::test]
503 async fn test_extraction_entities_only() {
504 let extractor = create_test_extractor(42);
505
506 let entities = extractor
507 .extract_entities_only("Bob met Charlie at Google")
508 .await
509 .unwrap();
510
511 assert!(!entities.is_empty());
512 }
513
514 #[tokio::test]
515 async fn test_extraction_with_min_confidence() {
516 let extractor = create_test_extractor(42);
517
518 let options = ExtractionOptions::new().with_min_confidence(0.9);
519
520 let result = extractor
521 .extract("Alice works at Acme", options)
522 .await
523 .unwrap();
524
525 for entity in &result.entities {
527 assert!(entity.confidence >= 0.9);
528 }
529 }
530
531 #[tokio::test]
532 async fn test_empty_text_error() {
533 let extractor = create_test_extractor(42);
534
535 let result = extractor.extract("", ExtractionOptions::default()).await;
536
537 assert!(matches!(result, Err(ExtractionError::EmptyText)));
538 }
539
540 #[tokio::test]
541 async fn test_text_too_long_error() {
542 let extractor = create_test_extractor(42);
543
544 let long_text = "x".repeat(EXTRACTION_TEXT_BYTES_MAX + 1);
545 let result = extractor
546 .extract(&long_text, ExtractionOptions::default())
547 .await;
548
549 assert!(matches!(result, Err(ExtractionError::TextTooLong { .. })));
550 }
551
552 #[tokio::test]
553 async fn test_invalid_confidence_error() {
554 let extractor = create_test_extractor(42);
555
556 let result = extractor
557 .extract(
558 "test",
559 ExtractionOptions {
560 existing_entities: vec![],
561 min_confidence: 1.5,
562 },
563 )
564 .await;
565
566 assert!(matches!(
567 result,
568 Err(ExtractionError::InvalidConfidence { .. })
569 ));
570 }
571
572 #[tokio::test]
573 async fn test_determinism() {
574 let extractor1 = create_test_extractor(42);
575 let extractor2 = create_test_extractor(42);
576
577 let result1 = extractor1
578 .extract("Alice works at Microsoft", ExtractionOptions::default())
579 .await
580 .unwrap();
581
582 let result2 = extractor2
583 .extract("Alice works at Microsoft", ExtractionOptions::default())
584 .await
585 .unwrap();
586
587 assert_eq!(result1.entity_count(), result2.entity_count());
589 assert_eq!(result1.relation_count(), result2.relation_count());
590 }
591
592 #[test]
593 fn test_parse_entities_with_valid_data() {
594 let extractor = create_test_extractor(42);
595
596 let raw = vec![
597 RawEntity {
598 name: Some("Alice".to_string()),
599 entity_type: Some("person".to_string()),
600 content: Some("A person".to_string()),
601 confidence: Some(0.9),
602 },
603 RawEntity {
604 name: Some("Acme".to_string()),
605 entity_type: Some("org".to_string()),
606 content: Some("A company".to_string()),
607 confidence: Some(0.8),
608 },
609 ];
610
611 let entities = extractor.parse_entities(&raw, "original text");
612
613 assert_eq!(entities.len(), 2);
614 assert_eq!(entities[0].name, "Alice");
615 assert_eq!(entities[0].entity_type, EntityType::Person);
616 assert_eq!(entities[1].name, "Acme");
617 assert_eq!(entities[1].entity_type, EntityType::Organization);
618 }
619
620 #[test]
621 fn test_parse_entities_with_invalid_data() {
622 let extractor = create_test_extractor(42);
623
624 let raw = vec![
625 RawEntity {
626 name: None, entity_type: Some("person".to_string()),
628 content: None,
629 confidence: None,
630 },
631 RawEntity {
632 name: Some(" ".to_string()), entity_type: None,
634 content: None,
635 confidence: None,
636 },
637 ];
638
639 let entities = extractor.parse_entities(&raw, "original text");
640
641 assert!(entities.is_empty());
643 }
644
645 #[test]
646 fn test_parse_entities_with_unknown_type() {
647 let extractor = create_test_extractor(42);
648
649 let raw = vec![RawEntity {
650 name: Some("Unknown".to_string()),
651 entity_type: Some("unknown_type".to_string()),
652 content: None,
653 confidence: None,
654 }];
655
656 let entities = extractor.parse_entities(&raw, "original text");
657
658 assert_eq!(entities.len(), 1);
659 assert_eq!(entities[0].entity_type, EntityType::Note); }
661
662 #[test]
663 fn test_parse_relations_with_valid_data() {
664 let extractor = create_test_extractor(42);
665
666 let raw = vec![RawRelation {
667 source: Some("Alice".to_string()),
668 target: Some("Acme".to_string()),
669 relation_type: Some("works_at".to_string()),
670 confidence: Some(0.9),
671 }];
672
673 let relations = extractor.parse_relations(&raw);
674
675 assert_eq!(relations.len(), 1);
676 assert_eq!(relations[0].source, "Alice");
677 assert_eq!(relations[0].target, "Acme");
678 assert_eq!(relations[0].relation_type, RelationType::WorksAt);
679 }
680
681 #[test]
682 fn test_parse_relations_with_missing_fields() {
683 let extractor = create_test_extractor(42);
684
685 let raw = vec![
686 RawRelation {
687 source: None,
688 target: Some("Acme".to_string()),
689 relation_type: None,
690 confidence: None,
691 },
692 RawRelation {
693 source: Some("Alice".to_string()),
694 target: None,
695 relation_type: None,
696 confidence: None,
697 },
698 ];
699
700 let relations = extractor.parse_relations(&raw);
701
702 assert!(relations.is_empty());
704 }
705
706 #[test]
707 fn test_create_fallback_entity() {
708 let extractor = create_test_extractor(42);
709
710 let fallback = extractor.create_fallback_entity("This is some text for testing");
711
712 assert_eq!(fallback.len(), 1);
713 assert!(fallback[0].name.starts_with("Note: "));
714 assert_eq!(fallback[0].entity_type, EntityType::Note);
715 assert_eq!(fallback[0].confidence, EXTRACTION_CONFIDENCE_DEFAULT);
716 }
717
718 #[test]
719 fn test_provider_accessor() {
720 let provider = SimLLMProvider::with_seed(42);
721 let extractor = EntityExtractor::new(provider);
722
723 assert!(extractor.provider().is_simulation());
724 }
725}
726
727#[cfg(test)]
732mod dst_tests {
733 use super::*;
734 use crate::dst::{FaultConfig, FaultType, SimConfig, Simulation};
735 use crate::llm::SimLLMProvider;
736
737 #[tokio::test]
742 async fn test_extract_with_llm_timeout() {
743 let sim = Simulation::new(SimConfig::with_seed(42))
744 .with_fault(FaultConfig::new(FaultType::LlmTimeout, 1.0)); sim.run(|env| async move {
747 let llm = SimLLMProvider::with_faults(42, env.faults.clone());
748 let extractor = EntityExtractor::new(llm);
749
750 let result = extractor
751 .extract("Alice works at Acme Corp", ExtractionOptions::default())
752 .await;
753
754 match result {
756 Ok(extraction) => {
757 assert!(
759 !extraction.entities.is_empty(),
760 "BUG: Should return fallback entity on timeout, got empty"
761 );
762 assert_eq!(
763 extraction.entities.len(),
764 1,
765 "BUG: Should have exactly one fallback entity"
766 );
767
768 let entity = &extraction.entities[0];
769
770 assert_eq!(
772 entity.entity_type,
773 EntityType::Note,
774 "BUG: Fallback entity should have type Note, got {:?}. This suggests fault didn't fire!",
775 entity.entity_type
776 );
777 assert!(
778 entity.name.starts_with("Note: "),
779 "BUG: Fallback entity name should start with 'Note: ', got '{}'. Fault may not have fired!",
780 entity.name
781 );
782
783 assert_eq!(
785 entity.confidence,
786 EXTRACTION_CONFIDENCE_DEFAULT,
787 "BUG: Fallback should have confidence {}, got {}",
788 EXTRACTION_CONFIDENCE_DEFAULT,
789 entity.confidence
790 );
791
792 println!("✓ VERIFIED: LLM timeout actually fired, fallback entity created (type=Note, name={}, confidence={})",
793 entity.name, entity.confidence);
794 }
795 Err(e) => {
796 panic!("BUG: LLM timeout should return fallback, not error: {:?}", e);
797 }
798 }
799
800 Ok::<_, anyhow::Error>(())
801 })
802 .await
803 .unwrap();
804 }
805
806 #[tokio::test]
811 async fn test_extract_with_llm_rate_limit() {
812 let sim = Simulation::new(SimConfig::with_seed(42))
813 .with_fault(FaultConfig::new(FaultType::LlmRateLimit, 1.0));
814
815 sim.run(|env| async move {
816 let llm = SimLLMProvider::with_faults(42, env.faults.clone());
817 let extractor = EntityExtractor::new(llm);
818
819 let result = extractor
820 .extract("Bob is the CTO at TechCo", ExtractionOptions::default())
821 .await;
822
823 match result {
824 Ok(extraction) => {
825 assert!(
826 !extraction.entities.is_empty(),
827 "BUG: Should return fallback on rate limit, got empty"
828 );
829
830 println!("✓ LLM rate limit handled gracefully: fallback entity created");
831 }
832 Err(e) => {
833 panic!(
834 "BUG: Rate limit should return fallback, not error: {:?}",
835 e
836 );
837 }
838 }
839
840 Ok::<_, anyhow::Error>(())
841 })
842 .await
843 .unwrap();
844 }
845
846 #[tokio::test]
851 async fn test_extract_with_llm_invalid_response() {
852 let sim = Simulation::new(SimConfig::with_seed(42))
853 .with_fault(FaultConfig::new(FaultType::LlmInvalidResponse, 1.0));
854
855 sim.run(|env| async move {
856 let llm = SimLLMProvider::with_faults(42, env.faults.clone());
857 let extractor = EntityExtractor::new(llm);
858
859 let result = extractor
860 .extract(
861 "Carol manages the engineering team",
862 ExtractionOptions::default(),
863 )
864 .await;
865
866 match result {
867 Ok(extraction) => {
868 assert!(
870 !extraction.entities.is_empty(),
871 "BUG: Should return fallback on invalid response, got empty"
872 );
873
874 println!("✓ Invalid LLM response handled: fallback entity created");
875 }
876 Err(e) => {
877 println!("Invalid response returned error (acceptable): {:?}", e);
879 }
880 }
881
882 Ok::<_, anyhow::Error>(())
883 })
884 .await
885 .unwrap();
886 }
887
888 #[tokio::test]
893 async fn test_extract_with_probabilistic_failure() {
894 let sim = Simulation::new(SimConfig::with_seed(42))
895 .with_fault(FaultConfig::new(FaultType::LlmTimeout, 0.5)); sim.run(|env| async move {
898 let llm = SimLLMProvider::with_faults(42, env.faults.clone());
899 let extractor = EntityExtractor::new(llm);
900
901 let mut fallback_count = 0;
902 let mut success_count = 0;
903
904 for i in 0..10 {
906 let result = extractor
907 .extract(
908 &format!("Person {} is a software engineer", i),
909 ExtractionOptions::default(),
910 )
911 .await;
912
913 match result {
914 Ok(extraction) => {
915 if extraction.entities.len() == 1
917 && extraction.entities[0].entity_type == EntityType::Note
918 {
919 fallback_count += 1;
920 } else {
921 success_count += 1;
922 }
923 }
924 Err(_) => {
925 fallback_count += 1; }
927 }
928 }
929
930 assert!(
933 fallback_count == 10,
934 "BUG: With seed 42 + 50% rate, should have 10 fallbacks (deterministic). Got {}",
935 fallback_count
936 );
937 assert!(
938 success_count == 0,
939 "BUG: With seed 42 + 50% rate, should have 0 successes (deterministic). Got {}",
940 success_count
941 );
942
943 println!(
944 "✓ Probabilistic failure DETERMINISTIC: {} fallbacks, {} successes (seed 42)",
945 fallback_count, success_count
946 );
947
948 Ok::<_, anyhow::Error>(())
949 })
950 .await
951 .unwrap();
952 }
953
954 #[tokio::test]
959 async fn test_extract_with_llm_service_unavailable() {
960 let sim = Simulation::new(SimConfig::with_seed(42))
961 .with_fault(FaultConfig::new(FaultType::LlmServiceUnavailable, 1.0));
962
963 sim.run(|env| async move {
964 let llm = SimLLMProvider::with_faults(42, env.faults.clone());
965 let extractor = EntityExtractor::new(llm);
966
967 let result = extractor
968 .extract("Test entity extraction", ExtractionOptions::default())
969 .await;
970
971 match result {
972 Ok(extraction) => {
973 assert!(
974 !extraction.entities.is_empty(),
975 "BUG: Should return fallback on service unavailable"
976 );
977 println!("✓ Service unavailable handled: fallback entity created");
978 }
979 Err(e) => {
980 panic!(
981 "BUG: Service unavailable should return fallback, not error: {:?}",
982 e
983 );
984 }
985 }
986
987 Ok::<_, anyhow::Error>(())
988 })
989 .await
990 .unwrap();
991 }
992}