1use ruv_neural_core::embedding::NeuralEmbedding;
4use ruv_neural_core::topology::{CognitiveState, TopologyMetrics};
5use serde::{Deserialize, Serialize};
6
7use crate::clinical::ClinicalScorer;
8use crate::knn_decoder::KnnDecoder;
9use crate::threshold_decoder::ThresholdDecoder;
10use crate::transition_decoder::{StateTransition, TransitionDecoder};
11
12pub struct DecoderPipeline {
17 knn: Option<KnnDecoder>,
18 threshold: Option<ThresholdDecoder>,
19 transition: Option<TransitionDecoder>,
20 clinical: Option<ClinicalScorer>,
21 ensemble_weights: [f64; 3],
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct DecoderOutput {
28 pub state: CognitiveState,
30 pub confidence: f64,
32 pub transition: Option<StateTransition>,
34 pub brain_health_index: Option<f64>,
36 pub clinical_flags: Vec<String>,
38 pub timestamp: f64,
40}
41
42impl DecoderPipeline {
43 pub fn new() -> Self {
45 Self {
46 knn: None,
47 threshold: None,
48 transition: None,
49 clinical: None,
50 ensemble_weights: [1.0, 1.0, 1.0],
51 }
52 }
53
54 pub fn with_knn(mut self, k: usize) -> Self {
56 self.knn = Some(KnnDecoder::new(k));
57 self
58 }
59
60 pub fn with_thresholds(mut self) -> Self {
62 self.threshold = Some(ThresholdDecoder::new());
63 self
64 }
65
66 pub fn with_transitions(mut self, window: usize) -> Self {
68 self.transition = Some(TransitionDecoder::new(window));
69 self
70 }
71
72 pub fn with_clinical(mut self, baseline: TopologyMetrics, std: TopologyMetrics) -> Self {
74 self.clinical = Some(ClinicalScorer::new(baseline, std));
75 self
76 }
77
78 pub fn with_weights(mut self, weights: [f64; 3]) -> Self {
80 self.ensemble_weights = weights;
81 self
82 }
83
84 pub fn knn_mut(&mut self) -> Option<&mut KnnDecoder> {
86 self.knn.as_mut()
87 }
88
89 pub fn threshold_mut(&mut self) -> Option<&mut ThresholdDecoder> {
91 self.threshold.as_mut()
92 }
93
94 pub fn transition_mut(&mut self) -> Option<&mut TransitionDecoder> {
96 self.transition.as_mut()
97 }
98
99 pub fn clinical_mut(&mut self) -> Option<&mut ClinicalScorer> {
101 self.clinical.as_mut()
102 }
103
104 pub fn decode(
106 &mut self,
107 embedding: &NeuralEmbedding,
108 metrics: &TopologyMetrics,
109 ) -> DecoderOutput {
110 let mut candidates: Vec<(CognitiveState, f64, f64)> = Vec::new(); if let Some(ref knn) = self.knn {
114 let (state, conf) = knn.predict_with_confidence(embedding);
115 if state != CognitiveState::Unknown {
116 candidates.push((state, conf, self.ensemble_weights[0]));
117 }
118 }
119
120 if let Some(ref threshold) = self.threshold {
122 let (state, conf) = threshold.decode(metrics);
123 if state != CognitiveState::Unknown {
124 candidates.push((state, conf, self.ensemble_weights[1]));
125 }
126 }
127
128 let transition = if let Some(ref mut trans) = self.transition {
130 let result = trans.update(metrics.clone());
131 if let Some(ref t) = result {
132 candidates.push((t.to, t.confidence, self.ensemble_weights[2]));
133 }
134 result
135 } else {
136 None
137 };
138
139 let (state, confidence) = if candidates.is_empty() {
141 (CognitiveState::Unknown, 0.0)
142 } else {
143 weighted_vote(&candidates)
144 };
145
146 let mut brain_health_index = None;
148 let mut clinical_flags = Vec::new();
149
150 if let Some(ref clinical) = self.clinical {
151 let health = clinical.brain_health_index(metrics);
152 brain_health_index = Some(health);
153
154 let alz = clinical.alzheimer_risk(metrics);
155 let epi = clinical.epilepsy_risk(metrics);
156 let dep = clinical.depression_risk(metrics);
157
158 if alz > 0.7 {
159 clinical_flags.push(format!("Elevated Alzheimer risk: {:.2}", alz));
160 }
161 if epi > 0.7 {
162 clinical_flags.push(format!("Elevated epilepsy risk: {:.2}", epi));
163 }
164 if dep > 0.7 {
165 clinical_flags.push(format!("Elevated depression risk: {:.2}", dep));
166 }
167 if health < 0.3 {
168 clinical_flags.push(format!("Low brain health index: {:.2}", health));
169 }
170 }
171
172 DecoderOutput {
173 state,
174 confidence,
175 transition,
176 brain_health_index,
177 clinical_flags,
178 timestamp: metrics.timestamp,
179 }
180 }
181}
182
183impl Default for DecoderPipeline {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189fn weighted_vote(candidates: &[(CognitiveState, f64, f64)]) -> (CognitiveState, f64) {
194 use std::collections::HashMap;
195
196 let mut state_scores: HashMap<CognitiveState, f64> = HashMap::new();
197 let mut total_weight = 0.0;
198
199 for &(state, confidence, weight) in candidates {
200 let score = confidence * weight;
201 *state_scores.entry(state).or_insert(0.0) += score;
202 total_weight += score;
203 }
204
205 let (best_state, best_score) = state_scores
206 .into_iter()
207 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
208 .unwrap_or((CognitiveState::Unknown, 0.0));
209
210 let normalized = if total_weight > 0.0 {
211 (best_score / total_weight).clamp(0.0, 1.0)
212 } else {
213 0.0
214 };
215
216 (best_state, normalized)
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use ruv_neural_core::brain::Atlas;
223 use ruv_neural_core::embedding::EmbeddingMetadata;
224
225 fn make_embedding(vector: Vec<f64>) -> NeuralEmbedding {
226 NeuralEmbedding::new(
227 vector,
228 0.0,
229 EmbeddingMetadata {
230 subject_id: None,
231 session_id: None,
232 cognitive_state: None,
233 source_atlas: Atlas::DesikanKilliany68,
234 embedding_method: "test".into(),
235 },
236 )
237 .unwrap()
238 }
239
240 fn make_metrics(mincut: f64, modularity: f64) -> TopologyMetrics {
241 TopologyMetrics {
242 global_mincut: mincut,
243 modularity,
244 global_efficiency: 0.3,
245 local_efficiency: 0.2,
246 graph_entropy: 2.0,
247 fiedler_value: 0.5,
248 num_modules: 4,
249 timestamp: 0.0,
250 }
251 }
252
253 #[test]
254 fn test_empty_pipeline() {
255 let mut pipeline = DecoderPipeline::new();
256 let emb = make_embedding(vec![1.0, 0.0]);
257 let met = make_metrics(5.0, 0.4);
258 let output = pipeline.decode(&emb, &met);
259 assert_eq!(output.state, CognitiveState::Unknown);
260 assert!(output.confidence >= 0.0 && output.confidence <= 1.0);
261 }
262
263 #[test]
264 fn test_pipeline_with_knn() {
265 let mut pipeline = DecoderPipeline::new().with_knn(3);
266 pipeline.knn_mut().unwrap().train(vec![
267 (make_embedding(vec![1.0, 0.0]), CognitiveState::Rest),
268 (make_embedding(vec![1.1, 0.1]), CognitiveState::Rest),
269 (make_embedding(vec![0.9, 0.0]), CognitiveState::Rest),
270 ]);
271
272 let output = pipeline.decode(&make_embedding(vec![1.0, 0.05]), &make_metrics(5.0, 0.4));
273 assert_eq!(output.state, CognitiveState::Rest);
274 assert!(output.confidence > 0.0);
275 }
276
277 #[test]
278 fn test_pipeline_with_thresholds() {
279 let mut pipeline = DecoderPipeline::new().with_thresholds();
280 pipeline.threshold_mut().unwrap().set_threshold(
281 CognitiveState::Focused,
282 crate::threshold_decoder::TopologyThreshold {
283 mincut_range: (7.0, 9.0),
284 modularity_range: (0.5, 0.7),
285 efficiency_range: (0.2, 0.4),
286 entropy_range: (1.5, 2.5),
287 },
288 );
289
290 let output = pipeline.decode(
291 &make_embedding(vec![0.5, 0.5]),
292 &make_metrics(8.0, 0.6),
293 );
294 assert_eq!(output.state, CognitiveState::Focused);
295 }
296
297 #[test]
298 fn test_pipeline_with_clinical() {
299 let baseline = make_metrics(5.0, 0.4);
300 let std_met = TopologyMetrics {
301 global_mincut: 1.0,
302 modularity: 0.1,
303 global_efficiency: 0.05,
304 local_efficiency: 0.05,
305 graph_entropy: 0.3,
306 fiedler_value: 0.1,
307 num_modules: 1,
308 timestamp: 0.0,
309 };
310 let mut pipeline = DecoderPipeline::new()
311 .with_knn(1)
312 .with_clinical(baseline, std_met);
313 pipeline.knn_mut().unwrap().train(vec![(
314 make_embedding(vec![1.0]),
315 CognitiveState::Rest,
316 )]);
317
318 let output = pipeline.decode(&make_embedding(vec![1.0]), &make_metrics(5.0, 0.4));
319 assert!(output.brain_health_index.is_some());
320 let health = output.brain_health_index.unwrap();
321 assert!(health >= 0.0 && health <= 1.0);
322 }
323
324 #[test]
325 fn test_pipeline_all_decoders() {
326 let baseline = make_metrics(5.0, 0.4);
327 let std_met = TopologyMetrics {
328 global_mincut: 1.0,
329 modularity: 0.1,
330 global_efficiency: 0.05,
331 local_efficiency: 0.05,
332 graph_entropy: 0.3,
333 fiedler_value: 0.1,
334 num_modules: 1,
335 timestamp: 0.0,
336 };
337 let mut pipeline = DecoderPipeline::new()
338 .with_knn(3)
339 .with_thresholds()
340 .with_transitions(5)
341 .with_clinical(baseline, std_met);
342
343 pipeline.knn_mut().unwrap().train(vec![
344 (make_embedding(vec![1.0, 0.0]), CognitiveState::Rest),
345 (make_embedding(vec![1.1, 0.1]), CognitiveState::Rest),
346 ]);
347
348 let output = pipeline.decode(&make_embedding(vec![1.0, 0.05]), &make_metrics(5.0, 0.4));
349 assert!(output.confidence >= 0.0 && output.confidence <= 1.0);
351 assert!(output.brain_health_index.is_some());
352 }
353
354 #[test]
355 fn test_decoder_output_serialization() {
356 let output = DecoderOutput {
357 state: CognitiveState::Rest,
358 confidence: 0.95,
359 transition: None,
360 brain_health_index: Some(0.92),
361 clinical_flags: vec![],
362 timestamp: 1234.5,
363 };
364 let json = serde_json::to_string(&output).unwrap();
365 let parsed: DecoderOutput = serde_json::from_str(&json).unwrap();
366 assert_eq!(parsed.state, CognitiveState::Rest);
367 assert!((parsed.confidence - 0.95).abs() < 1e-10);
368 }
369}