Skip to main content

ruv_neural_decoder/
knn_decoder.rs

1//! K-Nearest Neighbor decoder for cognitive state classification.
2
3use std::collections::HashMap;
4
5use ruv_neural_core::embedding::NeuralEmbedding;
6use ruv_neural_core::error::{Result, RuvNeuralError};
7use ruv_neural_core::topology::CognitiveState;
8use ruv_neural_core::traits::StateDecoder;
9
10/// Simple KNN decoder using stored labeled embeddings.
11///
12/// Classifies a query embedding by majority vote among its `k` nearest
13/// neighbors in Euclidean distance.
14pub struct KnnDecoder {
15    labeled_embeddings: Vec<(NeuralEmbedding, CognitiveState)>,
16    k: usize,
17}
18
19impl KnnDecoder {
20    /// Create a new KNN decoder with the given `k` (number of neighbors).
21    pub fn new(k: usize) -> Self {
22        let k = if k == 0 { 1 } else { k };
23        Self {
24            labeled_embeddings: Vec::new(),
25            k,
26        }
27    }
28
29    /// Load labeled training data into the decoder.
30    pub fn train(&mut self, embeddings: Vec<(NeuralEmbedding, CognitiveState)>) {
31        self.labeled_embeddings = embeddings;
32    }
33
34    /// Predict the cognitive state for a query embedding using majority vote.
35    ///
36    /// Returns `CognitiveState::Unknown` if no training data is available.
37    pub fn predict(&self, embedding: &NeuralEmbedding) -> CognitiveState {
38        self.predict_with_confidence(embedding).0
39    }
40
41    /// Predict the cognitive state with a confidence score in `[0, 1]`.
42    ///
43    /// Confidence is the fraction of the `k` nearest neighbors that agree
44    /// on the winning state.
45    pub fn predict_with_confidence(&self, embedding: &NeuralEmbedding) -> (CognitiveState, f64) {
46        if self.labeled_embeddings.is_empty() {
47            return (CognitiveState::Unknown, 0.0);
48        }
49
50        // Compute distances to all stored embeddings.
51        let mut distances: Vec<(f64, &CognitiveState)> = self
52            .labeled_embeddings
53            .iter()
54            .filter_map(|(stored, state)| {
55                let dist = euclidean_distance(&embedding.vector, &stored.vector);
56                Some((dist, state))
57            })
58            .collect();
59
60        // Sort by distance ascending.
61        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
62
63        // Take top-k neighbors.
64        let k = self.k.min(distances.len());
65        let neighbors = &distances[..k];
66
67        // Majority vote with distance weighting.
68        let mut vote_counts: HashMap<CognitiveState, f64> = HashMap::new();
69        for (dist, state) in neighbors {
70            // Use inverse distance weighting; add epsilon to avoid division by zero.
71            let weight = 1.0 / (dist + 1e-10);
72            *vote_counts.entry(**state).or_insert(0.0) += weight;
73        }
74
75        // Find the state with the highest weighted vote.
76        let total_weight: f64 = vote_counts.values().sum();
77        let (best_state, best_weight) = vote_counts
78            .into_iter()
79            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
80            .unwrap_or((CognitiveState::Unknown, 0.0));
81
82        let confidence = if total_weight > 0.0 {
83            (best_weight / total_weight).clamp(0.0, 1.0)
84        } else {
85            0.0
86        };
87
88        (best_state, confidence)
89    }
90
91    /// Number of stored labeled embeddings.
92    pub fn num_samples(&self) -> usize {
93        self.labeled_embeddings.len()
94    }
95}
96
97impl StateDecoder for KnnDecoder {
98    fn decode(&self, embedding: &NeuralEmbedding) -> Result<CognitiveState> {
99        if self.labeled_embeddings.is_empty() {
100            return Err(RuvNeuralError::Decoder(
101                "KNN decoder has no training data".into(),
102            ));
103        }
104        Ok(self.predict(embedding))
105    }
106
107    fn decode_with_confidence(
108        &self,
109        embedding: &NeuralEmbedding,
110    ) -> Result<(CognitiveState, f64)> {
111        if self.labeled_embeddings.is_empty() {
112            return Err(RuvNeuralError::Decoder(
113                "KNN decoder has no training data".into(),
114            ));
115        }
116        Ok(self.predict_with_confidence(embedding))
117    }
118}
119
120/// Euclidean distance between two vectors of the same length.
121///
122/// If lengths differ, computes distance over the shorter prefix.
123fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
124    a.iter()
125        .zip(b.iter())
126        .map(|(x, y)| (x - y) * (x - y))
127        .sum::<f64>()
128        .sqrt()
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use ruv_neural_core::brain::Atlas;
135    use ruv_neural_core::embedding::EmbeddingMetadata;
136
137    fn make_embedding(vector: Vec<f64>) -> NeuralEmbedding {
138        NeuralEmbedding::new(
139            vector,
140            0.0,
141            EmbeddingMetadata {
142                subject_id: None,
143                session_id: None,
144                cognitive_state: None,
145                source_atlas: Atlas::DesikanKilliany68,
146                embedding_method: "test".into(),
147            },
148        )
149        .unwrap()
150    }
151
152    #[test]
153    fn test_knn_classifies_correctly() {
154        let mut decoder = KnnDecoder::new(3);
155        decoder.train(vec![
156            (make_embedding(vec![1.0, 0.0, 0.0]), CognitiveState::Rest),
157            (make_embedding(vec![1.1, 0.1, 0.0]), CognitiveState::Rest),
158            (make_embedding(vec![0.9, 0.0, 0.1]), CognitiveState::Rest),
159            (
160                make_embedding(vec![0.0, 1.0, 0.0]),
161                CognitiveState::Focused,
162            ),
163            (
164                make_embedding(vec![0.1, 1.1, 0.0]),
165                CognitiveState::Focused,
166            ),
167            (
168                make_embedding(vec![0.0, 0.9, 0.1]),
169                CognitiveState::Focused,
170            ),
171        ]);
172
173        // Query near the Rest cluster.
174        let query = make_embedding(vec![1.0, 0.05, 0.0]);
175        let (state, confidence) = decoder.predict_with_confidence(&query);
176        assert_eq!(state, CognitiveState::Rest);
177        assert!(confidence > 0.5);
178
179        // Query near the Focused cluster.
180        let query = make_embedding(vec![0.05, 1.0, 0.0]);
181        let state = decoder.predict(&query);
182        assert_eq!(state, CognitiveState::Focused);
183    }
184
185    #[test]
186    fn test_knn_empty_returns_unknown() {
187        let decoder = KnnDecoder::new(3);
188        let query = make_embedding(vec![1.0, 0.0]);
189        assert_eq!(decoder.predict(&query), CognitiveState::Unknown);
190    }
191
192    #[test]
193    fn test_confidence_in_range() {
194        let mut decoder = KnnDecoder::new(3);
195        decoder.train(vec![
196            (make_embedding(vec![1.0, 0.0]), CognitiveState::Rest),
197            (make_embedding(vec![0.0, 1.0]), CognitiveState::Focused),
198        ]);
199        let query = make_embedding(vec![0.5, 0.5]);
200        let (_, confidence) = decoder.predict_with_confidence(&query);
201        assert!(confidence >= 0.0 && confidence <= 1.0);
202    }
203
204    #[test]
205    fn test_state_decoder_trait() {
206        let mut decoder = KnnDecoder::new(1);
207        decoder.train(vec![(
208            make_embedding(vec![1.0, 0.0]),
209            CognitiveState::MotorPlanning,
210        )]);
211        let query = make_embedding(vec![1.0, 0.0]);
212        let result = decoder.decode(&query).unwrap();
213        assert_eq!(result, CognitiveState::MotorPlanning);
214    }
215
216    #[test]
217    fn test_state_decoder_empty_errors() {
218        let decoder = KnnDecoder::new(3);
219        let query = make_embedding(vec![1.0]);
220        assert!(decoder.decode(&query).is_err());
221    }
222}