Skip to main content

uni_locy/
neural.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Neural classifier abstraction for Locy `CREATE MODEL` (Phase B).
5//!
6//! [`NeuralClassifier`] is the row-at-a-time surface that
7//! `LocyModelInvoke` (Phase B Slice 3) will drive. This module ships the
8//! trait and a deterministic [`MockClassifier`] used by unit tests and
9//! TCK; a Xervo-backed implementation lives behind a separate adapter PR
10//! (uni-xervo is an external crate so we can't extend `ModelTask::*`
11//! directly).
12//!
13//! ### Scope (Slice 1+2)
14//!
15//! The trait exposes `classify` returning probabilities in `[0, 1]` and
16//! an optional `classify_logits` for calibration paths (Phase C). The
17//! default `classify_logits` derives logits from probabilities via
18//! inverse-sigmoid so providers that only emit probabilities work out of
19//! the box.
20//!
21//! Phase B Slice 3 will wire `LocyModelInvoke` to call `classify` once per
22//! batch per `(model, feature-hash)` group with memoization.
23
24use std::collections::HashMap;
25use std::sync::Arc;
26
27use async_trait::async_trait;
28
29/// A feature value passed to a neural classifier. Mirrors the value types
30/// the property graph emits; `Vector` carries embedding inputs.
31#[derive(Debug, Clone)]
32pub enum FeatureValue {
33    Float(f64),
34    Int(i64),
35    String(String),
36    Vector(Vec<f32>),
37    Bool(bool),
38    Null,
39}
40
41// Phase B Slice 1 (post-Slice-3 follow-up): `Eq` + `Hash` so
42// `FeatureValue` can be used as a cache key. Float bit-comparison
43// is intentional — NaN-bit-equal is fine for an internal cache
44// (a single classifier invocation will reproduce the same NaN bit
45// pattern). `PartialEq` mirrors the bit-comparison so `Eq` is sound.
46impl PartialEq for FeatureValue {
47    fn eq(&self, other: &Self) -> bool {
48        match (self, other) {
49            (Self::Float(a), Self::Float(b)) => a.to_bits() == b.to_bits(),
50            (Self::Int(a), Self::Int(b)) => a == b,
51            (Self::String(a), Self::String(b)) => a == b,
52            (Self::Vector(a), Self::Vector(b)) => {
53                a.len() == b.len()
54                    && a.iter()
55                        .zip(b.iter())
56                        .all(|(x, y)| x.to_bits() == y.to_bits())
57            }
58            (Self::Bool(a), Self::Bool(b)) => a == b,
59            (Self::Null, Self::Null) => true,
60            _ => false,
61        }
62    }
63}
64
65impl Eq for FeatureValue {}
66
67impl std::hash::Hash for FeatureValue {
68    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
69        // Discriminant first, so e.g. `Float(0.0)` and `Int(0)` hash
70        // distinctly. `Hash::hash_slice` is used for vector elements to
71        // avoid an allocation.
72        std::mem::discriminant(self).hash(state);
73        match self {
74            Self::Float(f) => f.to_bits().hash(state),
75            Self::Int(i) => i.hash(state),
76            Self::String(s) => s.hash(state),
77            Self::Vector(v) => {
78                v.len().hash(state);
79                for f in v {
80                    f.to_bits().hash(state);
81                }
82            }
83            Self::Bool(b) => b.hash(state),
84            Self::Null => {}
85        }
86    }
87}
88
89/// One row of input to a classifier. Field names match the `FEATURES`
90/// clause identifiers from the `CREATE MODEL` declaration.
91#[derive(Debug, Clone, Default)]
92pub struct ClassifyInput {
93    pub features: HashMap<String, FeatureValue>,
94}
95
96impl ClassifyInput {
97    pub fn new() -> Self {
98        Self::default()
99    }
100    pub fn with(mut self, name: impl Into<String>, value: FeatureValue) -> Self {
101        self.features.insert(name.into(), value);
102        self
103    }
104
105    /// Order-independent stable hash used as a memoization key.
106    /// `HashMap` iteration order is non-deterministic; we collect
107    /// to a sorted Vec by feature name before hashing.
108    pub fn stable_hash(&self) -> u64 {
109        use std::hash::{Hash, Hasher};
110        let mut entries: Vec<(&String, &FeatureValue)> = self.features.iter().collect();
111        entries.sort_by(|a, b| a.0.cmp(b.0));
112        let mut h = std::collections::hash_map::DefaultHasher::new();
113        entries.len().hash(&mut h);
114        for (k, v) in entries {
115            k.hash(&mut h);
116            v.hash(&mut h);
117        }
118        h.finish()
119    }
120}
121
122impl PartialEq for ClassifyInput {
123    fn eq(&self, other: &Self) -> bool {
124        self.features == other.features
125    }
126}
127
128impl Eq for ClassifyInput {}
129
130/// Errors raised by a [`NeuralClassifier`] impl.
131#[derive(Debug, Clone, PartialEq)]
132pub enum ClassifierError {
133    /// Input length didn't match output length, or the provider returned
134    /// a malformed batch.
135    ArityMismatch { expected: usize, actual: usize },
136    /// Output value fell outside `[0, 1]` after calibration.
137    DomainViolation { value: f64 },
138    /// Upstream provider error.
139    Provider(String),
140}
141
142impl std::fmt::Display for ClassifierError {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        match self {
145            Self::ArityMismatch { expected, actual } => write!(
146                f,
147                "classifier arity mismatch: expected {expected} outputs, got {actual}"
148            ),
149            Self::DomainViolation { value } => {
150                write!(f, "classifier output {value} outside [0, 1]")
151            }
152            Self::Provider(msg) => write!(f, "classifier provider error: {msg}"),
153        }
154    }
155}
156
157impl std::error::Error for ClassifierError {}
158
159pub type ClassifierResult<T> = std::result::Result<T, ClassifierError>;
160
161/// Row-at-a-time neural classifier.
162///
163/// A provider returns one probability per input row. Logits are optional
164/// and used by Phase C calibration paths (Platt scaling, temperature
165/// scaling) that operate on pre-sigmoid scores; the default impl derives
166/// them from probabilities.
167#[async_trait]
168pub trait NeuralClassifier: Send + Sync + std::fmt::Debug {
169    /// Return probabilities in `[0, 1]`. `output.len() == inputs.len()`
170    /// is the trait contract — implementers MUST enforce it.
171    async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>>;
172
173    /// Return pre-sigmoid logits. The default implementation calls
174    /// [`NeuralClassifier::classify`] and applies inverse-sigmoid; bespoke providers
175    /// (Candle, MistralRS) override to expose raw logits cheaply.
176    async fn classify_logits(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
177        let probs = self.classify(inputs).await?;
178        Ok(probs.into_iter().map(inverse_sigmoid).collect())
179    }
180
181    /// Provider identifier for EXPLAIN / telemetry. Should match the
182    /// `xervo_alias` from `CREATE MODEL`.
183    fn name(&self) -> &str;
184
185    /// Phase C B1–B3 follow-up: introspect a wrapped Calibrator
186    /// when this classifier composes one (e.g.,
187    /// [`CalibratedClassifier`]). Default `None` — bare classifiers
188    /// don't expose a calibrator. EXPLAIN uses this to surface the
189    /// active calibrator's `confidence_band(p)` on derivations.
190    fn get_calibrator(&self) -> Option<Arc<dyn crate::calibration::Calibrator>> {
191        None
192    }
193
194    /// Phase C B1–B3 follow-up: return `(raw, Some(calibrated))`
195    /// per input when this classifier wraps a Calibrator, or
196    /// `(raw, None)` otherwise. The runtime writes both into the
197    /// per-query [`NeuralProvenanceStore`] so EXPLAIN can show
198    /// pre- and post-calibrator values side-by-side. The default
199    /// impl delegates to `classify` and reports `None` for the
200    /// calibrated half (the raw output IS whatever the classifier
201    /// emits — no introspection without an override).
202    async fn raw_and_calibrated(
203        &self,
204        inputs: &[ClassifyInput],
205    ) -> ClassifierResult<Vec<(f64, Option<f64>)>> {
206        let raw = self.classify(inputs).await?;
207        Ok(raw.into_iter().map(|p| (p, None)).collect())
208    }
209}
210
211/// Deterministic mock classifier for tests and TCK scenarios.
212///
213/// Holds a closure `Fn(&ClassifyInput) -> f64` so each scenario can
214/// configure the mapping it wants. Output clamps to `[0, 1]` and emits
215/// a `DomainViolation` error when the closure returns NaN.
216pub struct MockClassifier {
217    name: String,
218    f: Arc<dyn Fn(&ClassifyInput) -> f64 + Send + Sync>,
219}
220
221impl MockClassifier {
222    pub fn new<F>(name: impl Into<String>, f: F) -> Self
223    where
224        F: Fn(&ClassifyInput) -> f64 + Send + Sync + 'static,
225    {
226        Self {
227            name: name.into(),
228            f: Arc::new(f),
229        }
230    }
231
232    /// Construct a constant-output classifier, the canonical "always
233    /// returns 0.7" test fixture.
234    pub fn constant(name: impl Into<String>, value: f64) -> Self {
235        let v = value.clamp(0.0, 1.0);
236        Self::new(name, move |_| v)
237    }
238}
239
240impl std::fmt::Debug for MockClassifier {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        f.debug_struct("MockClassifier")
243            .field("name", &self.name)
244            .finish_non_exhaustive()
245    }
246}
247
248#[async_trait]
249impl NeuralClassifier for MockClassifier {
250    async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
251        let mut out = Vec::with_capacity(inputs.len());
252        for inp in inputs {
253            let v = (self.f)(inp);
254            if v.is_nan() {
255                return Err(ClassifierError::DomainViolation { value: v });
256            }
257            out.push(v.clamp(0.0, 1.0));
258        }
259        Ok(out)
260    }
261
262    fn name(&self) -> &str {
263        &self.name
264    }
265}
266
267/// Adapter wrapping a base classifier with a fitted [`crate::calibration::Calibrator`]
268/// (Phase C C2). After running `CALIBRATE`, users construct one of
269/// these and re-register it under the same model name to make
270/// subsequent invocations produce calibrated probabilities.
271///
272/// ```ignore
273/// let result = locy_result.command_results().iter().find_map(|(_, r)| match r {
274///     CommandResult::Calibrate(c) => Some(c.clone()),
275///     _ => None,
276/// }).unwrap();
277/// let wrapped = CalibratedClassifier::new(
278///     "scorer",
279///     Arc::clone(&base_classifier),
280///     Arc::clone(&result.calibrator),
281/// );
282/// config.classifier_registry.insert("scorer".into(), Arc::new(wrapped));
283/// ```
284pub struct CalibratedClassifier {
285    name: String,
286    base: Arc<dyn NeuralClassifier>,
287    calibrator: Arc<dyn crate::calibration::Calibrator>,
288}
289
290impl CalibratedClassifier {
291    pub fn new(
292        name: impl Into<String>,
293        base: Arc<dyn NeuralClassifier>,
294        calibrator: Arc<dyn crate::calibration::Calibrator>,
295    ) -> Self {
296        Self {
297            name: name.into(),
298            base,
299            calibrator,
300        }
301    }
302}
303
304impl std::fmt::Debug for CalibratedClassifier {
305    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
306        f.debug_struct("CalibratedClassifier")
307            .field("name", &self.name)
308            .field("base", &self.base.name())
309            .field("method", &self.calibrator.method())
310            .finish_non_exhaustive()
311    }
312}
313
314#[async_trait]
315impl NeuralClassifier for CalibratedClassifier {
316    async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
317        let raw = self.base.classify(inputs).await?;
318        Ok(self.calibrator.apply_batch(&raw))
319    }
320
321    fn name(&self) -> &str {
322        &self.name
323    }
324
325    fn get_calibrator(&self) -> Option<Arc<dyn crate::calibration::Calibrator>> {
326        Some(Arc::clone(&self.calibrator))
327    }
328
329    async fn raw_and_calibrated(
330        &self,
331        inputs: &[ClassifyInput],
332    ) -> ClassifierResult<Vec<(f64, Option<f64>)>> {
333        let raw = self.base.classify(inputs).await?;
334        let calibrated = self.calibrator.apply_batch(&raw);
335        Ok(raw
336            .into_iter()
337            .zip(calibrated)
338            .map(|(r, c)| (r, Some(c)))
339            .collect())
340    }
341}
342
343/// Phase C B1–B3 follow-up: per-query side-channel store
344/// recording the raw / calibrated / confidence-band tuple for
345/// every classifier invocation. EXPLAIN reads from this store
346/// when building `NeuralProvenance` entries on each
347/// `DerivationNode`.
348///
349/// Keyed by `(model_name, ClassifyInput::stable_hash)` — the same
350/// shape as [`ModelInvocationCache`], so rows with identical
351/// feature values share one record (consistent with the existing
352/// memoization semantics; the classifier output for a given input
353/// is deterministic).
354/// Shared `RwLock<HashMap<(model, input_hash), V>>` backing the two
355/// per-query side-channel stores ([`NeuralProvenanceStore`] and
356/// [`ModelInvocationCache`]). Centralizes the lock-poison-tolerant
357/// accessors so each store only adds its domain-specific surface
358/// (eviction policy, record shape) on top.
359#[derive(Debug)]
360struct KeyedStore<V> {
361    inner: std::sync::RwLock<HashMap<(String, u64), V>>,
362}
363
364impl<V> Default for KeyedStore<V> {
365    fn default() -> Self {
366        Self {
367            inner: std::sync::RwLock::new(HashMap::new()),
368        }
369    }
370}
371
372impl<V: Clone> KeyedStore<V> {
373    fn get(&self, model: &str, input_hash: u64) -> Option<V> {
374        self.inner
375            .read()
376            .ok()
377            .and_then(|g| g.get(&(model.to_string(), input_hash)).cloned())
378    }
379}
380
381impl<V> KeyedStore<V> {
382    fn insert(&self, model: &str, input_hash: u64, value: V) {
383        if let Ok(mut g) = self.inner.write() {
384            g.insert((model.to_string(), input_hash), value);
385        }
386    }
387
388    /// Insert, but first drop the whole map when it has reached
389    /// `max_entries` (`max_entries == 0` disables the bound). The
390    /// size check and insert happen under one write lock so the bound
391    /// holds even under concurrent inserts.
392    fn insert_bounded(&self, model: &str, input_hash: u64, value: V, max_entries: usize) {
393        if let Ok(mut g) = self.inner.write() {
394            if max_entries > 0 && g.len() >= max_entries {
395                g.clear();
396            }
397            g.insert((model.to_string(), input_hash), value);
398        }
399    }
400
401    fn clear(&self) {
402        if let Ok(mut g) = self.inner.write() {
403            g.clear();
404        }
405    }
406
407    fn len(&self) -> usize {
408        self.inner.read().map(|g| g.len()).unwrap_or(0)
409    }
410}
411
412#[derive(Debug, Default)]
413pub struct NeuralProvenanceStore {
414    inner: KeyedStore<NeuralProvenanceRecord>,
415}
416
417/// A single stored record. Matches the user-visible
418/// [`crate::NeuralProvenance`] shape so EXPLAIN can construct
419/// derivation entries without further transformation.
420///
421/// `feature_inputs` (Phase 12 EXPLAIN follow-up) carries the
422/// per-binding `FeatureValue` map that fed the classifier on the
423/// hot path. EXPLAIN Mode B reads from this map (when available)
424/// instead of re-evaluating feature expressions against the
425/// fact_row — which is the only way to surface authoritative values
426/// for graph-structural FEATURE functions (`degree_centrality`,
427/// `avg_neighbor`, etc.) whose evaluation requires the
428/// `GraphAlgoHandle` that isn't threaded into the EXPLAIN path.
429#[derive(Debug, Clone)]
430pub struct NeuralProvenanceRecord {
431    pub raw_probability: f64,
432    pub calibrated_probability: Option<f64>,
433    pub confidence_band: Option<crate::result::ConfidenceBand>,
434    pub feature_inputs: HashMap<String, FeatureValue>,
435}
436
437impl NeuralProvenanceStore {
438    pub fn new() -> Self {
439        Self::default()
440    }
441
442    pub fn record(&self, model: &str, input_hash: u64, record: NeuralProvenanceRecord) {
443        self.inner.insert(model, input_hash, record);
444    }
445
446    pub fn get(&self, model: &str, input_hash: u64) -> Option<NeuralProvenanceRecord> {
447        self.inner.get(model, input_hash)
448    }
449
450    pub fn clear(&self) {
451        self.inner.clear();
452    }
453
454    pub fn len(&self) -> usize {
455        self.inner.len()
456    }
457
458    pub fn is_empty(&self) -> bool {
459        self.len() == 0
460    }
461}
462
463/// Memoization cache for neural classifier outputs across a single
464/// query evaluation. Per impl plan §1.4 decision D-4: cache is scoped
465/// per-query (cleared at the start of evaluation); the key is
466/// `(model_name, ClassifyInput::stable_hash)`.
467///
468/// Eviction policy: a naive "clear-when-full" heuristic — when the
469/// cache reaches `max_entries`, the entire map is dropped. This keeps
470/// the type allocation-light and avoids dragging in an LRU dep for
471/// v1. Documented in impl plan as Stage 1 trade-off; a proper LRU
472/// follow-up can swap the inner type without changing the public API.
473#[derive(Debug, Default)]
474pub struct ModelInvocationCache {
475    inner: KeyedStore<f64>,
476    max_entries: usize,
477}
478
479impl ModelInvocationCache {
480    pub fn new(max_entries: usize) -> Self {
481        Self {
482            inner: KeyedStore::default(),
483            max_entries,
484        }
485    }
486
487    /// Lookup. Returns `Some(prob)` on hit, `None` on miss.
488    pub fn get(&self, model: &str, input_hash: u64) -> Option<f64> {
489        self.inner.get(model, input_hash)
490    }
491
492    /// Insert. On overflow (cache size ≥ `max_entries`), drops the
493    /// entire cache before inserting — naive but bounded. Callers
494    /// should size `max_entries` for the expected working set.
495    pub fn insert(&self, model: &str, input_hash: u64, value: f64) {
496        self.inner
497            .insert_bounded(model, input_hash, value, self.max_entries);
498    }
499
500    /// Empty the cache. Useful for `LocyConfig` users who reuse a
501    /// shared cache across evaluations and want explicit reset.
502    pub fn clear(&self) {
503        self.inner.clear();
504    }
505
506    pub fn len(&self) -> usize {
507        self.inner.len()
508    }
509
510    pub fn is_empty(&self) -> bool {
511        self.len() == 0
512    }
513}
514
515// ─── Phase B A3: Candle-backed linear classifier ────────────────────
516
517/// Real `NeuralClassifier` backed by a Candle single-layer logistic
518/// regression. Weights and bias are loaded from a `safetensors` file
519/// at construction time; the forward pass runs on CPU. Sufficient to
520/// close the Phase B "Real Candle classifier loads + invokes via
521/// mock-config TCK harness" gate without committing to a specific
522/// production architecture — future slices can swap the inner module
523/// for an MLP, transformer, or `hf-hub`-fetched checkpoint without
524/// touching the `NeuralClassifier` trait surface.
525///
526/// **Expected safetensors layout:**
527/// - `"weight"` — shape `[n_features]`, dtype `f32`.
528/// - `"bias"` — shape `[1]`, dtype `f32`.
529///
530/// **Feature encoding** (deterministic, matches the TCK fixture):
531/// - `FeatureValue::Float(f)` → `f as f32`.
532/// - `FeatureValue::Int(i)` → `i as f32`.
533/// - `FeatureValue::Bool(b)` → `0.0` / `1.0`.
534/// - `FeatureValue::String(s)` → stable hash projected to `[-1, 1]`
535///   via the djb2 algorithm divided by `i32::MAX` (production
536///   classifiers should route String features through embedding
537///   lookups in M3 D1; this is a pragmatic stand-in).
538/// - `FeatureValue::Null` (or missing feature) → `0.0`.
539pub struct CandleLinearClassifier {
540    name: String,
541    /// Feature names in the order the weight vector expects them.
542    feature_order: Vec<String>,
543    /// Loaded weight vector, length == `feature_order.len()`.
544    weight: Vec<f32>,
545    /// Scalar bias term.
546    bias: f32,
547    /// CPU device handle (cached for tensor construction).
548    device: candle_core::Device,
549}
550
551impl std::fmt::Debug for CandleLinearClassifier {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        f.debug_struct("CandleLinearClassifier")
554            .field("name", &self.name)
555            .field("feature_order", &self.feature_order)
556            .field("n_features", &self.weight.len())
557            .finish_non_exhaustive()
558    }
559}
560
561impl CandleLinearClassifier {
562    /// Load weights from a safetensors file on disk.
563    ///
564    /// `feature_order` must list features in the order matching the
565    /// weight tensor's columns. Returns
566    /// [`ClassifierError::Provider`] if the file is missing or the
567    /// tensor shapes don't match.
568    pub fn load(
569        name: impl Into<String>,
570        feature_order: Vec<String>,
571        weights_path: impl AsRef<std::path::Path>,
572    ) -> ClassifierResult<Self> {
573        let device = candle_core::Device::Cpu;
574        let path = weights_path.as_ref();
575        let tensors = candle_core::safetensors::load(path, &device).map_err(|e| {
576            ClassifierError::Provider(format!(
577                "candle: failed to load safetensors from {path:?}: {e}"
578            ))
579        })?;
580        let weight_t = tensors.get("weight").ok_or_else(|| {
581            ClassifierError::Provider("candle: safetensors missing 'weight' tensor".to_string())
582        })?;
583        let bias_t = tensors.get("bias").ok_or_else(|| {
584            ClassifierError::Provider("candle: safetensors missing 'bias' tensor".to_string())
585        })?;
586        let weight: Vec<f32> = weight_t
587            .flatten_all()
588            .and_then(|t| t.to_vec1::<f32>())
589            .map_err(|e| ClassifierError::Provider(format!("candle: weight read: {e}")))?;
590        let bias_vec: Vec<f32> = bias_t
591            .flatten_all()
592            .and_then(|t| t.to_vec1::<f32>())
593            .map_err(|e| ClassifierError::Provider(format!("candle: bias read: {e}")))?;
594        if bias_vec.len() != 1 {
595            return Err(ClassifierError::Provider(format!(
596                "candle: 'bias' must be scalar (shape [1]); got len={}",
597                bias_vec.len()
598            )));
599        }
600        if weight.len() != feature_order.len() {
601            return Err(ClassifierError::Provider(format!(
602                "candle: weight length {} != feature_order length {}",
603                weight.len(),
604                feature_order.len()
605            )));
606        }
607        Ok(Self {
608            name: name.into(),
609            feature_order,
610            weight,
611            bias: bias_vec[0],
612            device,
613        })
614    }
615
616    /// Project a feature value to an `f32` deterministically. See the
617    /// struct-level documentation for the per-variant policy.
618    fn encode_feature(&self, v: Option<&FeatureValue>) -> f32 {
619        match v {
620            Some(FeatureValue::Float(f)) => *f as f32,
621            Some(FeatureValue::Int(i)) => *i as f32,
622            Some(FeatureValue::Bool(b)) => f32::from(*b),
623            Some(FeatureValue::String(s)) => {
624                // djb2 → i32 → [-1, 1].
625                let mut h: u32 = 5381;
626                for byte in s.as_bytes() {
627                    h = h.wrapping_mul(33).wrapping_add(*byte as u32);
628                }
629                (h as i32) as f32 / i32::MAX as f32
630            }
631            Some(FeatureValue::Null) | None => 0.0,
632            _ => 0.0,
633        }
634    }
635}
636
637#[async_trait]
638impl NeuralClassifier for CandleLinearClassifier {
639    async fn classify(&self, inputs: &[ClassifyInput]) -> ClassifierResult<Vec<f64>> {
640        if inputs.is_empty() {
641            return Ok(Vec::new());
642        }
643        let n_features = self.weight.len();
644        // Pack inputs into a row-major [batch, n_features] vec.
645        let mut data: Vec<f32> = Vec::with_capacity(inputs.len() * n_features);
646        for inp in inputs {
647            for fname in &self.feature_order {
648                data.push(self.encode_feature(inp.features.get(fname)));
649            }
650        }
651        let x = candle_core::Tensor::from_vec(data, (inputs.len(), n_features), &self.device)
652            .map_err(|e| ClassifierError::Provider(format!("candle: input tensor: {e}")))?;
653        let w = candle_core::Tensor::from_slice(&self.weight, (n_features, 1), &self.device)
654            .map_err(|e| ClassifierError::Provider(format!("candle: weight tensor: {e}")))?;
655        let logits = x
656            .matmul(&w)
657            .and_then(|t| t.broadcast_add(&candle_core::Tensor::new(&[self.bias], &self.device)?))
658            .map_err(|e| ClassifierError::Provider(format!("candle: forward pass: {e}")))?;
659        // Sigmoid; flatten to [batch].
660        let probs = candle_nn::ops::sigmoid(&logits)
661            .and_then(|t| t.flatten_all())
662            .and_then(|t| t.to_vec1::<f32>())
663            .map_err(|e| ClassifierError::Provider(format!("candle: sigmoid: {e}")))?;
664        Ok(probs.into_iter().map(|p| p as f64).collect())
665    }
666
667    fn name(&self) -> &str {
668        &self.name
669    }
670}
671
672/// Numerically stable inverse sigmoid (logit) used by the default
673/// `classify_logits`. Probabilities exactly at 0 or 1 produce `±∞`
674/// logits which downstream calibration treats as a degenerate score.
675fn inverse_sigmoid(p: f64) -> f64 {
676    let p = p.clamp(0.0, 1.0);
677    if p == 0.0 {
678        f64::NEG_INFINITY
679    } else if p == 1.0 {
680        f64::INFINITY
681    } else {
682        (p / (1.0 - p)).ln()
683    }
684}
685
686#[cfg(test)]
687mod tests {
688    use super::*;
689
690    #[tokio::test]
691    async fn mock_constant_returns_value_per_row() {
692        let sr = MockClassifier::constant("classify/test", 0.7);
693        let inputs = vec![
694            ClassifyInput::new().with("x", FeatureValue::Float(1.0)),
695            ClassifyInput::new().with("x", FeatureValue::Float(2.0)),
696            ClassifyInput::new().with("x", FeatureValue::Float(3.0)),
697        ];
698        let out = sr.classify(&inputs).await.unwrap();
699        assert_eq!(out, vec![0.7, 0.7, 0.7]);
700        assert_eq!(out.len(), inputs.len());
701        assert_eq!(sr.name(), "classify/test");
702    }
703
704    #[tokio::test]
705    async fn mock_feature_driven() {
706        let sr = MockClassifier::new("classify/feature", |inp| {
707            match inp.features.get("severity") {
708                Some(FeatureValue::Float(v)) => (*v / 10.0).clamp(0.0, 1.0),
709                _ => 0.0,
710            }
711        });
712        let inputs = vec![
713            ClassifyInput::new().with("severity", FeatureValue::Float(2.0)),
714            ClassifyInput::new().with("severity", FeatureValue::Float(9.0)),
715            ClassifyInput::new().with("severity", FeatureValue::Float(15.0)), // clamps to 1.0
716        ];
717        let out = sr.classify(&inputs).await.unwrap();
718        assert_eq!(out, vec![0.2, 0.9, 1.0]);
719    }
720
721    #[tokio::test]
722    async fn classify_logits_default_inverse_sigmoid() {
723        let sr = MockClassifier::constant("classify/test", 0.5);
724        let out = sr.classify_logits(&[ClassifyInput::new()]).await.unwrap();
725        // sigmoid⁻¹(0.5) = 0
726        assert!((out[0] - 0.0).abs() < 1e-12);
727    }
728
729    #[tokio::test]
730    async fn mock_rejects_nan() {
731        let sr = MockClassifier::new("classify/nan", |_| f64::NAN);
732        let err = sr.classify(&[ClassifyInput::new()]).await.unwrap_err();
733        assert!(matches!(err, ClassifierError::DomainViolation { .. }));
734    }
735
736    #[test]
737    fn feature_value_hash_distinguishes_variants() {
738        // Slice 1: Float(0.0) and Int(0) MUST hash differently so the
739        // memoization cache doesn't conflate them.
740        fn h(v: FeatureValue) -> u64 {
741            use std::hash::{Hash, Hasher};
742            let mut hasher = std::collections::hash_map::DefaultHasher::new();
743            v.hash(&mut hasher);
744            hasher.finish()
745        }
746        assert_ne!(h(FeatureValue::Float(0.0)), h(FeatureValue::Int(0)));
747        assert_ne!(h(FeatureValue::Null), h(FeatureValue::Bool(false)));
748        // Same variant + value → same hash.
749        assert_eq!(h(FeatureValue::Float(0.5)), h(FeatureValue::Float(0.5)));
750    }
751
752    #[test]
753    fn classify_input_hash_order_independent() {
754        // Slice 1: HashMap insertion order shouldn't affect the
755        // stable_hash output — same set of features must hash equal.
756        let a = ClassifyInput::new()
757            .with("country", FeatureValue::String("US".into()))
758            .with("revenue", FeatureValue::Float(1.0e6));
759        let b = ClassifyInput::new()
760            .with("revenue", FeatureValue::Float(1.0e6))
761            .with("country", FeatureValue::String("US".into()));
762        assert_eq!(a.stable_hash(), b.stable_hash());
763        let c = ClassifyInput::new()
764            .with("country", FeatureValue::String("DE".into()))
765            .with("revenue", FeatureValue::Float(1.0e6));
766        assert_ne!(a.stable_hash(), c.stable_hash());
767    }
768
769    #[test]
770    fn feature_value_vector_hash() {
771        fn h(v: FeatureValue) -> u64 {
772            use std::hash::{Hash, Hasher};
773            let mut hasher = std::collections::hash_map::DefaultHasher::new();
774            v.hash(&mut hasher);
775            hasher.finish()
776        }
777        let a = FeatureValue::Vector(vec![1.0, 2.0, 3.0]);
778        let b = FeatureValue::Vector(vec![1.0, 2.0, 3.0]);
779        let c = FeatureValue::Vector(vec![1.0, 2.0, 3.5]);
780        assert_eq!(h(a.clone()), h(b));
781        assert_ne!(h(a), h(c));
782    }
783
784    #[test]
785    fn model_invocation_cache_hit_miss() {
786        let cache = ModelInvocationCache::new(100);
787        assert!(cache.get("m", 42).is_none());
788        cache.insert("m", 42, 0.7);
789        assert_eq!(cache.get("m", 42), Some(0.7));
790        // Different model with same hash → miss.
791        assert!(cache.get("other", 42).is_none());
792        // Different hash → miss.
793        assert!(cache.get("m", 43).is_none());
794    }
795
796    #[test]
797    fn model_invocation_cache_evicts_on_overflow() {
798        let cache = ModelInvocationCache::new(2);
799        cache.insert("m", 1, 0.1);
800        cache.insert("m", 2, 0.2);
801        assert_eq!(cache.len(), 2);
802        // Third insert triggers clear() then inserts; net size = 1.
803        cache.insert("m", 3, 0.3);
804        assert_eq!(cache.len(), 1);
805        assert_eq!(cache.get("m", 3), Some(0.3));
806    }
807
808    #[test]
809    fn inverse_sigmoid_endpoints() {
810        assert!(inverse_sigmoid(0.0).is_infinite() && inverse_sigmoid(0.0) < 0.0);
811        assert!(inverse_sigmoid(1.0).is_infinite() && inverse_sigmoid(1.0) > 0.0);
812        assert!((inverse_sigmoid(0.5) - 0.0).abs() < 1e-12);
813    }
814}