Skip to main content

smelt_memory/utility/
bellman.rs

1//! Bellman propagation for utility spreading
2//!
3//! Implements temporal difference-style learning to spread utility
4//! from helpful episodes to similar ones, improving future retrieval.
5
6use crate::storage::VectorStore;
7use crate::types::Episode;
8use std::collections::HashMap;
9use uuid::Uuid;
10
11/// Result of a propagation run
12#[derive(Debug, Clone)]
13pub struct PropagationResult {
14    /// Number of episodes updated
15    pub episodes_updated: usize,
16    /// Total utility change
17    pub total_change: f64,
18    /// Maximum individual change
19    pub max_change: f64,
20}
21
22/// Run Bellman propagation to spread utility through the memory
23///
24/// Uses TD-style update: Q(s) = Q(s) + α * (γ * Q(s') * sim - Q(s))
25///
26/// # Arguments
27/// * `episodes` - Episodes with their current utilities
28/// * `vectors` - Vector store for similarity lookup
29/// * `learning_rate` - Alpha: how quickly to update (0.0 to 1.0)
30/// * `discount` - Gamma: how much to trust similar episodes (0.0 to 1.0)
31/// * `similarity_threshold` - Minimum similarity to consider
32///
33/// # Returns
34/// Map of episode IDs to new utilities and propagation statistics
35pub fn bellman_propagate(
36    episodes: &[Episode],
37    vectors: &VectorStore,
38    learning_rate: f64,
39    discount: f64,
40    similarity_threshold: f64,
41) -> (HashMap<Uuid, f64>, PropagationResult) {
42    let mut new_utilities: HashMap<Uuid, f64> = HashMap::new();
43    let mut total_change: f64 = 0.0;
44    let mut max_change: f64 = 0.0;
45    let mut updates = 0;
46
47    // Build lookup for quick access
48    let utility_map: HashMap<Uuid, f64> = episodes.iter().map(|e| (e.id, e.utility)).collect();
49
50    for episode in episodes {
51        // Get this episode's embedding
52        let Some(embedding) = vectors.get(episode.id) else {
53            new_utilities.insert(episode.id, episode.utility);
54            continue;
55        };
56
57        // Find similar episodes
58        let similar = vectors.search(embedding, 10);
59
60        // Calculate utility update from similar episodes
61        let mut weighted_utility = 0.0;
62        let mut weight_sum = 0.0;
63
64        for (similar_id, similarity) in similar {
65            // Skip self
66            if similar_id == episode.id {
67                continue;
68            }
69
70            // Skip low similarity
71            if similarity < similarity_threshold {
72                continue;
73            }
74
75            // Get the similar episode's utility
76            if let Some(&other_utility) = utility_map.get(&similar_id) {
77                weighted_utility += similarity * other_utility;
78                weight_sum += similarity;
79            }
80        }
81
82        // Calculate new utility
83        let new_utility = if weight_sum > 0.0 {
84            let neighbor_contribution = weighted_utility / weight_sum;
85
86            // TD update: Q(s) = Q(s) + α * (γ * Q(s') - Q(s))
87            let target = discount * neighbor_contribution;
88            let update = learning_rate * (target - episode.utility);
89
90            // Clamp to valid range
91            (episode.utility + update).clamp(0.0, 1.0)
92        } else {
93            episode.utility
94        };
95
96        let change = (new_utility - episode.utility).abs();
97        if change > 0.001 {
98            updates += 1;
99            total_change += change;
100            max_change = max_change.max(change);
101        }
102
103        new_utilities.insert(episode.id, new_utility);
104    }
105
106    let result = PropagationResult {
107        episodes_updated: updates,
108        total_change,
109        max_change,
110    };
111
112    (new_utilities, result)
113}
114
115/// Run temporal credit assignment
116///
117/// Credits episodes that preceded successful outcomes, creating
118/// a temporal chain of utility.
119///
120/// # Arguments
121/// * `episodes` - Episodes sorted by time (oldest first)
122/// * `temporal_discount` - How much to credit earlier episodes (0.0 to 1.0)
123///
124/// # Returns
125/// Map of episode IDs to temporal credits
126pub fn temporal_credit_assignment(
127    episodes: &[Episode],
128    temporal_discount: f64,
129) -> HashMap<Uuid, f64> {
130    let mut credits: HashMap<Uuid, f64> = HashMap::new();
131
132    // Initialize with current utilities
133    for episode in episodes {
134        credits.insert(episode.id, episode.utility);
135    }
136
137    // Work backwards through time
138    let mut prev_utility = 0.0;
139
140    for episode in episodes.iter().rev() {
141        let current_utility = episode.utility;
142
143        // Add discounted future utility
144        let temporal_bonus = temporal_discount * prev_utility;
145        let new_utility = (current_utility + temporal_bonus).min(1.0);
146
147        credits.insert(episode.id, new_utility);
148        prev_utility = new_utility;
149    }
150
151    credits
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use crate::types::EpisodeOutcome;
158
159    fn make_episode(id: Uuid, utility: f64) -> Episode {
160        let mut ep = Episode::new(
161            format!("Episode {}", id),
162            "test".to_string(),
163            EpisodeOutcome::Success,
164        );
165        ep.id = id;
166        ep.utility = utility;
167        ep
168    }
169
170    #[test]
171    fn test_propagation_basic() {
172        let id1 = Uuid::new_v4();
173        let id2 = Uuid::new_v4();
174
175        let ep1 = make_episode(id1, 1.0); // High utility
176        let ep2 = make_episode(id2, 0.2); // Low utility
177
178        let episodes = vec![ep1, ep2];
179
180        // Create vector store with similar embeddings
181        let mut vectors = VectorStore::new(3);
182        vectors.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
183        vectors.store(id2, vec![0.9, 0.1, 0.0]).unwrap(); // Similar to id1
184
185        let (utilities, result) = bellman_propagate(&episodes, &vectors, 0.5, 0.9, 0.5);
186
187        // Low utility episode should increase due to similarity to high utility
188        assert!(utilities[&id2] > 0.2);
189        assert!(result.episodes_updated > 0);
190    }
191
192    #[test]
193    fn test_no_propagation_for_dissimilar() {
194        let id1 = Uuid::new_v4();
195        let id2 = Uuid::new_v4();
196
197        let ep1 = make_episode(id1, 1.0);
198        let ep2 = make_episode(id2, 0.2);
199
200        let episodes = vec![ep1, ep2];
201
202        // Create vector store with dissimilar embeddings
203        let mut vectors = VectorStore::new(3);
204        vectors.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
205        vectors.store(id2, vec![0.0, 1.0, 0.0]).unwrap(); // Orthogonal
206
207        let (utilities, _) = bellman_propagate(&episodes, &vectors, 0.5, 0.9, 0.5);
208
209        // Should be unchanged or minimal change due to low similarity
210        assert!((utilities[&id2] - 0.2).abs() < 0.1);
211    }
212
213    #[test]
214    fn test_temporal_credit() {
215        let id1 = Uuid::new_v4();
216        let id2 = Uuid::new_v4();
217        let id3 = Uuid::new_v4();
218
219        let ep1 = make_episode(id1, 0.3); // Earlier, low utility
220        let ep2 = make_episode(id2, 0.5); // Middle
221        let ep3 = make_episode(id3, 1.0); // Latest, high utility
222
223        let episodes = vec![ep1, ep2, ep3]; // Sorted by time
224
225        let credits = temporal_credit_assignment(&episodes, 0.5);
226
227        // Earlier episodes should get credit from later successes
228        assert!(credits[&id1] > 0.3);
229        assert!(credits[&id2] > 0.5);
230    }
231}