1use std::sync::Arc;
29
30use crate::engine::{EngineError, RedactionEngine};
31use crate::merge::Finding;
32
33const IOU_THRESHOLD: f64 = 0.5;
35
36pub struct EnsembleEngine {
53 engines: Vec<Arc<dyn RedactionEngine>>,
54 dual_confirm_labels: std::collections::BTreeSet<crate::label::PrivacyLabel>,
59 model_ids: Vec<String>,
63}
64
65impl std::fmt::Debug for EnsembleEngine {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("EnsembleEngine")
71 .field("engine_count", &self.engines.len())
72 .field("dual_confirm_labels", &self.dual_confirm_labels)
73 .field("model_ids", &self.model_ids)
74 .finish_non_exhaustive()
75 }
76}
77
78struct PerEngineFinding {
81 engine_idx: usize,
82 finding: Finding,
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
103#[non_exhaustive]
104pub struct EngineAttribution {
105 pub finding_index: usize,
107 pub contributing_engines: Vec<String>,
110}
111
112impl EnsembleEngine {
113 pub fn new(engines: Vec<Arc<dyn RedactionEngine>>) -> Self {
116 Self {
117 engines,
118 dual_confirm_labels: std::collections::BTreeSet::new(),
119 model_ids: Vec::new(),
120 }
121 }
122
123 pub fn with_model_ids<I, S>(mut self, ids: I) -> Self
129 where
130 I: IntoIterator<Item = S>,
131 S: Into<String>,
132 {
133 let collected: Vec<String> = ids.into_iter().map(Into::into).collect();
134 assert_eq!(
135 collected.len(),
136 self.engines.len(),
137 "EnsembleEngine::with_model_ids:ids.len()={} 与 engines.len()={} 不匹配",
138 collected.len(),
139 self.engines.len()
140 );
141 self.model_ids = collected;
142 self
143 }
144
145 pub fn with_dual_confirm<I>(mut self, labels: I) -> Self
158 where
159 I: IntoIterator<Item = crate::label::PrivacyLabel>,
160 {
161 self.dual_confirm_labels = labels.into_iter().collect();
162 self
163 }
164
165 pub fn engine_count(&self) -> usize {
167 self.engines.len()
168 }
169
170 pub fn infer_with_attribution(
189 &self,
190 text: &str,
191 ) -> Result<(Vec<Finding>, Vec<EngineAttribution>), EngineError> {
192 self.infer_with_attribution_with_lang(text, None)
195 }
196
197 pub fn infer_with_attribution_with_lang(
213 &self,
214 text: &str,
215 lang: Option<&str>,
216 ) -> Result<(Vec<Finding>, Vec<EngineAttribution>), EngineError> {
217 let mut per_engine: Vec<PerEngineFinding> = Vec::new();
218 for (idx, engine) in self.engines.iter().enumerate() {
219 let f = engine.infer_with_lang(text, lang)?;
221 for finding in f {
222 per_engine.push(PerEngineFinding {
223 engine_idx: idx,
224 finding,
225 });
226 }
227 }
228
229 let (findings, attrs) =
231 ensemble_merge_with_attribution(per_engine, &self.dual_confirm_labels, &self.model_ids);
232 Ok((findings, attrs))
233 }
234}
235
236impl RedactionEngine for EnsembleEngine {
237 fn infer(&self, text: &str) -> Result<Vec<Finding>, EngineError> {
238 self.infer_with_lang(text, None)
241 }
242
243 fn infer_with_lang(&self, text: &str, lang: Option<&str>) -> Result<Vec<Finding>, EngineError> {
251 let mut per_engine: Vec<PerEngineFinding> = Vec::new();
252 for (idx, engine) in self.engines.iter().enumerate() {
253 let f = engine.infer_with_lang(text, lang)?;
255 for finding in f {
256 per_engine.push(PerEngineFinding {
257 engine_idx: idx,
258 finding,
259 });
260 }
261 }
262 Ok(ensemble_merge_with_dual_confirm(
263 per_engine,
264 &self.dual_confirm_labels,
265 ))
266 }
267}
268
269fn iou(a: (usize, usize), b: (usize, usize)) -> f64 {
271 let inter_start = a.0.max(b.0);
272 let inter_end = a.1.min(b.1);
273 if inter_start >= inter_end {
274 return 0.0;
275 }
276 let inter = (inter_end - inter_start) as f64;
277 let union_start = a.0.min(b.0);
278 let union_end = a.1.max(b.1);
279 let union = (union_end - union_start) as f64;
280 if union <= 0.0 {
281 0.0
282 } else {
283 inter / union
284 }
285}
286
287#[allow(dead_code)] fn ensemble_merge(all: Vec<Finding>) -> Vec<Finding> {
296 let mut finals: Vec<Finding> = Vec::new();
297 for f in all {
298 let mut absorbed = false;
299 for slot in finals.iter_mut() {
300 if slot.kind == f.kind && iou(slot.span, f.span) >= IOU_THRESHOLD {
301 let cur_len = f.span.1.saturating_sub(f.span.0);
302 let exist_len = slot.span.1.saturating_sub(slot.span.0);
303 if cur_len > exist_len {
304 *slot = f.clone();
305 }
306 absorbed = true;
307 break;
308 }
309 }
310 if !absorbed {
311 finals.push(f);
312 }
313 }
314 finals.sort_by_key(|f| f.span.0);
315 finals
316}
317
318fn ensemble_merge_with_dual_confirm(
332 per_engine: Vec<PerEngineFinding>,
333 dual_confirm: &std::collections::BTreeSet<crate::label::PrivacyLabel>,
334) -> Vec<Finding> {
335 use crate::label::PrivacyLabel;
336
337 let mut clusters: Vec<(Option<PrivacyLabel>, Vec<PerEngineFinding>)> = Vec::new();
339 for pf in per_engine {
340 let canonical = PrivacyLabel::from_kind(pf.finding.kind);
341 let target_idx = clusters.iter().position(|(existing_label, group)| {
343 *existing_label == canonical
344 && group
345 .iter()
346 .any(|g| iou(g.finding.span, pf.finding.span) >= IOU_THRESHOLD)
347 });
348 match target_idx {
349 Some(idx) => clusters[idx].1.push(pf),
350 None => clusters.push((canonical, vec![pf])),
351 }
352 }
353
354 let mut finals: Vec<Finding> = Vec::new();
356 for (canonical, group) in clusters {
357 if let Some(label) = canonical {
359 if dual_confirm.contains(&label) {
360 let distinct: std::collections::BTreeSet<usize> =
361 group.iter().map(|p| p.engine_idx).collect();
362 if distinct.len() < 2 {
363 continue;
365 }
366 }
367 }
368 if let Some(longest) = group
370 .into_iter()
371 .map(|p| p.finding)
372 .max_by_key(|f| f.span.1.saturating_sub(f.span.0))
373 {
374 finals.push(longest);
375 }
376 }
377
378 finals.sort_by_key(|f| f.span.0);
379 finals
380}
381
382fn ensemble_merge_with_attribution(
388 per_engine: Vec<PerEngineFinding>,
389 dual_confirm: &std::collections::BTreeSet<crate::label::PrivacyLabel>,
390 model_ids: &[String],
391) -> (Vec<Finding>, Vec<EngineAttribution>) {
392 use crate::label::PrivacyLabel;
393
394 let mut clusters: Vec<(Option<PrivacyLabel>, Vec<PerEngineFinding>)> = Vec::new();
395 for pf in per_engine {
396 let canonical = PrivacyLabel::from_kind(pf.finding.kind);
397 let target_idx = clusters.iter().position(|(existing_label, group)| {
398 *existing_label == canonical
399 && group
400 .iter()
401 .any(|g| iou(g.finding.span, pf.finding.span) >= IOU_THRESHOLD)
402 });
403 match target_idx {
404 Some(idx) => clusters[idx].1.push(pf),
405 None => clusters.push((canonical, vec![pf])),
406 }
407 }
408
409 let mut staged: Vec<(Finding, std::collections::BTreeSet<usize>)> = Vec::new();
411 for (canonical, group) in clusters {
412 let distinct: std::collections::BTreeSet<usize> =
413 group.iter().map(|p| p.engine_idx).collect();
414
415 if let Some(label) = canonical {
417 if dual_confirm.contains(&label) && distinct.len() < 2 {
418 continue;
419 }
420 }
421 if let Some(longest) = group
423 .into_iter()
424 .map(|p| p.finding)
425 .max_by_key(|f| f.span.1.saturating_sub(f.span.0))
426 {
427 staged.push((longest, distinct));
428 }
429 }
430
431 staged.sort_by_key(|(f, _)| f.span.0);
432
433 let mut findings = Vec::with_capacity(staged.len());
434 let mut attrs = Vec::with_capacity(staged.len());
435 for (idx, (finding, distinct)) in staged.into_iter().enumerate() {
436 let mut contributing_engines: Vec<String> = distinct
437 .into_iter()
438 .map(|engine_idx| {
439 model_ids
440 .get(engine_idx)
441 .cloned()
442 .unwrap_or_else(|| format!("unknown-{engine_idx}"))
443 })
444 .collect();
445 contributing_engines.sort();
447 findings.push(finding);
448 attrs.push(EngineAttribution {
449 finding_index: idx,
450 contributing_engines,
451 });
452 }
453
454 (findings, attrs)
455}
456
457#[cfg(test)]
460#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
461mod tests {
462 use super::*;
463 use crate::engine::{MockEngine, NoopEngine};
464
465 #[test]
466 fn ensemble_empty_engines_returns_empty() {
467 let ens = EnsembleEngine::new(vec![]);
468 let f = ens.infer("text").unwrap();
469 assert!(f.is_empty(), "0 engines 应返空 findings");
470 assert_eq!(ens.engine_count(), 0);
471 }
472
473 #[test]
474 fn ensemble_single_noop_returns_empty() {
475 let ens = EnsembleEngine::new(vec![Arc::new(NoopEngine)]);
476 let f = ens.infer("hello world").unwrap();
477 assert!(f.is_empty());
478 assert_eq!(ens.engine_count(), 1);
479 }
480
481 #[test]
482 fn ensemble_two_engines_disjoint_findings_both_kept() {
483 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
485 "person",
486 (0, 5),
487 0.9,
488 5,
489 )]));
490 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
491 "email",
492 (10, 30),
493 0.95,
494 10,
495 )]));
496 let ens = EnsembleEngine::new(vec![a, b]);
497 let f = ens.infer("anything").unwrap();
498 assert_eq!(f.len(), 2, "无重叠应都保留");
499 assert_eq!(f[0].span, (0, 5));
501 assert_eq!(f[1].span, (10, 30));
502 }
503
504 #[test]
505 fn ensemble_same_kind_overlapping_picks_longer_span() {
506 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
509 "person",
510 (0, 5),
511 0.9,
512 5,
513 )]));
514 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
515 "person",
516 (0, 10),
517 0.85,
518 5,
519 )]));
520 let ens = EnsembleEngine::new(vec![a, b]);
521 let f = ens.infer("anything").unwrap();
522 assert_eq!(f.len(), 1, "同 kind IoU >= 0.5 应合并");
523 assert_eq!(f[0].span, (0, 10), "应取 longer span (10 > 5)");
524 }
525
526 #[test]
527 fn ensemble_same_kind_low_iou_both_kept() {
528 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
530 "person",
531 (0, 5),
532 0.9,
533 5,
534 )]));
535 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
536 "person",
537 (20, 30),
538 0.9,
539 5,
540 )]));
541 let ens = EnsembleEngine::new(vec![a, b]);
542 let f = ens.infer("any").unwrap();
543 assert_eq!(f.len(), 2, "同 kind 不重叠应都保留");
544 }
545
546 #[test]
547 fn ensemble_different_kind_overlapping_both_kept() {
548 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
550 "person",
551 (0, 10),
552 0.9,
553 5,
554 )]));
555 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
556 "email",
557 (0, 10),
558 0.9,
559 10,
560 )]));
561 let ens = EnsembleEngine::new(vec![a, b]);
562 let f = ens.infer("any").unwrap();
563 assert_eq!(f.len(), 2, "不同 kind 重叠不去重");
564 }
565
566 #[test]
567 fn ensemble_propagates_engine_error() {
568 struct FailingEngine;
569 impl RedactionEngine for FailingEngine {
570 fn infer(&self, _: &str) -> Result<Vec<Finding>, EngineError> {
571 Err(EngineError::InferRun("mock-failure".to_string()))
572 }
573 }
574 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
575 "person",
576 (0, 5),
577 0.9,
578 5,
579 )]));
580 let b = Arc::new(FailingEngine);
581 let ens = EnsembleEngine::new(vec![a, b]);
582 let r = ens.infer("any");
583 assert!(
584 matches!(r, Err(EngineError::InferRun(_))),
585 "任一引擎失败应 propagate(fail-closed)"
586 );
587 }
588
589 #[test]
590 fn ensemble_three_engines_iou_above_threshold_merges() {
591 let xlmr = Arc::new(MockEngine::from_findings(vec![Finding::model(
597 "person",
598 (0, 6),
599 0.85,
600 5,
601 )]));
602 let yonigo = Arc::new(MockEngine::from_findings(vec![]));
603 let openai = Arc::new(MockEngine::from_findings(vec![Finding::model(
604 "person",
605 (0, 10),
606 0.95,
607 5,
608 )]));
609 let ens = EnsembleEngine::new(vec![xlmr, yonigo, openai]);
610 let f = ens.infer("John Smith works here.").unwrap();
611 assert_eq!(f.len(), 1, "三 engine 同 kind IoU 0.6 应合 1");
612 assert_eq!(f[0].span, (0, 10), "应取 longer span (10 > 6)");
613 }
614
615 #[test]
616 fn ensemble_spike3_realistic_iou_below_threshold_keeps_both() {
617 let xlmr = Arc::new(MockEngine::from_findings(vec![Finding::model(
621 "person",
622 (0, 4),
623 0.85,
624 5,
625 )]));
626 let openai = Arc::new(MockEngine::from_findings(vec![Finding::model(
627 "person",
628 (0, 10),
629 0.95,
630 5,
631 )]));
632 let ens = EnsembleEngine::new(vec![xlmr, openai]);
633 let f = ens.infer("John Smith.").unwrap();
634 assert_eq!(f.len(), 2, "IoU 0.4 < 0.5 不合并(spike-3 实测真实行为)");
635 }
636
637 #[test]
638 fn ensemble_iou_threshold_boundary_just_below_05_keeps_both() {
639 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
642 "person",
643 (0, 4),
644 0.9,
645 5,
646 )]));
647 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
648 "person",
649 (3, 10),
650 0.9,
651 5,
652 )]));
653 let ens = EnsembleEngine::new(vec![a, b]);
654 let f = ens.infer("any").unwrap();
655 assert_eq!(f.len(), 2, "IoU < 0.5 不合并");
656 }
657
658 use crate::label::PrivacyLabel;
660
661 #[test]
663 fn dual_confirm_default_off_keeps_single_engine_finding() {
664 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
665 "private_address",
666 (0, 10),
667 0.9,
668 5,
669 )]));
670 let b = Arc::new(MockEngine::from_findings(vec![]));
671 let ens = EnsembleEngine::new(vec![a, b]); let f = ens.infer("any").unwrap();
673 assert_eq!(f.len(), 1, "默认无 dual_confirm,单 engine 报应保留");
674 }
675
676 #[test]
678 fn dual_confirm_address_drops_single_engine_finding() {
679 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
680 "private_address",
681 (0, 10),
682 0.9,
683 5,
684 )]));
685 let b = Arc::new(MockEngine::from_findings(vec![]));
686 let ens = EnsembleEngine::new(vec![a, b]).with_dual_confirm([PrivacyLabel::Address]);
687 let f = ens.infer("any").unwrap();
688 assert!(
689 f.is_empty(),
690 "dual_confirm Address 启用 + 仅 engine_a 报 → 丢弃,实际: {:?}",
691 f
692 );
693 }
694
695 #[test]
697 fn dual_confirm_address_keeps_dual_engine_consensus() {
698 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
701 "private_address",
702 (0, 6),
703 0.9,
704 5,
705 )]));
706 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
707 "private_address",
708 (0, 10),
709 0.9,
710 5,
711 )]));
712 let ens = EnsembleEngine::new(vec![a, b]).with_dual_confirm([PrivacyLabel::Address]);
713 let f = ens.infer("any").unwrap();
714 assert_eq!(f.len(), 1, "双 engine 共识应保留 1");
715 assert_eq!(f[0].span, (0, 10), "应取 longest span");
716 }
717
718 #[test]
720 fn dual_confirm_selective_keeps_other_labels() {
721 let a = Arc::new(MockEngine::from_findings(vec![
722 Finding::model("private_person", (0, 10), 0.9, 5),
723 Finding::model("private_address", (20, 30), 0.9, 5),
724 ]));
725 let b = Arc::new(MockEngine::from_findings(vec![])); let ens = EnsembleEngine::new(vec![a, b]).with_dual_confirm([PrivacyLabel::Address]);
727 let f = ens.infer("any").unwrap();
728 assert_eq!(f.len(), 1);
730 assert_eq!(f[0].kind, "private_person");
731 }
732
733 #[test]
735 fn dual_confirm_multi_labels() {
736 let a = Arc::new(MockEngine::from_findings(vec![
737 Finding::model("private_address", (0, 10), 0.9, 5),
738 Finding::model("private_date", (20, 30), 0.9, 5),
739 Finding::model("private_email", (40, 50), 0.9, 5),
740 ]));
741 let b = Arc::new(MockEngine::from_findings(vec![])); let ens = EnsembleEngine::new(vec![a, b])
743 .with_dual_confirm([PrivacyLabel::Address, PrivacyLabel::Date]);
744 let f = ens.infer("any").unwrap();
745 assert_eq!(f.len(), 1);
747 assert_eq!(f[0].kind, "private_email");
748 }
749
750 #[test]
752 fn dual_confirm_separate_clusters_each_checked() {
753 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
756 "private_address",
757 (0, 5),
758 0.9,
759 5,
760 )]));
761 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
762 "private_address",
763 (20, 30),
764 0.9,
765 5,
766 )]));
767 let ens = EnsembleEngine::new(vec![a, b]).with_dual_confirm([PrivacyLabel::Address]);
768 let f = ens.infer("any").unwrap();
769 assert!(
770 f.is_empty(),
771 "两个独立 cluster 各 1 engine → 都丢(dual_confirm 不跨 cluster 共识)"
772 );
773 }
774
775 #[test]
778 fn attribution_default_uses_unknown_idx() {
779 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
781 "person",
782 (0, 10),
783 0.9,
784 5,
785 )]));
786 let ens = EnsembleEngine::new(vec![a]);
787 let (findings, attrs) = ens.infer_with_attribution("any").unwrap();
788 assert_eq!(findings.len(), 1);
789 assert_eq!(attrs.len(), 1);
790 assert_eq!(attrs[0].finding_index, 0);
791 assert_eq!(attrs[0].contributing_engines, vec!["unknown-0".to_string()]);
792 }
793
794 #[test]
795 fn attribution_with_model_ids_returns_real_names() {
796 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
797 "person",
798 (0, 10),
799 0.9,
800 5,
801 )]));
802 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
803 "email",
804 (20, 30),
805 0.9,
806 5,
807 )]));
808 let ens = EnsembleEngine::new(vec![a, b])
809 .with_model_ids(["openai-privacy-filter-v1", "xlmr-pii-v1"]);
810 let (findings, attrs) = ens.infer_with_attribution("any").unwrap();
811 assert_eq!(findings.len(), 2);
812 assert_eq!(attrs.len(), 2);
813 assert_eq!(
815 attrs[0].contributing_engines,
816 vec!["openai-privacy-filter-v1".to_string()]
817 );
818 assert_eq!(
819 attrs[1].contributing_engines,
820 vec!["xlmr-pii-v1".to_string()]
821 );
822 }
823
824 #[test]
825 fn attribution_consensus_lists_all_contributing_engines() {
826 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
829 "person",
830 (0, 6),
831 0.85,
832 5,
833 )]));
834 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
835 "person",
836 (0, 10),
837 0.95,
838 5,
839 )]));
840 let ens = EnsembleEngine::new(vec![a, b])
841 .with_model_ids(["xlmr-pii-v1", "openai-privacy-filter-v1"]);
842 let (findings, attrs) = ens.infer_with_attribution("any").unwrap();
843 assert_eq!(findings.len(), 1, "IoU 0.6 应合 1");
844 assert_eq!(findings[0].span, (0, 10));
845 assert_eq!(attrs.len(), 1);
846 assert_eq!(
848 attrs[0].contributing_engines,
849 vec![
850 "openai-privacy-filter-v1".to_string(),
851 "xlmr-pii-v1".to_string()
852 ]
853 );
854 }
855
856 #[test]
857 #[should_panic(expected = "with_model_ids")]
858 fn attribution_with_mismatched_ids_count_panics() {
859 let a = Arc::new(MockEngine::from_findings(vec![]));
860 let b = Arc::new(MockEngine::from_findings(vec![]));
861 let _ens = EnsembleEngine::new(vec![a, b]).with_model_ids(["only-one"]);
863 }
864
865 #[test]
866 fn attribution_dual_confirm_drops_single_engine_consistent() {
867 let a = Arc::new(MockEngine::from_findings(vec![Finding::model(
869 "private_address",
870 (0, 10),
871 0.9,
872 5,
873 )]));
874 let b = Arc::new(MockEngine::from_findings(vec![]));
875 let ens = EnsembleEngine::new(vec![a, b])
876 .with_model_ids(["openai", "xlmr"])
877 .with_dual_confirm([PrivacyLabel::Address]);
878 let (findings, attrs) = ens.infer_with_attribution("any").unwrap();
879 assert!(findings.is_empty());
880 assert!(
881 attrs.is_empty(),
882 "dual_confirm 丢 finding 时 attribution 也应丢"
883 );
884 }
885
886 #[test]
887 fn attribution_finding_index_aligns_with_findings_array() {
888 let a = Arc::new(MockEngine::from_findings(vec![
890 Finding::model("address", (50, 100), 0.9, 5),
891 Finding::model("person", (0, 10), 0.9, 5),
892 ]));
893 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
894 "email",
895 (20, 40),
896 0.9,
897 10,
898 )]));
899 let ens = EnsembleEngine::new(vec![a, b]).with_model_ids(["e0", "e1"]);
900 let (findings, attrs) = ens.infer_with_attribution("any").unwrap();
901 assert_eq!(findings.len(), 3);
902 assert_eq!(attrs.len(), 3);
903 for (i, a) in attrs.iter().enumerate() {
904 assert_eq!(
905 a.finding_index, i,
906 "finding_index 必须与 findings 数组下标对齐"
907 );
908 }
909 }
910
911 struct LangCapturingTestEngine {
917 captured: std::sync::Mutex<Vec<Option<String>>>,
918 }
919
920 impl LangCapturingTestEngine {
921 fn new() -> Self {
922 Self {
923 captured: std::sync::Mutex::new(Vec::new()),
924 }
925 }
926
927 fn captured(&self) -> Vec<Option<String>> {
928 self.captured.lock().unwrap().clone()
929 }
930 }
931
932 impl RedactionEngine for LangCapturingTestEngine {
933 fn infer(&self, _text: &str) -> Result<Vec<Finding>, EngineError> {
934 self.captured.lock().unwrap().push(None);
937 Ok(Vec::new())
938 }
939
940 fn infer_with_lang(
941 &self,
942 text: &str,
943 lang: Option<&str>,
944 ) -> Result<Vec<Finding>, EngineError> {
945 if lang.is_none() {
946 return self.infer(text);
947 }
948 self.captured.lock().unwrap().push(lang.map(String::from));
949 Ok(Vec::new())
950 }
951 }
952
953 #[test]
957 fn ensemble_infer_with_lang_propagates_lang_to_all_sub_engines() {
958 let a = Arc::new(LangCapturingTestEngine::new());
959 let b = Arc::new(LangCapturingTestEngine::new());
960 let ens = EnsembleEngine::new(vec![a.clone(), b.clone()]);
961
962 let _ = ens.infer_with_lang("any text", Some("de")).unwrap();
963
964 assert_eq!(
966 a.captured(),
967 vec![Some("de".to_string())],
968 "engine a 应收到透传的 lang Some(\"de\");若是 None 表明 ensemble 走 default 委托 infer (bug 根因)"
969 );
970 assert_eq!(
971 b.captured(),
972 vec![Some("de".to_string())],
973 "engine b 应收到透传的 lang Some(\"de\")"
974 );
975 }
976
977 #[test]
980 fn ensemble_infer_legacy_passes_none_lang() {
981 let a = Arc::new(LangCapturingTestEngine::new());
982 let ens = EnsembleEngine::new(vec![a.clone()]);
983 let _ = ens.infer("any").unwrap();
984 assert_eq!(
985 a.captured(),
986 vec![None],
987 "legacy ensemble.infer 应让子 engine 走 lang None(等价 v0.8)"
988 );
989 }
990
991 #[test]
997 fn ensemble_infer_with_attribution_with_lang_propagates_lang() {
998 let a = Arc::new(LangCapturingTestEngine::new());
999 let b = Arc::new(LangCapturingTestEngine::new());
1000 let ens = EnsembleEngine::new(vec![a.clone(), b.clone()]).with_model_ids(["e0", "e1"]);
1001
1002 let _ = ens
1003 .infer_with_attribution_with_lang("any", Some("de"))
1004 .unwrap();
1005 assert_eq!(
1006 a.captured(),
1007 vec![Some("de".to_string())],
1008 "engine a 必须收到 lang Some(\"de\")(P1.3 R1 NICE attribution lang 透传)"
1009 );
1010 assert_eq!(b.captured(), vec![Some("de".to_string())]);
1011 }
1012
1013 #[test]
1016 fn ensemble_infer_with_attribution_legacy_passes_none_lang() {
1017 let a = Arc::new(LangCapturingTestEngine::new());
1018 let ens = EnsembleEngine::new(vec![a.clone()]).with_model_ids(["e0"]);
1019 let _ = ens.infer_with_attribution("any").unwrap();
1020 assert_eq!(
1021 a.captured(),
1022 vec![None],
1023 "legacy infer_with_attribution 应走 lang None(等价 v0.9 baseline)"
1024 );
1025 }
1026
1027 #[test]
1028 fn ensemble_output_sorted_by_span_start() {
1029 let a = Arc::new(MockEngine::from_findings(vec![
1031 Finding::model("address", (50, 100), 0.9, 5),
1032 Finding::model("person", (0, 10), 0.9, 5),
1033 ]));
1034 let b = Arc::new(MockEngine::from_findings(vec![Finding::model(
1035 "email",
1036 (20, 40),
1037 0.9,
1038 10,
1039 )]));
1040 let ens = EnsembleEngine::new(vec![a, b]);
1041 let f = ens.infer("any").unwrap();
1042 assert_eq!(f.len(), 3);
1043 assert_eq!(f[0].span, (0, 10));
1044 assert_eq!(f[1].span, (20, 40));
1045 assert_eq!(f[2].span, (50, 100));
1046 }
1047}