Skip to main content

tandem_memory/
distillation.rs

1use crate::types::{DistillationReport, DistilledFact, FactCategory, MemoryResult};
2use chrono::Utc;
3use std::sync::Arc;
4use tandem_providers::ProviderRegistry;
5
6const DISTILLATION_PROMPT: &str = r#"You are analyzing a conversation session to extract memories.
7
8## Conversation History:
9{conversation}
10
11## Task:
12Extract and summarize important information from this conversation. Return a JSON array of objects with this format:
13[
14  {
15    "category": "user_preference|task_outcome|learning|fact",
16    "content": "The extracted information",
17    "importance": 0.0-1.0,
18    "follow_up_needed": true|false
19  }
20]
21
22Only include information that would be valuable for future sessions.
23"#;
24
25pub struct SessionDistiller {
26    providers: Arc<ProviderRegistry>,
27    importance_threshold: f64,
28}
29
30impl SessionDistiller {
31    pub fn new(providers: Arc<ProviderRegistry>) -> Self {
32        Self {
33            providers,
34            importance_threshold: 0.5,
35        }
36    }
37
38    pub fn with_threshold(providers: Arc<ProviderRegistry>, importance_threshold: f64) -> Self {
39        Self {
40            providers,
41            importance_threshold,
42        }
43    }
44
45    pub async fn distill(
46        &self,
47        session_id: &str,
48        conversation: &[String],
49    ) -> MemoryResult<DistillationReport> {
50        let distillation_id = uuid::Uuid::new_v4().to_string();
51        let full_text = conversation.join("\n\n---\n\n");
52        let token_count = self.count_tokens(&full_text);
53
54        if token_count < 50 {
55            return Ok(DistillationReport {
56                distillation_id,
57                session_id: session_id.to_string(),
58                distilled_at: Utc::now(),
59                facts_extracted: 0,
60                importance_threshold: self.importance_threshold,
61                user_memory_updated: false,
62                agent_memory_updated: false,
63            });
64        }
65
66        let facts = self.extract_facts(&full_text, &distillation_id).await?;
67
68        let filtered_facts: Vec<&DistilledFact> = facts
69            .iter()
70            .filter(|f| f.importance_score >= self.importance_threshold)
71            .collect();
72
73        let user_memory_updated = self.update_user_memory(session_id, &filtered_facts).await?;
74        let agent_memory_updated = self
75            .update_agent_memory(session_id, &filtered_facts)
76            .await?;
77
78        Ok(DistillationReport {
79            distillation_id,
80            session_id: session_id.to_string(),
81            distilled_at: Utc::now(),
82            facts_extracted: filtered_facts.len(),
83            importance_threshold: self.importance_threshold,
84            user_memory_updated,
85            agent_memory_updated,
86        })
87    }
88
89    async fn extract_facts(
90        &self,
91        conversation: &str,
92        distillation_id: &str,
93    ) -> MemoryResult<Vec<DistilledFact>> {
94        let prompt = DISTILLATION_PROMPT.replace("{conversation}", conversation);
95
96        let response = match self.providers.complete_cheapest(&prompt, None, None).await {
97            Ok(r) => r,
98            Err(e) => {
99                tracing::warn!("Distillation LLM failed: {}", e);
100                return Ok(Vec::new());
101            }
102        };
103
104        let extracted: Vec<ExtractedFact> = match serde_json::from_str(&response) {
105            Ok(facts) => facts,
106            Err(e) => {
107                tracing::warn!("Failed to parse distillation response: {}", e);
108                return Ok(Vec::new());
109            }
110        };
111
112        let facts: Vec<DistilledFact> = extracted
113            .into_iter()
114            .map(|e| DistilledFact {
115                id: uuid::Uuid::new_v4().to_string(),
116                distillation_id: distillation_id.to_string(),
117                content: e.content,
118                category: parse_category(&e.category),
119                importance_score: e.importance,
120                source_message_ids: Vec::new(),
121                contradicts_fact_id: None,
122            })
123            .collect();
124
125        Ok(facts)
126    }
127
128    async fn update_user_memory(
129        &self,
130        session_id: &str,
131        facts: &[&DistilledFact],
132    ) -> MemoryResult<bool> {
133        if facts.is_empty() {
134            return Ok(false);
135        }
136
137        let user_facts: Vec<&DistilledFact> = facts
138            .iter()
139            .filter(|f| {
140                matches!(
141                    f.category,
142                    FactCategory::UserPreference | FactCategory::Fact
143                )
144            })
145            .cloned()
146            .collect();
147
148        if user_facts.is_empty() {
149            return Ok(false);
150        }
151
152        tracing::info!(
153            "Would update user memory with {} facts for session {}",
154            user_facts.len(),
155            session_id
156        );
157
158        Ok(true)
159    }
160
161    async fn update_agent_memory(
162        &self,
163        session_id: &str,
164        facts: &[&DistilledFact],
165    ) -> MemoryResult<bool> {
166        if facts.is_empty() {
167            return Ok(false);
168        }
169
170        let agent_facts: Vec<&DistilledFact> = facts
171            .iter()
172            .filter(|f| {
173                matches!(
174                    f.category,
175                    FactCategory::TaskOutcome | FactCategory::Learning
176                )
177            })
178            .cloned()
179            .collect();
180
181        if agent_facts.is_empty() {
182            return Ok(false);
183        }
184
185        tracing::info!(
186            "Would update agent memory with {} facts for session {}",
187            agent_facts.len(),
188            session_id
189        );
190
191        Ok(true)
192    }
193
194    fn count_tokens(&self, text: &str) -> i64 {
195        tiktoken_rs::cl100k_base()
196            .map(|bpe| bpe.encode_with_special_tokens(text).len() as i64)
197            .unwrap_or((text.len() / 4) as i64)
198    }
199}
200
201fn parse_category(s: &str) -> FactCategory {
202    match s.to_lowercase().as_str() {
203        "user_preference" => FactCategory::UserPreference,
204        "task_outcome" => FactCategory::TaskOutcome,
205        "learning" => FactCategory::Learning,
206        _ => FactCategory::Fact,
207    }
208}
209
210#[derive(Debug, Clone, serde::Deserialize)]
211#[allow(dead_code)]
212struct ExtractedFact {
213    category: String,
214    content: String,
215    importance: f64,
216    #[serde(default)]
217    follow_up_needed: bool,
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn test_parse_category() {
226        assert_eq!(
227            parse_category("user_preference"),
228            FactCategory::UserPreference
229        );
230        assert_eq!(parse_category("task_outcome"), FactCategory::TaskOutcome);
231        assert_eq!(parse_category("learning"), FactCategory::Learning);
232        assert_eq!(parse_category("unknown"), FactCategory::Fact);
233    }
234
235    #[tokio::test]
236    async fn test_distiller_requires_conversation() {
237        // This test would require ProviderRegistry mock
238        // Placeholder for now
239    }
240}