ruvector_sona/
reasoning_bank.rs

1//! ReasoningBank - Pattern storage and extraction for SONA
2//!
3//! Implements trajectory clustering using K-means++ for pattern discovery.
4
5use crate::types::{LearnedPattern, PatternType, QueryTrajectory};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// ReasoningBank configuration
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct PatternConfig {
12    /// Number of clusters for K-means++
13    pub k_clusters: usize,
14    /// Embedding dimension
15    pub embedding_dim: usize,
16    /// Maximum K-means iterations
17    pub max_iterations: usize,
18    /// Convergence threshold
19    pub convergence_threshold: f32,
20    /// Minimum cluster size to keep
21    pub min_cluster_size: usize,
22    /// Maximum trajectories to store
23    pub max_trajectories: usize,
24    /// Quality threshold for pattern
25    pub quality_threshold: f32,
26}
27
28impl Default for PatternConfig {
29    fn default() -> Self {
30        Self {
31            k_clusters: 50,
32            embedding_dim: 256,
33            max_iterations: 100,
34            convergence_threshold: 0.001,
35            min_cluster_size: 5,
36            max_trajectories: 10000,
37            quality_threshold: 0.5,
38        }
39    }
40}
41
42/// ReasoningBank for pattern storage and extraction
43#[derive(Clone, Debug)]
44pub struct ReasoningBank {
45    /// Configuration
46    config: PatternConfig,
47    /// Stored trajectories
48    trajectories: Vec<TrajectoryEntry>,
49    /// Extracted patterns
50    patterns: HashMap<u64, LearnedPattern>,
51    /// Next pattern ID
52    next_pattern_id: u64,
53    /// Pattern index (embedding -> pattern_id)
54    pattern_index: Vec<(Vec<f32>, u64)>,
55}
56
57/// Internal trajectory entry with embedding
58#[derive(Clone, Debug)]
59struct TrajectoryEntry {
60    /// Trajectory embedding (query + avg activations)
61    embedding: Vec<f32>,
62    /// Quality score
63    quality: f32,
64    /// Cluster assignment
65    cluster: Option<usize>,
66    /// Original trajectory ID
67    trajectory_id: u64,
68}
69
70impl ReasoningBank {
71    /// Create new ReasoningBank
72    pub fn new(config: PatternConfig) -> Self {
73        Self {
74            config,
75            trajectories: Vec::new(),
76            patterns: HashMap::new(),
77            next_pattern_id: 0,
78            pattern_index: Vec::new(),
79        }
80    }
81
82    /// Add trajectory to bank
83    pub fn add_trajectory(&mut self, trajectory: &QueryTrajectory) {
84        // Compute embedding from trajectory
85        let embedding = self.compute_embedding(trajectory);
86
87        let entry = TrajectoryEntry {
88            embedding,
89            quality: trajectory.final_quality,
90            cluster: None,
91            trajectory_id: trajectory.id,
92        };
93
94        // Enforce capacity
95        if self.trajectories.len() >= self.config.max_trajectories {
96            // Remove oldest entries
97            let to_remove = self.trajectories.len() - self.config.max_trajectories + 1;
98            self.trajectories.drain(0..to_remove);
99        }
100
101        self.trajectories.push(entry);
102    }
103
104    /// Compute embedding from trajectory
105    fn compute_embedding(&self, trajectory: &QueryTrajectory) -> Vec<f32> {
106        let dim = self.config.embedding_dim;
107        let mut embedding = vec![0.0f32; dim];
108
109        // Start with query embedding
110        let query_len = trajectory.query_embedding.len().min(dim);
111        embedding[..query_len].copy_from_slice(&trajectory.query_embedding[..query_len]);
112
113        // Average in step activations (weighted by reward)
114        if !trajectory.steps.is_empty() {
115            let mut total_reward = 0.0f32;
116
117            for step in &trajectory.steps {
118                let weight = step.reward.max(0.0);
119                total_reward += weight;
120
121                for (i, &act) in step.activations.iter().enumerate() {
122                    if i < dim {
123                        embedding[i] += act * weight;
124                    }
125                }
126            }
127
128            if total_reward > 0.0 {
129                for e in &mut embedding {
130                    *e /= total_reward + 1.0; // +1 for query contribution
131                }
132            }
133        }
134
135        // L2 normalize
136        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
137        if norm > 1e-8 {
138            for e in &mut embedding {
139                *e /= norm;
140            }
141        }
142
143        embedding
144    }
145
146    /// Extract patterns using K-means++
147    pub fn extract_patterns(&mut self) -> Vec<LearnedPattern> {
148        if self.trajectories.is_empty() {
149            return Vec::new();
150        }
151
152        let k = self.config.k_clusters.min(self.trajectories.len());
153        if k == 0 {
154            return Vec::new();
155        }
156
157        // K-means++ initialization
158        let centroids = self.kmeans_plus_plus_init(k);
159
160        // Run K-means
161        let (final_centroids, assignments) = self.run_kmeans(centroids);
162
163        // Create patterns from clusters
164        let mut patterns = Vec::new();
165
166        for (cluster_idx, centroid) in final_centroids.into_iter().enumerate() {
167            // Collect cluster members
168            let members: Vec<_> = self.trajectories.iter()
169                .enumerate()
170                .filter(|(i, _)| assignments.get(*i) == Some(&cluster_idx))
171                .map(|(_, t)| t)
172                .collect();
173
174            if members.len() < self.config.min_cluster_size {
175                continue;
176            }
177
178            // Compute cluster statistics
179            let cluster_size = members.len();
180            let total_weight: f32 = members.iter().map(|t| t.quality).sum();
181            let avg_quality = total_weight / cluster_size as f32;
182
183            if avg_quality < self.config.quality_threshold {
184                continue;
185            }
186
187            let pattern_id = self.next_pattern_id;
188            self.next_pattern_id += 1;
189
190            let pattern = LearnedPattern {
191                id: pattern_id,
192                centroid,
193                cluster_size,
194                total_weight,
195                avg_quality,
196                created_at: std::time::SystemTime::now()
197                    .duration_since(std::time::UNIX_EPOCH)
198                    .unwrap_or_default()
199                    .as_secs(),
200                last_accessed: std::time::SystemTime::now()
201                    .duration_since(std::time::UNIX_EPOCH)
202                    .unwrap_or_default()
203                    .as_secs(),
204                access_count: 0,
205                pattern_type: PatternType::General,
206            };
207
208            self.patterns.insert(pattern_id, pattern.clone());
209            self.pattern_index.push((pattern.centroid.clone(), pattern_id));
210            patterns.push(pattern);
211        }
212
213        // Update trajectory cluster assignments
214        for (i, cluster) in assignments.into_iter().enumerate() {
215            if i < self.trajectories.len() {
216                self.trajectories[i].cluster = Some(cluster);
217            }
218        }
219
220        patterns
221    }
222
223    /// K-means++ initialization
224    fn kmeans_plus_plus_init(&self, k: usize) -> Vec<Vec<f32>> {
225        let mut centroids = Vec::with_capacity(k);
226        let n = self.trajectories.len();
227
228        if n == 0 || k == 0 {
229            return centroids;
230        }
231
232        // First centroid: random (use deterministic selection for reproducibility)
233        let first_idx = 0;
234        centroids.push(self.trajectories[first_idx].embedding.clone());
235
236        // Remaining centroids: D^2 weighting
237        for _ in 1..k {
238            // Compute distances to nearest centroid
239            let mut distances: Vec<f32> = self.trajectories.iter()
240                .map(|t| {
241                    centroids.iter()
242                        .map(|c| self.squared_distance(&t.embedding, c))
243                        .fold(f32::MAX, f32::min)
244                })
245                .collect();
246
247            // Normalize to probabilities
248            let total: f32 = distances.iter().sum();
249            if total > 0.0 {
250                for d in &mut distances {
251                    *d /= total;
252                }
253            }
254
255            // Select next centroid (deterministic: highest distance)
256            let (next_idx, _) = distances.iter()
257                .enumerate()
258                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
259                .unwrap_or((0, &0.0));
260
261            centroids.push(self.trajectories[next_idx].embedding.clone());
262        }
263
264        centroids
265    }
266
267    /// Run K-means algorithm
268    fn run_kmeans(&self, mut centroids: Vec<Vec<f32>>) -> (Vec<Vec<f32>>, Vec<usize>) {
269        let n = self.trajectories.len();
270        let k = centroids.len();
271        let dim = self.config.embedding_dim;
272
273        let mut assignments = vec![0usize; n];
274
275        for _iter in 0..self.config.max_iterations {
276            // Assign points to nearest centroid
277            let mut changed = false;
278            for (i, t) in self.trajectories.iter().enumerate() {
279                let (nearest, _) = centroids.iter()
280                    .enumerate()
281                    .map(|(j, c)| (j, self.squared_distance(&t.embedding, c)))
282                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
283                    .unwrap_or((0, 0.0));
284
285                if assignments[i] != nearest {
286                    assignments[i] = nearest;
287                    changed = true;
288                }
289            }
290
291            if !changed {
292                break;
293            }
294
295            // Update centroids
296            let mut new_centroids = vec![vec![0.0f32; dim]; k];
297            let mut counts = vec![0usize; k];
298
299            for (i, t) in self.trajectories.iter().enumerate() {
300                let cluster = assignments[i];
301                counts[cluster] += 1;
302                for (j, &e) in t.embedding.iter().enumerate() {
303                    new_centroids[cluster][j] += e;
304                }
305            }
306
307            // Average and check convergence
308            let mut max_shift = 0.0f32;
309            for (i, new_c) in new_centroids.iter_mut().enumerate() {
310                if counts[i] > 0 {
311                    for e in new_c.iter_mut() {
312                        *e /= counts[i] as f32;
313                    }
314                    let shift = self.squared_distance(new_c, &centroids[i]).sqrt();
315                    max_shift = max_shift.max(shift);
316                }
317            }
318
319            centroids = new_centroids;
320
321            if max_shift < self.config.convergence_threshold {
322                break;
323            }
324        }
325
326        (centroids, assignments)
327    }
328
329    /// Squared Euclidean distance
330    fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
331        a.iter()
332            .zip(b.iter())
333            .map(|(&x, &y)| (x - y) * (x - y))
334            .sum()
335    }
336
337    /// Find similar patterns
338    pub fn find_similar(&self, query: &[f32], k: usize) -> Vec<&LearnedPattern> {
339        let mut scored: Vec<_> = self.patterns.values()
340            .map(|p| (p, p.similarity(query)))
341            .collect();
342
343        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
344
345        scored.into_iter()
346            .take(k)
347            .map(|(p, _)| p)
348            .collect()
349    }
350
351    /// Get pattern by ID
352    pub fn get_pattern(&self, id: u64) -> Option<&LearnedPattern> {
353        self.patterns.get(&id)
354    }
355
356    /// Get mutable pattern by ID
357    pub fn get_pattern_mut(&mut self, id: u64) -> Option<&mut LearnedPattern> {
358        self.patterns.get_mut(&id)
359    }
360
361    /// Get trajectory count
362    pub fn trajectory_count(&self) -> usize {
363        self.trajectories.len()
364    }
365
366    /// Get pattern count
367    pub fn pattern_count(&self) -> usize {
368        self.patterns.len()
369    }
370
371    /// Clear trajectories (keep patterns)
372    pub fn clear_trajectories(&mut self) {
373        self.trajectories.clear();
374    }
375
376    /// Prune low-quality patterns
377    pub fn prune_patterns(&mut self, min_quality: f32, min_accesses: u32, max_age_secs: u64) {
378        let to_remove: Vec<u64> = self.patterns.iter()
379            .filter(|(_, p)| p.should_prune(min_quality, min_accesses, max_age_secs))
380            .map(|(id, _)| *id)
381            .collect();
382
383        for id in to_remove {
384            self.patterns.remove(&id);
385        }
386
387        // Update index
388        self.pattern_index.retain(|(_, id)| self.patterns.contains_key(id));
389    }
390
391    /// Consolidate similar patterns
392    pub fn consolidate(&mut self, similarity_threshold: f32) {
393        let pattern_ids: Vec<u64> = self.patterns.keys().copied().collect();
394        let mut merged = Vec::new();
395
396        for i in 0..pattern_ids.len() {
397            for j in i+1..pattern_ids.len() {
398                let id1 = pattern_ids[i];
399                let id2 = pattern_ids[j];
400
401                if merged.contains(&id1) || merged.contains(&id2) {
402                    continue;
403                }
404
405                if let (Some(p1), Some(p2)) = (self.patterns.get(&id1), self.patterns.get(&id2)) {
406                    let sim = p1.similarity(&p2.centroid);
407                    if sim > similarity_threshold {
408                        // Merge p2 into p1
409                        let merged_pattern = p1.merge(p2);
410                        self.patterns.insert(id1, merged_pattern);
411                        merged.push(id2);
412                    }
413                }
414            }
415        }
416
417        // Remove merged patterns
418        for id in merged {
419            self.patterns.remove(&id);
420        }
421
422        // Update index
423        self.pattern_index.retain(|(_, id)| self.patterns.contains_key(id));
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    fn make_trajectory(id: u64, embedding: Vec<f32>, quality: f32) -> QueryTrajectory {
432        let mut t = QueryTrajectory::new(id, embedding);
433        t.finalize(quality, 1000);
434        t
435    }
436
437    #[test]
438    fn test_bank_creation() {
439        let bank = ReasoningBank::new(PatternConfig::default());
440        assert_eq!(bank.trajectory_count(), 0);
441        assert_eq!(bank.pattern_count(), 0);
442    }
443
444    #[test]
445    fn test_add_trajectory() {
446        let config = PatternConfig {
447            embedding_dim: 4,
448            ..Default::default()
449        };
450        let mut bank = ReasoningBank::new(config);
451
452        let t = make_trajectory(1, vec![0.1, 0.2, 0.3, 0.4], 0.8);
453        bank.add_trajectory(&t);
454
455        assert_eq!(bank.trajectory_count(), 1);
456    }
457
458    #[test]
459    fn test_extract_patterns() {
460        let config = PatternConfig {
461            embedding_dim: 4,
462            k_clusters: 2,
463            min_cluster_size: 2,
464            quality_threshold: 0.0,
465            ..Default::default()
466        };
467        let mut bank = ReasoningBank::new(config);
468
469        // Add clustered trajectories
470        for i in 0..5 {
471            let t = make_trajectory(i, vec![1.0, 0.0, 0.0, 0.0], 0.8);
472            bank.add_trajectory(&t);
473        }
474        for i in 5..10 {
475            let t = make_trajectory(i, vec![0.0, 1.0, 0.0, 0.0], 0.7);
476            bank.add_trajectory(&t);
477        }
478
479        let patterns = bank.extract_patterns();
480        assert!(!patterns.is_empty());
481    }
482
483    #[test]
484    fn test_find_similar() {
485        let config = PatternConfig {
486            embedding_dim: 4,
487            k_clusters: 2,
488            min_cluster_size: 2,
489            quality_threshold: 0.0,
490            ..Default::default()
491        };
492        let mut bank = ReasoningBank::new(config);
493
494        for i in 0..10 {
495            let emb = if i < 5 {
496                vec![1.0, 0.0, 0.0, 0.0]
497            } else {
498                vec![0.0, 1.0, 0.0, 0.0]
499            };
500            bank.add_trajectory(&make_trajectory(i, emb, 0.8));
501        }
502
503        bank.extract_patterns();
504
505        let query = vec![0.9, 0.1, 0.0, 0.0];
506        let similar = bank.find_similar(&query, 1);
507        assert!(!similar.is_empty());
508    }
509
510    #[test]
511    fn test_consolidate() {
512        let config = PatternConfig {
513            embedding_dim: 4,
514            k_clusters: 3,
515            min_cluster_size: 1,
516            quality_threshold: 0.0,
517            ..Default::default()
518        };
519        let mut bank = ReasoningBank::new(config);
520
521        // Create very similar trajectories
522        for i in 0..9 {
523            let emb = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0, 0.0];
524            bank.add_trajectory(&make_trajectory(i, emb, 0.8));
525        }
526
527        bank.extract_patterns();
528        let before = bank.pattern_count();
529
530        bank.consolidate(0.99);
531        let after = bank.pattern_count();
532
533        assert!(after <= before);
534    }
535}