smelt_memory/utility/
bellman.rs1use crate::storage::VectorStore;
7use crate::types::Episode;
8use std::collections::HashMap;
9use uuid::Uuid;
10
11#[derive(Debug, Clone)]
13pub struct PropagationResult {
14 pub episodes_updated: usize,
16 pub total_change: f64,
18 pub max_change: f64,
20}
21
22pub fn bellman_propagate(
36 episodes: &[Episode],
37 vectors: &VectorStore,
38 learning_rate: f64,
39 discount: f64,
40 similarity_threshold: f64,
41) -> (HashMap<Uuid, f64>, PropagationResult) {
42 let mut new_utilities: HashMap<Uuid, f64> = HashMap::new();
43 let mut total_change: f64 = 0.0;
44 let mut max_change: f64 = 0.0;
45 let mut updates = 0;
46
47 let utility_map: HashMap<Uuid, f64> = episodes.iter().map(|e| (e.id, e.utility)).collect();
49
50 for episode in episodes {
51 let Some(embedding) = vectors.get(episode.id) else {
53 new_utilities.insert(episode.id, episode.utility);
54 continue;
55 };
56
57 let similar = vectors.search(embedding, 10);
59
60 let mut weighted_utility = 0.0;
62 let mut weight_sum = 0.0;
63
64 for (similar_id, similarity) in similar {
65 if similar_id == episode.id {
67 continue;
68 }
69
70 if similarity < similarity_threshold {
72 continue;
73 }
74
75 if let Some(&other_utility) = utility_map.get(&similar_id) {
77 weighted_utility += similarity * other_utility;
78 weight_sum += similarity;
79 }
80 }
81
82 let new_utility = if weight_sum > 0.0 {
84 let neighbor_contribution = weighted_utility / weight_sum;
85
86 let target = discount * neighbor_contribution;
88 let update = learning_rate * (target - episode.utility);
89
90 (episode.utility + update).clamp(0.0, 1.0)
92 } else {
93 episode.utility
94 };
95
96 let change = (new_utility - episode.utility).abs();
97 if change > 0.001 {
98 updates += 1;
99 total_change += change;
100 max_change = max_change.max(change);
101 }
102
103 new_utilities.insert(episode.id, new_utility);
104 }
105
106 let result = PropagationResult {
107 episodes_updated: updates,
108 total_change,
109 max_change,
110 };
111
112 (new_utilities, result)
113}
114
115pub fn temporal_credit_assignment(
127 episodes: &[Episode],
128 temporal_discount: f64,
129) -> HashMap<Uuid, f64> {
130 let mut credits: HashMap<Uuid, f64> = HashMap::new();
131
132 for episode in episodes {
134 credits.insert(episode.id, episode.utility);
135 }
136
137 let mut prev_utility = 0.0;
139
140 for episode in episodes.iter().rev() {
141 let current_utility = episode.utility;
142
143 let temporal_bonus = temporal_discount * prev_utility;
145 let new_utility = (current_utility + temporal_bonus).min(1.0);
146
147 credits.insert(episode.id, new_utility);
148 prev_utility = new_utility;
149 }
150
151 credits
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use crate::types::EpisodeOutcome;
158
159 fn make_episode(id: Uuid, utility: f64) -> Episode {
160 let mut ep = Episode::new(
161 format!("Episode {}", id),
162 "test".to_string(),
163 EpisodeOutcome::Success,
164 );
165 ep.id = id;
166 ep.utility = utility;
167 ep
168 }
169
170 #[test]
171 fn test_propagation_basic() {
172 let id1 = Uuid::new_v4();
173 let id2 = Uuid::new_v4();
174
175 let ep1 = make_episode(id1, 1.0); let ep2 = make_episode(id2, 0.2); let episodes = vec![ep1, ep2];
179
180 let mut vectors = VectorStore::new(3);
182 vectors.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
183 vectors.store(id2, vec![0.9, 0.1, 0.0]).unwrap(); let (utilities, result) = bellman_propagate(&episodes, &vectors, 0.5, 0.9, 0.5);
186
187 assert!(utilities[&id2] > 0.2);
189 assert!(result.episodes_updated > 0);
190 }
191
192 #[test]
193 fn test_no_propagation_for_dissimilar() {
194 let id1 = Uuid::new_v4();
195 let id2 = Uuid::new_v4();
196
197 let ep1 = make_episode(id1, 1.0);
198 let ep2 = make_episode(id2, 0.2);
199
200 let episodes = vec![ep1, ep2];
201
202 let mut vectors = VectorStore::new(3);
204 vectors.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
205 vectors.store(id2, vec![0.0, 1.0, 0.0]).unwrap(); let (utilities, _) = bellman_propagate(&episodes, &vectors, 0.5, 0.9, 0.5);
208
209 assert!((utilities[&id2] - 0.2).abs() < 0.1);
211 }
212
213 #[test]
214 fn test_temporal_credit() {
215 let id1 = Uuid::new_v4();
216 let id2 = Uuid::new_v4();
217 let id3 = Uuid::new_v4();
218
219 let ep1 = make_episode(id1, 0.3); let ep2 = make_episode(id2, 0.5); let ep3 = make_episode(id3, 1.0); let episodes = vec![ep1, ep2, ep3]; let credits = temporal_credit_assignment(&episodes, 0.5);
226
227 assert!(credits[&id1] > 0.3);
229 assert!(credits[&id2] > 0.5);
230 }
231}