1use std::sync::Arc;
28use std::time::Duration;
29
30use zeph_llm::any::AnyProvider;
31use zeph_llm::provider::LlmProvider as _;
32
33use crate::graph::GraphStore;
34
35#[derive(Debug, Clone)]
39pub struct QualityGateConfig {
40 pub enabled: bool,
42 pub threshold: f32,
44 pub recent_window: usize,
46 pub contradiction_grace_seconds: u64,
49 pub information_value_weight: f32,
51 pub reference_completeness_weight: f32,
53 pub contradiction_weight: f32,
55 pub rejection_rate_alarm_ratio: f32,
58 pub llm_timeout_ms: u64,
60 pub llm_weight: f32,
62 pub reference_check_lang_en: bool,
65}
66
67impl Default for QualityGateConfig {
68 fn default() -> Self {
69 Self {
70 enabled: false,
71 threshold: 0.55,
72 recent_window: 32,
73 contradiction_grace_seconds: 300,
74 information_value_weight: 0.4,
75 reference_completeness_weight: 0.3,
76 contradiction_weight: 0.3,
77 rejection_rate_alarm_ratio: 0.35,
78 llm_timeout_ms: 500,
79 llm_weight: 0.5,
80 reference_check_lang_en: true,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
89pub struct QualityScore {
90 pub information_value: f32,
92 pub reference_completeness: f32,
94 pub contradiction_risk: f32,
98 pub combined: f32,
100 pub final_score: f32,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize)]
106#[serde(rename_all = "snake_case")]
107pub enum QualityRejectionReason {
108 Redundant,
110 IncompleteReference,
112 Contradiction,
114 LlmLowConfidence,
116}
117
118impl QualityRejectionReason {
119 #[must_use]
121 pub fn label(self) -> &'static str {
122 match self {
123 Self::Redundant => "redundant",
124 Self::IncompleteReference => "incomplete_reference",
125 Self::Contradiction => "contradiction",
126 Self::LlmLowConfidence => "llm_low_confidence",
127 }
128 }
129}
130
131struct RollingRateTracker {
133 window: std::collections::VecDeque<bool>,
134 capacity: usize,
135 reject_count: usize,
136}
137
138impl RollingRateTracker {
139 fn new(capacity: usize) -> Self {
140 Self {
141 window: std::collections::VecDeque::with_capacity(capacity + 1),
142 capacity,
143 reject_count: 0,
144 }
145 }
146
147 fn push(&mut self, rejected: bool) {
148 if self.window.len() >= self.capacity
149 && let Some(evicted) = self.window.pop_front()
150 && evicted
151 {
152 self.reject_count = self.reject_count.saturating_sub(1);
153 }
154 self.window.push_back(rejected);
155 if rejected {
156 self.reject_count += 1;
157 }
158 }
159
160 #[allow(clippy::cast_precision_loss)]
161 fn rate(&self) -> f32 {
162 if self.window.is_empty() {
163 return 0.0;
164 }
165 self.reject_count as f32 / self.window.len() as f32
166 }
167}
168
169pub struct QualityGate {
181 config: Arc<QualityGateConfig>,
182 llm_provider: Option<Arc<AnyProvider>>,
184 graph_store: Option<Arc<GraphStore>>,
185 rejection_counts: std::sync::Mutex<std::collections::HashMap<QualityRejectionReason, u64>>,
187 rate_tracker: std::sync::Mutex<RollingRateTracker>,
189}
190
191impl QualityGate {
192 #[must_use]
194 pub fn new(config: QualityGateConfig) -> Self {
195 Self {
196 config: Arc::new(config),
197 llm_provider: None,
198 graph_store: None,
199 rejection_counts: std::sync::Mutex::new(std::collections::HashMap::new()),
200 rate_tracker: std::sync::Mutex::new(RollingRateTracker::new(100)),
201 }
202 }
203
204 #[must_use]
206 pub fn with_llm_provider(mut self, provider: AnyProvider) -> Self {
207 self.llm_provider = Some(Arc::new(provider));
208 self
209 }
210
211 #[must_use]
213 pub fn with_graph_store(mut self, store: Arc<GraphStore>) -> Self {
214 self.graph_store = Some(store);
215 self
216 }
217
218 #[must_use]
220 pub fn config(&self) -> &QualityGateConfig {
221 &self.config
222 }
223
224 #[must_use]
226 pub fn rejection_counts(&self) -> std::collections::HashMap<QualityRejectionReason, u64> {
227 self.rejection_counts
228 .lock()
229 .map(|g| g.clone())
230 .unwrap_or_default()
231 }
232
233 #[tracing::instrument(name = "memory.quality_gate.evaluate", skip_all)]
240 pub async fn evaluate(
241 &self,
242 content: &str,
243 embed_provider: &AnyProvider,
244 recent_embeddings: &[Vec<f32>],
245 ) -> Option<QualityRejectionReason> {
246 if !self.config.enabled {
247 return None;
248 }
249
250 let info_val = compute_information_value(content, embed_provider, recent_embeddings).await;
251 let ref_comp = if self.config.reference_check_lang_en {
252 compute_reference_completeness(content)
253 } else {
254 1.0
255 };
256 let contradiction_risk =
257 compute_contradiction_risk(content, self.graph_store.as_deref(), &self.config).await;
258
259 let w_v = self.config.information_value_weight;
260 let w_c = self.config.reference_completeness_weight;
261 let w_k = self.config.contradiction_weight;
262
263 let rule_score = w_v * info_val + w_c * ref_comp + w_k * (1.0 - contradiction_risk);
264
265 let final_score = if let Some(ref llm) = self.llm_provider {
266 let llm_score = call_llm_scorer(content, llm, self.config.llm_timeout_ms).await;
267 let lw = self.config.llm_weight;
268 (1.0 - lw) * rule_score + lw * llm_score
269 } else {
270 rule_score
271 };
272
273 let rejected = final_score < self.config.threshold;
274
275 if let Ok(mut tracker) = self.rate_tracker.lock() {
277 tracker.push(rejected);
278 let rate = tracker.rate();
279 if rate > self.config.rejection_rate_alarm_ratio {
280 tracing::warn!(
281 rate = %format!("{:.2}", rate),
282 window_size = self.config.recent_window,
283 threshold = self.config.rejection_rate_alarm_ratio,
284 "quality_gate: high rejection rate alarm"
285 );
286 }
287 }
288
289 if !rejected {
290 return None;
291 }
292
293 let reason = if info_val < 0.1 {
295 QualityRejectionReason::Redundant
296 } else if ref_comp < 0.5 && self.config.reference_check_lang_en {
297 QualityRejectionReason::IncompleteReference
298 } else if contradiction_risk >= 1.0 {
299 QualityRejectionReason::Contradiction
300 } else {
301 QualityRejectionReason::LlmLowConfidence
302 };
303
304 if let Ok(mut counts) = self.rejection_counts.lock() {
305 *counts.entry(reason).or_insert(0) += 1;
306 }
307
308 tracing::debug!(
309 reason = reason.label(),
310 final_score,
311 info_val,
312 ref_comp,
313 contradiction_risk,
314 "quality_gate: rejected write"
315 );
316
317 Some(reason)
318 }
319}
320
321async fn compute_information_value(
327 content: &str,
328 provider: &AnyProvider,
329 recent_embeddings: &[Vec<f32>],
330) -> f32 {
331 if recent_embeddings.is_empty() {
332 return 1.0;
333 }
334 if !provider.supports_embeddings() {
335 return 1.0;
336 }
337 let candidate = match provider.embed(content).await {
338 Ok(v) => v,
339 Err(e) => {
340 tracing::debug!(error = %e, "quality_gate: embed failed, treating info_val = 1.0 (fail-open)");
341 return 1.0;
342 }
343 };
344 let max_sim = recent_embeddings
345 .iter()
346 .map(|r| zeph_common::math::cosine_similarity(&candidate, r))
347 .fold(0.0f32, f32::max);
348 (1.0 - max_sim).max(0.0)
349}
350
351#[must_use]
356pub fn compute_reference_completeness(content: &str) -> f32 {
357 const PRONOUNS: &[&str] = &[
359 " he ", " she ", " they ", " it ", " him ", " her ", " them ",
360 ];
361 const DEICTIC_TIME: &[&str] = &[
363 "yesterday",
364 "tomorrow",
365 "last week",
366 "next week",
367 "last month",
368 "next month",
369 "last year",
370 "next year",
371 ];
372 const DATE_ANCHORS: &[&str] = &[
374 "january",
375 "february",
376 "march",
377 "april",
378 "may",
379 "june",
380 "july",
381 "august",
382 "september",
383 "october",
384 "november",
385 "december",
386 "jan ",
387 "feb ",
388 "mar ",
389 "apr ",
390 "jun ",
391 "jul ",
392 "aug ",
393 "sep ",
394 "oct ",
395 "nov ",
396 "dec ",
397 ];
398
399 let lower = content.to_lowercase();
400 let padded = format!(" {lower} ");
401 let pronoun_count = PRONOUNS.iter().filter(|&&p| padded.contains(p)).count();
402
403 let has_year_anchor = has_4digit_year_anchor(&lower);
406 let has_date_anchor = has_year_anchor || DATE_ANCHORS.iter().any(|&a| lower.contains(a));
407 let deictic_count = if has_date_anchor {
408 0
409 } else {
410 DEICTIC_TIME.iter().filter(|&&t| lower.contains(t)).count()
411 };
412
413 let total_issues = pronoun_count + deictic_count;
414 if total_issues == 0 {
415 return 1.0;
416 }
417
418 let word_count = content.split_ascii_whitespace().count().max(1);
420 #[allow(clippy::cast_precision_loss)]
421 let ratio = total_issues as f32 / word_count as f32;
422 (1.0 - ratio * 2.0).clamp(0.0, 1.0)
423}
424
425fn has_4digit_year_anchor(text: &str) -> bool {
430 let bytes = text.as_bytes();
431 let len = bytes.len();
432 if len < 4 {
433 return false;
434 }
435 let mut i = 0usize;
436 while i + 3 < len {
437 let c0 = bytes[i];
438 let c1 = bytes[i + 1];
439 if ((c0 == b'1' && c1 == b'9') || (c0 == b'2' && c1 == b'0'))
440 && bytes[i + 2].is_ascii_digit()
441 && bytes[i + 3].is_ascii_digit()
442 {
443 let left_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
444 let right_ok = i + 4 >= len || !bytes[i + 4].is_ascii_digit();
445 if left_ok && right_ok {
446 return true;
447 }
448 }
449 i += 1;
450 }
451 false
452}
453
454async fn compute_contradiction_risk(
463 content: &str,
464 graph: Option<&GraphStore>,
465 config: &QualityGateConfig,
466) -> f32 {
467 let Some(store) = graph else {
468 return 0.0;
469 };
470
471 let content_lower = content.to_lowercase();
472
473 let subject_query = extract_subject_tokens(&content_lower);
476 if subject_query.is_empty() {
477 return 0.0;
478 }
479
480 let Ok(entities) = store.find_entities_fuzzy(&subject_query, 1).await else {
482 return 0.0;
483 };
484 let Some(subject_entity) = entities.into_iter().next() else {
485 return 0.0;
486 };
487
488 let canonical_predicate = extract_predicate_token(&content_lower);
490
491 let Ok(edges) = store.edges_for_entity(subject_entity.id).await else {
493 return 0.0;
494 };
495
496 let relevant_edges: Vec<_> = edges
498 .iter()
499 .filter(|e| {
500 e.source_entity_id == subject_entity.id
501 && canonical_predicate
502 .as_ref()
503 .is_none_or(|p| e.relation == *p)
504 })
505 .collect();
506
507 if relevant_edges.is_empty() {
508 return 0.0;
509 }
510
511 let now_secs = std::time::SystemTime::now()
512 .duration_since(std::time::UNIX_EPOCH)
513 .map_or(0, |d| d.as_secs());
514
515 let has_old_conflict = relevant_edges.iter().any(|edge| {
516 let edge_ts = chrono::DateTime::parse_from_rfc3339(&edge.created_at)
517 .map_or(0u64, |dt| u64::try_from(dt.timestamp()).unwrap_or(0));
518 now_secs.saturating_sub(edge_ts) > config.contradiction_grace_seconds
519 });
520
521 if has_old_conflict { 1.0 } else { 0.5 }
522}
523
524fn extract_subject_tokens(content_lower: &str) -> String {
526 const VERB_MARKERS: &[&str] = &["is", "was", "are", "were", "has", "have", "had", "will"];
527 let tokens: Vec<&str> = content_lower.split_ascii_whitespace().collect();
528 let end = tokens
529 .iter()
530 .position(|t| VERB_MARKERS.contains(t))
531 .unwrap_or(2.min(tokens.len()));
532 let subject_tokens = &tokens[..end.min(3)];
533 subject_tokens.join(" ")
534}
535
536fn extract_predicate_token(content_lower: &str) -> Option<String> {
538 const VERB_MARKERS: &[&str] = &["is", "was", "are", "were", "has", "have", "had", "will"];
539 content_lower
540 .split_ascii_whitespace()
541 .find(|t| VERB_MARKERS.contains(t))
542 .map(str::to_owned)
543}
544
545async fn call_llm_scorer(content: &str, provider: &AnyProvider, timeout_ms: u64) -> f32 {
549 use zeph_llm::provider::{Message, MessageMetadata, Role};
550
551 let system = "You are a memory quality judge. Rate the quality of the following message \
552 for long-term storage on a scale of 0.0 to 1.0. Consider: information density, \
553 completeness of references, factual clarity. \
554 Respond with ONLY a JSON object: \
555 {\"information_value\": 0.0-1.0, \"reference_completeness\": 0.0-1.0, \
556 \"contradiction_risk\": 0.0-1.0}";
557
558 let user = format!(
559 "Message: {}\n\nQuality JSON:",
560 content.chars().take(500).collect::<String>()
561 );
562
563 let messages = vec![
564 Message {
565 role: Role::System,
566 content: system.to_owned(),
567 parts: vec![],
568 metadata: MessageMetadata::default(),
569 },
570 Message {
571 role: Role::User,
572 content: user,
573 parts: vec![],
574 metadata: MessageMetadata::default(),
575 },
576 ];
577
578 let timeout = Duration::from_millis(timeout_ms);
579 let result = match tokio::time::timeout(timeout, provider.chat(&messages)).await {
580 Ok(Ok(r)) => r,
581 Ok(Err(e)) => {
582 tracing::debug!(error = %e, "quality_gate: LLM scorer failed, using 0.5");
583 return 0.5;
584 }
585 Err(_) => {
586 tracing::debug!("quality_gate: LLM scorer timed out, using 0.5");
587 return 0.5;
588 }
589 };
590
591 parse_llm_score(&result)
592}
593
594fn parse_llm_score(response: &str) -> f32 {
598 let start = response.find('{');
600 let end = response.rfind('}');
601 let (Some(s), Some(e)) = (start, end) else {
602 return 0.5;
603 };
604 let json_str = &response[s..=e];
605 let Ok(val) = serde_json::from_str::<serde_json::Value>(json_str) else {
606 return 0.5;
607 };
608
609 #[allow(clippy::cast_possible_truncation)]
610 let iv = val["information_value"].as_f64().unwrap_or(0.5) as f32;
611 #[allow(clippy::cast_possible_truncation)]
612 let rc = val["reference_completeness"].as_f64().unwrap_or(0.5) as f32;
613 #[allow(clippy::cast_possible_truncation)]
614 let cr = val["contradiction_risk"].as_f64().unwrap_or(0.0) as f32;
615
616 let score =
618 0.4 * iv.clamp(0.0, 1.0) + 0.3 * rc.clamp(0.0, 1.0) + 0.3 * (1.0 - cr.clamp(0.0, 1.0));
619 score.clamp(0.0, 1.0)
620}
621
622#[cfg(test)]
625mod tests {
626 use super::*;
627
628 #[test]
629 fn reference_completeness_clean_text() {
630 let score = compute_reference_completeness("The Rust compiler enforces memory safety.");
631 assert!((score - 1.0).abs() < 0.01, "clean text should score 1.0");
632 }
633
634 #[test]
635 fn reference_completeness_pronoun_heavy() {
636 let score = compute_reference_completeness("yeah he said they confirmed it");
638 assert!(
639 score < 0.5,
640 "pronoun-heavy message should score below 0.5, got {score}"
641 );
642 }
643
644 #[test]
645 fn reference_completeness_deictic_without_anchor() {
646 let score = compute_reference_completeness("We agreed yesterday to postpone");
647 assert!(
648 score < 1.0,
649 "deictic time without anchor should penalize, got {score}"
650 );
651 }
652
653 #[test]
654 fn reference_completeness_deictic_with_anchor() {
655 let score = compute_reference_completeness("We agreed yesterday (2026-04-18) to postpone");
656 assert!(
657 score >= 0.9,
658 "deictic with anchor '20' should not penalize, got {score}"
659 );
660 }
661
662 #[test]
663 fn rejection_reason_labels() {
664 assert_eq!(QualityRejectionReason::Redundant.label(), "redundant");
665 assert_eq!(
666 QualityRejectionReason::IncompleteReference.label(),
667 "incomplete_reference"
668 );
669 assert_eq!(
670 QualityRejectionReason::Contradiction.label(),
671 "contradiction"
672 );
673 assert_eq!(
674 QualityRejectionReason::LlmLowConfidence.label(),
675 "llm_low_confidence"
676 );
677 }
678
679 #[test]
680 fn rolling_rate_tracker_basic() {
681 let mut tracker = RollingRateTracker::new(4);
682 tracker.push(true);
683 tracker.push(true);
684 tracker.push(false);
685 tracker.push(false);
686 let rate = tracker.rate();
687 assert!((rate - 0.5).abs() < 0.01, "rate should be 0.5, got {rate}");
688 }
689
690 #[test]
691 fn rolling_rate_tracker_evicts_oldest() {
692 let mut tracker = RollingRateTracker::new(3);
693 tracker.push(true); tracker.push(false);
695 tracker.push(false);
696 tracker.push(false); let rate = tracker.rate();
698 assert!(
699 rate < 0.01,
700 "evicted rejection should not count, rate={rate}"
701 );
702 }
703
704 #[test]
705 fn parse_llm_score_valid_json() {
706 let json = r#"{"information_value": 0.8, "reference_completeness": 0.9, "contradiction_risk": 0.1}"#;
707 let score = parse_llm_score(json);
708 assert!(
709 score > 0.7,
710 "high-quality JSON should yield high score, got {score}"
711 );
712 }
713
714 #[test]
715 fn parse_llm_score_malformed_returns_neutral() {
716 let score = parse_llm_score("not json");
717 assert!(
718 (score - 0.5).abs() < 0.01,
719 "malformed JSON should return 0.5"
720 );
721 }
722
723 fn mock_provider() -> zeph_llm::any::AnyProvider {
724 zeph_llm::any::AnyProvider::Mock(zeph_llm::mock::MockProvider::default())
725 }
726
727 #[tokio::test]
728 async fn gate_disabled_always_passes() {
729 let config = QualityGateConfig {
730 enabled: false,
731 ..QualityGateConfig::default()
732 };
733 let gate = QualityGate::new(config);
734 let provider = mock_provider();
735
736 let result = gate.evaluate("yeah he confirmed it", &provider, &[]).await;
737 assert!(result.is_none(), "disabled gate must always pass");
738 }
739
740 #[tokio::test]
741 async fn gate_admits_novel_clean_content() {
742 let config = QualityGateConfig {
743 enabled: true,
744 threshold: 0.3, ..QualityGateConfig::default()
746 };
747 let gate = QualityGate::new(config);
748 let provider = mock_provider();
749
750 let result = gate
752 .evaluate(
753 "The Rust compiler enforces memory safety through the borrow checker.",
754 &provider,
755 &[],
756 )
757 .await;
758 assert!(result.is_none(), "clean novel content should be admitted");
759 }
760
761 #[tokio::test]
762 async fn gate_rejects_pronoun_only_at_low_threshold() {
763 let config = QualityGateConfig {
764 enabled: true,
765 threshold: 0.75, reference_completeness_weight: 0.9,
767 information_value_weight: 0.05,
768 contradiction_weight: 0.05,
769 ..QualityGateConfig::default()
770 };
771 let gate = QualityGate::new(config);
772 let provider = mock_provider();
773
774 let result = gate
775 .evaluate("yeah he confirmed it they said so", &provider, &[])
776 .await;
777 assert!(
778 result == Some(QualityRejectionReason::IncompleteReference),
779 "pronoun-heavy message should be rejected as IncompleteReference, got {result:?}"
780 );
781 }
782
783 #[test]
784 fn quality_gate_counts_rejections() {
785 let config = QualityGateConfig {
786 enabled: true,
787 threshold: 0.99, ..QualityGateConfig::default()
789 };
790 let gate = QualityGate::new(config);
791
792 if let Ok(mut counts) = gate.rejection_counts.lock() {
794 *counts.entry(QualityRejectionReason::Redundant).or_insert(0) += 1;
795 }
796
797 let counts = gate.rejection_counts();
798 assert_eq!(counts.get(&QualityRejectionReason::Redundant), Some(&1));
799 }
800
801 #[tokio::test]
803 async fn gate_fail_open_on_embed_error() {
804 let config = QualityGateConfig {
805 enabled: true,
806 threshold: 0.5,
807 ..QualityGateConfig::default()
808 };
809 let gate = QualityGate::new(config);
810
811 let provider = zeph_llm::any::AnyProvider::Mock(
813 zeph_llm::mock::MockProvider::default().with_embed_invalid_input(),
814 );
815
816 let result = gate
817 .evaluate(
818 "Alice confirmed the meeting at 3pm.",
819 &provider,
820 &[], )
822 .await;
823 assert!(
824 result.is_none(),
825 "embed error must be treated as fail-open (admitted), got {result:?}"
826 );
827 }
828
829 #[tokio::test]
831 async fn gate_rejects_redundant_with_populated_embeddings() {
832 let config = QualityGateConfig {
833 enabled: true,
834 threshold: 0.5,
835 information_value_weight: 0.9,
837 reference_completeness_weight: 0.05,
838 contradiction_weight: 0.05,
839 ..QualityGateConfig::default()
840 };
841 let gate = QualityGate::new(config);
842
843 let fixed_embedding = vec![0.1_f32; 384];
845 let provider = zeph_llm::any::AnyProvider::Mock(
846 zeph_llm::mock::MockProvider::default().with_embedding(fixed_embedding.clone()),
847 );
848
849 let result = gate
851 .evaluate(
852 "The Rust compiler enforces memory safety through the borrow checker.",
853 &provider,
854 &[fixed_embedding],
855 )
856 .await;
857 assert_eq!(
858 result,
859 Some(QualityRejectionReason::Redundant),
860 "identical recent embedding must trigger Redundant rejection"
861 );
862 }
863
864 #[tokio::test]
867 async fn gate_llm_timeout_falls_back_to_rule_score() {
868 let config = QualityGateConfig {
869 enabled: true,
870 threshold: 0.3, llm_timeout_ms: 50, llm_weight: 0.5,
873 ..QualityGateConfig::default()
874 };
875 let gate = QualityGate::new(config);
876
877 let slow_provider = zeph_llm::any::AnyProvider::Mock(
879 zeph_llm::mock::MockProvider::default().with_delay(600),
880 );
881 let gate = gate.with_llm_provider(slow_provider);
882
883 let embed_provider = mock_provider(); let result = gate
886 .evaluate(
887 "The release is scheduled for next Friday.",
888 &embed_provider,
889 &[],
890 )
891 .await;
892 assert!(
895 result.is_none(),
896 "LLM timeout must fall back to rule score and admit clean content, got {result:?}"
897 );
898 }
899}