Skip to main content

ruvector_sona/
reasoning_bank.rs

1//! ReasoningBank - Pattern storage and extraction for SONA
2//!
3//! Implements trajectory clustering using K-means++ for pattern discovery.
4
5use crate::types::{LearnedPattern, PatternType, QueryTrajectory};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// ReasoningBank configuration
10#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct PatternConfig {
12    /// Number of clusters for K-means++
13    pub k_clusters: usize,
14    /// Embedding dimension
15    pub embedding_dim: usize,
16    /// Maximum K-means iterations
17    pub max_iterations: usize,
18    /// Convergence threshold
19    pub convergence_threshold: f32,
20    /// Minimum cluster size to keep
21    pub min_cluster_size: usize,
22    /// Maximum trajectories to store
23    pub max_trajectories: usize,
24    /// Quality threshold for pattern
25    pub quality_threshold: f32,
26}
27
28impl Default for PatternConfig {
29    fn default() -> Self {
30        // OPTIMIZED DEFAULTS based on @ruvector/sona v0.1.1 benchmarks:
31        // - ADR-123: Relaxed thresholds to enable pattern crystallization
32        //   with fewer trajectories. Previous defaults (k=100, min=5, q=0.3)
33        //   prevented crystallization when trajectory count < 500.
34        Self {
35            k_clusters: 5,   // Was 50; fewer clusters = more members per cluster with low trajectory counts
36            embedding_dim: 256,
37            max_iterations: 100,
38            convergence_threshold: 0.001,
39            min_cluster_size: 1,  // Was 2; allow single-trajectory clusters to crystallize
40            max_trajectories: 10000,
41            quality_threshold: 0.05, // Was 0.1; very permissive so early patterns survive
42        }
43    }
44}
45
46/// ReasoningBank for pattern storage and extraction
47#[derive(Clone, Debug)]
48pub struct ReasoningBank {
49    /// Configuration
50    config: PatternConfig,
51    /// Stored trajectories
52    trajectories: Vec<TrajectoryEntry>,
53    /// Extracted patterns
54    patterns: HashMap<u64, LearnedPattern>,
55    /// Next pattern ID
56    next_pattern_id: u64,
57    /// Pattern index (embedding -> pattern_id)
58    pattern_index: Vec<(Vec<f32>, u64)>,
59}
60
61/// Internal trajectory entry with embedding
62#[derive(Clone, Debug)]
63struct TrajectoryEntry {
64    /// Trajectory embedding (query + avg activations)
65    embedding: Vec<f32>,
66    /// Quality score
67    quality: f32,
68    /// Cluster assignment
69    cluster: Option<usize>,
70    /// Original trajectory ID
71    _trajectory_id: u64,
72}
73
74impl ReasoningBank {
75    /// Create new ReasoningBank
76    pub fn new(config: PatternConfig) -> Self {
77        Self {
78            config,
79            trajectories: Vec::new(),
80            patterns: HashMap::new(),
81            next_pattern_id: 0,
82            pattern_index: Vec::new(),
83        }
84    }
85
86    /// Add trajectory to bank
87    pub fn add_trajectory(&mut self, trajectory: &QueryTrajectory) {
88        // Compute embedding from trajectory
89        let embedding = self.compute_embedding(trajectory);
90
91        let entry = TrajectoryEntry {
92            embedding,
93            quality: trajectory.final_quality,
94            cluster: None,
95            _trajectory_id: trajectory.id,
96        };
97
98        // Enforce capacity
99        if self.trajectories.len() >= self.config.max_trajectories {
100            // Remove oldest entries
101            let to_remove = self.trajectories.len() - self.config.max_trajectories + 1;
102            self.trajectories.drain(0..to_remove);
103        }
104
105        self.trajectories.push(entry);
106    }
107
108    /// Compute embedding from trajectory
109    fn compute_embedding(&self, trajectory: &QueryTrajectory) -> Vec<f32> {
110        let dim = self.config.embedding_dim;
111        let mut embedding = vec![0.0f32; dim];
112
113        // Start with query embedding
114        let query_len = trajectory.query_embedding.len().min(dim);
115        embedding[..query_len].copy_from_slice(&trajectory.query_embedding[..query_len]);
116
117        // Average in step activations (weighted by reward)
118        if !trajectory.steps.is_empty() {
119            let mut total_reward = 0.0f32;
120
121            for step in &trajectory.steps {
122                let weight = step.reward.max(0.0);
123                total_reward += weight;
124
125                for (i, &act) in step.activations.iter().enumerate() {
126                    if i < dim {
127                        embedding[i] += act * weight;
128                    }
129                }
130            }
131
132            if total_reward > 0.0 {
133                for e in &mut embedding {
134                    *e /= total_reward + 1.0; // +1 for query contribution
135                }
136            }
137        }
138
139        // L2 normalize
140        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
141        if norm > 1e-8 {
142            for e in &mut embedding {
143                *e /= norm;
144            }
145        }
146
147        embedding
148    }
149
150    /// Extract patterns using K-means++
151    pub fn extract_patterns(&mut self) -> Vec<LearnedPattern> {
152        if self.trajectories.is_empty() {
153            return Vec::new();
154        }
155
156        let k = self.config.k_clusters.min(self.trajectories.len());
157        if k == 0 {
158            return Vec::new();
159        }
160
161        // K-means++ initialization
162        let centroids = self.kmeans_plus_plus_init(k);
163
164        // Run K-means
165        let (final_centroids, assignments) = self.run_kmeans(centroids);
166
167        // Create patterns from clusters
168        let mut patterns = Vec::new();
169
170        for (cluster_idx, centroid) in final_centroids.into_iter().enumerate() {
171            // Collect cluster members
172            let members: Vec<_> = self
173                .trajectories
174                .iter()
175                .enumerate()
176                .filter(|(i, _)| assignments.get(*i) == Some(&cluster_idx))
177                .map(|(_, t)| t)
178                .collect();
179
180            if members.len() < self.config.min_cluster_size {
181                continue;
182            }
183
184            // Compute cluster statistics
185            let cluster_size = members.len();
186            let total_weight: f32 = members.iter().map(|t| t.quality).sum();
187            let avg_quality = total_weight / cluster_size as f32;
188
189            if avg_quality < self.config.quality_threshold {
190                continue;
191            }
192
193            let pattern_id = self.next_pattern_id;
194            self.next_pattern_id += 1;
195
196            let now = crate::time_compat::SystemTime::now()
197                .duration_since_epoch()
198                .as_secs();
199            let pattern = LearnedPattern {
200                id: pattern_id,
201                centroid,
202                cluster_size,
203                total_weight,
204                avg_quality,
205                created_at: now,
206                last_accessed: now,
207                access_count: 0,
208                pattern_type: PatternType::General,
209            };
210
211            self.patterns.insert(pattern_id, pattern.clone());
212            self.pattern_index
213                .push((pattern.centroid.clone(), pattern_id));
214            patterns.push(pattern);
215        }
216
217        // Update trajectory cluster assignments
218        for (i, cluster) in assignments.into_iter().enumerate() {
219            if i < self.trajectories.len() {
220                self.trajectories[i].cluster = Some(cluster);
221            }
222        }
223
224        patterns
225    }
226
227    /// K-means++ initialization
228    fn kmeans_plus_plus_init(&self, k: usize) -> Vec<Vec<f32>> {
229        let mut centroids = Vec::with_capacity(k);
230        let n = self.trajectories.len();
231
232        if n == 0 || k == 0 {
233            return centroids;
234        }
235
236        // First centroid: random (use deterministic selection for reproducibility)
237        let first_idx = 0;
238        centroids.push(self.trajectories[first_idx].embedding.clone());
239
240        // Remaining centroids: D^2 weighting
241        for _ in 1..k {
242            // Compute distances to nearest centroid
243            let mut distances: Vec<f32> = self
244                .trajectories
245                .iter()
246                .map(|t| {
247                    centroids
248                        .iter()
249                        .map(|c| self.squared_distance(&t.embedding, c))
250                        .fold(f32::MAX, f32::min)
251                })
252                .collect();
253
254            // Normalize to probabilities
255            let total: f32 = distances.iter().sum();
256            if total > 0.0 {
257                for d in &mut distances {
258                    *d /= total;
259                }
260            }
261
262            // Select next centroid (deterministic: highest distance)
263            // SECURITY FIX (H-004): Handle NaN values in partial_cmp safely
264            let (next_idx, _) = distances
265                .iter()
266                .enumerate()
267                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
268                .unwrap_or((0, &0.0));
269
270            centroids.push(self.trajectories[next_idx].embedding.clone());
271        }
272
273        centroids
274    }
275
276    /// Run K-means algorithm
277    fn run_kmeans(&self, mut centroids: Vec<Vec<f32>>) -> (Vec<Vec<f32>>, Vec<usize>) {
278        let n = self.trajectories.len();
279        let k = centroids.len();
280        let dim = self.config.embedding_dim;
281
282        let mut assignments = vec![0usize; n];
283
284        for _iter in 0..self.config.max_iterations {
285            // Assign points to nearest centroid
286            let mut changed = false;
287            for (i, t) in self.trajectories.iter().enumerate() {
288                // SECURITY FIX (H-004): Handle NaN values in partial_cmp safely
289                let (nearest, _) = centroids
290                    .iter()
291                    .enumerate()
292                    .map(|(j, c)| (j, self.squared_distance(&t.embedding, c)))
293                    .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
294                    .unwrap_or((0, 0.0));
295
296                if assignments[i] != nearest {
297                    assignments[i] = nearest;
298                    changed = true;
299                }
300            }
301
302            if !changed {
303                break;
304            }
305
306            // Update centroids
307            let mut new_centroids = vec![vec![0.0f32; dim]; k];
308            let mut counts = vec![0usize; k];
309
310            for (i, t) in self.trajectories.iter().enumerate() {
311                let cluster = assignments[i];
312                counts[cluster] += 1;
313                for (j, &e) in t.embedding.iter().enumerate() {
314                    new_centroids[cluster][j] += e;
315                }
316            }
317
318            // Average and check convergence
319            let mut max_shift = 0.0f32;
320            for (i, new_c) in new_centroids.iter_mut().enumerate() {
321                if counts[i] > 0 {
322                    for e in new_c.iter_mut() {
323                        *e /= counts[i] as f32;
324                    }
325                    let shift = self.squared_distance(new_c, &centroids[i]).sqrt();
326                    max_shift = max_shift.max(shift);
327                }
328            }
329
330            centroids = new_centroids;
331
332            if max_shift < self.config.convergence_threshold {
333                break;
334            }
335        }
336
337        (centroids, assignments)
338    }
339
340    /// Squared Euclidean distance
341    fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
342        a.iter()
343            .zip(b.iter())
344            .map(|(&x, &y)| (x - y) * (x - y))
345            .sum()
346    }
347
348    /// Find similar patterns
349    pub fn find_similar(&self, query: &[f32], k: usize) -> Vec<&LearnedPattern> {
350        let mut scored: Vec<_> = self
351            .patterns
352            .values()
353            .map(|p| (p, p.similarity(query)))
354            .collect();
355
356        // Note: This already has the safe unwrap_or pattern for NaN handling
357        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
358
359        scored.into_iter().take(k).map(|(p, _)| p).collect()
360    }
361
362    /// Get pattern by ID
363    pub fn get_pattern(&self, id: u64) -> Option<&LearnedPattern> {
364        self.patterns.get(&id)
365    }
366
367    /// Get mutable pattern by ID
368    pub fn get_pattern_mut(&mut self, id: u64) -> Option<&mut LearnedPattern> {
369        self.patterns.get_mut(&id)
370    }
371
372    /// Get trajectory count
373    pub fn trajectory_count(&self) -> usize {
374        self.trajectories.len()
375    }
376
377    /// Get pattern count
378    pub fn pattern_count(&self) -> usize {
379        self.patterns.len()
380    }
381
382    /// Clear trajectories (keep patterns)
383    pub fn clear_trajectories(&mut self) {
384        self.trajectories.clear();
385    }
386
387    /// Prune low-quality patterns
388    pub fn prune_patterns(&mut self, min_quality: f32, min_accesses: u32, max_age_secs: u64) {
389        let to_remove: Vec<u64> = self
390            .patterns
391            .iter()
392            .filter(|(_, p)| p.should_prune(min_quality, min_accesses, max_age_secs))
393            .map(|(id, _)| *id)
394            .collect();
395
396        for id in to_remove {
397            self.patterns.remove(&id);
398        }
399
400        // Update index
401        self.pattern_index
402            .retain(|(_, id)| self.patterns.contains_key(id));
403    }
404
405    /// Get current configuration (read-only)
406    pub fn config(&self) -> &PatternConfig {
407        &self.config
408    }
409
410    /// Dynamically adjust the quality threshold for pattern crystallization.
411    /// Used by the self-optimization loop to adapt when patterns are not forming.
412    pub fn set_quality_threshold(&mut self, threshold: f32) {
413        self.config.quality_threshold = threshold.clamp(0.01, 1.0);
414    }
415
416    /// Get all patterns for export
417    pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
418        self.patterns.values().cloned().collect()
419    }
420
421    /// Insert a pattern directly (for state restoration, fixes #274)
422    pub fn insert_pattern(&mut self, pattern: LearnedPattern) {
423        let id = pattern.id;
424        if id >= self.next_pattern_id {
425            self.next_pattern_id = id + 1;
426        }
427        self.pattern_index.push((pattern.centroid.clone(), id));
428        self.patterns.insert(id, pattern);
429    }
430
431    /// Set the minimum cluster size required for a cluster to become a pattern.
432    ///
433    /// Lower values allow patterns to crystallize from fewer trajectories.
434    /// Default is 1 (was 2 before permissive tuning).
435    pub fn set_min_trajectory_length(&mut self, n: usize) {
436        self.config.min_cluster_size = n;
437    }
438
439    /// Set the minimum average quality a cluster must have to become a pattern.
440    ///
441    /// Lower values allow lower-quality patterns through.
442    /// Default is 0.05 (was 0.1 before permissive tuning).
443    pub fn set_min_pattern_quality(&mut self, q: f64) {
444        self.config.quality_threshold = q as f32;
445    }
446
447    /// Consolidate similar patterns
448    pub fn consolidate(&mut self, similarity_threshold: f32) {
449        let pattern_ids: Vec<u64> = self.patterns.keys().copied().collect();
450        let mut merged = Vec::new();
451
452        for i in 0..pattern_ids.len() {
453            for j in i + 1..pattern_ids.len() {
454                let id1 = pattern_ids[i];
455                let id2 = pattern_ids[j];
456
457                if merged.contains(&id1) || merged.contains(&id2) {
458                    continue;
459                }
460
461                if let (Some(p1), Some(p2)) = (self.patterns.get(&id1), self.patterns.get(&id2)) {
462                    let sim = p1.similarity(&p2.centroid);
463                    if sim > similarity_threshold {
464                        // Merge p2 into p1
465                        let merged_pattern = p1.merge(p2);
466                        self.patterns.insert(id1, merged_pattern);
467                        merged.push(id2);
468                    }
469                }
470            }
471        }
472
473        // Remove merged patterns
474        for id in merged {
475            self.patterns.remove(&id);
476        }
477
478        // Update index
479        self.pattern_index
480            .retain(|(_, id)| self.patterns.contains_key(id));
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    fn make_trajectory(id: u64, embedding: Vec<f32>, quality: f32) -> QueryTrajectory {
489        let mut t = QueryTrajectory::new(id, embedding);
490        t.finalize(quality, 1000);
491        t
492    }
493
494    #[test]
495    fn test_bank_creation() {
496        let bank = ReasoningBank::new(PatternConfig::default());
497        assert_eq!(bank.trajectory_count(), 0);
498        assert_eq!(bank.pattern_count(), 0);
499    }
500
501    #[test]
502    fn test_add_trajectory() {
503        let config = PatternConfig {
504            embedding_dim: 4,
505            ..Default::default()
506        };
507        let mut bank = ReasoningBank::new(config);
508
509        let t = make_trajectory(1, vec![0.1, 0.2, 0.3, 0.4], 0.8);
510        bank.add_trajectory(&t);
511
512        assert_eq!(bank.trajectory_count(), 1);
513    }
514
515    #[test]
516    fn test_extract_patterns() {
517        let config = PatternConfig {
518            embedding_dim: 4,
519            k_clusters: 2,
520            min_cluster_size: 2,
521            quality_threshold: 0.0,
522            ..Default::default()
523        };
524        let mut bank = ReasoningBank::new(config);
525
526        // Add clustered trajectories
527        for i in 0..5 {
528            let t = make_trajectory(i, vec![1.0, 0.0, 0.0, 0.0], 0.8);
529            bank.add_trajectory(&t);
530        }
531        for i in 5..10 {
532            let t = make_trajectory(i, vec![0.0, 1.0, 0.0, 0.0], 0.7);
533            bank.add_trajectory(&t);
534        }
535
536        let patterns = bank.extract_patterns();
537        assert!(!patterns.is_empty());
538    }
539
540    #[test]
541    fn test_find_similar() {
542        let config = PatternConfig {
543            embedding_dim: 4,
544            k_clusters: 2,
545            min_cluster_size: 2,
546            quality_threshold: 0.0,
547            ..Default::default()
548        };
549        let mut bank = ReasoningBank::new(config);
550
551        for i in 0..10 {
552            let emb = if i < 5 {
553                vec![1.0, 0.0, 0.0, 0.0]
554            } else {
555                vec![0.0, 1.0, 0.0, 0.0]
556            };
557            bank.add_trajectory(&make_trajectory(i, emb, 0.8));
558        }
559
560        bank.extract_patterns();
561
562        let query = vec![0.9, 0.1, 0.0, 0.0];
563        let similar = bank.find_similar(&query, 1);
564        assert!(!similar.is_empty());
565    }
566
567    #[test]
568    fn test_consolidate() {
569        let config = PatternConfig {
570            embedding_dim: 4,
571            k_clusters: 3,
572            min_cluster_size: 1,
573            quality_threshold: 0.0,
574            ..Default::default()
575        };
576        let mut bank = ReasoningBank::new(config);
577
578        // Create very similar trajectories
579        for i in 0..9 {
580            let emb = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0, 0.0];
581            bank.add_trajectory(&make_trajectory(i, emb, 0.8));
582        }
583
584        bank.extract_patterns();
585        let before = bank.pattern_count();
586
587        bank.consolidate(0.99);
588        let after = bank.pattern_count();
589
590        assert!(after <= before);
591    }
592}