Skip to main content

ruvector_dag/sona/
reasoning_bank.rs

1//! Reasoning Bank: K-means++ clustering for pattern storage
2
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
6pub struct DagPattern {
7    pub id: u64,
8    pub vector: Vec<f32>,
9    pub quality_score: f32,
10    pub usage_count: usize,
11    pub metadata: HashMap<String, String>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ReasoningBankConfig {
16    pub num_clusters: usize,
17    pub pattern_dim: usize,
18    pub max_patterns: usize,
19    pub similarity_threshold: f32,
20}
21
22impl Default for ReasoningBankConfig {
23    fn default() -> Self {
24        Self {
25            num_clusters: 100,
26            pattern_dim: 256,
27            max_patterns: 10000,
28            similarity_threshold: 0.7,
29        }
30    }
31}
32
33pub struct DagReasoningBank {
34    config: ReasoningBankConfig,
35    patterns: Vec<DagPattern>,
36    centroids: Vec<Vec<f32>>,
37    cluster_assignments: Vec<usize>,
38    next_id: u64,
39}
40
41impl DagReasoningBank {
42    pub fn new(config: ReasoningBankConfig) -> Self {
43        Self {
44            config,
45            patterns: Vec::new(),
46            centroids: Vec::new(),
47            cluster_assignments: Vec::new(),
48            next_id: 0,
49        }
50    }
51
52    /// Store a new pattern
53    pub fn store_pattern(&mut self, vector: Vec<f32>, quality: f32) -> u64 {
54        let id = self.next_id;
55        self.next_id += 1;
56
57        let pattern = DagPattern {
58            id,
59            vector,
60            quality_score: quality,
61            usage_count: 0,
62            metadata: HashMap::new(),
63        };
64
65        self.patterns.push(pattern);
66
67        // Evict if over capacity
68        if self.patterns.len() > self.config.max_patterns {
69            self.evict_lowest_quality();
70        }
71
72        id
73    }
74
75    /// Query similar patterns using cosine similarity
76    pub fn query_similar(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
77        let mut similarities: Vec<(u64, f32)> = self
78            .patterns
79            .iter()
80            .map(|p| (p.id, cosine_similarity(&p.vector, query)))
81            .filter(|(_, sim)| *sim >= self.config.similarity_threshold)
82            .collect();
83
84        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
85        similarities.truncate(k);
86        similarities
87    }
88
89    /// Run K-means++ clustering
90    pub fn recompute_clusters(&mut self) {
91        if self.patterns.is_empty() {
92            return;
93        }
94
95        let k = self.config.num_clusters.min(self.patterns.len());
96
97        // K-means++ initialization
98        self.centroids = kmeans_pp_init(&self.patterns, k);
99
100        // K-means iterations
101        for _ in 0..10 {
102            // Assign points to clusters
103            self.cluster_assignments = self
104                .patterns
105                .iter()
106                .map(|p| self.nearest_centroid(&p.vector))
107                .collect();
108
109            // Update centroids
110            self.update_centroids();
111        }
112    }
113
114    fn nearest_centroid(&self, point: &[f32]) -> usize {
115        self.centroids
116            .iter()
117            .enumerate()
118            .map(|(i, c)| (i, euclidean_distance(point, c)))
119            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
120            .map(|(i, _)| i)
121            .unwrap_or(0)
122    }
123
124    fn update_centroids(&mut self) {
125        let k = self.centroids.len();
126        let dim = if !self.centroids.is_empty() {
127            self.centroids[0].len()
128        } else {
129            return;
130        };
131
132        // Initialize new centroids
133        let mut new_centroids = vec![vec![0.0; dim]; k];
134        let mut counts = vec![0usize; k];
135
136        // Sum points in each cluster
137        for (pattern, &cluster) in self.patterns.iter().zip(self.cluster_assignments.iter()) {
138            if cluster < k {
139                for (i, &val) in pattern.vector.iter().enumerate() {
140                    new_centroids[cluster][i] += val;
141                }
142                counts[cluster] += 1;
143            }
144        }
145
146        // Average to get centroids
147        for (centroid, count) in new_centroids.iter_mut().zip(counts.iter()) {
148            if *count > 0 {
149                for val in centroid.iter_mut() {
150                    *val /= *count as f32;
151                }
152            }
153        }
154
155        self.centroids = new_centroids;
156    }
157
158    fn evict_lowest_quality(&mut self) {
159        // Remove pattern with lowest quality * usage score
160        if let Some(min_idx) = self
161            .patterns
162            .iter()
163            .enumerate()
164            .min_by(|(_, a), (_, b)| {
165                let score_a = a.quality_score * (a.usage_count as f32 + 1.0).ln();
166                let score_b = b.quality_score * (b.usage_count as f32 + 1.0).ln();
167                score_a.partial_cmp(&score_b).unwrap()
168            })
169            .map(|(i, _)| i)
170        {
171            self.patterns.remove(min_idx);
172        }
173    }
174
175    pub fn pattern_count(&self) -> usize {
176        self.patterns.len()
177    }
178
179    pub fn cluster_count(&self) -> usize {
180        self.centroids.len()
181    }
182}
183
184fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
185    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
186    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
187    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
188    if norm_a > 0.0 && norm_b > 0.0 {
189        dot / (norm_a * norm_b)
190    } else {
191        0.0
192    }
193}
194
195fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
196    a.iter()
197        .zip(b.iter())
198        .map(|(x, y)| (x - y).powi(2))
199        .sum::<f32>()
200        .sqrt()
201}
202
203fn kmeans_pp_init(patterns: &[DagPattern], k: usize) -> Vec<Vec<f32>> {
204    use rand::Rng;
205
206    if patterns.is_empty() || k == 0 {
207        return Vec::new();
208    }
209
210    let mut rng = rand::thread_rng();
211    let mut centroids = Vec::with_capacity(k);
212    let _dim = patterns[0].vector.len();
213
214    // Choose first centroid randomly
215    let first_idx = rng.gen_range(0..patterns.len());
216    centroids.push(patterns[first_idx].vector.clone());
217
218    // Choose remaining centroids using D^2 weighting
219    for _ in 1..k {
220        let mut distances = Vec::with_capacity(patterns.len());
221        let mut total_distance = 0.0f32;
222
223        // Compute minimum distance to existing centroids for each point
224        for pattern in patterns {
225            let min_dist = centroids
226                .iter()
227                .map(|c| euclidean_distance(&pattern.vector, c))
228                .min_by(|a, b| a.partial_cmp(b).unwrap())
229                .unwrap_or(0.0);
230            let squared = min_dist * min_dist;
231            distances.push(squared);
232            total_distance += squared;
233        }
234
235        // Select next centroid with probability proportional to D^2
236        if total_distance > 0.0 {
237            let mut threshold = rng.gen::<f32>() * total_distance;
238            for (idx, &dist) in distances.iter().enumerate() {
239                threshold -= dist;
240                if threshold <= 0.0 {
241                    centroids.push(patterns[idx].vector.clone());
242                    break;
243                }
244            }
245        } else {
246            // Fallback: choose random point
247            let idx = rng.gen_range(0..patterns.len());
248            centroids.push(patterns[idx].vector.clone());
249        }
250
251        if centroids.len() >= k {
252            break;
253        }
254    }
255
256    centroids
257}