1use crate::types::{LearnedPattern, PatternType, QueryTrajectory};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct PatternConfig {
12 pub k_clusters: usize,
14 pub embedding_dim: usize,
16 pub max_iterations: usize,
18 pub convergence_threshold: f32,
20 pub min_cluster_size: usize,
22 pub max_trajectories: usize,
24 pub quality_threshold: f32,
26}
27
28impl Default for PatternConfig {
29 fn default() -> Self {
30 Self {
31 k_clusters: 50,
32 embedding_dim: 256,
33 max_iterations: 100,
34 convergence_threshold: 0.001,
35 min_cluster_size: 5,
36 max_trajectories: 10000,
37 quality_threshold: 0.5,
38 }
39 }
40}
41
42#[derive(Clone, Debug)]
44pub struct ReasoningBank {
45 config: PatternConfig,
47 trajectories: Vec<TrajectoryEntry>,
49 patterns: HashMap<u64, LearnedPattern>,
51 next_pattern_id: u64,
53 pattern_index: Vec<(Vec<f32>, u64)>,
55}
56
57#[derive(Clone, Debug)]
59struct TrajectoryEntry {
60 embedding: Vec<f32>,
62 quality: f32,
64 cluster: Option<usize>,
66 trajectory_id: u64,
68}
69
70impl ReasoningBank {
71 pub fn new(config: PatternConfig) -> Self {
73 Self {
74 config,
75 trajectories: Vec::new(),
76 patterns: HashMap::new(),
77 next_pattern_id: 0,
78 pattern_index: Vec::new(),
79 }
80 }
81
82 pub fn add_trajectory(&mut self, trajectory: &QueryTrajectory) {
84 let embedding = self.compute_embedding(trajectory);
86
87 let entry = TrajectoryEntry {
88 embedding,
89 quality: trajectory.final_quality,
90 cluster: None,
91 trajectory_id: trajectory.id,
92 };
93
94 if self.trajectories.len() >= self.config.max_trajectories {
96 let to_remove = self.trajectories.len() - self.config.max_trajectories + 1;
98 self.trajectories.drain(0..to_remove);
99 }
100
101 self.trajectories.push(entry);
102 }
103
104 fn compute_embedding(&self, trajectory: &QueryTrajectory) -> Vec<f32> {
106 let dim = self.config.embedding_dim;
107 let mut embedding = vec![0.0f32; dim];
108
109 let query_len = trajectory.query_embedding.len().min(dim);
111 embedding[..query_len].copy_from_slice(&trajectory.query_embedding[..query_len]);
112
113 if !trajectory.steps.is_empty() {
115 let mut total_reward = 0.0f32;
116
117 for step in &trajectory.steps {
118 let weight = step.reward.max(0.0);
119 total_reward += weight;
120
121 for (i, &act) in step.activations.iter().enumerate() {
122 if i < dim {
123 embedding[i] += act * weight;
124 }
125 }
126 }
127
128 if total_reward > 0.0 {
129 for e in &mut embedding {
130 *e /= total_reward + 1.0; }
132 }
133 }
134
135 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
137 if norm > 1e-8 {
138 for e in &mut embedding {
139 *e /= norm;
140 }
141 }
142
143 embedding
144 }
145
146 pub fn extract_patterns(&mut self) -> Vec<LearnedPattern> {
148 if self.trajectories.is_empty() {
149 return Vec::new();
150 }
151
152 let k = self.config.k_clusters.min(self.trajectories.len());
153 if k == 0 {
154 return Vec::new();
155 }
156
157 let centroids = self.kmeans_plus_plus_init(k);
159
160 let (final_centroids, assignments) = self.run_kmeans(centroids);
162
163 let mut patterns = Vec::new();
165
166 for (cluster_idx, centroid) in final_centroids.into_iter().enumerate() {
167 let members: Vec<_> = self.trajectories.iter()
169 .enumerate()
170 .filter(|(i, _)| assignments.get(*i) == Some(&cluster_idx))
171 .map(|(_, t)| t)
172 .collect();
173
174 if members.len() < self.config.min_cluster_size {
175 continue;
176 }
177
178 let cluster_size = members.len();
180 let total_weight: f32 = members.iter().map(|t| t.quality).sum();
181 let avg_quality = total_weight / cluster_size as f32;
182
183 if avg_quality < self.config.quality_threshold {
184 continue;
185 }
186
187 let pattern_id = self.next_pattern_id;
188 self.next_pattern_id += 1;
189
190 let pattern = LearnedPattern {
191 id: pattern_id,
192 centroid,
193 cluster_size,
194 total_weight,
195 avg_quality,
196 created_at: std::time::SystemTime::now()
197 .duration_since(std::time::UNIX_EPOCH)
198 .unwrap_or_default()
199 .as_secs(),
200 last_accessed: std::time::SystemTime::now()
201 .duration_since(std::time::UNIX_EPOCH)
202 .unwrap_or_default()
203 .as_secs(),
204 access_count: 0,
205 pattern_type: PatternType::General,
206 };
207
208 self.patterns.insert(pattern_id, pattern.clone());
209 self.pattern_index.push((pattern.centroid.clone(), pattern_id));
210 patterns.push(pattern);
211 }
212
213 for (i, cluster) in assignments.into_iter().enumerate() {
215 if i < self.trajectories.len() {
216 self.trajectories[i].cluster = Some(cluster);
217 }
218 }
219
220 patterns
221 }
222
223 fn kmeans_plus_plus_init(&self, k: usize) -> Vec<Vec<f32>> {
225 let mut centroids = Vec::with_capacity(k);
226 let n = self.trajectories.len();
227
228 if n == 0 || k == 0 {
229 return centroids;
230 }
231
232 let first_idx = 0;
234 centroids.push(self.trajectories[first_idx].embedding.clone());
235
236 for _ in 1..k {
238 let mut distances: Vec<f32> = self.trajectories.iter()
240 .map(|t| {
241 centroids.iter()
242 .map(|c| self.squared_distance(&t.embedding, c))
243 .fold(f32::MAX, f32::min)
244 })
245 .collect();
246
247 let total: f32 = distances.iter().sum();
249 if total > 0.0 {
250 for d in &mut distances {
251 *d /= total;
252 }
253 }
254
255 let (next_idx, _) = distances.iter()
257 .enumerate()
258 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
259 .unwrap_or((0, &0.0));
260
261 centroids.push(self.trajectories[next_idx].embedding.clone());
262 }
263
264 centroids
265 }
266
267 fn run_kmeans(&self, mut centroids: Vec<Vec<f32>>) -> (Vec<Vec<f32>>, Vec<usize>) {
269 let n = self.trajectories.len();
270 let k = centroids.len();
271 let dim = self.config.embedding_dim;
272
273 let mut assignments = vec![0usize; n];
274
275 for _iter in 0..self.config.max_iterations {
276 let mut changed = false;
278 for (i, t) in self.trajectories.iter().enumerate() {
279 let (nearest, _) = centroids.iter()
280 .enumerate()
281 .map(|(j, c)| (j, self.squared_distance(&t.embedding, c)))
282 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
283 .unwrap_or((0, 0.0));
284
285 if assignments[i] != nearest {
286 assignments[i] = nearest;
287 changed = true;
288 }
289 }
290
291 if !changed {
292 break;
293 }
294
295 let mut new_centroids = vec![vec![0.0f32; dim]; k];
297 let mut counts = vec![0usize; k];
298
299 for (i, t) in self.trajectories.iter().enumerate() {
300 let cluster = assignments[i];
301 counts[cluster] += 1;
302 for (j, &e) in t.embedding.iter().enumerate() {
303 new_centroids[cluster][j] += e;
304 }
305 }
306
307 let mut max_shift = 0.0f32;
309 for (i, new_c) in new_centroids.iter_mut().enumerate() {
310 if counts[i] > 0 {
311 for e in new_c.iter_mut() {
312 *e /= counts[i] as f32;
313 }
314 let shift = self.squared_distance(new_c, ¢roids[i]).sqrt();
315 max_shift = max_shift.max(shift);
316 }
317 }
318
319 centroids = new_centroids;
320
321 if max_shift < self.config.convergence_threshold {
322 break;
323 }
324 }
325
326 (centroids, assignments)
327 }
328
329 fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
331 a.iter()
332 .zip(b.iter())
333 .map(|(&x, &y)| (x - y) * (x - y))
334 .sum()
335 }
336
337 pub fn find_similar(&self, query: &[f32], k: usize) -> Vec<&LearnedPattern> {
339 let mut scored: Vec<_> = self.patterns.values()
340 .map(|p| (p, p.similarity(query)))
341 .collect();
342
343 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
344
345 scored.into_iter()
346 .take(k)
347 .map(|(p, _)| p)
348 .collect()
349 }
350
351 pub fn get_pattern(&self, id: u64) -> Option<&LearnedPattern> {
353 self.patterns.get(&id)
354 }
355
356 pub fn get_pattern_mut(&mut self, id: u64) -> Option<&mut LearnedPattern> {
358 self.patterns.get_mut(&id)
359 }
360
361 pub fn trajectory_count(&self) -> usize {
363 self.trajectories.len()
364 }
365
366 pub fn pattern_count(&self) -> usize {
368 self.patterns.len()
369 }
370
371 pub fn clear_trajectories(&mut self) {
373 self.trajectories.clear();
374 }
375
376 pub fn prune_patterns(&mut self, min_quality: f32, min_accesses: u32, max_age_secs: u64) {
378 let to_remove: Vec<u64> = self.patterns.iter()
379 .filter(|(_, p)| p.should_prune(min_quality, min_accesses, max_age_secs))
380 .map(|(id, _)| *id)
381 .collect();
382
383 for id in to_remove {
384 self.patterns.remove(&id);
385 }
386
387 self.pattern_index.retain(|(_, id)| self.patterns.contains_key(id));
389 }
390
391 pub fn consolidate(&mut self, similarity_threshold: f32) {
393 let pattern_ids: Vec<u64> = self.patterns.keys().copied().collect();
394 let mut merged = Vec::new();
395
396 for i in 0..pattern_ids.len() {
397 for j in i+1..pattern_ids.len() {
398 let id1 = pattern_ids[i];
399 let id2 = pattern_ids[j];
400
401 if merged.contains(&id1) || merged.contains(&id2) {
402 continue;
403 }
404
405 if let (Some(p1), Some(p2)) = (self.patterns.get(&id1), self.patterns.get(&id2)) {
406 let sim = p1.similarity(&p2.centroid);
407 if sim > similarity_threshold {
408 let merged_pattern = p1.merge(p2);
410 self.patterns.insert(id1, merged_pattern);
411 merged.push(id2);
412 }
413 }
414 }
415 }
416
417 for id in merged {
419 self.patterns.remove(&id);
420 }
421
422 self.pattern_index.retain(|(_, id)| self.patterns.contains_key(id));
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 fn make_trajectory(id: u64, embedding: Vec<f32>, quality: f32) -> QueryTrajectory {
432 let mut t = QueryTrajectory::new(id, embedding);
433 t.finalize(quality, 1000);
434 t
435 }
436
437 #[test]
438 fn test_bank_creation() {
439 let bank = ReasoningBank::new(PatternConfig::default());
440 assert_eq!(bank.trajectory_count(), 0);
441 assert_eq!(bank.pattern_count(), 0);
442 }
443
444 #[test]
445 fn test_add_trajectory() {
446 let config = PatternConfig {
447 embedding_dim: 4,
448 ..Default::default()
449 };
450 let mut bank = ReasoningBank::new(config);
451
452 let t = make_trajectory(1, vec![0.1, 0.2, 0.3, 0.4], 0.8);
453 bank.add_trajectory(&t);
454
455 assert_eq!(bank.trajectory_count(), 1);
456 }
457
458 #[test]
459 fn test_extract_patterns() {
460 let config = PatternConfig {
461 embedding_dim: 4,
462 k_clusters: 2,
463 min_cluster_size: 2,
464 quality_threshold: 0.0,
465 ..Default::default()
466 };
467 let mut bank = ReasoningBank::new(config);
468
469 for i in 0..5 {
471 let t = make_trajectory(i, vec![1.0, 0.0, 0.0, 0.0], 0.8);
472 bank.add_trajectory(&t);
473 }
474 for i in 5..10 {
475 let t = make_trajectory(i, vec![0.0, 1.0, 0.0, 0.0], 0.7);
476 bank.add_trajectory(&t);
477 }
478
479 let patterns = bank.extract_patterns();
480 assert!(!patterns.is_empty());
481 }
482
483 #[test]
484 fn test_find_similar() {
485 let config = PatternConfig {
486 embedding_dim: 4,
487 k_clusters: 2,
488 min_cluster_size: 2,
489 quality_threshold: 0.0,
490 ..Default::default()
491 };
492 let mut bank = ReasoningBank::new(config);
493
494 for i in 0..10 {
495 let emb = if i < 5 {
496 vec![1.0, 0.0, 0.0, 0.0]
497 } else {
498 vec![0.0, 1.0, 0.0, 0.0]
499 };
500 bank.add_trajectory(&make_trajectory(i, emb, 0.8));
501 }
502
503 bank.extract_patterns();
504
505 let query = vec![0.9, 0.1, 0.0, 0.0];
506 let similar = bank.find_similar(&query, 1);
507 assert!(!similar.is_empty());
508 }
509
510 #[test]
511 fn test_consolidate() {
512 let config = PatternConfig {
513 embedding_dim: 4,
514 k_clusters: 3,
515 min_cluster_size: 1,
516 quality_threshold: 0.0,
517 ..Default::default()
518 };
519 let mut bank = ReasoningBank::new(config);
520
521 for i in 0..9 {
523 let emb = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0, 0.0];
524 bank.add_trajectory(&make_trajectory(i, emb, 0.8));
525 }
526
527 bank.extract_patterns();
528 let before = bank.pattern_count();
529
530 bank.consolidate(0.99);
531 let after = bank.pattern_count();
532
533 assert!(after <= before);
534 }
535}