Skip to main content

ruv_neural_decoder/
pipeline.rs

1//! End-to-end decoder pipeline combining multiple decoding strategies.
2
3use ruv_neural_core::embedding::NeuralEmbedding;
4use ruv_neural_core::topology::{CognitiveState, TopologyMetrics};
5use serde::{Deserialize, Serialize};
6
7use crate::clinical::ClinicalScorer;
8use crate::knn_decoder::KnnDecoder;
9use crate::threshold_decoder::ThresholdDecoder;
10use crate::transition_decoder::{StateTransition, TransitionDecoder};
11
12/// End-to-end decoder pipeline that ensembles multiple decoding strategies.
13///
14/// Combines KNN, threshold, and transition decoders with configurable
15/// ensemble weights, and optionally includes clinical scoring.
16pub struct DecoderPipeline {
17    knn: Option<KnnDecoder>,
18    threshold: Option<ThresholdDecoder>,
19    transition: Option<TransitionDecoder>,
20    clinical: Option<ClinicalScorer>,
21    /// Ensemble weights: [knn_weight, threshold_weight, transition_weight].
22    ensemble_weights: [f64; 3],
23}
24
25/// Output of the decoder pipeline.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DecoderOutput {
28    /// Decoded cognitive state (ensemble result).
29    pub state: CognitiveState,
30    /// Overall confidence in `[0, 1]`.
31    pub confidence: f64,
32    /// Detected state transition, if any.
33    pub transition: Option<StateTransition>,
34    /// Brain health index from clinical scorer, if configured.
35    pub brain_health_index: Option<f64>,
36    /// Clinical warning flags.
37    pub clinical_flags: Vec<String>,
38    /// Timestamp of the input data.
39    pub timestamp: f64,
40}
41
42impl DecoderPipeline {
43    /// Create an empty pipeline with default ensemble weights.
44    pub fn new() -> Self {
45        Self {
46            knn: None,
47            threshold: None,
48            transition: None,
49            clinical: None,
50            ensemble_weights: [1.0, 1.0, 1.0],
51        }
52    }
53
54    /// Add a KNN decoder to the pipeline.
55    pub fn with_knn(mut self, k: usize) -> Self {
56        self.knn = Some(KnnDecoder::new(k));
57        self
58    }
59
60    /// Add a threshold decoder to the pipeline.
61    pub fn with_thresholds(mut self) -> Self {
62        self.threshold = Some(ThresholdDecoder::new());
63        self
64    }
65
66    /// Add a transition decoder to the pipeline.
67    pub fn with_transitions(mut self, window: usize) -> Self {
68        self.transition = Some(TransitionDecoder::new(window));
69        self
70    }
71
72    /// Add a clinical scorer to the pipeline.
73    pub fn with_clinical(mut self, baseline: TopologyMetrics, std: TopologyMetrics) -> Self {
74        self.clinical = Some(ClinicalScorer::new(baseline, std));
75        self
76    }
77
78    /// Set custom ensemble weights for [knn, threshold, transition].
79    pub fn with_weights(mut self, weights: [f64; 3]) -> Self {
80        self.ensemble_weights = weights;
81        self
82    }
83
84    /// Get a mutable reference to the KNN decoder (for training).
85    pub fn knn_mut(&mut self) -> Option<&mut KnnDecoder> {
86        self.knn.as_mut()
87    }
88
89    /// Get a mutable reference to the threshold decoder (for configuring thresholds).
90    pub fn threshold_mut(&mut self) -> Option<&mut ThresholdDecoder> {
91        self.threshold.as_mut()
92    }
93
94    /// Get a mutable reference to the transition decoder (for registering patterns).
95    pub fn transition_mut(&mut self) -> Option<&mut TransitionDecoder> {
96        self.transition.as_mut()
97    }
98
99    /// Get a mutable reference to the clinical scorer.
100    pub fn clinical_mut(&mut self) -> Option<&mut ClinicalScorer> {
101        self.clinical.as_mut()
102    }
103
104    /// Run the full decoding pipeline on an embedding and topology metrics.
105    pub fn decode(
106        &mut self,
107        embedding: &NeuralEmbedding,
108        metrics: &TopologyMetrics,
109    ) -> DecoderOutput {
110        let mut candidates: Vec<(CognitiveState, f64, f64)> = Vec::new(); // (state, confidence, weight)
111
112        // KNN decoder.
113        if let Some(ref knn) = self.knn {
114            let (state, conf) = knn.predict_with_confidence(embedding);
115            if state != CognitiveState::Unknown {
116                candidates.push((state, conf, self.ensemble_weights[0]));
117            }
118        }
119
120        // Threshold decoder.
121        if let Some(ref threshold) = self.threshold {
122            let (state, conf) = threshold.decode(metrics);
123            if state != CognitiveState::Unknown {
124                candidates.push((state, conf, self.ensemble_weights[1]));
125            }
126        }
127
128        // Transition decoder.
129        let transition = if let Some(ref mut trans) = self.transition {
130            let result = trans.update(metrics.clone());
131            if let Some(ref t) = result {
132                candidates.push((t.to, t.confidence, self.ensemble_weights[2]));
133            }
134            result
135        } else {
136            None
137        };
138
139        // Ensemble: weighted vote.
140        let (state, confidence) = if candidates.is_empty() {
141            (CognitiveState::Unknown, 0.0)
142        } else {
143            weighted_vote(&candidates)
144        };
145
146        // Clinical scoring.
147        let mut brain_health_index = None;
148        let mut clinical_flags = Vec::new();
149
150        if let Some(ref clinical) = self.clinical {
151            let health = clinical.brain_health_index(metrics);
152            brain_health_index = Some(health);
153
154            let alz = clinical.alzheimer_risk(metrics);
155            let epi = clinical.epilepsy_risk(metrics);
156            let dep = clinical.depression_risk(metrics);
157
158            if alz > 0.7 {
159                clinical_flags.push(format!("Elevated Alzheimer risk: {:.2}", alz));
160            }
161            if epi > 0.7 {
162                clinical_flags.push(format!("Elevated epilepsy risk: {:.2}", epi));
163            }
164            if dep > 0.7 {
165                clinical_flags.push(format!("Elevated depression risk: {:.2}", dep));
166            }
167            if health < 0.3 {
168                clinical_flags.push(format!("Low brain health index: {:.2}", health));
169            }
170        }
171
172        DecoderOutput {
173            state,
174            confidence,
175            transition,
176            brain_health_index,
177            clinical_flags,
178            timestamp: metrics.timestamp,
179        }
180    }
181}
182
183impl Default for DecoderPipeline {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189/// Weighted majority vote across candidate predictions.
190///
191/// Returns the state with the highest weighted confidence and the
192/// normalized confidence score.
193fn weighted_vote(candidates: &[(CognitiveState, f64, f64)]) -> (CognitiveState, f64) {
194    use std::collections::HashMap;
195
196    let mut state_scores: HashMap<CognitiveState, f64> = HashMap::new();
197    let mut total_weight = 0.0;
198
199    for &(state, confidence, weight) in candidates {
200        let score = confidence * weight;
201        *state_scores.entry(state).or_insert(0.0) += score;
202        total_weight += score;
203    }
204
205    let (best_state, best_score) = state_scores
206        .into_iter()
207        .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
208        .unwrap_or((CognitiveState::Unknown, 0.0));
209
210    let normalized = if total_weight > 0.0 {
211        (best_score / total_weight).clamp(0.0, 1.0)
212    } else {
213        0.0
214    };
215
216    (best_state, normalized)
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use ruv_neural_core::brain::Atlas;
223    use ruv_neural_core::embedding::EmbeddingMetadata;
224
225    fn make_embedding(vector: Vec<f64>) -> NeuralEmbedding {
226        NeuralEmbedding::new(
227            vector,
228            0.0,
229            EmbeddingMetadata {
230                subject_id: None,
231                session_id: None,
232                cognitive_state: None,
233                source_atlas: Atlas::DesikanKilliany68,
234                embedding_method: "test".into(),
235            },
236        )
237        .unwrap()
238    }
239
240    fn make_metrics(mincut: f64, modularity: f64) -> TopologyMetrics {
241        TopologyMetrics {
242            global_mincut: mincut,
243            modularity,
244            global_efficiency: 0.3,
245            local_efficiency: 0.2,
246            graph_entropy: 2.0,
247            fiedler_value: 0.5,
248            num_modules: 4,
249            timestamp: 0.0,
250        }
251    }
252
253    #[test]
254    fn test_empty_pipeline() {
255        let mut pipeline = DecoderPipeline::new();
256        let emb = make_embedding(vec![1.0, 0.0]);
257        let met = make_metrics(5.0, 0.4);
258        let output = pipeline.decode(&emb, &met);
259        assert_eq!(output.state, CognitiveState::Unknown);
260        assert!(output.confidence >= 0.0 && output.confidence <= 1.0);
261    }
262
263    #[test]
264    fn test_pipeline_with_knn() {
265        let mut pipeline = DecoderPipeline::new().with_knn(3);
266        pipeline.knn_mut().unwrap().train(vec![
267            (make_embedding(vec![1.0, 0.0]), CognitiveState::Rest),
268            (make_embedding(vec![1.1, 0.1]), CognitiveState::Rest),
269            (make_embedding(vec![0.9, 0.0]), CognitiveState::Rest),
270        ]);
271
272        let output = pipeline.decode(&make_embedding(vec![1.0, 0.05]), &make_metrics(5.0, 0.4));
273        assert_eq!(output.state, CognitiveState::Rest);
274        assert!(output.confidence > 0.0);
275    }
276
277    #[test]
278    fn test_pipeline_with_thresholds() {
279        let mut pipeline = DecoderPipeline::new().with_thresholds();
280        pipeline.threshold_mut().unwrap().set_threshold(
281            CognitiveState::Focused,
282            crate::threshold_decoder::TopologyThreshold {
283                mincut_range: (7.0, 9.0),
284                modularity_range: (0.5, 0.7),
285                efficiency_range: (0.2, 0.4),
286                entropy_range: (1.5, 2.5),
287            },
288        );
289
290        let output = pipeline.decode(
291            &make_embedding(vec![0.5, 0.5]),
292            &make_metrics(8.0, 0.6),
293        );
294        assert_eq!(output.state, CognitiveState::Focused);
295    }
296
297    #[test]
298    fn test_pipeline_with_clinical() {
299        let baseline = make_metrics(5.0, 0.4);
300        let std_met = TopologyMetrics {
301            global_mincut: 1.0,
302            modularity: 0.1,
303            global_efficiency: 0.05,
304            local_efficiency: 0.05,
305            graph_entropy: 0.3,
306            fiedler_value: 0.1,
307            num_modules: 1,
308            timestamp: 0.0,
309        };
310        let mut pipeline = DecoderPipeline::new()
311            .with_knn(1)
312            .with_clinical(baseline, std_met);
313        pipeline.knn_mut().unwrap().train(vec![(
314            make_embedding(vec![1.0]),
315            CognitiveState::Rest,
316        )]);
317
318        let output = pipeline.decode(&make_embedding(vec![1.0]), &make_metrics(5.0, 0.4));
319        assert!(output.brain_health_index.is_some());
320        let health = output.brain_health_index.unwrap();
321        assert!(health >= 0.0 && health <= 1.0);
322    }
323
324    #[test]
325    fn test_pipeline_all_decoders() {
326        let baseline = make_metrics(5.0, 0.4);
327        let std_met = TopologyMetrics {
328            global_mincut: 1.0,
329            modularity: 0.1,
330            global_efficiency: 0.05,
331            local_efficiency: 0.05,
332            graph_entropy: 0.3,
333            fiedler_value: 0.1,
334            num_modules: 1,
335            timestamp: 0.0,
336        };
337        let mut pipeline = DecoderPipeline::new()
338            .with_knn(3)
339            .with_thresholds()
340            .with_transitions(5)
341            .with_clinical(baseline, std_met);
342
343        pipeline.knn_mut().unwrap().train(vec![
344            (make_embedding(vec![1.0, 0.0]), CognitiveState::Rest),
345            (make_embedding(vec![1.1, 0.1]), CognitiveState::Rest),
346        ]);
347
348        let output = pipeline.decode(&make_embedding(vec![1.0, 0.05]), &make_metrics(5.0, 0.4));
349        // Should produce some output regardless of which decoders fire.
350        assert!(output.confidence >= 0.0 && output.confidence <= 1.0);
351        assert!(output.brain_health_index.is_some());
352    }
353
354    #[test]
355    fn test_decoder_output_serialization() {
356        let output = DecoderOutput {
357            state: CognitiveState::Rest,
358            confidence: 0.95,
359            transition: None,
360            brain_health_index: Some(0.92),
361            clinical_flags: vec![],
362            timestamp: 1234.5,
363        };
364        let json = serde_json::to_string(&output).unwrap();
365        let parsed: DecoderOutput = serde_json::from_str(&json).unwrap();
366        assert_eq!(parsed.state, CognitiveState::Rest);
367        assert!((parsed.confidence - 0.95).abs() < 1e-10);
368    }
369}