1use std::collections::HashMap;
25use std::sync::Arc;
26
27use async_trait::async_trait;
28
29#[derive(Debug, Clone)]
32pub enum FeatureValue {
33 Float(f64),
34 Int(i64),
35 String(String),
36 Vector(Vec<f32>),
37 Bool(bool),
38 Null,
39}
40
41impl PartialEq for FeatureValue {
47 fn eq(&self, other: &Self) -> bool {
48 match (self, other) {
49 (Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(),
50 (Self::Int(a), Self::Int(b)) => a == b,
51 (Self::String(a), Self::String(b)) => a == b,
52 (Self::Vector(a), Self::Vector(b)) => {
53 a.len() == b.len()
54 && a.iter()
55 .zip(b.iter())
56 .all(|(x, y)| x.to_bits() == y.to_bits())
57 }
58 (Self::Bool(a), Self::Bool(b)) => a == b,
59 (Self::Null, Self::Null) => true,
60 _ => false,
61 }
62 }
63}
64
65impl Eq for FeatureValue {}
66
67impl std::hash::Hash for FeatureValue {
68 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
69 std::mem::discriminant(self).hash(state);
73 match self {
74 Self::Float(f) => f.to_bits().hash(state),
75 Self::Int(i) => i.hash(state),
76 Self::String(s) => s.hash(state),
77 Self::Vector(v) => {
78 v.len().hash(state);
79 for f in v {
80 f.to_bits().hash(state);
81 }
82 }
83 Self::Bool(b) => b.hash(state),
84 Self::Null => {}
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default)]
92pub struct ClassifyInput {
93 pub features: HashMap<String, FeatureValue>,
94}
95
96impl ClassifyInput {
97 pub fn new() -> Self {
98 Self::default()
99 }
100 pub fn with(mut self, name: impl Into<String>, value: FeatureValue) -> Self {
101 self.features.insert(name.into(), value);
102 self
103 }
104
105 pub fn stable_hash(&self) -> u64 {
109 use std::hash::{Hash, Hasher};
110 let mut entries: Vec<(&String, &FeatureValue)> = self.features.iter().collect();
111 entries.sort_by(|a, b| a.0.cmp(b.0));
112 let mut h = std::collections::hash_map::DefaultHasher::new();
113 entries.len().hash(&mut h);
114 for (k, v) in entries {
115 k.hash(&mut h);
116 v.hash(&mut h);
117 }
118 h.finish()
119 }
120}
121
122impl PartialEq for ClassifyInput {
123 fn eq(&self, other: &Self) -> bool {
124 self.features == other.features
125 }
126}
127
128impl Eq for ClassifyInput {}
129
130#[derive(Debug, Clone, PartialEq)]
132pub enum ClassifierError {
133 ArityMismatch { expected: usize, actual: usize },
136 DomainViolation { value: f64 },
138 Provider(String),
140}
141
142impl std::fmt::Display for ClassifierError {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 match self {
145 Self::ArityMismatch { expected, actual } => write!(
146 f,
147 "classifier arity mismatch: expected {expected} outputs, got {actual}"
148 ),
149 Self::DomainViolation { value } => {
150 write!(f, "classifier output {value} outside [0, 1]")
151 }
152 Self::Provider(msg) => write!(f, "classifier provider error: {msg}"),
153 }
154 }
155}
156
157impl std::error::Error for ClassifierError {}
158
159pub type ClassifierResult<T> = std::result::Result<T, ClassifierError>;
160
161#[async_trait]
168pub trait NeuralClassifier: Send + Sync + std::fmt::Debug {
169 async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>>;
172
173 async fn classify_logits(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
177 let probs = self.classify(inputs).await?;
178 Ok(probs.into_iter().map(inverse_sigmoid).collect())
179 }
180
181 fn name(&self) -> &str;
184
185 fn get_calibrator(&self) -> Option<Arc<dyn crate::calibration::Calibrator>> {
191 None
192 }
193
194 async fn raw_and_calibrated(
203 &self,
204 inputs: &[ClassifyInput],
205 ) -> ClassifierResult<Vec<(f64, Option<f64>)>> {
206 let raw = self.classify(inputs).await?;
207 Ok(raw.into_iter().map(|p| (p, None)).collect())
208 }
209}
210
211pub struct MockClassifier {
217 name: String,
218 f: Arc<dyn Fn(&ClassifyInput) -> f64 + Send + Sync>,
219}
220
221impl MockClassifier {
222 pub fn new<F>(name: impl Into<String>, f: F) -> Self
223 where
224 F: Fn(&ClassifyInput) -> f64 + Send + Sync + 'static,
225 {
226 Self {
227 name: name.into(),
228 f: Arc::new(f),
229 }
230 }
231
232 pub fn constant(name: impl Into<String>, value: f64) -> Self {
235 let v = value.clamp(0.0, 1.0);
236 Self::new(name, move |_| v)
237 }
238}
239
240impl std::fmt::Debug for MockClassifier {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 f.debug_struct("MockClassifier")
243 .field("name", &self.name)
244 .finish_non_exhaustive()
245 }
246}
247
248#[async_trait]
249impl NeuralClassifier for MockClassifier {
250 async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
251 let mut out = Vec::with_capacity(inputs.len());
252 for inp in inputs {
253 let v = (self.f)(inp);
254 if v.is_nan() {
255 return Err(ClassifierError::DomainViolation { value: v });
256 }
257 out.push(v.clamp(0.0, 1.0));
258 }
259 Ok(out)
260 }
261
262 fn name(&self) -> &str {
263 &self.name
264 }
265}
266
267pub struct CalibratedClassifier {
285 name: String,
286 base: Arc<dyn NeuralClassifier>,
287 calibrator: Arc<dyn crate::calibration::Calibrator>,
288}
289
290impl CalibratedClassifier {
291 pub fn new(
292 name: impl Into<String>,
293 base: Arc<dyn NeuralClassifier>,
294 calibrator: Arc<dyn crate::calibration::Calibrator>,
295 ) -> Self {
296 Self {
297 name: name.into(),
298 base,
299 calibrator,
300 }
301 }
302}
303
304impl std::fmt::Debug for CalibratedClassifier {
305 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306 f.debug_struct("CalibratedClassifier")
307 .field("name", &self.name)
308 .field("base", &self.base.name())
309 .field("method", &self.calibrator.method())
310 .finish_non_exhaustive()
311 }
312}
313
314#[async_trait]
315impl NeuralClassifier for CalibratedClassifier {
316 async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
317 let raw = self.base.classify(inputs).await?;
318 Ok(self.calibrator.apply_batch(&raw))
319 }
320
321 fn name(&self) -> &str {
322 &self.name
323 }
324
325 fn get_calibrator(&self) -> Option<Arc<dyn crate::calibration::Calibrator>> {
326 Some(Arc::clone(&self.calibrator))
327 }
328
329 async fn raw_and_calibrated(
330 &self,
331 inputs: &[ClassifyInput],
332 ) -> ClassifierResult<Vec<(f64, Option<f64>)>> {
333 let raw = self.base.classify(inputs).await?;
334 let calibrated = self.calibrator.apply_batch(&raw);
335 Ok(raw
336 .into_iter()
337 .zip(calibrated)
338 .map(|(r, c)| (r, Some(c)))
339 .collect())
340 }
341}
342
343#[derive(Debug)]
360struct KeyedStore<V> {
361 inner: std::sync::RwLock<HashMap<(String, u64), V>>,
362}
363
364impl<V> Default for KeyedStore<V> {
365 fn default() -> Self {
366 Self {
367 inner: std::sync::RwLock::new(HashMap::new()),
368 }
369 }
370}
371
372impl<V: Clone> KeyedStore<V> {
373 fn get(&self, model: &str, input_hash: u64) -> Option<V> {
374 self.inner
375 .read()
376 .ok()
377 .and_then(|g| g.get(&(model.to_string(), input_hash)).cloned())
378 }
379}
380
381impl<V> KeyedStore<V> {
382 fn insert(&self, model: &str, input_hash: u64, value: V) {
383 if let Ok(mut g) = self.inner.write() {
384 g.insert((model.to_string(), input_hash), value);
385 }
386 }
387
388 fn insert_bounded(&self, model: &str, input_hash: u64, value: V, max_entries: usize) {
393 if let Ok(mut g) = self.inner.write() {
394 if max_entries > 0 && g.len() >= max_entries {
395 g.clear();
396 }
397 g.insert((model.to_string(), input_hash), value);
398 }
399 }
400
401 fn clear(&self) {
402 if let Ok(mut g) = self.inner.write() {
403 g.clear();
404 }
405 }
406
407 fn len(&self) -> usize {
408 self.inner.read().map(|g| g.len()).unwrap_or(0)
409 }
410}
411
412#[derive(Debug, Default)]
413pub struct NeuralProvenanceStore {
414 inner: KeyedStore<NeuralProvenanceRecord>,
415}
416
417#[derive(Debug, Clone)]
430pub struct NeuralProvenanceRecord {
431 pub raw_probability: f64,
432 pub calibrated_probability: Option<f64>,
433 pub confidence_band: Option<crate::result::ConfidenceBand>,
434 pub feature_inputs: HashMap<String, FeatureValue>,
435}
436
437impl NeuralProvenanceStore {
438 pub fn new() -> Self {
439 Self::default()
440 }
441
442 pub fn record(&self, model: &str, input_hash: u64, record: NeuralProvenanceRecord) {
443 self.inner.insert(model, input_hash, record);
444 }
445
446 pub fn get(&self, model: &str, input_hash: u64) -> Option<NeuralProvenanceRecord> {
447 self.inner.get(model, input_hash)
448 }
449
450 pub fn clear(&self) {
451 self.inner.clear();
452 }
453
454 pub fn len(&self) -> usize {
455 self.inner.len()
456 }
457
458 pub fn is_empty(&self) -> bool {
459 self.len() == 0
460 }
461}
462
463#[derive(Debug, Default)]
474pub struct ModelInvocationCache {
475 inner: KeyedStore<f64>,
476 max_entries: usize,
477}
478
479impl ModelInvocationCache {
480 pub fn new(max_entries: usize) -> Self {
481 Self {
482 inner: KeyedStore::default(),
483 max_entries,
484 }
485 }
486
487 pub fn get(&self, model: &str, input_hash: u64) -> Option<f64> {
489 self.inner.get(model, input_hash)
490 }
491
492 pub fn insert(&self, model: &str, input_hash: u64, value: f64) {
496 self.inner
497 .insert_bounded(model, input_hash, value, self.max_entries);
498 }
499
500 pub fn clear(&self) {
503 self.inner.clear();
504 }
505
506 pub fn len(&self) -> usize {
507 self.inner.len()
508 }
509
510 pub fn is_empty(&self) -> bool {
511 self.len() == 0
512 }
513}
514
515pub struct CandleLinearClassifier {
540 name: String,
541 feature_order: Vec<String>,
543 weight: Vec<f32>,
545 bias: f32,
547 device: candle_core::Device,
549}
550
551impl std::fmt::Debug for CandleLinearClassifier {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 f.debug_struct("CandleLinearClassifier")
554 .field("name", &self.name)
555 .field("feature_order", &self.feature_order)
556 .field("n_features", &self.weight.len())
557 .finish_non_exhaustive()
558 }
559}
560
561impl CandleLinearClassifier {
562 pub fn load(
569 name: impl Into<String>,
570 feature_order: Vec<String>,
571 weights_path: impl AsRef<std::path::Path>,
572 ) -> ClassifierResult<Self> {
573 let device = candle_core::Device::Cpu;
574 let path = weights_path.as_ref();
575 let tensors = candle_core::safetensors::load(path, &device).map_err(|e| {
576 ClassifierError::Provider(format!(
577 "candle: failed to load safetensors from {path:?}: {e}"
578 ))
579 })?;
580 let weight_t = tensors.get("weight").ok_or_else(|| {
581 ClassifierError::Provider("candle: safetensors missing 'weight' tensor".to_string())
582 })?;
583 let bias_t = tensors.get("bias").ok_or_else(|| {
584 ClassifierError::Provider("candle: safetensors missing 'bias' tensor".to_string())
585 })?;
586 let weight: Vec<f32> = weight_t
587 .flatten_all()
588 .and_then(|t| t.to_vec1::<f32>())
589 .map_err(|e| ClassifierError::Provider(format!("candle: weight read: {e}")))?;
590 let bias_vec: Vec<f32> = bias_t
591 .flatten_all()
592 .and_then(|t| t.to_vec1::<f32>())
593 .map_err(|e| ClassifierError::Provider(format!("candle: bias read: {e}")))?;
594 if bias_vec.len() != 1 {
595 return Err(ClassifierError::Provider(format!(
596 "candle: 'bias' must be scalar (shape [1]); got len={}",
597 bias_vec.len()
598 )));
599 }
600 if weight.len() != feature_order.len() {
601 return Err(ClassifierError::Provider(format!(
602 "candle: weight length {} != feature_order length {}",
603 weight.len(),
604 feature_order.len()
605 )));
606 }
607 Ok(Self {
608 name: name.into(),
609 feature_order,
610 weight,
611 bias: bias_vec[0],
612 device,
613 })
614 }
615
616 fn encode_feature(&self, v: Option<&FeatureValue>) -> f32 {
619 match v {
620 Some(FeatureValue::Float(f)) => *f as f32,
621 Some(FeatureValue::Int(i)) => *i as f32,
622 Some(FeatureValue::Bool(b)) => f32::from(*b),
623 Some(FeatureValue::String(s)) => {
624 let mut h: u32 = 5381;
626 for byte in s.as_bytes() {
627 h = h.wrapping_mul(33).wrapping_add(*byte as u32);
628 }
629 (h as i32) as f32 / i32::MAX as f32
630 }
631 Some(FeatureValue::Null) | None => 0.0,
632 _ => 0.0,
633 }
634 }
635}
636
637#[async_trait]
638impl NeuralClassifier for CandleLinearClassifier {
639 async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
640 if inputs.is_empty() {
641 return Ok(Vec::new());
642 }
643 let n_features = self.weight.len();
644 let mut data: Vec<f32> = Vec::with_capacity(inputs.len() * n_features);
646 for inp in inputs {
647 for fname in &self.feature_order {
648 data.push(self.encode_feature(inp.features.get(fname)));
649 }
650 }
651 let x = candle_core::Tensor::from_vec(data, (inputs.len(), n_features), &self.device)
652 .map_err(|e| ClassifierError::Provider(format!("candle: input tensor: {e}")))?;
653 let w = candle_core::Tensor::from_slice(&self.weight, (n_features, 1), &self.device)
654 .map_err(|e| ClassifierError::Provider(format!("candle: weight tensor: {e}")))?;
655 let logits = x
656 .matmul(&w)
657 .and_then(|t| t.broadcast_add(&candle_core::Tensor::new(&[self.bias], &self.device)?))
658 .map_err(|e| ClassifierError::Provider(format!("candle: forward pass: {e}")))?;
659 let probs = candle_nn::ops::sigmoid(&logits)
661 .and_then(|t| t.flatten_all())
662 .and_then(|t| t.to_vec1::<f32>())
663 .map_err(|e| ClassifierError::Provider(format!("candle: sigmoid: {e}")))?;
664 Ok(probs.into_iter().map(|p| p as f64).collect())
665 }
666
667 fn name(&self) -> &str {
668 &self.name
669 }
670}
671
672fn inverse_sigmoid(p: f64) -> f64 {
676 let p = p.clamp(0.0, 1.0);
677 if p == 0.0 {
678 f64::NEG_INFINITY
679 } else if p == 1.0 {
680 f64::INFINITY
681 } else {
682 (p / (1.0 - p)).ln()
683 }
684}
685
686#[cfg(test)]
687mod tests {
688 use super::*;
689
690 #[tokio::test]
691 async fn mock_constant_returns_value_per_row() {
692 let sr = MockClassifier::constant("classify/test", 0.7);
693 let inputs = vec![
694 ClassifyInput::new().with("x", FeatureValue::Float(1.0)),
695 ClassifyInput::new().with("x", FeatureValue::Float(2.0)),
696 ClassifyInput::new().with("x", FeatureValue::Float(3.0)),
697 ];
698 let out = sr.classify(&inputs).await.unwrap();
699 assert_eq!(out, vec![0.7, 0.7, 0.7]);
700 assert_eq!(out.len(), inputs.len());
701 assert_eq!(sr.name(), "classify/test");
702 }
703
704 #[tokio::test]
705 async fn mock_feature_driven() {
706 let sr = MockClassifier::new("classify/feature", |inp| {
707 match inp.features.get("severity") {
708 Some(FeatureValue::Float(v)) => (*v / 10.0).clamp(0.0, 1.0),
709 _ => 0.0,
710 }
711 });
712 let inputs = vec![
713 ClassifyInput::new().with("severity", FeatureValue::Float(2.0)),
714 ClassifyInput::new().with("severity", FeatureValue::Float(9.0)),
715 ClassifyInput::new().with("severity", FeatureValue::Float(15.0)), ];
717 let out = sr.classify(&inputs).await.unwrap();
718 assert_eq!(out, vec![0.2, 0.9, 1.0]);
719 }
720
721 #[tokio::test]
722 async fn classify_logits_default_inverse_sigmoid() {
723 let sr = MockClassifier::constant("classify/test", 0.5);
724 let out = sr.classify_logits(&[ClassifyInput::new()]).await.unwrap();
725 assert!((out[0] - 0.0).abs() < 1e-12);
727 }
728
729 #[tokio::test]
730 async fn mock_rejects_nan() {
731 let sr = MockClassifier::new("classify/nan", |_| f64::NAN);
732 let err = sr.classify(&[ClassifyInput::new()]).await.unwrap_err();
733 assert!(matches!(err, ClassifierError::DomainViolation { .. }));
734 }
735
736 #[test]
737 fn feature_value_hash_distinguishes_variants() {
738 fn h(v: FeatureValue) -> u64 {
741 use std::hash::{Hash, Hasher};
742 let mut hasher = std::collections::hash_map::DefaultHasher::new();
743 v.hash(&mut hasher);
744 hasher.finish()
745 }
746 assert_ne!(h(FeatureValue::Float(0.0)), h(FeatureValue::Int(0)));
747 assert_ne!(h(FeatureValue::Null), h(FeatureValue::Bool(false)));
748 assert_eq!(h(FeatureValue::Float(0.5)), h(FeatureValue::Float(0.5)));
750 }
751
752 #[test]
753 fn classify_input_hash_order_independent() {
754 let a = ClassifyInput::new()
757 .with("country", FeatureValue::String("US".into()))
758 .with("revenue", FeatureValue::Float(1.0e6));
759 let b = ClassifyInput::new()
760 .with("revenue", FeatureValue::Float(1.0e6))
761 .with("country", FeatureValue::String("US".into()));
762 assert_eq!(a.stable_hash(), b.stable_hash());
763 let c = ClassifyInput::new()
764 .with("country", FeatureValue::String("DE".into()))
765 .with("revenue", FeatureValue::Float(1.0e6));
766 assert_ne!(a.stable_hash(), c.stable_hash());
767 }
768
769 #[test]
770 fn feature_value_vector_hash() {
771 fn h(v: FeatureValue) -> u64 {
772 use std::hash::{Hash, Hasher};
773 let mut hasher = std::collections::hash_map::DefaultHasher::new();
774 v.hash(&mut hasher);
775 hasher.finish()
776 }
777 let a = FeatureValue::Vector(vec![1.0, 2.0, 3.0]);
778 let b = FeatureValue::Vector(vec![1.0, 2.0, 3.0]);
779 let c = FeatureValue::Vector(vec![1.0, 2.0, 3.5]);
780 assert_eq!(h(a.clone()), h(b));
781 assert_ne!(h(a), h(c));
782 }
783
784 #[test]
785 fn model_invocation_cache_hit_miss() {
786 let cache = ModelInvocationCache::new(100);
787 assert!(cache.get("m", 42).is_none());
788 cache.insert("m", 42, 0.7);
789 assert_eq!(cache.get("m", 42), Some(0.7));
790 assert!(cache.get("other", 42).is_none());
792 assert!(cache.get("m", 43).is_none());
794 }
795
796 #[test]
797 fn model_invocation_cache_evicts_on_overflow() {
798 let cache = ModelInvocationCache::new(2);
799 cache.insert("m", 1, 0.1);
800 cache.insert("m", 2, 0.2);
801 assert_eq!(cache.len(), 2);
802 cache.insert("m", 3, 0.3);
804 assert_eq!(cache.len(), 1);
805 assert_eq!(cache.get("m", 3), Some(0.3));
806 }
807
808 #[test]
809 fn inverse_sigmoid_endpoints() {
810 assert!(inverse_sigmoid(0.0).is_infinite() && inverse_sigmoid(0.0) < 0.0);
811 assert!(inverse_sigmoid(1.0).is_infinite() && inverse_sigmoid(1.0) > 0.0);
812 assert!((inverse_sigmoid(0.5) - 0.0).abs() < 1e-12);
813 }
814}