tandem_memory/
distillation.rs1use crate::types::{DistillationReport, DistilledFact, FactCategory, MemoryResult};
2use chrono::Utc;
3use std::sync::Arc;
4use tandem_providers::ProviderRegistry;
5
6const DISTILLATION_PROMPT: &str = r#"You are analyzing a conversation session to extract memories.
7
8## Conversation History:
9{conversation}
10
11## Task:
12Extract and summarize important information from this conversation. Return a JSON array of objects with this format:
13[
14 {
15 "category": "user_preference|task_outcome|learning|fact",
16 "content": "The extracted information",
17 "importance": 0.0-1.0,
18 "follow_up_needed": true|false
19 }
20]
21
22Only include information that would be valuable for future sessions.
23"#;
24
25pub struct SessionDistiller {
26 providers: Arc<ProviderRegistry>,
27 importance_threshold: f64,
28}
29
30impl SessionDistiller {
31 pub fn new(providers: Arc<ProviderRegistry>) -> Self {
32 Self {
33 providers,
34 importance_threshold: 0.5,
35 }
36 }
37
38 pub fn with_threshold(providers: Arc<ProviderRegistry>, importance_threshold: f64) -> Self {
39 Self {
40 providers,
41 importance_threshold,
42 }
43 }
44
45 pub async fn distill(
46 &self,
47 session_id: &str,
48 conversation: &[String],
49 ) -> MemoryResult<DistillationReport> {
50 let distillation_id = uuid::Uuid::new_v4().to_string();
51 let full_text = conversation.join("\n\n---\n\n");
52 let token_count = self.count_tokens(&full_text);
53
54 if token_count < 50 {
55 return Ok(DistillationReport {
56 distillation_id,
57 session_id: session_id.to_string(),
58 distilled_at: Utc::now(),
59 facts_extracted: 0,
60 importance_threshold: self.importance_threshold,
61 user_memory_updated: false,
62 agent_memory_updated: false,
63 });
64 }
65
66 let facts = self.extract_facts(&full_text, &distillation_id).await?;
67
68 let filtered_facts: Vec<&DistilledFact> = facts
69 .iter()
70 .filter(|f| f.importance_score >= self.importance_threshold)
71 .collect();
72
73 let user_memory_updated = self.update_user_memory(session_id, &filtered_facts).await?;
74 let agent_memory_updated = self
75 .update_agent_memory(session_id, &filtered_facts)
76 .await?;
77
78 Ok(DistillationReport {
79 distillation_id,
80 session_id: session_id.to_string(),
81 distilled_at: Utc::now(),
82 facts_extracted: filtered_facts.len(),
83 importance_threshold: self.importance_threshold,
84 user_memory_updated,
85 agent_memory_updated,
86 })
87 }
88
89 async fn extract_facts(
90 &self,
91 conversation: &str,
92 distillation_id: &str,
93 ) -> MemoryResult<Vec<DistilledFact>> {
94 let prompt = DISTILLATION_PROMPT.replace("{conversation}", conversation);
95
96 let response = match self.providers.complete_cheapest(&prompt, None, None).await {
97 Ok(r) => r,
98 Err(e) => {
99 tracing::warn!("Distillation LLM failed: {}", e);
100 return Ok(Vec::new());
101 }
102 };
103
104 let extracted: Vec<ExtractedFact> = match serde_json::from_str(&response) {
105 Ok(facts) => facts,
106 Err(e) => {
107 tracing::warn!("Failed to parse distillation response: {}", e);
108 return Ok(Vec::new());
109 }
110 };
111
112 let facts: Vec<DistilledFact> = extracted
113 .into_iter()
114 .map(|e| DistilledFact {
115 id: uuid::Uuid::new_v4().to_string(),
116 distillation_id: distillation_id.to_string(),
117 content: e.content,
118 category: parse_category(&e.category),
119 importance_score: e.importance,
120 source_message_ids: Vec::new(),
121 contradicts_fact_id: None,
122 })
123 .collect();
124
125 Ok(facts)
126 }
127
128 async fn update_user_memory(
129 &self,
130 session_id: &str,
131 facts: &[&DistilledFact],
132 ) -> MemoryResult<bool> {
133 if facts.is_empty() {
134 return Ok(false);
135 }
136
137 let user_facts: Vec<&DistilledFact> = facts
138 .iter()
139 .filter(|f| {
140 matches!(
141 f.category,
142 FactCategory::UserPreference | FactCategory::Fact
143 )
144 })
145 .cloned()
146 .collect();
147
148 if user_facts.is_empty() {
149 return Ok(false);
150 }
151
152 tracing::info!(
153 "Would update user memory with {} facts for session {}",
154 user_facts.len(),
155 session_id
156 );
157
158 Ok(true)
159 }
160
161 async fn update_agent_memory(
162 &self,
163 session_id: &str,
164 facts: &[&DistilledFact],
165 ) -> MemoryResult<bool> {
166 if facts.is_empty() {
167 return Ok(false);
168 }
169
170 let agent_facts: Vec<&DistilledFact> = facts
171 .iter()
172 .filter(|f| {
173 matches!(
174 f.category,
175 FactCategory::TaskOutcome | FactCategory::Learning
176 )
177 })
178 .cloned()
179 .collect();
180
181 if agent_facts.is_empty() {
182 return Ok(false);
183 }
184
185 tracing::info!(
186 "Would update agent memory with {} facts for session {}",
187 agent_facts.len(),
188 session_id
189 );
190
191 Ok(true)
192 }
193
194 fn count_tokens(&self, text: &str) -> i64 {
195 tiktoken_rs::cl100k_base()
196 .map(|bpe| bpe.encode_with_special_tokens(text).len() as i64)
197 .unwrap_or((text.len() / 4) as i64)
198 }
199}
200
201fn parse_category(s: &str) -> FactCategory {
202 match s.to_lowercase().as_str() {
203 "user_preference" => FactCategory::UserPreference,
204 "task_outcome" => FactCategory::TaskOutcome,
205 "learning" => FactCategory::Learning,
206 _ => FactCategory::Fact,
207 }
208}
209
210#[derive(Debug, Clone, serde::Deserialize)]
211#[allow(dead_code)]
212struct ExtractedFact {
213 category: String,
214 content: String,
215 importance: f64,
216 #[serde(default)]
217 follow_up_needed: bool,
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_parse_category() {
226 assert_eq!(
227 parse_category("user_preference"),
228 FactCategory::UserPreference
229 );
230 assert_eq!(parse_category("task_outcome"), FactCategory::TaskOutcome);
231 assert_eq!(parse_category("learning"), FactCategory::Learning);
232 assert_eq!(parse_category("unknown"), FactCategory::Fact);
233 }
234
235 #[tokio::test]
236 async fn test_distiller_requires_conversation() {
237 }
240}