1use crate::embedder::{Embedder, FastEmbedder, DEFAULT_DIMENSION};
4use crate::error::MemoryResult;
5use crate::storage::{EpisodeStorage, MemoryStats, VectorStore};
6use crate::types::{Episode, EpisodeOutcome, ErrorResolution, RankedEpisode};
7use crate::utility::{
8 bellman_propagate, temporal_credit_assignment, PropagationResult, UtilityRanker,
9};
10use smelt_core::IntentRecord;
11use std::path::Path;
12use std::sync::Arc;
13use uuid::Uuid;
14
15pub struct SmeltMemory {
17 storage: EpisodeStorage,
19 vectors: VectorStore,
21 embedder: Arc<dyn Embedder>,
23 ranker: UtilityRanker,
25 project: Option<String>,
27}
28
29impl SmeltMemory {
30 pub fn open(path: &Path) -> MemoryResult<Self> {
32 let db_path = path.join("memory.db");
33 let vectors_path = path.join("vectors.json");
34
35 let storage = EpisodeStorage::open(&db_path)?;
36
37 let (embedder, dimension): (Arc<dyn Embedder>, usize) = match FastEmbedder::new() {
39 Ok(e) => {
40 let dim = e.dimension();
41 (Arc::new(e), dim)
42 }
43 Err(e) => {
44 tracing::warn!("Failed to initialize embedder: {}", e);
45 (
47 Arc::new(DummyEmbedder::new(DEFAULT_DIMENSION)),
48 DEFAULT_DIMENSION,
49 )
50 }
51 };
52
53 let vectors = VectorStore::open(&vectors_path, dimension)?;
54
55 Ok(Self {
56 storage,
57 vectors,
58 embedder,
59 ranker: UtilityRanker::new(),
60 project: None,
61 })
62 }
63
64 pub fn in_memory() -> MemoryResult<Self> {
66 let storage = EpisodeStorage::in_memory()?;
67 let vectors = VectorStore::new(DEFAULT_DIMENSION);
68
69 Ok(Self {
70 storage,
71 vectors,
72 embedder: Arc::new(DummyEmbedder::new(DEFAULT_DIMENSION)),
73 ranker: UtilityRanker::new(),
74 project: None,
75 })
76 }
77
78 pub fn with_project(mut self, project: String) -> Self {
80 self.project = Some(project);
81 self
82 }
83
84 pub fn capture_from_intent(
86 &mut self,
87 intent: &IntentRecord,
88 outcome: EpisodeOutcome,
89 files_modified: Vec<String>,
90 errors_resolved: Vec<ErrorResolution>,
91 tags: Vec<String>,
92 commit_sha: Option<String>,
93 ) -> MemoryResult<Uuid> {
94 let mut episode = Episode::new(intent.goal.clone(), "intent".to_string(), outcome)
95 .with_intent(intent.id)
96 .with_files(files_modified)
97 .with_errors(errors_resolved)
98 .with_tags(tags);
99
100 if let Some(sha) = commit_sha {
101 episode = episode.with_commit(sha);
102 }
103
104 if let Some(ref project) = self.project {
105 episode = episode.with_project(project.clone());
106 }
107
108 self.capture(episode)
109 }
110
111 pub fn capture(&mut self, mut episode: Episode) -> MemoryResult<Uuid> {
113 if episode.project.is_none() {
115 episode.project = self.project.clone();
116 }
117
118 let id = episode.id;
119
120 let text = episode.to_embedding_text();
122 match self.embedder.embed(&text) {
123 Ok(embedding) => {
124 self.vectors.store(id, embedding)?;
125 }
126 Err(e) => {
127 tracing::warn!("Failed to embed episode {}: {}", id, e);
128 }
129 }
130
131 self.storage.store_episode(&episode)?;
133
134 Ok(id)
135 }
136
137 pub fn retrieve(&self, query: &str, limit: usize) -> MemoryResult<Vec<RankedEpisode>> {
139 let query_embedding = self.embedder.embed(query)?;
141
142 let similar = self.vectors.search(&query_embedding, limit * 2);
144
145 if similar.is_empty() {
146 return Ok(Vec::new());
147 }
148
149 let mut episodes = Vec::new();
151 let mut similarities = Vec::new();
152
153 for (id, similarity) in similar {
154 if let Some(episode) = self.storage.get_episode(id)? {
155 if let Some(ref project) = self.project {
157 if episode.project.as_ref() != Some(project) {
158 continue;
159 }
160 }
161
162 episodes.push(episode);
163 similarities.push(similarity);
164
165 if episodes.len() >= limit {
166 break;
167 }
168 }
169 }
170
171 let ranked = self.ranker.rank(episodes, similarities);
173
174 Ok(ranked.into_iter().take(limit).collect())
175 }
176
177 pub fn record_feedback(&mut self, episode_id: Uuid, helpful: bool) -> MemoryResult<()> {
179 self.storage.record_feedback(episode_id, helpful)?;
181
182 if let Some(episode) = self.storage.get_episode(episode_id)? {
184 let new_utility = self
185 .ranker
186 .update_utility_from_feedback(&episode, helpful, 0.1);
187 self.storage.update_utility(episode_id, new_utility)?;
188 }
189
190 Ok(())
191 }
192
193 pub fn propagate_utility(&mut self, temporal: bool) -> MemoryResult<PropagationResult> {
195 let episodes = self.storage.list_episodes(self.project.as_deref())?;
196
197 if episodes.is_empty() {
198 return Ok(PropagationResult {
199 episodes_updated: 0,
200 total_change: 0.0,
201 max_change: 0.0,
202 });
203 }
204
205 let (new_utilities, result) = bellman_propagate(
207 &episodes,
208 &self.vectors,
209 0.1, 0.9, 0.5, );
213
214 let final_utilities = if temporal {
216 let temporal_credits = temporal_credit_assignment(&episodes, 0.5);
217
218 new_utilities
220 .into_iter()
221 .map(|(id, bellman_u)| {
222 let temporal_u = temporal_credits.get(&id).copied().unwrap_or(bellman_u);
223 (id, (bellman_u + temporal_u) / 2.0)
224 })
225 .collect()
226 } else {
227 new_utilities
228 };
229
230 for (id, utility) in final_utilities {
232 self.storage.update_utility(id, utility)?;
233 }
234
235 Ok(result)
236 }
237
238 pub fn get_episode(&self, id: Uuid) -> MemoryResult<Option<Episode>> {
240 self.storage.get_episode(id)
241 }
242
243 pub fn list_episodes(&self) -> MemoryResult<Vec<Episode>> {
245 self.storage.list_episodes(self.project.as_deref())
246 }
247
248 pub fn stats(&self) -> MemoryResult<MemoryStats> {
250 self.storage.get_stats(self.project.as_deref())
251 }
252}
253
254struct DummyEmbedder {
256 dimension: usize,
257}
258
259impl DummyEmbedder {
260 fn new(dimension: usize) -> Self {
261 Self { dimension }
262 }
263}
264
265impl Embedder for DummyEmbedder {
266 fn dimension(&self) -> usize {
267 self.dimension
268 }
269
270 fn embed(&self, text: &str) -> MemoryResult<Vec<f32>> {
271 let mut embedding = vec![0.0f32; self.dimension];
274
275 for (i, byte) in text.bytes().enumerate() {
276 let idx = i % self.dimension;
277 embedding[idx] += byte as f32 / 255.0;
278 }
279
280 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
282 if norm > 0.0 {
283 for x in &mut embedding {
284 *x /= norm;
285 }
286 }
287
288 Ok(embedding)
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_capture_and_retrieve() {
298 let mut memory = SmeltMemory::in_memory().unwrap();
299
300 let episode = Episode::new(
301 "Fixed authentication bug in login flow".to_string(),
302 "bugfix".to_string(),
303 EpisodeOutcome::Success,
304 )
305 .with_tags(vec!["auth".to_string(), "security".to_string()]);
306
307 let id = memory.capture(episode).unwrap();
308
309 let results = memory.retrieve("authentication login", 5).unwrap();
311 assert!(!results.is_empty());
312 assert_eq!(results[0].episode.id, id);
313 }
314
315 #[test]
316 fn test_feedback() {
317 let mut memory = SmeltMemory::in_memory().unwrap();
318
319 let episode = Episode::new(
320 "Test episode".to_string(),
321 "test".to_string(),
322 EpisodeOutcome::Success,
323 );
324
325 let id = memory.capture(episode).unwrap();
326
327 memory.record_feedback(id, true).unwrap();
329 memory.record_feedback(id, true).unwrap();
330
331 let updated = memory.get_episode(id).unwrap().unwrap();
332 assert_eq!(updated.helpful_count, 2);
333 assert!(updated.utility > 1.0 - 0.001); }
335
336 #[test]
337 fn test_propagation() {
338 let mut memory = SmeltMemory::in_memory().unwrap();
339
340 let ep1 = Episode::new(
342 "Auth fix".to_string(),
343 "bugfix".to_string(),
344 EpisodeOutcome::Success,
345 );
346 let ep2 = Episode::new(
347 "Auth test".to_string(),
348 "test".to_string(),
349 EpisodeOutcome::Partial,
350 );
351
352 memory.capture(ep1).unwrap();
353 memory.capture(ep2).unwrap();
354
355 let result = memory.propagate_utility(false).unwrap();
357 assert!(result.total_change >= 0.0);
359 }
360
361 #[test]
362 fn test_project_filter() {
363 let mut memory = SmeltMemory::in_memory().unwrap();
364
365 let ep1 = Episode::new(
366 "Project A work".to_string(),
367 "feature".to_string(),
368 EpisodeOutcome::Success,
369 )
370 .with_project("project-a".to_string());
371 let ep2 = Episode::new(
372 "Project B work".to_string(),
373 "feature".to_string(),
374 EpisodeOutcome::Success,
375 )
376 .with_project("project-b".to_string());
377
378 memory.capture(ep1).unwrap();
379 memory.capture(ep2).unwrap();
380
381 let all = memory.list_episodes().unwrap();
383 assert_eq!(all.len(), 2);
384
385 let _memory_a = memory.with_project("project-a".to_string());
387 }
389
390 #[test]
391 fn test_stats() {
392 let mut memory = SmeltMemory::in_memory().unwrap();
393
394 let ep1 = Episode::new(
395 "Ep1".to_string(),
396 "test".to_string(),
397 EpisodeOutcome::Success,
398 );
399 let ep2 = Episode::new(
400 "Ep2".to_string(),
401 "test".to_string(),
402 EpisodeOutcome::Success,
403 );
404
405 memory.capture(ep1).unwrap();
406 memory.capture(ep2).unwrap();
407
408 let stats = memory.stats().unwrap();
409 assert_eq!(stats.total_episodes, 2);
410 }
411}