1use std::collections::HashMap;
4
5use ruv_neural_core::embedding::NeuralEmbedding;
6use ruv_neural_core::error::{Result, RuvNeuralError};
7use ruv_neural_core::topology::CognitiveState;
8use ruv_neural_core::traits::StateDecoder;
9
10pub struct KnnDecoder {
15 labeled_embeddings: Vec<(NeuralEmbedding, CognitiveState)>,
16 k: usize,
17}
18
19impl KnnDecoder {
20 pub fn new(k: usize) -> Self {
22 let k = if k == 0 { 1 } else { k };
23 Self {
24 labeled_embeddings: Vec::new(),
25 k,
26 }
27 }
28
29 pub fn train(&mut self, embeddings: Vec<(NeuralEmbedding, CognitiveState)>) {
31 self.labeled_embeddings = embeddings;
32 }
33
34 pub fn predict(&self, embedding: &NeuralEmbedding) -> CognitiveState {
38 self.predict_with_confidence(embedding).0
39 }
40
41 pub fn predict_with_confidence(&self, embedding: &NeuralEmbedding) -> (CognitiveState, f64) {
46 if self.labeled_embeddings.is_empty() {
47 return (CognitiveState::Unknown, 0.0);
48 }
49
50 let mut distances: Vec<(f64, &CognitiveState)> = self
52 .labeled_embeddings
53 .iter()
54 .filter_map(|(stored, state)| {
55 let dist = euclidean_distance(&embedding.vector, &stored.vector);
56 Some((dist, state))
57 })
58 .collect();
59
60 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
62
63 let k = self.k.min(distances.len());
65 let neighbors = &distances[..k];
66
67 let mut vote_counts: HashMap<CognitiveState, f64> = HashMap::new();
69 for (dist, state) in neighbors {
70 let weight = 1.0 / (dist + 1e-10);
72 *vote_counts.entry(**state).or_insert(0.0) += weight;
73 }
74
75 let total_weight: f64 = vote_counts.values().sum();
77 let (best_state, best_weight) = vote_counts
78 .into_iter()
79 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
80 .unwrap_or((CognitiveState::Unknown, 0.0));
81
82 let confidence = if total_weight > 0.0 {
83 (best_weight / total_weight).clamp(0.0, 1.0)
84 } else {
85 0.0
86 };
87
88 (best_state, confidence)
89 }
90
91 pub fn num_samples(&self) -> usize {
93 self.labeled_embeddings.len()
94 }
95}
96
97impl StateDecoder for KnnDecoder {
98 fn decode(&self, embedding: &NeuralEmbedding) -> Result<CognitiveState> {
99 if self.labeled_embeddings.is_empty() {
100 return Err(RuvNeuralError::Decoder(
101 "KNN decoder has no training data".into(),
102 ));
103 }
104 Ok(self.predict(embedding))
105 }
106
107 fn decode_with_confidence(
108 &self,
109 embedding: &NeuralEmbedding,
110 ) -> Result<(CognitiveState, f64)> {
111 if self.labeled_embeddings.is_empty() {
112 return Err(RuvNeuralError::Decoder(
113 "KNN decoder has no training data".into(),
114 ));
115 }
116 Ok(self.predict_with_confidence(embedding))
117 }
118}
119
120fn euclidean_distance(a: &[f64], b: &[f64]) -> f64 {
124 a.iter()
125 .zip(b.iter())
126 .map(|(x, y)| (x - y) * (x - y))
127 .sum::<f64>()
128 .sqrt()
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use ruv_neural_core::brain::Atlas;
135 use ruv_neural_core::embedding::EmbeddingMetadata;
136
137 fn make_embedding(vector: Vec<f64>) -> NeuralEmbedding {
138 NeuralEmbedding::new(
139 vector,
140 0.0,
141 EmbeddingMetadata {
142 subject_id: None,
143 session_id: None,
144 cognitive_state: None,
145 source_atlas: Atlas::DesikanKilliany68,
146 embedding_method: "test".into(),
147 },
148 )
149 .unwrap()
150 }
151
152 #[test]
153 fn test_knn_classifies_correctly() {
154 let mut decoder = KnnDecoder::new(3);
155 decoder.train(vec![
156 (make_embedding(vec![1.0, 0.0, 0.0]), CognitiveState::Rest),
157 (make_embedding(vec![1.1, 0.1, 0.0]), CognitiveState::Rest),
158 (make_embedding(vec![0.9, 0.0, 0.1]), CognitiveState::Rest),
159 (
160 make_embedding(vec![0.0, 1.0, 0.0]),
161 CognitiveState::Focused,
162 ),
163 (
164 make_embedding(vec![0.1, 1.1, 0.0]),
165 CognitiveState::Focused,
166 ),
167 (
168 make_embedding(vec![0.0, 0.9, 0.1]),
169 CognitiveState::Focused,
170 ),
171 ]);
172
173 let query = make_embedding(vec![1.0, 0.05, 0.0]);
175 let (state, confidence) = decoder.predict_with_confidence(&query);
176 assert_eq!(state, CognitiveState::Rest);
177 assert!(confidence > 0.5);
178
179 let query = make_embedding(vec![0.05, 1.0, 0.0]);
181 let state = decoder.predict(&query);
182 assert_eq!(state, CognitiveState::Focused);
183 }
184
185 #[test]
186 fn test_knn_empty_returns_unknown() {
187 let decoder = KnnDecoder::new(3);
188 let query = make_embedding(vec![1.0, 0.0]);
189 assert_eq!(decoder.predict(&query), CognitiveState::Unknown);
190 }
191
192 #[test]
193 fn test_confidence_in_range() {
194 let mut decoder = KnnDecoder::new(3);
195 decoder.train(vec![
196 (make_embedding(vec![1.0, 0.0]), CognitiveState::Rest),
197 (make_embedding(vec![0.0, 1.0]), CognitiveState::Focused),
198 ]);
199 let query = make_embedding(vec![0.5, 0.5]);
200 let (_, confidence) = decoder.predict_with_confidence(&query);
201 assert!(confidence >= 0.0 && confidence <= 1.0);
202 }
203
204 #[test]
205 fn test_state_decoder_trait() {
206 let mut decoder = KnnDecoder::new(1);
207 decoder.train(vec![(
208 make_embedding(vec![1.0, 0.0]),
209 CognitiveState::MotorPlanning,
210 )]);
211 let query = make_embedding(vec![1.0, 0.0]);
212 let result = decoder.decode(&query).unwrap();
213 assert_eq!(result, CognitiveState::MotorPlanning);
214 }
215
216 #[test]
217 fn test_state_decoder_empty_errors() {
218 let decoder = KnnDecoder::new(3);
219 let query = make_embedding(vec![1.0]);
220 assert!(decoder.decode(&query).is_err());
221 }
222}