1use std::collections::{HashMap, VecDeque};
8
9use crate::error::{NdimageError, NdimageResult};
10use crate::hyperdimensional_computing::types::{
11 AdaptationParameters, ConsolidationResult, Experience, HDCConfig, Hypervector,
12 OnlineLearningResult, PerformanceMetrics, PredictionResult, UpdateResult,
13};
14
15#[derive(Debug, Clone)]
17pub struct HDCMemory {
18 pub patterns: HashMap<String, Hypervector>,
20 pub item_memory: HashMap<String, Hypervector>,
22 pub config: HDCConfig,
24}
25
26impl HDCMemory {
27 pub fn new(config: HDCConfig) -> Self {
29 Self {
30 patterns: HashMap::new(),
31 item_memory: HashMap::new(),
32 config,
33 }
34 }
35
36 pub fn store(&mut self, label: String, pattern: Hypervector) {
43 self.patterns.insert(label, pattern);
44 }
45
46 pub fn retrieve(&self, query: &Hypervector) -> Option<(String, f64)> {
56 let mut best_match = None;
57 let mut best_similarity = 0.0;
58
59 for (label, pattern) in &self.patterns {
60 let similarity = query.similarity(pattern);
61 if similarity > best_similarity && similarity >= self.config.similarity_threshold {
62 best_similarity = similarity;
63 best_match = Some((label.clone(), similarity));
64 }
65 }
66
67 best_match
68 }
69
70 pub fn get_patterns(&self) -> &HashMap<String, Hypervector> {
72 &self.patterns
73 }
74
75 pub fn remove(&mut self, label: &str) -> Option<Hypervector> {
85 self.patterns.remove(label)
86 }
87
88 pub fn clear(&mut self) {
90 self.patterns.clear();
91 }
92
93 pub fn size(&self) -> usize {
95 self.patterns.len()
96 }
97
98 pub fn update_pattern(
106 &mut self,
107 label: String,
108 new_pattern: Hypervector,
109 learning_rate: f64,
110 ) -> NdimageResult<()> {
111 if let Some(existing_pattern) = self.patterns.get(&label) {
112 let weighted_new = new_pattern.scale(learning_rate);
114 let weighted_existing = existing_pattern.scale(1.0 - learning_rate);
115 let updated = weighted_existing.bundle(&weighted_new)?;
116 self.patterns.insert(label, updated);
117 } else {
118 self.patterns.insert(label, new_pattern);
119 }
120 Ok(())
121 }
122
123 pub fn store_item(&mut self, name: String, item: Hypervector) {
125 self.item_memory.insert(name, item);
126 }
127
128 pub fn get_item(&self, name: &str) -> Option<&Hypervector> {
130 self.item_memory.get(name)
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct ContinualLearningMemory {
137 pub associative_memory: HashMap<String, Hypervector>,
139 pub episodic_buffer: VecDeque<Experience>,
141 pub consolidated_memory: HashMap<String, Hypervector>,
143 pub config: HDCConfig,
145 pub buffer_capacity: usize,
147 pub interference_threshold: f64,
149}
150
151impl ContinualLearningMemory {
152 pub fn new(config: &HDCConfig) -> Self {
154 Self {
155 associative_memory: HashMap::new(),
156 episodic_buffer: VecDeque::new(),
157 consolidated_memory: HashMap::new(),
158 config: config.clone(),
159 buffer_capacity: 1000,
160 interference_threshold: 0.7,
161 }
162 }
163
164 pub fn add_experience(
171 &mut self,
172 experience: Experience,
173 _consolidation: &ConsolidationResult,
174 ) -> NdimageResult<()> {
175 let interference = self.calculate_interference(&experience.encoding);
177
178 if interference > self.interference_threshold {
179 self.perform_replay_consolidation(&experience)?;
180 }
181
182 self.episodic_buffer.push_back(experience.clone());
184
185 if self.episodic_buffer.len() > self.buffer_capacity {
187 self.episodic_buffer.pop_front();
188 }
189
190 self.associative_memory
192 .insert(experience.label.clone(), experience.encoding);
193
194 Ok(())
195 }
196
197 pub fn calculate_interference(&self, new_encoding: &Hypervector) -> f64 {
207 let mut max_interference = 0.0;
208
209 for (_, existing_encoding) in &self.associative_memory {
210 let similarity = new_encoding.similarity(existing_encoding);
211 if similarity > max_interference {
212 max_interference = similarity;
213 }
214 }
215
216 max_interference
217 }
218
219 fn perform_replay_consolidation(&mut self, new_experience: &Experience) -> NdimageResult<()> {
221 let mut most_similar_label = None;
226 let mut max_similarity = 0.0;
227
228 for experience in &self.episodic_buffer {
229 let similarity = new_experience.encoding.similarity(&experience.encoding);
230 if similarity > max_similarity {
231 max_similarity = similarity;
232 most_similar_label = Some(experience.label.clone());
233 }
234 }
235
236 if let Some(label) = most_similar_label {
238 if max_similarity > self.interference_threshold {
239 if let Some(existing) = self.associative_memory.get(&label) {
240 let consolidated = existing.bundle(&new_experience.encoding)?;
241 self.consolidated_memory.insert(label, consolidated);
242 }
243 }
244 }
245
246 Ok(())
247 }
248
249 pub fn retrieve(&self, query: &Hypervector) -> Option<(String, f64)> {
251 let mut best_match = None;
252 let mut best_similarity = 0.0;
253
254 for (label, pattern) in &self.associative_memory {
256 let similarity = query.similarity(pattern);
257 if similarity > best_similarity && similarity >= self.config.similarity_threshold {
258 best_similarity = similarity;
259 best_match = Some((label.clone(), similarity));
260 }
261 }
262
263 for (label, pattern) in &self.consolidated_memory {
265 let similarity = query.similarity(pattern);
266 if similarity > best_similarity && similarity >= self.config.similarity_threshold {
267 best_similarity = similarity;
268 best_match = Some((label.clone(), similarity));
269 }
270 }
271
272 best_match
273 }
274
275 pub fn get_memory_stats(&self) -> (usize, usize, usize) {
277 (
278 self.associative_memory.len(),
279 self.episodic_buffer.len(),
280 self.consolidated_memory.len(),
281 )
282 }
283
284 pub fn get_current_time(&self) -> usize {
286 self.episodic_buffer.len()
288 }
289
290 pub fn update_meta_learning_parameters(
292 &mut self,
293 _stats: &crate::hyperdimensional_computing::reasoning::ContinualLearningStats,
294 ) {
295 }
297
298 pub fn get_meta_learning_score(&self) -> f64 {
300 0.7
302 }
303}
304
305#[derive(Debug, Clone)]
307pub struct PerformanceTracker {
308 pub accuracyhistory: VecDeque<f64>,
310 pub learning_speedhistory: VecDeque<f64>,
312 pub update_count: usize,
314 pub max_history_length: usize,
316}
317
318impl PerformanceTracker {
319 pub fn new() -> Self {
321 Self {
322 accuracyhistory: VecDeque::new(),
323 learning_speedhistory: VecDeque::new(),
324 update_count: 0,
325 max_history_length: 100,
326 }
327 }
328
329 pub fn record_update(&mut self, accuracy: f64, learning_speed: f64) {
336 self.accuracyhistory.push_back(accuracy);
337 self.learning_speedhistory.push_back(learning_speed);
338 self.update_count += 1;
339
340 if self.accuracyhistory.len() > self.max_history_length {
342 self.accuracyhistory.pop_front();
343 }
344 if self.learning_speedhistory.len() > self.max_history_length {
345 self.learning_speedhistory.pop_front();
346 }
347 }
348
349 pub fn get_accuracy(&self) -> f64 {
351 if self.accuracyhistory.is_empty() {
352 0.0
353 } else {
354 self.accuracyhistory.iter().sum::<f64>() / self.accuracyhistory.len() as f64
355 }
356 }
357
358 pub fn get_learning_speed(&self) -> f64 {
360 if self.learning_speedhistory.is_empty() {
361 0.0
362 } else {
363 self.learning_speedhistory.iter().sum::<f64>() / self.learning_speedhistory.len() as f64
364 }
365 }
366
367 pub fn get_memory_efficiency(&self) -> f64 {
369 1.0 / (1.0 + self.update_count as f64 / 1000.0)
371 }
372
373 pub fn get_recent_performance_change(&self) -> f64 {
375 if self.accuracyhistory.len() < 10 {
376 return 0.0;
377 }
378
379 let recent: f64 = self.accuracyhistory.iter().rev().take(5).sum::<f64>() / 5.0;
380 let older: f64 = self
381 .accuracyhistory
382 .iter()
383 .rev()
384 .skip(5)
385 .take(5)
386 .sum::<f64>()
387 / 5.0;
388 recent - older
389 }
390
391 pub fn reset(&mut self) {
393 self.accuracyhistory.clear();
394 self.learning_speedhistory.clear();
395 self.update_count = 0;
396 }
397}
398
399#[derive(Debug, Clone)]
401pub struct OnlineLearningSystem {
402 pub memory: HDCMemory,
404 pub continual_memory: ContinualLearningMemory,
406 pub performance_tracker: PerformanceTracker,
408 pub adaptation_params: AdaptationParameters,
410 pub learning_state: LearningState,
412}
413
414#[derive(Debug, Clone)]
416pub enum LearningState {
417 Normal,
419 RapidAdaptation,
421 Conservative,
423}
424
425impl OnlineLearningSystem {
426 pub fn new(config: &HDCConfig) -> Self {
428 Self {
429 memory: HDCMemory::new(config.clone()),
430 continual_memory: ContinualLearningMemory::new(config),
431 performance_tracker: PerformanceTracker::new(),
432 adaptation_params: AdaptationParameters::default(),
433 learning_state: LearningState::Normal,
434 }
435 }
436
437 pub fn predict(&self, input: &Hypervector) -> NdimageResult<PredictionResult> {
447 if let Some((label, confidence)) = self.memory.retrieve(input) {
449 let mut alternatives = Vec::new();
451 if let Some((alt_label, alt_confidence)) = self.continual_memory.retrieve(input) {
452 if alt_label != label {
453 alternatives.push((alt_label, alt_confidence));
454 }
455 }
456
457 return Ok(PredictionResult {
458 predicted_label: label,
459 confidence,
460 alternatives,
461 });
462 }
463
464 if let Some((label, confidence)) = self.continual_memory.retrieve(input) {
466 return Ok(PredictionResult {
467 predicted_label: label,
468 confidence,
469 alternatives: Vec::new(),
470 });
471 }
472
473 Ok(PredictionResult {
475 predicted_label: "unknown".to_string(),
476 confidence: 0.0,
477 alternatives: Vec::new(),
478 })
479 }
480
481 pub fn update_with_feedback(
494 &mut self,
495 input: &Hypervector,
496 correct_label: &str,
497 learning_rate: f64,
498 prediction_error: f64,
499 ) -> NdimageResult<UpdateResult> {
500 let accuracy = 1.0 - prediction_error;
502 self.performance_tracker
503 .record_update(accuracy, learning_rate);
504
505 self.memory
507 .update_pattern(correct_label.to_string(), input.clone(), learning_rate)?;
508
509 let experience = Experience {
511 encoding: input.clone(),
512 label: correct_label.to_string(),
513 timestamp: self.performance_tracker.update_count,
514 importance: 1.0 - prediction_error,
515 };
516
517 let consolidation = ConsolidationResult {
518 interference_prevented: 0,
519 effectiveness_score: accuracy,
520 replay_cycles_used: 1,
521 };
522
523 self.continual_memory
524 .add_experience(experience, &consolidation)?;
525
526 self.adaptation_params
528 .adjust_based_on_performance(&self.performance_tracker);
529
530 self.update_learning_state();
532
533 Ok(UpdateResult {
534 memory_updated: true,
535 learning_rate_used: learning_rate,
536 performance_change: self.performance_tracker.get_recent_performance_change(),
537 })
538 }
539
540 pub fn online_learning_step(
551 &mut self,
552 input: &Hypervector,
553 true_label: Option<&str>,
554 ) -> NdimageResult<OnlineLearningResult> {
555 let prediction = self.predict(input)?;
557
558 let learning_update = if let Some(label) = true_label {
560 let error = calculate_prediction_error(&prediction, label);
561 self.update_with_feedback(input, label, self.adaptation_params.current_rate, error)?
562 } else {
563 UpdateResult {
564 memory_updated: false,
565 learning_rate_used: 0.0,
566 performance_change: 0.0,
567 }
568 };
569
570 let system_performance = self.get_performancemetrics();
572
573 Ok(OnlineLearningResult {
574 prediction,
575 learning_update,
576 system_performance,
577 adaptation_rate: self.adaptation_params.current_rate,
578 })
579 }
580
581 fn update_learning_state(&mut self) {
583 let recent_change = self.performance_tracker.get_recent_performance_change();
584 let accuracy = self.performance_tracker.get_accuracy();
585
586 self.learning_state = if recent_change < -0.1 && accuracy < 0.7 {
587 LearningState::RapidAdaptation
588 } else if accuracy > 0.9 && recent_change.abs() < 0.05 {
589 LearningState::Conservative
590 } else {
591 LearningState::Normal
592 };
593 }
594
595 pub fn get_performancemetrics(&self) -> PerformanceMetrics {
597 PerformanceMetrics {
598 accuracy: self.performance_tracker.get_accuracy(),
599 learning_speed: self.performance_tracker.get_learning_speed(),
600 memory_efficiency: self.performance_tracker.get_memory_efficiency(),
601 adaptation_effectiveness: self.adaptation_params.current_rate
602 / self.adaptation_params.base_rate,
603 }
604 }
605
606 pub fn perform_maintenance_cycle(&mut self, _config: &HDCConfig) -> NdimageResult<()> {
608 if self.performance_tracker.get_accuracy() < 0.5
612 && self.performance_tracker.update_count > 100
613 {
614 self.adaptation_params.reset();
615 }
616
617 Ok(())
623 }
624
625 pub fn get_learning_state(&self) -> &LearningState {
627 &self.learning_state
628 }
629
630 pub fn get_memory_stats(&self) -> (usize, (usize, usize, usize)) {
632 (self.memory.size(), self.continual_memory.get_memory_stats())
633 }
634
635 pub fn compute_adaptive_learning_rate(&self, prediction_error: f64) -> f64 {
637 let base_rate = self.adaptation_params.base_rate;
639 let error_factor = 1.0 + prediction_error;
640 (base_rate * error_factor)
641 .min(self.adaptation_params.max_rate)
642 .max(self.adaptation_params.min_rate)
643 }
644
645 pub fn unsupervised_update(&mut self, input: &Hypervector) -> NdimageResult<UpdateResult> {
647 let synthetic_label = format!("unsupervised_{}", self.performance_tracker.update_count);
649 self.memory.store(synthetic_label, input.clone());
650
651 Ok(UpdateResult {
652 memory_updated: true,
653 learning_rate_used: self.adaptation_params.current_rate,
654 performance_change: 0.0, })
656 }
657
658 pub fn get_performance_metrics(&self) -> PerformanceMetrics {
660 self.get_performancemetrics()
661 }
662
663 pub fn get_current_adaptation_rate(&self) -> f64 {
665 self.adaptation_params.current_rate
666 }
667}
668
669pub fn calculate_prediction_error(prediction: &PredictionResult, true_label: &str) -> f64 {
680 if prediction.predicted_label == true_label {
681 0.0 } else {
683 1.0 - prediction.confidence }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::hyperdimensional_computing::vector_ops::*;
691
692 #[test]
693 fn test_hdc_memory_basic_operations() {
694 let config = HDCConfig::default();
695 let mut memory = HDCMemory::new(config);
696
697 let hv1 = Hypervector::random(1000, 0.1);
698 let hv2 = Hypervector::random(1000, 0.1);
699
700 memory.store("pattern1".to_string(), hv1.clone());
702 memory.store("pattern2".to_string(), hv2.clone());
703
704 assert_eq!(memory.size(), 2);
705
706 let result = memory.retrieve(&hv1);
708 assert!(result.is_some());
709 let (label, confidence) = result.expect("Operation failed");
710 assert_eq!(label, "pattern1");
711 assert!(confidence > 0.8);
712
713 let removed = memory.remove("pattern1");
715 assert!(removed.is_some());
716 assert_eq!(memory.size(), 1);
717
718 memory.clear();
720 assert_eq!(memory.size(), 0);
721 }
722
723 #[test]
724 fn test_continual_learning_memory() {
725 let config = HDCConfig::default();
726 let mut memory = ContinualLearningMemory::new(&config);
727
728 let encoding = Hypervector::random(config.hypervector_dim, config.sparsity);
729 let experience = Experience {
730 encoding: encoding.clone(),
731 label: "test".to_string(),
732 timestamp: 0,
733 importance: 0.8,
734 };
735
736 let consolidation = ConsolidationResult {
737 interference_prevented: 1,
738 effectiveness_score: 0.9,
739 replay_cycles_used: 3,
740 };
741
742 assert!(memory.add_experience(experience, &consolidation).is_ok());
743 assert_eq!(memory.episodic_buffer.len(), 1);
744
745 let interference = memory.calculate_interference(&encoding);
746 assert!(interference >= 0.0);
747 assert!(interference <= 1.0);
748 }
749
750 #[test]
751 fn test_performance_tracker() {
752 let mut tracker = PerformanceTracker::new();
753
754 tracker.record_update(0.8, 0.1);
756 tracker.record_update(0.9, 0.12);
757 tracker.record_update(0.85, 0.11);
758
759 assert_eq!(tracker.update_count, 3);
760
761 let accuracy = tracker.get_accuracy();
762 assert!(accuracy > 0.8);
763 assert!(accuracy < 0.9);
764
765 let learning_speed = tracker.get_learning_speed();
766 assert!(learning_speed > 0.1);
767 assert!(learning_speed < 0.13);
768
769 let memory_efficiency = tracker.get_memory_efficiency();
770 assert!(memory_efficiency > 0.0);
771 assert!(memory_efficiency <= 1.0);
772
773 tracker.reset();
775 assert_eq!(tracker.update_count, 0);
776 assert!(tracker.accuracyhistory.is_empty());
777 }
778
779 #[test]
780 fn test_online_learning_system() {
781 let config = HDCConfig::default();
782 let mut system = OnlineLearningSystem::new(&config);
783
784 let encoding = Hypervector::random(config.hypervector_dim, config.sparsity);
785
786 let prediction = system.predict(&encoding).expect("Operation failed");
788 assert_eq!(prediction.predicted_label, "unknown");
789 assert_eq!(prediction.confidence, 0.0);
790
791 let learning_rate = 0.1;
793 let error = 0.5;
794 let update_result = system
795 .update_with_feedback(&encoding, "test_label", learning_rate, error)
796 .expect("Operation failed");
797 assert!(update_result.memory_updated);
798 assert_eq!(update_result.learning_rate_used, learning_rate);
799
800 let result = system
802 .online_learning_step(&encoding, Some("test_label"))
803 .expect("Operation failed");
804 assert!(result.prediction.confidence > 0.0);
805 assert!(result.system_performance.accuracy > 0.0);
806
807 assert!(system.perform_maintenance_cycle(&config).is_ok());
809
810 let metrics = system.get_performancemetrics();
812 assert!(metrics.accuracy >= 0.0);
813 assert!(metrics.accuracy <= 1.0);
814 }
815
816 #[test]
817 fn test_calculate_prediction_error() {
818 let correct_prediction = PredictionResult {
819 predicted_label: "cat".to_string(),
820 confidence: 0.9,
821 alternatives: Vec::new(),
822 };
823
824 let incorrect_prediction = PredictionResult {
825 predicted_label: "dog".to_string(),
826 confidence: 0.7,
827 alternatives: Vec::new(),
828 };
829
830 let error1 = calculate_prediction_error(&correct_prediction, "cat");
831 assert_eq!(error1, 0.0);
832
833 let error2 = calculate_prediction_error(&incorrect_prediction, "cat");
834 assert_eq!(error2, 1.0 - 0.7); }
836
837 #[test]
838 fn test_memory_update_pattern() {
839 let config = HDCConfig::default();
840 let mut memory = HDCMemory::new(config);
841
842 let hv1 = Hypervector::random(1000, 0.1);
843 let hv2 = Hypervector::random(1000, 0.1);
844
845 memory.store("test".to_string(), hv1.clone());
847
848 assert!(memory.update_pattern("test".to_string(), hv2, 0.5).is_ok());
850
851 let stored = memory.patterns.get("test").expect("Operation failed");
853 assert!(stored.similarity(&hv1) < 1.0); }
855}