Skip to main content

tandem_memory/
distillation.rs

1use crate::types::{DistillationReport, DistilledFact, FactCategory, MemoryResult};
2use async_trait::async_trait;
3use chrono::Utc;
4use std::sync::Arc;
5use tandem_providers::ProviderRegistry;
6
7const DISTILLATION_PROMPT: &str = r#"You are analyzing a conversation session to extract memories.
8
9## Conversation History:
10{conversation}
11
12## Task:
13Extract and summarize important information from this conversation. Return a JSON array of objects with this format:
14[
15  {
16    "category": "user_preference|task_outcome|learning|fact",
17    "content": "The extracted information",
18    "importance": 0.0-1.0,
19    "follow_up_needed": true|false
20  }
21]
22
23Only include information that would be valuable for future sessions.
24"#;
25
26pub struct SessionDistiller {
27    providers: Arc<ProviderRegistry>,
28    importance_threshold: f64,
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DistillationMemoryWrite {
33    pub stored: bool,
34    pub deduped: bool,
35    pub memory_id: Option<String>,
36    pub candidate_id: Option<String>,
37}
38
39#[async_trait]
40pub trait DistillationMemoryWriter: Send + Sync {
41    async fn store_user_fact(
42        &self,
43        session_id: &str,
44        fact: &DistilledFact,
45    ) -> MemoryResult<DistillationMemoryWrite>;
46
47    async fn store_agent_fact(
48        &self,
49        session_id: &str,
50        fact: &DistilledFact,
51    ) -> MemoryResult<DistillationMemoryWrite>;
52}
53
54struct NoopDistillationMemoryWriter;
55
56#[async_trait]
57impl DistillationMemoryWriter for NoopDistillationMemoryWriter {
58    async fn store_user_fact(
59        &self,
60        _session_id: &str,
61        _fact: &DistilledFact,
62    ) -> MemoryResult<DistillationMemoryWrite> {
63        Ok(DistillationMemoryWrite::default())
64    }
65
66    async fn store_agent_fact(
67        &self,
68        _session_id: &str,
69        _fact: &DistilledFact,
70    ) -> MemoryResult<DistillationMemoryWrite> {
71        Ok(DistillationMemoryWrite::default())
72    }
73}
74
75impl SessionDistiller {
76    pub fn new(providers: Arc<ProviderRegistry>) -> Self {
77        Self {
78            providers,
79            importance_threshold: 0.5,
80        }
81    }
82
83    pub fn with_threshold(providers: Arc<ProviderRegistry>, importance_threshold: f64) -> Self {
84        Self {
85            providers,
86            importance_threshold,
87        }
88    }
89
90    pub async fn distill(
91        &self,
92        session_id: &str,
93        conversation: &[String],
94    ) -> MemoryResult<DistillationReport> {
95        self.distill_with_writer(session_id, conversation, &NoopDistillationMemoryWriter)
96            .await
97    }
98
99    pub async fn distill_with_writer<W: DistillationMemoryWriter>(
100        &self,
101        session_id: &str,
102        conversation: &[String],
103        writer: &W,
104    ) -> MemoryResult<DistillationReport> {
105        let distillation_id = uuid::Uuid::new_v4().to_string();
106        let full_text = conversation.join("\n\n---\n\n");
107        let token_count = self.count_tokens(&full_text);
108
109        if token_count < 50 {
110            return Ok(DistillationReport {
111                distillation_id,
112                session_id: session_id.to_string(),
113                distilled_at: Utc::now(),
114                facts_extracted: 0,
115                importance_threshold: self.importance_threshold,
116                user_memory_updated: false,
117                agent_memory_updated: false,
118                stored_count: 0,
119                deduped_count: 0,
120                memory_ids: Vec::new(),
121                candidate_ids: Vec::new(),
122                status: "skipped_short_conversation".to_string(),
123            });
124        }
125
126        let facts = self.extract_facts(&full_text, &distillation_id).await?;
127
128        let filtered_facts: Vec<&DistilledFact> = facts
129            .iter()
130            .filter(|f| f.importance_score >= self.importance_threshold)
131            .collect();
132
133        Self::build_distillation_report(
134            distillation_id,
135            session_id,
136            self.importance_threshold,
137            &filtered_facts,
138            writer,
139        )
140        .await
141    }
142
143    async fn build_distillation_report<W: DistillationMemoryWriter>(
144        distillation_id: String,
145        session_id: &str,
146        importance_threshold: f64,
147        filtered_facts: &[&DistilledFact],
148        writer: &W,
149    ) -> MemoryResult<DistillationReport> {
150        if filtered_facts.is_empty() {
151            return Ok(DistillationReport {
152                distillation_id,
153                session_id: session_id.to_string(),
154                distilled_at: Utc::now(),
155                facts_extracted: 0,
156                importance_threshold,
157                user_memory_updated: false,
158                agent_memory_updated: false,
159                stored_count: 0,
160                deduped_count: 0,
161                memory_ids: Vec::new(),
162                candidate_ids: Vec::new(),
163                status: "no_important_facts".to_string(),
164            });
165        }
166
167        let user_results =
168            Self::write_user_memory_facts(session_id, filtered_facts, writer).await?;
169        let agent_results =
170            Self::write_agent_memory_facts(session_id, filtered_facts, writer).await?;
171        let all_results = user_results
172            .iter()
173            .chain(agent_results.iter())
174            .cloned()
175            .collect::<Vec<_>>();
176        let stored_count = all_results.iter().filter(|row| row.stored).count();
177        let deduped_count = all_results.iter().filter(|row| row.deduped).count();
178        let memory_ids = all_results
179            .iter()
180            .filter_map(|row| row.memory_id.clone())
181            .collect::<Vec<_>>();
182        let candidate_ids = all_results
183            .iter()
184            .filter_map(|row| row.candidate_id.clone())
185            .collect::<Vec<_>>();
186
187        Ok(DistillationReport {
188            distillation_id,
189            session_id: session_id.to_string(),
190            distilled_at: Utc::now(),
191            facts_extracted: filtered_facts.len(),
192            importance_threshold,
193            user_memory_updated: !user_results.is_empty(),
194            agent_memory_updated: !agent_results.is_empty(),
195            stored_count,
196            deduped_count,
197            memory_ids,
198            candidate_ids,
199            status: if stored_count > 0 || deduped_count > 0 {
200                "stored".to_string()
201            } else {
202                "facts_extracted_only".to_string()
203            },
204        })
205    }
206
207    async fn extract_facts(
208        &self,
209        conversation: &str,
210        distillation_id: &str,
211    ) -> MemoryResult<Vec<DistilledFact>> {
212        let prompt = DISTILLATION_PROMPT.replace("{conversation}", conversation);
213
214        let response = match self.providers.complete_cheapest(&prompt, None, None).await {
215            Ok(r) => r,
216            Err(e) => {
217                tracing::warn!("Distillation LLM failed: {}", e);
218                return Ok(Vec::new());
219            }
220        };
221
222        let extracted: Vec<ExtractedFact> = match serde_json::from_str(&response) {
223            Ok(facts) => facts,
224            Err(e) => {
225                tracing::warn!("Failed to parse distillation response: {}", e);
226                return Ok(Vec::new());
227            }
228        };
229
230        let facts: Vec<DistilledFact> = extracted
231            .into_iter()
232            .map(|e| DistilledFact {
233                id: uuid::Uuid::new_v4().to_string(),
234                distillation_id: distillation_id.to_string(),
235                content: e.content,
236                category: parse_category(&e.category),
237                importance_score: e.importance,
238                source_message_ids: Vec::new(),
239                contradicts_fact_id: None,
240            })
241            .collect();
242
243        Ok(facts)
244    }
245
246    async fn write_user_memory_facts(
247        session_id: &str,
248        facts: &[&DistilledFact],
249        writer: &impl DistillationMemoryWriter,
250    ) -> MemoryResult<Vec<DistillationMemoryWrite>> {
251        if facts.is_empty() {
252            return Ok(Vec::new());
253        }
254
255        let user_facts: Vec<&DistilledFact> = facts
256            .iter()
257            .filter(|f| {
258                matches!(
259                    f.category,
260                    FactCategory::UserPreference | FactCategory::Fact
261                )
262            })
263            .cloned()
264            .collect();
265
266        if user_facts.is_empty() {
267            return Ok(Vec::new());
268        }
269
270        let mut writes = Vec::new();
271        for fact in user_facts {
272            writes.push(writer.store_user_fact(session_id, fact).await?);
273        }
274        Ok(writes)
275    }
276
277    async fn write_agent_memory_facts(
278        session_id: &str,
279        facts: &[&DistilledFact],
280        writer: &impl DistillationMemoryWriter,
281    ) -> MemoryResult<Vec<DistillationMemoryWrite>> {
282        if facts.is_empty() {
283            return Ok(Vec::new());
284        }
285
286        let agent_facts: Vec<&DistilledFact> = facts
287            .iter()
288            .filter(|f| {
289                matches!(
290                    f.category,
291                    FactCategory::TaskOutcome | FactCategory::Learning
292                )
293            })
294            .cloned()
295            .collect();
296
297        if agent_facts.is_empty() {
298            return Ok(Vec::new());
299        }
300
301        let mut writes = Vec::new();
302        for fact in agent_facts {
303            writes.push(writer.store_agent_fact(session_id, fact).await?);
304        }
305        Ok(writes)
306    }
307
308    fn count_tokens(&self, text: &str) -> i64 {
309        tiktoken_rs::cl100k_base()
310            .map(|bpe| bpe.encode_with_special_tokens(text).len() as i64)
311            .unwrap_or((text.len() / 4) as i64)
312    }
313}
314
315fn parse_category(s: &str) -> FactCategory {
316    match s.to_lowercase().as_str() {
317        "user_preference" => FactCategory::UserPreference,
318        "task_outcome" => FactCategory::TaskOutcome,
319        "learning" => FactCategory::Learning,
320        _ => FactCategory::Fact,
321    }
322}
323
324#[derive(Debug, Clone, serde::Deserialize)]
325#[allow(dead_code)]
326struct ExtractedFact {
327    category: String,
328    content: String,
329    importance: f64,
330    #[serde(default)]
331    follow_up_needed: bool,
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use std::sync::{Arc, Mutex};
338
339    #[derive(Clone, Default)]
340    struct RecordingWriter {
341        user_calls: Arc<Mutex<Vec<String>>>,
342        agent_calls: Arc<Mutex<Vec<String>>>,
343        dedupe: bool,
344    }
345
346    #[async_trait]
347    impl DistillationMemoryWriter for RecordingWriter {
348        async fn store_user_fact(
349            &self,
350            _session_id: &str,
351            fact: &DistilledFact,
352        ) -> MemoryResult<DistillationMemoryWrite> {
353            self.user_calls
354                .lock()
355                .expect("user calls")
356                .push(fact.id.clone());
357            Ok(DistillationMemoryWrite {
358                stored: !self.dedupe,
359                deduped: self.dedupe,
360                memory_id: Some(format!("memory-{}", fact.id)),
361                candidate_id: Some(format!("candidate-{}", fact.id)),
362            })
363        }
364
365        async fn store_agent_fact(
366            &self,
367            _session_id: &str,
368            fact: &DistilledFact,
369        ) -> MemoryResult<DistillationMemoryWrite> {
370            self.agent_calls
371                .lock()
372                .expect("agent calls")
373                .push(fact.id.clone());
374            Ok(DistillationMemoryWrite {
375                stored: !self.dedupe,
376                deduped: self.dedupe,
377                memory_id: Some(format!("memory-{}", fact.id)),
378                candidate_id: Some(format!("candidate-{}", fact.id)),
379            })
380        }
381    }
382
383    fn test_fact(id: &str, category: FactCategory) -> DistilledFact {
384        DistilledFact {
385            id: id.to_string(),
386            distillation_id: "distill-test".to_string(),
387            content: format!("content for {id}"),
388            category,
389            importance_score: 0.9,
390            source_message_ids: Vec::new(),
391            contradicts_fact_id: None,
392        }
393    }
394
395    #[test]
396    fn test_parse_category() {
397        assert_eq!(
398            parse_category("user_preference"),
399            FactCategory::UserPreference
400        );
401        assert_eq!(parse_category("task_outcome"), FactCategory::TaskOutcome);
402        assert_eq!(parse_category("learning"), FactCategory::Learning);
403        assert_eq!(parse_category("unknown"), FactCategory::Fact);
404    }
405
406    #[tokio::test]
407    async fn test_distiller_requires_conversation() {
408        // This test would require ProviderRegistry mock
409        // Placeholder for now
410    }
411
412    #[tokio::test]
413    async fn build_distillation_report_routes_user_and_agent_facts_to_writer() {
414        let writer = RecordingWriter::default();
415        let facts = vec![
416            test_fact("fact-user-preference", FactCategory::UserPreference),
417            test_fact("fact-user-fact", FactCategory::Fact),
418            test_fact("fact-agent-outcome", FactCategory::TaskOutcome),
419            test_fact("fact-agent-learning", FactCategory::Learning),
420        ];
421        let fact_refs = facts.iter().collect::<Vec<_>>();
422
423        let report = SessionDistiller::build_distillation_report(
424            "distill-1".to_string(),
425            "session-1",
426            0.5,
427            &fact_refs,
428            &writer,
429        )
430        .await
431        .expect("distillation report");
432
433        assert_eq!(report.facts_extracted, 4);
434        assert!(report.user_memory_updated);
435        assert!(report.agent_memory_updated);
436        assert_eq!(report.stored_count, 4);
437        assert_eq!(report.deduped_count, 0);
438        assert_eq!(report.memory_ids.len(), 4);
439        assert_eq!(report.candidate_ids.len(), 4);
440        assert_eq!(report.status, "stored");
441
442        let user_calls = writer.user_calls.lock().expect("user calls").clone();
443        let agent_calls = writer.agent_calls.lock().expect("agent calls").clone();
444        assert_eq!(user_calls.len(), 2);
445        assert_eq!(agent_calls.len(), 2);
446        assert!(user_calls.iter().any(|id| id == "fact-user-preference"));
447        assert!(user_calls.iter().any(|id| id == "fact-user-fact"));
448        assert!(agent_calls.iter().any(|id| id == "fact-agent-outcome"));
449        assert!(agent_calls.iter().any(|id| id == "fact-agent-learning"));
450    }
451
452    #[tokio::test]
453    async fn build_distillation_report_counts_deduped_writes_as_stored_status() {
454        let writer = RecordingWriter {
455            dedupe: true,
456            ..RecordingWriter::default()
457        };
458        let facts = vec![test_fact("fact-deduped", FactCategory::Fact)];
459        let fact_refs = facts.iter().collect::<Vec<_>>();
460
461        let report = SessionDistiller::build_distillation_report(
462            "distill-2".to_string(),
463            "session-2",
464            0.5,
465            &fact_refs,
466            &writer,
467        )
468        .await
469        .expect("distillation report");
470
471        assert_eq!(report.facts_extracted, 1);
472        assert_eq!(report.stored_count, 0);
473        assert_eq!(report.deduped_count, 1);
474        assert_eq!(report.status, "stored");
475        assert_eq!(report.memory_ids, vec!["memory-fact-deduped".to_string()]);
476        assert_eq!(
477            report.candidate_ids,
478            vec!["candidate-fact-deduped".to_string()]
479        );
480    }
481}