Skip to main content

ruvector_sona/
reasoning_bank.rs

1//! ReasoningBank - Pattern storage and extraction for SONA
2//!
3//! Implements trajectory clustering using K-means++ for pattern discovery.
4
5use crate::types::{LearnedPattern, PatternType, QueryTrajectory};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// ReasoningBank configuration
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct PatternConfig {
12    /// Number of clusters for K-means++
13    pub k_clusters: usize,
14    /// Embedding dimension
15    pub embedding_dim: usize,
16    /// Maximum K-means iterations
17    pub max_iterations: usize,
18    /// Convergence threshold
19    pub convergence_threshold: f32,
20    /// Minimum cluster size to keep
21    pub min_cluster_size: usize,
22    /// Maximum trajectories to store
23    pub max_trajectories: usize,
24    /// Quality threshold for pattern
25    pub quality_threshold: f32,
26}
27
28impl Default for PatternConfig {
29    fn default() -> Self {
30        // OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
31        // - 100 clusters = 1.3ms search vs 50 clusters = 3.0ms (2.3x faster)
32        // - Quality threshold 0.3 balances learning vs noise filtering
33        Self {
34            k_clusters: 100, // OPTIMIZED: 2.3x faster search (1.3ms vs 3.0ms)
35            embedding_dim: 256,
36            max_iterations: 100,
37            convergence_threshold: 0.001,
38            min_cluster_size: 5,
39            max_trajectories: 10000,
40            quality_threshold: 0.3, // OPTIMIZED: Lower threshold for more learning
41        }
42    }
43}
44
45/// ReasoningBank for pattern storage and extraction
46#[derive(Clone, Debug)]
47pub struct ReasoningBank {
48    /// Configuration
49    config: PatternConfig,
50    /// Stored trajectories
51    trajectories: Vec<TrajectoryEntry>,
52    /// Extracted patterns
53    patterns: HashMap<u64, LearnedPattern>,
54    /// Next pattern ID
55    next_pattern_id: u64,
56    /// Pattern index (embedding -> pattern_id)
57    pattern_index: Vec<(Vec<f32>, u64)>,
58}
59
60/// Internal trajectory entry with embedding
61#[derive(Clone, Debug)]
62struct TrajectoryEntry {
63    /// Trajectory embedding (query + avg activations)
64    embedding: Vec<f32>,
65    /// Quality score
66    quality: f32,
67    /// Cluster assignment
68    cluster: Option<usize>,
69    /// Original trajectory ID
70    trajectory_id: u64,
71}
72
73impl ReasoningBank {
74    /// Create new ReasoningBank
75    pub fn new(config: PatternConfig) -> Self {
76        Self {
77            config,
78            trajectories: Vec::new(),
79            patterns: HashMap::new(),
80            next_pattern_id: 0,
81            pattern_index: Vec::new(),
82        }
83    }
84
85    /// Add trajectory to bank
86    pub fn add_trajectory(&mut self, trajectory: &QueryTrajectory) {
87        // Compute embedding from trajectory
88        let embedding = self.compute_embedding(trajectory);
89
90        let entry = TrajectoryEntry {
91            embedding,
92            quality: trajectory.final_quality,
93            cluster: None,
94            trajectory_id: trajectory.id,
95        };
96
97        // Enforce capacity
98        if self.trajectories.len() >= self.config.max_trajectories {
99            // Remove oldest entries
100            let to_remove = self.trajectories.len() - self.config.max_trajectories + 1;
101            self.trajectories.drain(0..to_remove);
102        }
103
104        self.trajectories.push(entry);
105    }
106
107    /// Compute embedding from trajectory
108    fn compute_embedding(&self, trajectory: &QueryTrajectory) -> Vec<f32> {
109        let dim = self.config.embedding_dim;
110        let mut embedding = vec![0.0f32; dim];
111
112        // Start with query embedding
113        let query_len = trajectory.query_embedding.len().min(dim);
114        embedding[..query_len].copy_from_slice(&trajectory.query_embedding[..query_len]);
115
116        // Average in step activations (weighted by reward)
117        if !trajectory.steps.is_empty() {
118            let mut total_reward = 0.0f32;
119
120            for step in &trajectory.steps {
121                let weight = step.reward.max(0.0);
122                total_reward += weight;
123
124                for (i, &act) in step.activations.iter().enumerate() {
125                    if i < dim {
126                        embedding[i] += act * weight;
127                    }
128                }
129            }
130
131            if total_reward > 0.0 {
132                for e in &mut embedding {
133                    *e /= total_reward + 1.0; // +1 for query contribution
134                }
135            }
136        }
137
138        // L2 normalize
139        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
140        if norm > 1e-8 {
141            for e in &mut embedding {
142                *e /= norm;
143            }
144        }
145
146        embedding
147    }
148
149    /// Extract patterns using K-means++
150    pub fn extract_patterns(&mut self) -> Vec<LearnedPattern> {
151        if self.trajectories.is_empty() {
152            return Vec::new();
153        }
154
155        let k = self.config.k_clusters.min(self.trajectories.len());
156        if k == 0 {
157            return Vec::new();
158        }
159
160        // K-means++ initialization
161        let centroids = self.kmeans_plus_plus_init(k);
162
163        // Run K-means
164        let (final_centroids, assignments) = self.run_kmeans(centroids);
165
166        // Create patterns from clusters
167        let mut patterns = Vec::new();
168
169        for (cluster_idx, centroid) in final_centroids.into_iter().enumerate() {
170            // Collect cluster members
171            let members: Vec<_> = self
172                .trajectories
173                .iter()
174                .enumerate()
175                .filter(|(i, _)| assignments.get(*i) == Some(&cluster_idx))
176                .map(|(_, t)| t)
177                .collect();
178
179            if members.len() < self.config.min_cluster_size {
180                continue;
181            }
182
183            // Compute cluster statistics
184            let cluster_size = members.len();
185            let total_weight: f32 = members.iter().map(|t| t.quality).sum();
186            let avg_quality = total_weight / cluster_size as f32;
187
188            if avg_quality < self.config.quality_threshold {
189                continue;
190            }
191
192            let pattern_id = self.next_pattern_id;
193            self.next_pattern_id += 1;
194
195            let now = crate::time_compat::SystemTime::now()
196                .duration_since_epoch()
197                .as_secs();
198            let pattern = LearnedPattern {
199                id: pattern_id,
200                centroid,
201                cluster_size,
202                total_weight,
203                avg_quality,
204                created_at: now,
205                last_accessed: now,
206                access_count: 0,
207                pattern_type: PatternType::General,
208            };
209
210            self.patterns.insert(pattern_id, pattern.clone());
211            self.pattern_index
212                .push((pattern.centroid.clone(), pattern_id));
213            patterns.push(pattern);
214        }
215
216        // Update trajectory cluster assignments
217        for (i, cluster) in assignments.into_iter().enumerate() {
218            if i < self.trajectories.len() {
219                self.trajectories[i].cluster = Some(cluster);
220            }
221        }
222
223        patterns
224    }
225
226    /// K-means++ initialization
227    fn kmeans_plus_plus_init(&self, k: usize) -> Vec<Vec<f32>> {
228        let mut centroids = Vec::with_capacity(k);
229        let n = self.trajectories.len();
230
231        if n == 0 || k == 0 {
232            return centroids;
233        }
234
235        // First centroid: random (use deterministic selection for reproducibility)
236        let first_idx = 0;
237        centroids.push(self.trajectories[first_idx].embedding.clone());
238
239        // Remaining centroids: D^2 weighting
240        for _ in 1..k {
241            // Compute distances to nearest centroid
242            let mut distances: Vec<f32> = self
243                .trajectories
244                .iter()
245                .map(|t| {
246                    centroids
247                        .iter()
248                        .map(|c| self.squared_distance(&t.embedding, c))
249                        .fold(f32::MAX, f32::min)
250                })
251                .collect();
252
253            // Normalize to probabilities
254            let total: f32 = distances.iter().sum();
255            if total > 0.0 {
256                for d in &mut distances {
257                    *d /= total;
258                }
259            }
260
261            // Select next centroid (deterministic: highest distance)
262            // SECURITY FIX (H-004): Handle NaN values in partial_cmp safely
263            let (next_idx, _) = distances
264                .iter()
265                .enumerate()
266                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
267                .unwrap_or((0, &0.0));
268
269            centroids.push(self.trajectories[next_idx].embedding.clone());
270        }
271
272        centroids
273    }
274
275    /// Run K-means algorithm
276    fn run_kmeans(&self, mut centroids: Vec<Vec<f32>>) -> (Vec<Vec<f32>>, Vec<usize>) {
277        let n = self.trajectories.len();
278        let k = centroids.len();
279        let dim = self.config.embedding_dim;
280
281        let mut assignments = vec![0usize; n];
282
283        for _iter in 0..self.config.max_iterations {
284            // Assign points to nearest centroid
285            let mut changed = false;
286            for (i, t) in self.trajectories.iter().enumerate() {
287                // SECURITY FIX (H-004): Handle NaN values in partial_cmp safely
288                let (nearest, _) = centroids
289                    .iter()
290                    .enumerate()
291                    .map(|(j, c)| (j, self.squared_distance(&t.embedding, c)))
292                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
293                    .unwrap_or((0, 0.0));
294
295                if assignments[i] != nearest {
296                    assignments[i] = nearest;
297                    changed = true;
298                }
299            }
300
301            if !changed {
302                break;
303            }
304
305            // Update centroids
306            let mut new_centroids = vec![vec![0.0f32; dim]; k];
307            let mut counts = vec![0usize; k];
308
309            for (i, t) in self.trajectories.iter().enumerate() {
310                let cluster = assignments[i];
311                counts[cluster] += 1;
312                for (j, &e) in t.embedding.iter().enumerate() {
313                    new_centroids[cluster][j] += e;
314                }
315            }
316
317            // Average and check convergence
318            let mut max_shift = 0.0f32;
319            for (i, new_c) in new_centroids.iter_mut().enumerate() {
320                if counts[i] > 0 {
321                    for e in new_c.iter_mut() {
322                        *e /= counts[i] as f32;
323                    }
324                    let shift = self.squared_distance(new_c, &centroids[i]).sqrt();
325                    max_shift = max_shift.max(shift);
326                }
327            }
328
329            centroids = new_centroids;
330
331            if max_shift < self.config.convergence_threshold {
332                break;
333            }
334        }
335
336        (centroids, assignments)
337    }
338
339    /// Squared Euclidean distance
340    fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
341        a.iter()
342            .zip(b.iter())
343            .map(|(&x, &y)| (x - y) * (x - y))
344            .sum()
345    }
346
347    /// Find similar patterns
348    pub fn find_similar(&self, query: &[f32], k: usize) -> Vec<&LearnedPattern> {
349        let mut scored: Vec<_> = self
350            .patterns
351            .values()
352            .map(|p| (p, p.similarity(query)))
353            .collect();
354
355        // Note: This already has the safe unwrap_or pattern for NaN handling
356        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
357
358        scored.into_iter().take(k).map(|(p, _)| p).collect()
359    }
360
361    /// Get pattern by ID
362    pub fn get_pattern(&self, id: u64) -> Option<&LearnedPattern> {
363        self.patterns.get(&id)
364    }
365
366    /// Get mutable pattern by ID
367    pub fn get_pattern_mut(&mut self, id: u64) -> Option<&mut LearnedPattern> {
368        self.patterns.get_mut(&id)
369    }
370
371    /// Get trajectory count
372    pub fn trajectory_count(&self) -> usize {
373        self.trajectories.len()
374    }
375
376    /// Get pattern count
377    pub fn pattern_count(&self) -> usize {
378        self.patterns.len()
379    }
380
381    /// Clear trajectories (keep patterns)
382    pub fn clear_trajectories(&mut self) {
383        self.trajectories.clear();
384    }
385
386    /// Prune low-quality patterns
387    pub fn prune_patterns(&mut self, min_quality: f32, min_accesses: u32, max_age_secs: u64) {
388        let to_remove: Vec<u64> = self
389            .patterns
390            .iter()
391            .filter(|(_, p)| p.should_prune(min_quality, min_accesses, max_age_secs))
392            .map(|(id, _)| *id)
393            .collect();
394
395        for id in to_remove {
396            self.patterns.remove(&id);
397        }
398
399        // Update index
400        self.pattern_index
401            .retain(|(_, id)| self.patterns.contains_key(id));
402    }
403
404    /// Get all patterns for export
405    pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
406        self.patterns.values().cloned().collect()
407    }
408
409    /// Consolidate similar patterns
410    pub fn consolidate(&mut self, similarity_threshold: f32) {
411        let pattern_ids: Vec<u64> = self.patterns.keys().copied().collect();
412        let mut merged = Vec::new();
413
414        for i in 0..pattern_ids.len() {
415            for j in i + 1..pattern_ids.len() {
416                let id1 = pattern_ids[i];
417                let id2 = pattern_ids[j];
418
419                if merged.contains(&id1) || merged.contains(&id2) {
420                    continue;
421                }
422
423                if let (Some(p1), Some(p2)) = (self.patterns.get(&id1), self.patterns.get(&id2)) {
424                    let sim = p1.similarity(&p2.centroid);
425                    if sim > similarity_threshold {
426                        // Merge p2 into p1
427                        let merged_pattern = p1.merge(p2);
428                        self.patterns.insert(id1, merged_pattern);
429                        merged.push(id2);
430                    }
431                }
432            }
433        }
434
435        // Remove merged patterns
436        for id in merged {
437            self.patterns.remove(&id);
438        }
439
440        // Update index
441        self.pattern_index
442            .retain(|(_, id)| self.patterns.contains_key(id));
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use super::*;
449
450    fn make_trajectory(id: u64, embedding: Vec<f32>, quality: f32) -> QueryTrajectory {
451        let mut t = QueryTrajectory::new(id, embedding);
452        t.finalize(quality, 1000);
453        t
454    }
455
456    #[test]
457    fn test_bank_creation() {
458        let bank = ReasoningBank::new(PatternConfig::default());
459        assert_eq!(bank.trajectory_count(), 0);
460        assert_eq!(bank.pattern_count(), 0);
461    }
462
463    #[test]
464    fn test_add_trajectory() {
465        let config = PatternConfig {
466            embedding_dim: 4,
467            ..Default::default()
468        };
469        let mut bank = ReasoningBank::new(config);
470
471        let t = make_trajectory(1, vec![0.1, 0.2, 0.3, 0.4], 0.8);
472        bank.add_trajectory(&t);
473
474        assert_eq!(bank.trajectory_count(), 1);
475    }
476
477    #[test]
478    fn test_extract_patterns() {
479        let config = PatternConfig {
480            embedding_dim: 4,
481            k_clusters: 2,
482            min_cluster_size: 2,
483            quality_threshold: 0.0,
484            ..Default::default()
485        };
486        let mut bank = ReasoningBank::new(config);
487
488        // Add clustered trajectories
489        for i in 0..5 {
490            let t = make_trajectory(i, vec![1.0, 0.0, 0.0, 0.0], 0.8);
491            bank.add_trajectory(&t);
492        }
493        for i in 5..10 {
494            let t = make_trajectory(i, vec![0.0, 1.0, 0.0, 0.0], 0.7);
495            bank.add_trajectory(&t);
496        }
497
498        let patterns = bank.extract_patterns();
499        assert!(!patterns.is_empty());
500    }
501
502    #[test]
503    fn test_find_similar() {
504        let config = PatternConfig {
505            embedding_dim: 4,
506            k_clusters: 2,
507            min_cluster_size: 2,
508            quality_threshold: 0.0,
509            ..Default::default()
510        };
511        let mut bank = ReasoningBank::new(config);
512
513        for i in 0..10 {
514            let emb = if i < 5 {
515                vec![1.0, 0.0, 0.0, 0.0]
516            } else {
517                vec![0.0, 1.0, 0.0, 0.0]
518            };
519            bank.add_trajectory(&make_trajectory(i, emb, 0.8));
520        }
521
522        bank.extract_patterns();
523
524        let query = vec![0.9, 0.1, 0.0, 0.0];
525        let similar = bank.find_similar(&query, 1);
526        assert!(!similar.is_empty());
527    }
528
529    #[test]
530    fn test_consolidate() {
531        let config = PatternConfig {
532            embedding_dim: 4,
533            k_clusters: 3,
534            min_cluster_size: 1,
535            quality_threshold: 0.0,
536            ..Default::default()
537        };
538        let mut bank = ReasoningBank::new(config);
539
540        // Create very similar trajectories
541        for i in 0..9 {
542            let emb = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0, 0.0];
543            bank.add_trajectory(&make_trajectory(i, emb, 0.8));
544        }
545
546        bank.extract_patterns();
547        let before = bank.pattern_count();
548
549        bank.consolidate(0.99);
550        let after = bank.pattern_count();
551
552        assert!(after <= before);
553    }
554}