1use crate::types::{LearnedPattern, PatternType, QueryTrajectory};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Clone, Debug, Serialize, Deserialize)]
11pub struct PatternConfig {
12 pub k_clusters: usize,
14 pub embedding_dim: usize,
16 pub max_iterations: usize,
18 pub convergence_threshold: f32,
20 pub min_cluster_size: usize,
22 pub max_trajectories: usize,
24 pub quality_threshold: f32,
26}
27
28impl Default for PatternConfig {
29 fn default() -> Self {
30 Self {
35 k_clusters: 5, embedding_dim: 256,
37 max_iterations: 100,
38 convergence_threshold: 0.001,
39 min_cluster_size: 1, max_trajectories: 10000,
41 quality_threshold: 0.05, }
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct ReasoningBank {
49 config: PatternConfig,
51 trajectories: Vec<TrajectoryEntry>,
53 patterns: HashMap<u64, LearnedPattern>,
55 next_pattern_id: u64,
57 pattern_index: Vec<(Vec<f32>, u64)>,
59}
60
61#[derive(Clone, Debug)]
63struct TrajectoryEntry {
64 embedding: Vec<f32>,
66 quality: f32,
68 cluster: Option<usize>,
70 _trajectory_id: u64,
72}
73
74impl ReasoningBank {
75 pub fn new(config: PatternConfig) -> Self {
77 Self {
78 config,
79 trajectories: Vec::new(),
80 patterns: HashMap::new(),
81 next_pattern_id: 0,
82 pattern_index: Vec::new(),
83 }
84 }
85
86 pub fn add_trajectory(&mut self, trajectory: &QueryTrajectory) {
88 let embedding = self.compute_embedding(trajectory);
90
91 let entry = TrajectoryEntry {
92 embedding,
93 quality: trajectory.final_quality,
94 cluster: None,
95 _trajectory_id: trajectory.id,
96 };
97
98 if self.trajectories.len() >= self.config.max_trajectories {
100 let to_remove = self.trajectories.len() - self.config.max_trajectories + 1;
102 self.trajectories.drain(0..to_remove);
103 }
104
105 self.trajectories.push(entry);
106 }
107
108 fn compute_embedding(&self, trajectory: &QueryTrajectory) -> Vec<f32> {
110 let dim = self.config.embedding_dim;
111 let mut embedding = vec![0.0f32; dim];
112
113 let query_len = trajectory.query_embedding.len().min(dim);
115 embedding[..query_len].copy_from_slice(&trajectory.query_embedding[..query_len]);
116
117 if !trajectory.steps.is_empty() {
119 let mut total_reward = 0.0f32;
120
121 for step in &trajectory.steps {
122 let weight = step.reward.max(0.0);
123 total_reward += weight;
124
125 for (i, &act) in step.activations.iter().enumerate() {
126 if i < dim {
127 embedding[i] += act * weight;
128 }
129 }
130 }
131
132 if total_reward > 0.0 {
133 for e in &mut embedding {
134 *e /= total_reward + 1.0; }
136 }
137 }
138
139 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
141 if norm > 1e-8 {
142 for e in &mut embedding {
143 *e /= norm;
144 }
145 }
146
147 embedding
148 }
149
150 pub fn extract_patterns(&mut self) -> Vec<LearnedPattern> {
152 if self.trajectories.is_empty() {
153 return Vec::new();
154 }
155
156 let k = self.config.k_clusters.min(self.trajectories.len());
157 if k == 0 {
158 return Vec::new();
159 }
160
161 let centroids = self.kmeans_plus_plus_init(k);
163
164 let (final_centroids, assignments) = self.run_kmeans(centroids);
166
167 let mut patterns = Vec::new();
169
170 for (cluster_idx, centroid) in final_centroids.into_iter().enumerate() {
171 let members: Vec<_> = self
173 .trajectories
174 .iter()
175 .enumerate()
176 .filter(|(i, _)| assignments.get(*i) == Some(&cluster_idx))
177 .map(|(_, t)| t)
178 .collect();
179
180 if members.len() < self.config.min_cluster_size {
181 continue;
182 }
183
184 let cluster_size = members.len();
186 let total_weight: f32 = members.iter().map(|t| t.quality).sum();
187 let avg_quality = total_weight / cluster_size as f32;
188
189 if avg_quality < self.config.quality_threshold {
190 continue;
191 }
192
193 let pattern_id = self.next_pattern_id;
194 self.next_pattern_id += 1;
195
196 let now = crate::time_compat::SystemTime::now()
197 .duration_since_epoch()
198 .as_secs();
199 let pattern = LearnedPattern {
200 id: pattern_id,
201 centroid,
202 cluster_size,
203 total_weight,
204 avg_quality,
205 created_at: now,
206 last_accessed: now,
207 access_count: 0,
208 pattern_type: PatternType::General,
209 };
210
211 self.patterns.insert(pattern_id, pattern.clone());
212 self.pattern_index
213 .push((pattern.centroid.clone(), pattern_id));
214 patterns.push(pattern);
215 }
216
217 for (i, cluster) in assignments.into_iter().enumerate() {
219 if i < self.trajectories.len() {
220 self.trajectories[i].cluster = Some(cluster);
221 }
222 }
223
224 patterns
225 }
226
227 fn kmeans_plus_plus_init(&self, k: usize) -> Vec<Vec<f32>> {
229 let mut centroids = Vec::with_capacity(k);
230 let n = self.trajectories.len();
231
232 if n == 0 || k == 0 {
233 return centroids;
234 }
235
236 let first_idx = 0;
238 centroids.push(self.trajectories[first_idx].embedding.clone());
239
240 for _ in 1..k {
242 let mut distances: Vec<f32> = self
244 .trajectories
245 .iter()
246 .map(|t| {
247 centroids
248 .iter()
249 .map(|c| self.squared_distance(&t.embedding, c))
250 .fold(f32::MAX, f32::min)
251 })
252 .collect();
253
254 let total: f32 = distances.iter().sum();
256 if total > 0.0 {
257 for d in &mut distances {
258 *d /= total;
259 }
260 }
261
262 let (next_idx, _) = distances
265 .iter()
266 .enumerate()
267 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
268 .unwrap_or((0, &0.0));
269
270 centroids.push(self.trajectories[next_idx].embedding.clone());
271 }
272
273 centroids
274 }
275
276 fn run_kmeans(&self, mut centroids: Vec<Vec<f32>>) -> (Vec<Vec<f32>>, Vec<usize>) {
278 let n = self.trajectories.len();
279 let k = centroids.len();
280 let dim = self.config.embedding_dim;
281
282 let mut assignments = vec![0usize; n];
283
284 for _iter in 0..self.config.max_iterations {
285 let mut changed = false;
287 for (i, t) in self.trajectories.iter().enumerate() {
288 let (nearest, _) = centroids
290 .iter()
291 .enumerate()
292 .map(|(j, c)| (j, self.squared_distance(&t.embedding, c)))
293 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
294 .unwrap_or((0, 0.0));
295
296 if assignments[i] != nearest {
297 assignments[i] = nearest;
298 changed = true;
299 }
300 }
301
302 if !changed {
303 break;
304 }
305
306 let mut new_centroids = vec![vec![0.0f32; dim]; k];
308 let mut counts = vec![0usize; k];
309
310 for (i, t) in self.trajectories.iter().enumerate() {
311 let cluster = assignments[i];
312 counts[cluster] += 1;
313 for (j, &e) in t.embedding.iter().enumerate() {
314 new_centroids[cluster][j] += e;
315 }
316 }
317
318 let mut max_shift = 0.0f32;
320 for (i, new_c) in new_centroids.iter_mut().enumerate() {
321 if counts[i] > 0 {
322 for e in new_c.iter_mut() {
323 *e /= counts[i] as f32;
324 }
325 let shift = self.squared_distance(new_c, ¢roids[i]).sqrt();
326 max_shift = max_shift.max(shift);
327 }
328 }
329
330 centroids = new_centroids;
331
332 if max_shift < self.config.convergence_threshold {
333 break;
334 }
335 }
336
337 (centroids, assignments)
338 }
339
340 fn squared_distance(&self, a: &[f32], b: &[f32]) -> f32 {
342 a.iter()
343 .zip(b.iter())
344 .map(|(&x, &y)| (x - y) * (x - y))
345 .sum()
346 }
347
348 pub fn find_similar(&self, query: &[f32], k: usize) -> Vec<&LearnedPattern> {
350 let mut scored: Vec<_> = self
351 .patterns
352 .values()
353 .map(|p| (p, p.similarity(query)))
354 .collect();
355
356 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
358
359 scored.into_iter().take(k).map(|(p, _)| p).collect()
360 }
361
362 pub fn get_pattern(&self, id: u64) -> Option<&LearnedPattern> {
364 self.patterns.get(&id)
365 }
366
367 pub fn get_pattern_mut(&mut self, id: u64) -> Option<&mut LearnedPattern> {
369 self.patterns.get_mut(&id)
370 }
371
372 pub fn trajectory_count(&self) -> usize {
374 self.trajectories.len()
375 }
376
377 pub fn pattern_count(&self) -> usize {
379 self.patterns.len()
380 }
381
382 pub fn clear_trajectories(&mut self) {
384 self.trajectories.clear();
385 }
386
387 pub fn prune_patterns(&mut self, min_quality: f32, min_accesses: u32, max_age_secs: u64) {
389 let to_remove: Vec<u64> = self
390 .patterns
391 .iter()
392 .filter(|(_, p)| p.should_prune(min_quality, min_accesses, max_age_secs))
393 .map(|(id, _)| *id)
394 .collect();
395
396 for id in to_remove {
397 self.patterns.remove(&id);
398 }
399
400 self.pattern_index
402 .retain(|(_, id)| self.patterns.contains_key(id));
403 }
404
405 pub fn config(&self) -> &PatternConfig {
407 &self.config
408 }
409
410 pub fn set_quality_threshold(&mut self, threshold: f32) {
413 self.config.quality_threshold = threshold.clamp(0.01, 1.0);
414 }
415
416 pub fn get_all_patterns(&self) -> Vec<LearnedPattern> {
418 self.patterns.values().cloned().collect()
419 }
420
421 pub fn insert_pattern(&mut self, pattern: LearnedPattern) {
423 let id = pattern.id;
424 if id >= self.next_pattern_id {
425 self.next_pattern_id = id + 1;
426 }
427 self.pattern_index.push((pattern.centroid.clone(), id));
428 self.patterns.insert(id, pattern);
429 }
430
431 pub fn set_min_trajectory_length(&mut self, n: usize) {
436 self.config.min_cluster_size = n;
437 }
438
439 pub fn set_min_pattern_quality(&mut self, q: f64) {
444 self.config.quality_threshold = q as f32;
445 }
446
447 pub fn consolidate(&mut self, similarity_threshold: f32) {
449 let pattern_ids: Vec<u64> = self.patterns.keys().copied().collect();
450 let mut merged = Vec::new();
451
452 for i in 0..pattern_ids.len() {
453 for j in i + 1..pattern_ids.len() {
454 let id1 = pattern_ids[i];
455 let id2 = pattern_ids[j];
456
457 if merged.contains(&id1) || merged.contains(&id2) {
458 continue;
459 }
460
461 if let (Some(p1), Some(p2)) = (self.patterns.get(&id1), self.patterns.get(&id2)) {
462 let sim = p1.similarity(&p2.centroid);
463 if sim > similarity_threshold {
464 let merged_pattern = p1.merge(p2);
466 self.patterns.insert(id1, merged_pattern);
467 merged.push(id2);
468 }
469 }
470 }
471 }
472
473 for id in merged {
475 self.patterns.remove(&id);
476 }
477
478 self.pattern_index
480 .retain(|(_, id)| self.patterns.contains_key(id));
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 fn make_trajectory(id: u64, embedding: Vec<f32>, quality: f32) -> QueryTrajectory {
489 let mut t = QueryTrajectory::new(id, embedding);
490 t.finalize(quality, 1000);
491 t
492 }
493
494 #[test]
495 fn test_bank_creation() {
496 let bank = ReasoningBank::new(PatternConfig::default());
497 assert_eq!(bank.trajectory_count(), 0);
498 assert_eq!(bank.pattern_count(), 0);
499 }
500
501 #[test]
502 fn test_add_trajectory() {
503 let config = PatternConfig {
504 embedding_dim: 4,
505 ..Default::default()
506 };
507 let mut bank = ReasoningBank::new(config);
508
509 let t = make_trajectory(1, vec![0.1, 0.2, 0.3, 0.4], 0.8);
510 bank.add_trajectory(&t);
511
512 assert_eq!(bank.trajectory_count(), 1);
513 }
514
515 #[test]
516 fn test_extract_patterns() {
517 let config = PatternConfig {
518 embedding_dim: 4,
519 k_clusters: 2,
520 min_cluster_size: 2,
521 quality_threshold: 0.0,
522 ..Default::default()
523 };
524 let mut bank = ReasoningBank::new(config);
525
526 for i in 0..5 {
528 let t = make_trajectory(i, vec![1.0, 0.0, 0.0, 0.0], 0.8);
529 bank.add_trajectory(&t);
530 }
531 for i in 5..10 {
532 let t = make_trajectory(i, vec![0.0, 1.0, 0.0, 0.0], 0.7);
533 bank.add_trajectory(&t);
534 }
535
536 let patterns = bank.extract_patterns();
537 assert!(!patterns.is_empty());
538 }
539
540 #[test]
541 fn test_find_similar() {
542 let config = PatternConfig {
543 embedding_dim: 4,
544 k_clusters: 2,
545 min_cluster_size: 2,
546 quality_threshold: 0.0,
547 ..Default::default()
548 };
549 let mut bank = ReasoningBank::new(config);
550
551 for i in 0..10 {
552 let emb = if i < 5 {
553 vec![1.0, 0.0, 0.0, 0.0]
554 } else {
555 vec![0.0, 1.0, 0.0, 0.0]
556 };
557 bank.add_trajectory(&make_trajectory(i, emb, 0.8));
558 }
559
560 bank.extract_patterns();
561
562 let query = vec![0.9, 0.1, 0.0, 0.0];
563 let similar = bank.find_similar(&query, 1);
564 assert!(!similar.is_empty());
565 }
566
567 #[test]
568 fn test_consolidate() {
569 let config = PatternConfig {
570 embedding_dim: 4,
571 k_clusters: 3,
572 min_cluster_size: 1,
573 quality_threshold: 0.0,
574 ..Default::default()
575 };
576 let mut bank = ReasoningBank::new(config);
577
578 for i in 0..9 {
580 let emb = vec![1.0 + (i as f32 * 0.001), 0.0, 0.0, 0.0];
581 bank.add_trajectory(&make_trajectory(i, emb, 0.8));
582 }
583
584 bank.extract_patterns();
585 let before = bank.pattern_count();
586
587 bank.consolidate(0.99);
588 let after = bank.pattern_count();
589
590 assert!(after <= before);
591 }
592}