1use std::collections::HashMap;
4use std::collections::VecDeque;
5
6use ruv_neural_core::embedding::NeuralEmbedding;
7use ruv_neural_core::error::Result;
8use ruv_neural_core::topology::CognitiveState;
9use ruv_neural_core::traits::NeuralMemory;
10
11#[derive(Debug, Clone)]
15pub struct NeuralMemoryStore {
16 embeddings: VecDeque<NeuralEmbedding>,
18 index: HashMap<String, Vec<usize>>,
20 capacity: usize,
22 evicted_count: usize,
25}
26
27impl NeuralMemoryStore {
28 pub fn new(capacity: usize) -> Self {
30 Self {
31 embeddings: VecDeque::with_capacity(capacity.min(1024)),
32 index: HashMap::new(),
33 capacity,
34 evicted_count: 0,
35 }
36 }
37
38 pub fn store(&mut self, embedding: NeuralEmbedding) -> Result<usize> {
44 if let Some(first) = self.embeddings.front() {
46 if embedding.dimension != first.dimension {
47 return Err(ruv_neural_core::error::RuvNeuralError::DimensionMismatch {
48 expected: first.dimension,
49 got: embedding.dimension,
50 });
51 }
52 }
53
54 if self.embeddings.len() >= self.capacity {
55 self.evict_oldest();
56 }
57
58 let idx = self.embeddings.len();
59
60 if let Some(ref subject_id) = embedding.metadata.subject_id {
61 self.index
62 .entry(subject_id.clone())
63 .or_default()
64 .push(idx);
65 }
66
67 self.embeddings.push_back(embedding);
68 Ok(idx)
69 }
70
71 pub fn get(&self, id: usize) -> Option<&NeuralEmbedding> {
73 self.embeddings.get(id)
74 }
75
76 pub fn len(&self) -> usize {
78 self.embeddings.len()
79 }
80
81 pub fn is_empty(&self) -> bool {
83 self.embeddings.is_empty()
84 }
85
86 pub fn query_nearest(&self, query: &NeuralEmbedding, k: usize) -> Vec<(usize, f64)> {
90 let mut distances: Vec<(usize, f64)> = self
91 .embeddings
92 .iter()
93 .enumerate()
94 .filter_map(|(i, emb)| {
95 emb.euclidean_distance(query).ok().map(|d| (i, d))
96 })
97 .collect();
98
99 distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
100 distances.truncate(k);
101 distances
102 }
103
104 pub fn query_by_state(&self, state: CognitiveState) -> Vec<&NeuralEmbedding> {
106 self.embeddings
107 .iter()
108 .filter(|e| e.metadata.cognitive_state == Some(state))
109 .collect()
110 }
111
112 pub fn query_by_subject(&self, subject_id: &str) -> Vec<&NeuralEmbedding> {
114 match self.index.get(subject_id) {
115 Some(indices) => indices
116 .iter()
117 .filter_map(|&i| self.embeddings.get(i))
118 .collect(),
119 None => Vec::new(),
120 }
121 }
122
123 pub fn query_time_range(&self, start: f64, end: f64) -> Vec<&NeuralEmbedding> {
125 self.embeddings
126 .iter()
127 .filter(|e| e.timestamp >= start && e.timestamp <= end)
128 .collect()
129 }
130
131 pub fn embeddings_iter(&self) -> impl Iterator<Item = &NeuralEmbedding> {
136 self.embeddings.iter()
137 }
138
139 pub fn embeddings(&self) -> Vec<&NeuralEmbedding> {
141 self.embeddings.iter().collect()
142 }
143
144 pub fn capacity(&self) -> usize {
146 self.capacity
147 }
148
149 fn evict_oldest(&mut self) {
154 if self.embeddings.is_empty() {
155 return;
156 }
157
158 let evicted = self.embeddings.pop_front().unwrap();
159 self.evicted_count += 1;
160
161 if let Some(ref subject_id) = evicted.metadata.subject_id {
163 if let Some(indices) = self.index.get_mut(subject_id) {
164 indices.retain(|&i| i != 0);
165 }
166 }
167
168 for indices in self.index.values_mut() {
170 for idx in indices.iter_mut() {
171 *idx -= 1;
172 }
173 }
174
175 self.index.retain(|_, v| !v.is_empty());
177 }
178}
179
180impl NeuralMemory for NeuralMemoryStore {
181 fn store(&mut self, embedding: &NeuralEmbedding) -> Result<()> {
182 NeuralMemoryStore::store(self, embedding.clone())?;
183 Ok(())
184 }
185
186 fn query_nearest(
187 &self,
188 embedding: &NeuralEmbedding,
189 k: usize,
190 ) -> Result<Vec<NeuralEmbedding>> {
191 let results = NeuralMemoryStore::query_nearest(self, embedding, k);
192 Ok(results
193 .into_iter()
194 .filter_map(|(i, _)| self.get(i).cloned())
195 .collect())
196 }
197
198 fn query_by_state(&self, state: CognitiveState) -> Result<Vec<NeuralEmbedding>> {
199 Ok(NeuralMemoryStore::query_by_state(self, state)
200 .into_iter()
201 .cloned()
202 .collect())
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use ruv_neural_core::brain::Atlas;
210 use ruv_neural_core::embedding::EmbeddingMetadata;
211
212 fn make_embedding(vector: Vec<f64>, subject: &str, timestamp: f64) -> NeuralEmbedding {
213 NeuralEmbedding::new(
214 vector,
215 timestamp,
216 EmbeddingMetadata {
217 subject_id: Some(subject.to_string()),
218 session_id: None,
219 cognitive_state: Some(CognitiveState::Rest),
220 source_atlas: Atlas::Schaefer100,
221 embedding_method: "test".to_string(),
222 },
223 )
224 .unwrap()
225 }
226
227 fn make_embedding_with_state(
228 vector: Vec<f64>,
229 state: CognitiveState,
230 timestamp: f64,
231 ) -> NeuralEmbedding {
232 NeuralEmbedding::new(
233 vector,
234 timestamp,
235 EmbeddingMetadata {
236 subject_id: Some("subj1".to_string()),
237 session_id: None,
238 cognitive_state: Some(state),
239 source_atlas: Atlas::Schaefer100,
240 embedding_method: "test".to_string(),
241 },
242 )
243 .unwrap()
244 }
245
246 #[test]
247 fn store_and_retrieve() {
248 let mut store = NeuralMemoryStore::new(100);
249 let emb = make_embedding(vec![1.0, 2.0, 3.0], "subj1", 0.0);
250 let idx = store.store(emb.clone()).unwrap();
251 assert_eq!(idx, 0);
252 assert_eq!(store.len(), 1);
253
254 let retrieved = store.get(0).unwrap();
255 assert_eq!(retrieved.vector, vec![1.0, 2.0, 3.0]);
256 }
257
258 #[test]
259 fn nearest_neighbor_returns_correct_results() {
260 let mut store = NeuralMemoryStore::new(100);
261 store
262 .store(make_embedding(vec![0.0, 0.0, 0.0], "a", 0.0))
263 .unwrap();
264 store
265 .store(make_embedding(vec![1.0, 0.0, 0.0], "b", 1.0))
266 .unwrap();
267 store
268 .store(make_embedding(vec![10.0, 10.0, 10.0], "c", 2.0))
269 .unwrap();
270
271 let query = make_embedding(vec![0.5, 0.0, 0.0], "q", 3.0);
272 let results = store.query_nearest(&query, 2);
273
274 assert_eq!(results.len(), 2);
275 assert!(results[0].1 <= results[1].1);
277 }
278
279 #[test]
280 fn query_by_state_filters_correctly() {
281 let mut store = NeuralMemoryStore::new(100);
282 store
283 .store(make_embedding_with_state(
284 vec![1.0, 0.0],
285 CognitiveState::Rest,
286 0.0,
287 ))
288 .unwrap();
289 store
290 .store(make_embedding_with_state(
291 vec![0.0, 1.0],
292 CognitiveState::Focused,
293 1.0,
294 ))
295 .unwrap();
296 store
297 .store(make_embedding_with_state(
298 vec![1.0, 1.0],
299 CognitiveState::Rest,
300 2.0,
301 ))
302 .unwrap();
303
304 let resting = store.query_by_state(CognitiveState::Rest);
305 assert_eq!(resting.len(), 2);
306
307 let focused = store.query_by_state(CognitiveState::Focused);
308 assert_eq!(focused.len(), 1);
309 }
310
311 #[test]
312 fn query_by_subject() {
313 let mut store = NeuralMemoryStore::new(100);
314 store
315 .store(make_embedding(vec![1.0, 0.0], "alice", 0.0))
316 .unwrap();
317 store
318 .store(make_embedding(vec![0.0, 1.0], "bob", 1.0))
319 .unwrap();
320 store
321 .store(make_embedding(vec![1.0, 1.0], "alice", 2.0))
322 .unwrap();
323
324 let alice = store.query_by_subject("alice");
325 assert_eq!(alice.len(), 2);
326
327 let bob = store.query_by_subject("bob");
328 assert_eq!(bob.len(), 1);
329
330 let unknown = store.query_by_subject("charlie");
331 assert_eq!(unknown.len(), 0);
332 }
333
334 #[test]
335 fn query_time_range() {
336 let mut store = NeuralMemoryStore::new(100);
337 store
338 .store(make_embedding(vec![1.0], "a", 1.0))
339 .unwrap();
340 store
341 .store(make_embedding(vec![2.0], "a", 5.0))
342 .unwrap();
343 store
344 .store(make_embedding(vec![3.0], "a", 10.0))
345 .unwrap();
346
347 let in_range = store.query_time_range(2.0, 8.0);
348 assert_eq!(in_range.len(), 1);
349 assert_eq!(in_range[0].vector, vec![2.0]);
350
351 let all = store.query_time_range(0.0, 20.0);
352 assert_eq!(all.len(), 3);
353 }
354
355 #[test]
356 fn capacity_eviction() {
357 let mut store = NeuralMemoryStore::new(2);
358 store
359 .store(make_embedding(vec![1.0], "a", 0.0))
360 .unwrap();
361 store
362 .store(make_embedding(vec![2.0], "b", 1.0))
363 .unwrap();
364 assert_eq!(store.len(), 2);
365
366 store
368 .store(make_embedding(vec![3.0], "c", 2.0))
369 .unwrap();
370 assert_eq!(store.len(), 2);
371 assert_eq!(store.get(0).unwrap().vector, vec![2.0]);
373 }
374}