Skip to main content

smelt_memory/
memory.rs

1//! Main SmeltMemory implementation
2
3use crate::embedder::{Embedder, FastEmbedder, DEFAULT_DIMENSION};
4use crate::error::MemoryResult;
5use crate::storage::{EpisodeStorage, MemoryStats, VectorStore};
6use crate::types::{Episode, EpisodeOutcome, ErrorResolution, RankedEpisode};
7use crate::utility::{
8    bellman_propagate, temporal_credit_assignment, PropagationResult, UtilityRanker,
9};
10use smelt_core::IntentRecord;
11use std::path::Path;
12use std::sync::Arc;
13use uuid::Uuid;
14
15/// Main memory system for Smelt
16pub struct SmeltMemory {
17    /// Episode metadata storage
18    storage: EpisodeStorage,
19    /// Vector embeddings storage
20    vectors: VectorStore,
21    /// Text embedder
22    embedder: Arc<dyn Embedder>,
23    /// Utility-based ranker
24    ranker: UtilityRanker,
25    /// Current project name
26    project: Option<String>,
27}
28
29impl SmeltMemory {
30    /// Open or create a memory system in the given directory
31    pub fn open(path: &Path) -> MemoryResult<Self> {
32        let db_path = path.join("memory.db");
33        let vectors_path = path.join("vectors.json");
34
35        let storage = EpisodeStorage::open(&db_path)?;
36
37        // Try to create embedder, fall back to dummy dimension if model unavailable
38        let (embedder, dimension): (Arc<dyn Embedder>, usize) = match FastEmbedder::new() {
39            Ok(e) => {
40                let dim = e.dimension();
41                (Arc::new(e), dim)
42            }
43            Err(e) => {
44                tracing::warn!("Failed to initialize embedder: {}", e);
45                // Use a dummy embedder that returns zeros
46                (
47                    Arc::new(DummyEmbedder::new(DEFAULT_DIMENSION)),
48                    DEFAULT_DIMENSION,
49                )
50            }
51        };
52
53        let vectors = VectorStore::open(&vectors_path, dimension)?;
54
55        Ok(Self {
56            storage,
57            vectors,
58            embedder,
59            ranker: UtilityRanker::new(),
60            project: None,
61        })
62    }
63
64    /// Create an in-memory instance for testing
65    pub fn in_memory() -> MemoryResult<Self> {
66        let storage = EpisodeStorage::in_memory()?;
67        let vectors = VectorStore::new(DEFAULT_DIMENSION);
68
69        Ok(Self {
70            storage,
71            vectors,
72            embedder: Arc::new(DummyEmbedder::new(DEFAULT_DIMENSION)),
73            ranker: UtilityRanker::new(),
74            project: None,
75        })
76    }
77
78    /// Set the current project
79    pub fn with_project(mut self, project: String) -> Self {
80        self.project = Some(project);
81        self
82    }
83
84    /// Capture an episode from an intent and its outcome
85    pub fn capture_from_intent(
86        &mut self,
87        intent: &IntentRecord,
88        outcome: EpisodeOutcome,
89        files_modified: Vec<String>,
90        errors_resolved: Vec<ErrorResolution>,
91        tags: Vec<String>,
92        commit_sha: Option<String>,
93    ) -> MemoryResult<Uuid> {
94        let mut episode = Episode::new(intent.goal.clone(), "intent".to_string(), outcome)
95            .with_intent(intent.id)
96            .with_files(files_modified)
97            .with_errors(errors_resolved)
98            .with_tags(tags);
99
100        if let Some(sha) = commit_sha {
101            episode = episode.with_commit(sha);
102        }
103
104        if let Some(ref project) = self.project {
105            episode = episode.with_project(project.clone());
106        }
107
108        self.capture(episode)
109    }
110
111    /// Capture an episode directly
112    pub fn capture(&mut self, mut episode: Episode) -> MemoryResult<Uuid> {
113        // Set project if not already set
114        if episode.project.is_none() {
115            episode.project = self.project.clone();
116        }
117
118        let id = episode.id;
119
120        // Generate and store embedding
121        let text = episode.to_embedding_text();
122        match self.embedder.embed(&text) {
123            Ok(embedding) => {
124                self.vectors.store(id, embedding)?;
125            }
126            Err(e) => {
127                tracing::warn!("Failed to embed episode {}: {}", id, e);
128            }
129        }
130
131        // Store metadata
132        self.storage.store_episode(&episode)?;
133
134        Ok(id)
135    }
136
137    /// Retrieve relevant episodes for a query
138    pub fn retrieve(&self, query: &str, limit: usize) -> MemoryResult<Vec<RankedEpisode>> {
139        // Embed the query
140        let query_embedding = self.embedder.embed(query)?;
141
142        // Find similar episodes
143        let similar = self.vectors.search(&query_embedding, limit * 2);
144
145        if similar.is_empty() {
146            return Ok(Vec::new());
147        }
148
149        // Fetch episode metadata and filter by project
150        let mut episodes = Vec::new();
151        let mut similarities = Vec::new();
152
153        for (id, similarity) in similar {
154            if let Some(episode) = self.storage.get_episode(id)? {
155                // Filter by project if set
156                if let Some(ref project) = self.project {
157                    if episode.project.as_ref() != Some(project) {
158                        continue;
159                    }
160                }
161
162                episodes.push(episode);
163                similarities.push(similarity);
164
165                if episodes.len() >= limit {
166                    break;
167                }
168            }
169        }
170
171        // Rank with utility
172        let ranked = self.ranker.rank(episodes, similarities);
173
174        Ok(ranked.into_iter().take(limit).collect())
175    }
176
177    /// Record feedback for an episode
178    pub fn record_feedback(&mut self, episode_id: Uuid, helpful: bool) -> MemoryResult<()> {
179        // Update feedback counts
180        self.storage.record_feedback(episode_id, helpful)?;
181
182        // Update utility
183        if let Some(episode) = self.storage.get_episode(episode_id)? {
184            let new_utility = self
185                .ranker
186                .update_utility_from_feedback(&episode, helpful, 0.1);
187            self.storage.update_utility(episode_id, new_utility)?;
188        }
189
190        Ok(())
191    }
192
193    /// Run utility propagation to spread value through the memory
194    pub fn propagate_utility(&mut self, temporal: bool) -> MemoryResult<PropagationResult> {
195        let episodes = self.storage.list_episodes(self.project.as_deref())?;
196
197        if episodes.is_empty() {
198            return Ok(PropagationResult {
199                episodes_updated: 0,
200                total_change: 0.0,
201                max_change: 0.0,
202            });
203        }
204
205        // Run Bellman propagation
206        let (new_utilities, result) = bellman_propagate(
207            &episodes,
208            &self.vectors,
209            0.1, // learning_rate
210            0.9, // discount
211            0.5, // similarity_threshold
212        );
213
214        // Apply temporal credit if requested
215        let final_utilities = if temporal {
216            let temporal_credits = temporal_credit_assignment(&episodes, 0.5);
217
218            // Combine Bellman and temporal credits
219            new_utilities
220                .into_iter()
221                .map(|(id, bellman_u)| {
222                    let temporal_u = temporal_credits.get(&id).copied().unwrap_or(bellman_u);
223                    (id, (bellman_u + temporal_u) / 2.0)
224                })
225                .collect()
226        } else {
227            new_utilities
228        };
229
230        // Update storage
231        for (id, utility) in final_utilities {
232            self.storage.update_utility(id, utility)?;
233        }
234
235        Ok(result)
236    }
237
238    /// Get an episode by ID
239    pub fn get_episode(&self, id: Uuid) -> MemoryResult<Option<Episode>> {
240        self.storage.get_episode(id)
241    }
242
243    /// List all episodes
244    pub fn list_episodes(&self) -> MemoryResult<Vec<Episode>> {
245        self.storage.list_episodes(self.project.as_deref())
246    }
247
248    /// Get memory statistics
249    pub fn stats(&self) -> MemoryResult<MemoryStats> {
250        self.storage.get_stats(self.project.as_deref())
251    }
252}
253
254/// Dummy embedder for when fastembed is unavailable
255struct DummyEmbedder {
256    dimension: usize,
257}
258
259impl DummyEmbedder {
260    fn new(dimension: usize) -> Self {
261        Self { dimension }
262    }
263}
264
265impl Embedder for DummyEmbedder {
266    fn dimension(&self) -> usize {
267        self.dimension
268    }
269
270    fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
271        // Generate a simple hash-based embedding
272        // This is not semantically meaningful but allows basic testing
273        let mut embedding = vec![0.0f32; self.dimension];
274
275        for (i, byte) in text.bytes().enumerate() {
276            let idx = i % self.dimension;
277            embedding[idx] += byte as f32 / 255.0;
278        }
279
280        // Normalize
281        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
282        if norm > 0.0 {
283            for x in &mut embedding {
284                *x /= norm;
285            }
286        }
287
288        Ok(embedding)
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_capture_and_retrieve() {
298        let mut memory = SmeltMemory::in_memory().unwrap();
299
300        let episode = Episode::new(
301            "Fixed authentication bug in login flow".to_string(),
302            "bugfix".to_string(),
303            EpisodeOutcome::Success,
304        )
305        .with_tags(vec!["auth".to_string(), "security".to_string()]);
306
307        let id = memory.capture(episode).unwrap();
308
309        // Retrieve with similar query
310        let results = memory.retrieve("authentication login", 5).unwrap();
311        assert!(!results.is_empty());
312        assert_eq!(results[0].episode.id, id);
313    }
314
315    #[test]
316    fn test_feedback() {
317        let mut memory = SmeltMemory::in_memory().unwrap();
318
319        let episode = Episode::new(
320            "Test episode".to_string(),
321            "test".to_string(),
322            EpisodeOutcome::Success,
323        );
324
325        let id = memory.capture(episode).unwrap();
326
327        // Record positive feedback
328        memory.record_feedback(id, true).unwrap();
329        memory.record_feedback(id, true).unwrap();
330
331        let updated = memory.get_episode(id).unwrap().unwrap();
332        assert_eq!(updated.helpful_count, 2);
333        assert!(updated.utility > 1.0 - 0.001); // Should increase toward 1.0
334    }
335
336    #[test]
337    fn test_propagation() {
338        let mut memory = SmeltMemory::in_memory().unwrap();
339
340        // Create some episodes
341        let ep1 = Episode::new(
342            "Auth fix".to_string(),
343            "bugfix".to_string(),
344            EpisodeOutcome::Success,
345        );
346        let ep2 = Episode::new(
347            "Auth test".to_string(),
348            "test".to_string(),
349            EpisodeOutcome::Partial,
350        );
351
352        memory.capture(ep1).unwrap();
353        memory.capture(ep2).unwrap();
354
355        // Run propagation
356        let result = memory.propagate_utility(false).unwrap();
357        // Just verify it runs without error
358        assert!(result.total_change >= 0.0);
359    }
360
361    #[test]
362    fn test_project_filter() {
363        let mut memory = SmeltMemory::in_memory().unwrap();
364
365        let ep1 = Episode::new(
366            "Project A work".to_string(),
367            "feature".to_string(),
368            EpisodeOutcome::Success,
369        )
370        .with_project("project-a".to_string());
371        let ep2 = Episode::new(
372            "Project B work".to_string(),
373            "feature".to_string(),
374            EpisodeOutcome::Success,
375        )
376        .with_project("project-b".to_string());
377
378        memory.capture(ep1).unwrap();
379        memory.capture(ep2).unwrap();
380
381        // Without filter, should see both
382        let all = memory.list_episodes().unwrap();
383        assert_eq!(all.len(), 2);
384
385        // With project filter
386        let _memory_a = memory.with_project("project-a".to_string());
387        // Note: list_episodes would now filter by project
388    }
389
390    #[test]
391    fn test_stats() {
392        let mut memory = SmeltMemory::in_memory().unwrap();
393
394        let ep1 = Episode::new(
395            "Ep1".to_string(),
396            "test".to_string(),
397            EpisodeOutcome::Success,
398        );
399        let ep2 = Episode::new(
400            "Ep2".to_string(),
401            "test".to_string(),
402            EpisodeOutcome::Success,
403        );
404
405        memory.capture(ep1).unwrap();
406        memory.capture(ep2).unwrap();
407
408        let stats = memory.stats().unwrap();
409        assert_eq!(stats.total_episodes, 2);
410    }
411}