1use crate::types::{LearnedPattern, PatternType, QueryTrajectory};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct PatternConfig {
12 pub k_clusters: usize,
14 pub embedding_dim: usize,
16 pub max_iterations: usize,
18 pub convergence_threshold: f32,
20 pub min_cluster_size: usize,
22 pub max_trajectories: usize,
24 pub quality_threshold: f32,
26}
27
28impl Default for PatternConfig {
29 fn default() -> Self {
30 Self {
34 k_clusters: 100, embedding_dim: 256,
36 max_iterations: 100,
37 convergence_threshold: 0.001,
38 min_cluster_size: 5,
39 max_trajectories: 10000,
40 quality_threshold: 0.3, }
42 }
43}
44
45#[derive(Clone, Debug)]
47pub struct ReasoningBank {
48 config: PatternConfig,
50 trajectories: Vec<TrajectoryEntry>,
52 patterns: HashMap<u64, LearnedPattern>,
54 next_pattern_id: u64,
56 pattern_index: Vec<(Vec<f32>, u64)>,
58}
59
60#[derive(Clone, Debug)]
62struct TrajectoryEntry {
63 embedding: Vec<f32>,
65 quality: f32,
67 cluster: Option<usize>,
69 trajectory_id: u64,
71}
72
73impl ReasoningBank {
74 pub fn new(config: PatternConfig) -> Self {
76 Self {
77 config,
78 trajectories: Vec::new(),
79 patterns: HashMap::new(),
80 next_pattern_id: 0,
81 pattern_index: Vec::new(),
82 }
83 }
84
85 pub fn add_trajectory(&mut self, trajectory: &QueryTrajectory) {
87 let embedding = self.compute_embedding(trajectory);
89
90 let entry = TrajectoryEntry {
91 embedding,
92 quality: trajectory.final_quality,
93 cluster: None,
94 trajectory_id: trajectory.id,
95 };
96
97 if self.trajectories.len() >= self.config.max_trajectories {
99 let to_remove = self.trajectories.len() - self.config.max_trajectories + 1;
101 self.trajectories.drain(0..to_remove);
102 }
103
104 self.trajectories.push(entry);
105 }
106
107 fn compute_embedding(&self, trajectory: &QueryTrajectory) -> Vec<f32> {
109 let dim = self.config.embedding_dim;
110 let mut embedding = vec![0.0f32; dim];
111
112 let query_len = trajectory.query_embedding.len().min(dim);
114 embedding[..query_len].copy_from_slice(&trajectory.query_embedding[..query_len]);
115
116 if !trajectory.steps.is_empty() {
118 let mut total_reward = 0.0f32;
119
120 for step in &trajectory.steps {
121 let weight = step.reward.max(0.0);
122 total_reward += weight;
123
124 for (i, &act) in step.activations.iter().enumerate() {
125 if i < dim {
126 embedding[i] += act * weight;
127 }
128 }
129 }
130
131 if total_reward > 0.0 {
132 for e in &mut embedding {
133 *e /= total_reward + 1.0; }
135 }
136 }
137
138 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
140 if norm > 1e-8 {
141 for e in &mut embedding {
142 *e /= norm;
143 }
144 }
145
146 embedding
147 }
148
149 pub fn extract_patterns(&mut self) -> Vec<LearnedPattern> {
151 if self.trajectories.is_empty() {
152 return Vec::new();
153 }
154
155 let k = self.config.k_clusters.min(self.trajectories.len());
156 if k == 0 {
157 return Vec::new();
158 }
159
160 let centroids = self.kmeans_plus_plus_init(k);
162
163 let (final_centroids, assignments) = self.run_kmeans(centroids);
165
166 let mut patterns = Vec::new();
168
169 for (cluster_idx, centroid) in final_centroids.into_iter().enumerate() {
170 let members: Vec<_> = self
172 .trajectories
173 .iter()
174 .enumerate()
175 .filter(|(i, _)| assignments.get(*i) == Some(&cluster_idx))
176 .map(|(_, t)| t)
177 .collect();
178
179 if members.len() < self.config.min_cluster_size {
180 continue;
181 }
182
183 let cluster_size = members.len();
185 let total_weight: f32 = members.iter().map(|t| t.quality).sum();
186 let avg_quality = total_weight / cluster_size as f32;
187
188 if avg_quality < self.config.quality_threshold {
189 continue;
190 }
191
192 let pattern_id = self.next_pattern_id;
193 self.next_pattern_id += 1;
194
195 let now = crate::time_compat::SystemTime::now()
196 .duration_since_epoch()
197 .as_secs();
198 let pattern = LearnedPattern {
199 id: pattern_id,
200 centroid,
201 cluster_size,
202 total_weight,
203 avg_quality,
204 created_at: now,
205 last_accessed: now,
206 access_count: 0,
207 pattern_type: PatternType::General,
208 };
209
210 self.patterns.insert(pattern_id, pattern.clone());
211 self.pattern_index
212 .push((pattern.centroid.clone(), pattern_id));
213 patterns.push(pattern);
214 }
215
216 for (i, cluster) in assignments.into_iter().enumerate() {
218 if i < self.trajectories.len() {
219 self.trajectories[i].cluster = Some(cluster);
220 }
221 }
222
223 patterns
224 }
225
226 fn kmeans_plus_plus_init(&self, k: usize) -> Vec<Vec<f32>> {
228 let mut centroids = Vec::with_capacity(k);
229 let n = self.trajectories.len();
230
231 if n == 0 || k == 0 {
232 return centroids;
233 }
234
235 let first_idx = 0;
237 centroids.push(self.trajectories[first_idx].embedding.clone());
238
239 for _ in 1..k {
241 let mut distances: Vec<f32> = self
243 .trajectories
244 .iter()
245 .map(|t| {
246 centroids
247 .iter()
248 .map(|c| self.squared_distance(&t.embedding, c))
249 .fold(f32::MAX, f32::min)
250 })
251 .collect();
252
253 let total: f32 = distances.iter().sum();
255 if total > 0.0 {
256 for d in &mut distances {
257 *d /= total;
258 }
259 }
260
261 let (next_idx, _) = distances
264 .iter()
265 .enumerate()
266 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
267 .unwrap_or((0, &0.0));
268
269 centroids.push(self.trajectories[next_idx].embedding.clone());
270 }
271
272 centroids
273 }
274
275 fn run_kmeans(&self, mut centroids: Vec<Vec<f32>>) -> (Vec<Vec<f32>>, Vec<usize>) {
277 let n = self.trajectories.len();
278 let k = centroids.len();
279 let dim = self.config.embedding_dim;
280
281 let mut assignments = vec![0usize; n];
282
283 for _iter in 0..self.config.max_iterations {
284 let mut changed = false;
286 for (i, t) in self.trajectories.iter().enumerate() {
287 let (nearest, _) = centroids
289 .iter()
290 .enumerate()
291 .map(|(j, c)| (j, self.squared_distance(&t.embedding, c)))
292 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
293 .unwrap_or((0, 0.0));
294
295 if assignments[i] != nearest {
296 assignments[i] = nearest;
297 changed = true;
298 }
299 }
300
301 if !changed {
302 break;
303 }
304
305 let mut new_centroids = vec![vec![0.0f32; dim]; k];
307 let mut counts = vec![0usize; k];
308
309 for (i, t) in self.trajectories.iter().enumerate() {
310 let cluster = assignments[i];
311 counts[cluster] += 1;
312 for (j, &e) in t.embedding.iter().enumerate() {
313 new_centroids[cluster][j] += e;
314 }
315 }
316
317 let mut max_shift = 0.0f32;
319 for (i, new_c) in new_centroids.iter_mut().enumerate() {
320 if counts[i] > 0 {
321 for e in new_c.iter_mut() {
322 *e /= counts[i] as f32;
323 }
324 let shift = self.squared_distance(new_c, ¢roids[i]).sqrt();
325 max_shift = max_shift.max(shift);
326 }
327 }
328
329 centroids = new_centroids;
330
331 if max_shift < self.config.convergence_threshold {
332 break;
333 }
334 }
335
336 (centroids, assignments)
337 }
338
339 fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
341 a.iter()
342 .zip(b.iter())
343 .map(|(&x, &y)| (x - y) * (x - y))
344 .sum()
345 }
346
347 pub fn find_similar(&self, query: &[f32], k: usize) -> Vec<&LearnedPattern> {
349 let mut scored: Vec<_> = self
350 .patterns
351 .values()
352 .map(|p| (p, p.similarity(query)))
353 .collect();
354
355 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
357
358 scored.into_iter().take(k).map(|(p, _)| p).collect()
359 }
360
361 pub fn get_pattern(&self, id: u64) -> Option<&LearnedPattern> {
363 self.patterns.get(&id)
364 }
365
366 pub fn get_pattern_mut(&mut self, id: u64) -> Option<&mut LearnedPattern> {
368 self.patterns.get_mut(&id)
369 }
370
371 pub fn trajectory_count(&self) -> usize {
373 self.trajectories.len()
374 }
375
376 pub fn pattern_count(&self) -> usize {
378 self.patterns.len()
379 }
380
381 pub fn clear_trajectories(&mut self) {
383 self.trajectories.clear();
384 }
385
386 pub fn prune_patterns(&mut self, min_quality: f32, min_accesses: u32, max_age_secs: u64) {
388 let to_remove: Vec<u64> = self
389 .patterns
390 .iter()
391 .filter(|(_, p)| p.should_prune(min_quality, min_accesses, max_age_secs))
392 .map(|(id, _)| *id)
393 .collect();
394
395 for id in to_remove {
396 self.patterns.remove(&id);
397 }
398
399 self.pattern_index
401 .retain(|(_, id)| self.patterns.contains_key(id));
402 }
403
404 pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
406 self.patterns.values().cloned().collect()
407 }
408
409 pub fn consolidate(&mut self, similarity_threshold: f32) {
411 let pattern_ids: Vec<u64> = self.patterns.keys().copied().collect();
412 let mut merged = Vec::new();
413
414 for i in 0..pattern_ids.len() {
415 for j in i + 1..pattern_ids.len() {
416 let id1 = pattern_ids[i];
417 let id2 = pattern_ids[j];
418
419 if merged.contains(&id1) || merged.contains(&id2) {
420 continue;
421 }
422
423 if let (Some(p1), Some(p2)) = (self.patterns.get(&id1), self.patterns.get(&id2)) {
424 let sim = p1.similarity(&p2.centroid);
425 if sim > similarity_threshold {
426 let merged_pattern = p1.merge(p2);
428 self.patterns.insert(id1, merged_pattern);
429 merged.push(id2);
430 }
431 }
432 }
433 }
434
435 for id in merged {
437 self.patterns.remove(&id);
438 }
439
440 self.pattern_index
442 .retain(|(_, id)| self.patterns.contains_key(id));
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 fn make_trajectory(id: u64, embedding: Vec<f32>, quality: f32) -> QueryTrajectory {
451 let mut t = QueryTrajectory::new(id, embedding);
452 t.finalize(quality, 1000);
453 t
454 }
455
456 #[test]
457 fn test_bank_creation() {
458 let bank = ReasoningBank::new(PatternConfig::default());
459 assert_eq!(bank.trajectory_count(), 0);
460 assert_eq!(bank.pattern_count(), 0);
461 }
462
463 #[test]
464 fn test_add_trajectory() {
465 let config = PatternConfig {
466 embedding_dim: 4,
467 ..Default::default()
468 };
469 let mut bank = ReasoningBank::new(config);
470
471 let t = make_trajectory(1, vec![0.1, 0.2, 0.3, 0.4], 0.8);
472 bank.add_trajectory(&t);
473
474 assert_eq!(bank.trajectory_count(), 1);
475 }
476
477 #[test]
478 fn test_extract_patterns() {
479 let config = PatternConfig {
480 embedding_dim: 4,
481 k_clusters: 2,
482 min_cluster_size: 2,
483 quality_threshold: 0.0,
484 ..Default::default()
485 };
486 let mut bank = ReasoningBank::new(config);
487
488 for i in 0..5 {
490 let t = make_trajectory(i, vec![1.0, 0.0, 0.0, 0.0], 0.8);
491 bank.add_trajectory(&t);
492 }
493 for i in 5..10 {
494 let t = make_trajectory(i, vec![0.0, 1.0, 0.0, 0.0], 0.7);
495 bank.add_trajectory(&t);
496 }
497
498 let patterns = bank.extract_patterns();
499 assert!(!patterns.is_empty());
500 }
501
502 #[test]
503 fn test_find_similar() {
504 let config = PatternConfig {
505 embedding_dim: 4,
506 k_clusters: 2,
507 min_cluster_size: 2,
508 quality_threshold: 0.0,
509 ..Default::default()
510 };
511 let mut bank = ReasoningBank::new(config);
512
513 for i in 0..10 {
514 let emb = if i < 5 {
515 vec![1.0, 0.0, 0.0, 0.0]
516 } else {
517 vec![0.0, 1.0, 0.0, 0.0]
518 };
519 bank.add_trajectory(&make_trajectory(i, emb, 0.8));
520 }
521
522 bank.extract_patterns();
523
524 let query = vec![0.9, 0.1, 0.0, 0.0];
525 let similar = bank.find_similar(&query, 1);
526 assert!(!similar.is_empty());
527 }
528
529 #[test]
530 fn test_consolidate() {
531 let config = PatternConfig {
532 embedding_dim: 4,
533 k_clusters: 3,
534 min_cluster_size: 1,
535 quality_threshold: 0.0,
536 ..Default::default()
537 };
538 let mut bank = ReasoningBank::new(config);
539
540 for i in 0..9 {
542 let emb = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0, 0.0];
543 bank.add_trajectory(&make_trajectory(i, emb, 0.8));
544 }
545
546 bank.extract_patterns();
547 let before = bank.pattern_count();
548
549 bank.consolidate(0.99);
550 let after = bank.pattern_count();
551
552 assert!(after <= before);
553 }
554}