ruvector_dag/sona/
reasoning_bank.rs1use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
6pub struct DagPattern {
7 pub id: u64,
8 pub vector: Vec<f32>,
9 pub quality_score: f32,
10 pub usage_count: usize,
11 pub metadata: HashMap<String, String>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ReasoningBankConfig {
16 pub num_clusters: usize,
17 pub pattern_dim: usize,
18 pub max_patterns: usize,
19 pub similarity_threshold: f32,
20}
21
22impl Default for ReasoningBankConfig {
23 fn default() -> Self {
24 Self {
25 num_clusters: 100,
26 pattern_dim: 256,
27 max_patterns: 10000,
28 similarity_threshold: 0.7,
29 }
30 }
31}
32
33pub struct DagReasoningBank {
34 config: ReasoningBankConfig,
35 patterns: Vec<DagPattern>,
36 centroids: Vec<Vec<f32>>,
37 cluster_assignments: Vec<usize>,
38 next_id: u64,
39}
40
41impl DagReasoningBank {
42 pub fn new(config: ReasoningBankConfig) -> Self {
43 Self {
44 config,
45 patterns: Vec::new(),
46 centroids: Vec::new(),
47 cluster_assignments: Vec::new(),
48 next_id: 0,
49 }
50 }
51
52 pub fn store_pattern(&mut self, vector: Vec<f32>, quality: f32) -> u64 {
54 let id = self.next_id;
55 self.next_id += 1;
56
57 let pattern = DagPattern {
58 id,
59 vector,
60 quality_score: quality,
61 usage_count: 0,
62 metadata: HashMap::new(),
63 };
64
65 self.patterns.push(pattern);
66
67 if self.patterns.len() > self.config.max_patterns {
69 self.evict_lowest_quality();
70 }
71
72 id
73 }
74
75 pub fn query_similar(&self, query: &[f32], k: usize) -> Vec<(u64, f32)> {
77 let mut similarities: Vec<(u64, f32)> = self
78 .patterns
79 .iter()
80 .map(|p| (p.id, cosine_similarity(&p.vector, query)))
81 .filter(|(_, sim)| *sim >= self.config.similarity_threshold)
82 .collect();
83
84 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
85 similarities.truncate(k);
86 similarities
87 }
88
89 pub fn recompute_clusters(&mut self) {
91 if self.patterns.is_empty() {
92 return;
93 }
94
95 let k = self.config.num_clusters.min(self.patterns.len());
96
97 self.centroids = kmeans_pp_init(&self.patterns, k);
99
100 for _ in 0..10 {
102 self.cluster_assignments = self
104 .patterns
105 .iter()
106 .map(|p| self.nearest_centroid(&p.vector))
107 .collect();
108
109 self.update_centroids();
111 }
112 }
113
114 fn nearest_centroid(&self, point: &[f32]) -> usize {
115 self.centroids
116 .iter()
117 .enumerate()
118 .map(|(i, c)| (i, euclidean_distance(point, c)))
119 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
120 .map(|(i, _)| i)
121 .unwrap_or(0)
122 }
123
124 fn update_centroids(&mut self) {
125 let k = self.centroids.len();
126 let dim = if !self.centroids.is_empty() {
127 self.centroids[0].len()
128 } else {
129 return;
130 };
131
132 let mut new_centroids = vec![vec![0.0; dim]; k];
134 let mut counts = vec![0usize; k];
135
136 for (pattern, &cluster) in self.patterns.iter().zip(self.cluster_assignments.iter()) {
138 if cluster < k {
139 for (i, &val) in pattern.vector.iter().enumerate() {
140 new_centroids[cluster][i] += val;
141 }
142 counts[cluster] += 1;
143 }
144 }
145
146 for (centroid, count) in new_centroids.iter_mut().zip(counts.iter()) {
148 if *count > 0 {
149 for val in centroid.iter_mut() {
150 *val /= *count as f32;
151 }
152 }
153 }
154
155 self.centroids = new_centroids;
156 }
157
158 fn evict_lowest_quality(&mut self) {
159 if let Some(min_idx) = self
161 .patterns
162 .iter()
163 .enumerate()
164 .min_by(|(_, a), (_, b)| {
165 let score_a = a.quality_score * (a.usage_count as f32 + 1.0).ln();
166 let score_b = b.quality_score * (b.usage_count as f32 + 1.0).ln();
167 score_a.partial_cmp(&score_b).unwrap()
168 })
169 .map(|(i, _)| i)
170 {
171 self.patterns.remove(min_idx);
172 }
173 }
174
175 pub fn pattern_count(&self) -> usize {
176 self.patterns.len()
177 }
178
179 pub fn cluster_count(&self) -> usize {
180 self.centroids.len()
181 }
182}
183
184fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
185 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
186 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
187 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
188 if norm_a > 0.0 && norm_b > 0.0 {
189 dot / (norm_a * norm_b)
190 } else {
191 0.0
192 }
193}
194
195fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
196 a.iter()
197 .zip(b.iter())
198 .map(|(x, y)| (x - y).powi(2))
199 .sum::<f32>()
200 .sqrt()
201}
202
203fn kmeans_pp_init(patterns: &[DagPattern], k: usize) -> Vec<Vec<f32>> {
204 use rand::Rng;
205
206 if patterns.is_empty() || k == 0 {
207 return Vec::new();
208 }
209
210 let mut rng = rand::thread_rng();
211 let mut centroids = Vec::with_capacity(k);
212 let _dim = patterns[0].vector.len();
213
214 let first_idx = rng.gen_range(0..patterns.len());
216 centroids.push(patterns[first_idx].vector.clone());
217
218 for _ in 1..k {
220 let mut distances = Vec::with_capacity(patterns.len());
221 let mut total_distance = 0.0f32;
222
223 for pattern in patterns {
225 let min_dist = centroids
226 .iter()
227 .map(|c| euclidean_distance(&pattern.vector, c))
228 .min_by(|a, b| a.partial_cmp(b).unwrap())
229 .unwrap_or(0.0);
230 let squared = min_dist * min_dist;
231 distances.push(squared);
232 total_distance += squared;
233 }
234
235 if total_distance > 0.0 {
237 let mut threshold = rng.gen::<f32>() * total_distance;
238 for (idx, &dist) in distances.iter().enumerate() {
239 threshold -= dist;
240 if threshold <= 0.0 {
241 centroids.push(patterns[idx].vector.clone());
242 break;
243 }
244 }
245 } else {
246 let idx = rng.gen_range(0..patterns.len());
248 centroids.push(patterns[idx].vector.clone());
249 }
250
251 if centroids.len() >= k {
252 break;
253 }
254 }
255
256 centroids
257}