1use crate::types::{DistillationReport, DistilledFact, FactCategory, MemoryResult};
2use async_trait::async_trait;
3use chrono::Utc;
4use std::sync::Arc;
5use tandem_providers::ProviderRegistry;
6
7const DISTILLATION_PROMPT: &str = r#"You are analyzing a conversation session to extract memories.
8
9## Conversation History:
10{conversation}
11
12## Task:
13Extract and summarize important information from this conversation. Return a JSON array of objects with this format:
14[
15 {
16 "category": "user_preference|task_outcome|learning|fact",
17 "content": "The extracted information",
18 "importance": 0.0-1.0,
19 "follow_up_needed": true|false
20 }
21]
22
23Only include information that would be valuable for future sessions.
24"#;
25
26pub struct SessionDistiller {
27 providers: Arc<ProviderRegistry>,
28 importance_threshold: f64,
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct DistillationMemoryWrite {
33 pub stored: bool,
34 pub deduped: bool,
35 pub memory_id: Option<String>,
36 pub candidate_id: Option<String>,
37}
38
39#[async_trait]
40pub trait DistillationMemoryWriter: Send + Sync {
41 async fn store_user_fact(
42 &self,
43 session_id: &str,
44 fact: &DistilledFact,
45 ) -> MemoryResult<DistillationMemoryWrite>;
46
47 async fn store_agent_fact(
48 &self,
49 session_id: &str,
50 fact: &DistilledFact,
51 ) -> MemoryResult<DistillationMemoryWrite>;
52}
53
54struct NoopDistillationMemoryWriter;
55
56#[async_trait]
57impl DistillationMemoryWriter for NoopDistillationMemoryWriter {
58 async fn store_user_fact(
59 &self,
60 _session_id: &str,
61 _fact: &DistilledFact,
62 ) -> MemoryResult<DistillationMemoryWrite> {
63 Ok(DistillationMemoryWrite::default())
64 }
65
66 async fn store_agent_fact(
67 &self,
68 _session_id: &str,
69 _fact: &DistilledFact,
70 ) -> MemoryResult<DistillationMemoryWrite> {
71 Ok(DistillationMemoryWrite::default())
72 }
73}
74
75impl SessionDistiller {
76 pub fn new(providers: Arc<ProviderRegistry>) -> Self {
77 Self {
78 providers,
79 importance_threshold: 0.5,
80 }
81 }
82
83 pub fn with_threshold(providers: Arc<ProviderRegistry>, importance_threshold: f64) -> Self {
84 Self {
85 providers,
86 importance_threshold,
87 }
88 }
89
90 pub async fn distill(
91 &self,
92 session_id: &str,
93 conversation: &[String],
94 ) -> MemoryResult<DistillationReport> {
95 self.distill_with_writer(session_id, conversation, &NoopDistillationMemoryWriter)
96 .await
97 }
98
99 pub async fn distill_with_writer<W: DistillationMemoryWriter>(
100 &self,
101 session_id: &str,
102 conversation: &[String],
103 writer: &W,
104 ) -> MemoryResult<DistillationReport> {
105 let distillation_id = uuid::Uuid::new_v4().to_string();
106 let full_text = conversation.join("\n\n---\n\n");
107 let token_count = self.count_tokens(&full_text);
108
109 if token_count < 50 {
110 return Ok(DistillationReport {
111 distillation_id,
112 session_id: session_id.to_string(),
113 distilled_at: Utc::now(),
114 facts_extracted: 0,
115 importance_threshold: self.importance_threshold,
116 user_memory_updated: false,
117 agent_memory_updated: false,
118 stored_count: 0,
119 deduped_count: 0,
120 memory_ids: Vec::new(),
121 candidate_ids: Vec::new(),
122 status: "skipped_short_conversation".to_string(),
123 });
124 }
125
126 let facts = self.extract_facts(&full_text, &distillation_id).await?;
127
128 let filtered_facts: Vec<&DistilledFact> = facts
129 .iter()
130 .filter(|f| f.importance_score >= self.importance_threshold)
131 .collect();
132
133 Self::build_distillation_report(
134 distillation_id,
135 session_id,
136 self.importance_threshold,
137 &filtered_facts,
138 writer,
139 )
140 .await
141 }
142
143 async fn build_distillation_report<W: DistillationMemoryWriter>(
144 distillation_id: String,
145 session_id: &str,
146 importance_threshold: f64,
147 filtered_facts: &[&DistilledFact],
148 writer: &W,
149 ) -> MemoryResult<DistillationReport> {
150 if filtered_facts.is_empty() {
151 return Ok(DistillationReport {
152 distillation_id,
153 session_id: session_id.to_string(),
154 distilled_at: Utc::now(),
155 facts_extracted: 0,
156 importance_threshold,
157 user_memory_updated: false,
158 agent_memory_updated: false,
159 stored_count: 0,
160 deduped_count: 0,
161 memory_ids: Vec::new(),
162 candidate_ids: Vec::new(),
163 status: "no_important_facts".to_string(),
164 });
165 }
166
167 let user_results =
168 Self::write_user_memory_facts(session_id, filtered_facts, writer).await?;
169 let agent_results =
170 Self::write_agent_memory_facts(session_id, filtered_facts, writer).await?;
171 let all_results = user_results
172 .iter()
173 .chain(agent_results.iter())
174 .cloned()
175 .collect::<Vec<_>>();
176 let stored_count = all_results.iter().filter(|row| row.stored).count();
177 let deduped_count = all_results.iter().filter(|row| row.deduped).count();
178 let memory_ids = all_results
179 .iter()
180 .filter_map(|row| row.memory_id.clone())
181 .collect::<Vec<_>>();
182 let candidate_ids = all_results
183 .iter()
184 .filter_map(|row| row.candidate_id.clone())
185 .collect::<Vec<_>>();
186
187 Ok(DistillationReport {
188 distillation_id,
189 session_id: session_id.to_string(),
190 distilled_at: Utc::now(),
191 facts_extracted: filtered_facts.len(),
192 importance_threshold,
193 user_memory_updated: !user_results.is_empty(),
194 agent_memory_updated: !agent_results.is_empty(),
195 stored_count,
196 deduped_count,
197 memory_ids,
198 candidate_ids,
199 status: if stored_count > 0 || deduped_count > 0 {
200 "stored".to_string()
201 } else {
202 "facts_extracted_only".to_string()
203 },
204 })
205 }
206
207 async fn extract_facts(
208 &self,
209 conversation: &str,
210 distillation_id: &str,
211 ) -> MemoryResult<Vec<DistilledFact>> {
212 let prompt = DISTILLATION_PROMPT.replace("{conversation}", conversation);
213
214 let response = match self.providers.complete_cheapest(&prompt, None, None).await {
215 Ok(r) => r,
216 Err(e) => {
217 tracing::warn!("Distillation LLM failed: {}", e);
218 return Ok(Vec::new());
219 }
220 };
221
222 let extracted: Vec<ExtractedFact> = match serde_json::from_str(&response) {
223 Ok(facts) => facts,
224 Err(e) => {
225 tracing::warn!("Failed to parse distillation response: {}", e);
226 return Ok(Vec::new());
227 }
228 };
229
230 let facts: Vec<DistilledFact> = extracted
231 .into_iter()
232 .map(|e| DistilledFact {
233 id: uuid::Uuid::new_v4().to_string(),
234 distillation_id: distillation_id.to_string(),
235 content: e.content,
236 category: parse_category(&e.category),
237 importance_score: e.importance,
238 source_message_ids: Vec::new(),
239 contradicts_fact_id: None,
240 })
241 .collect();
242
243 Ok(facts)
244 }
245
246 async fn write_user_memory_facts(
247 session_id: &str,
248 facts: &[&DistilledFact],
249 writer: &impl DistillationMemoryWriter,
250 ) -> MemoryResult<Vec<DistillationMemoryWrite>> {
251 if facts.is_empty() {
252 return Ok(Vec::new());
253 }
254
255 let user_facts: Vec<&DistilledFact> = facts
256 .iter()
257 .filter(|f| {
258 matches!(
259 f.category,
260 FactCategory::UserPreference | FactCategory::Fact
261 )
262 })
263 .cloned()
264 .collect();
265
266 if user_facts.is_empty() {
267 return Ok(Vec::new());
268 }
269
270 let mut writes = Vec::new();
271 for fact in user_facts {
272 writes.push(writer.store_user_fact(session_id, fact).await?);
273 }
274 Ok(writes)
275 }
276
277 async fn write_agent_memory_facts(
278 session_id: &str,
279 facts: &[&DistilledFact],
280 writer: &impl DistillationMemoryWriter,
281 ) -> MemoryResult<Vec<DistillationMemoryWrite>> {
282 if facts.is_empty() {
283 return Ok(Vec::new());
284 }
285
286 let agent_facts: Vec<&DistilledFact> = facts
287 .iter()
288 .filter(|f| {
289 matches!(
290 f.category,
291 FactCategory::TaskOutcome | FactCategory::Learning
292 )
293 })
294 .cloned()
295 .collect();
296
297 if agent_facts.is_empty() {
298 return Ok(Vec::new());
299 }
300
301 let mut writes = Vec::new();
302 for fact in agent_facts {
303 writes.push(writer.store_agent_fact(session_id, fact).await?);
304 }
305 Ok(writes)
306 }
307
308 fn count_tokens(&self, text: &str) -> i64 {
309 tiktoken_rs::cl100k_base()
310 .map(|bpe| bpe.encode_with_special_tokens(text).len() as i64)
311 .unwrap_or((text.len() / 4) as i64)
312 }
313}
314
315fn parse_category(s: &str) -> FactCategory {
316 match s.to_lowercase().as_str() {
317 "user_preference" => FactCategory::UserPreference,
318 "task_outcome" => FactCategory::TaskOutcome,
319 "learning" => FactCategory::Learning,
320 _ => FactCategory::Fact,
321 }
322}
323
324#[derive(Debug, Clone, serde::Deserialize)]
325#[allow(dead_code)]
326struct ExtractedFact {
327 category: String,
328 content: String,
329 importance: f64,
330 #[serde(default)]
331 follow_up_needed: bool,
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use std::sync::{Arc, Mutex};
338
339 #[derive(Clone, Default)]
340 struct RecordingWriter {
341 user_calls: Arc<Mutex<Vec<String>>>,
342 agent_calls: Arc<Mutex<Vec<String>>>,
343 dedupe: bool,
344 }
345
346 #[async_trait]
347 impl DistillationMemoryWriter for RecordingWriter {
348 async fn store_user_fact(
349 &self,
350 _session_id: &str,
351 fact: &DistilledFact,
352 ) -> MemoryResult<DistillationMemoryWrite> {
353 self.user_calls
354 .lock()
355 .expect("user calls")
356 .push(fact.id.clone());
357 Ok(DistillationMemoryWrite {
358 stored: !self.dedupe,
359 deduped: self.dedupe,
360 memory_id: Some(format!("memory-{}", fact.id)),
361 candidate_id: Some(format!("candidate-{}", fact.id)),
362 })
363 }
364
365 async fn store_agent_fact(
366 &self,
367 _session_id: &str,
368 fact: &DistilledFact,
369 ) -> MemoryResult<DistillationMemoryWrite> {
370 self.agent_calls
371 .lock()
372 .expect("agent calls")
373 .push(fact.id.clone());
374 Ok(DistillationMemoryWrite {
375 stored: !self.dedupe,
376 deduped: self.dedupe,
377 memory_id: Some(format!("memory-{}", fact.id)),
378 candidate_id: Some(format!("candidate-{}", fact.id)),
379 })
380 }
381 }
382
383 fn test_fact(id: &str, category: FactCategory) -> DistilledFact {
384 DistilledFact {
385 id: id.to_string(),
386 distillation_id: "distill-test".to_string(),
387 content: format!("content for {id}"),
388 category,
389 importance_score: 0.9,
390 source_message_ids: Vec::new(),
391 contradicts_fact_id: None,
392 }
393 }
394
395 #[test]
396 fn test_parse_category() {
397 assert_eq!(
398 parse_category("user_preference"),
399 FactCategory::UserPreference
400 );
401 assert_eq!(parse_category("task_outcome"), FactCategory::TaskOutcome);
402 assert_eq!(parse_category("learning"), FactCategory::Learning);
403 assert_eq!(parse_category("unknown"), FactCategory::Fact);
404 }
405
406 #[tokio::test]
407 async fn test_distiller_requires_conversation() {
408 }
411
412 #[tokio::test]
413 async fn build_distillation_report_routes_user_and_agent_facts_to_writer() {
414 let writer = RecordingWriter::default();
415 let facts = vec![
416 test_fact("fact-user-preference", FactCategory::UserPreference),
417 test_fact("fact-user-fact", FactCategory::Fact),
418 test_fact("fact-agent-outcome", FactCategory::TaskOutcome),
419 test_fact("fact-agent-learning", FactCategory::Learning),
420 ];
421 let fact_refs = facts.iter().collect::<Vec<_>>();
422
423 let report = SessionDistiller::build_distillation_report(
424 "distill-1".to_string(),
425 "session-1",
426 0.5,
427 &fact_refs,
428 &writer,
429 )
430 .await
431 .expect("distillation report");
432
433 assert_eq!(report.facts_extracted, 4);
434 assert!(report.user_memory_updated);
435 assert!(report.agent_memory_updated);
436 assert_eq!(report.stored_count, 4);
437 assert_eq!(report.deduped_count, 0);
438 assert_eq!(report.memory_ids.len(), 4);
439 assert_eq!(report.candidate_ids.len(), 4);
440 assert_eq!(report.status, "stored");
441
442 let user_calls = writer.user_calls.lock().expect("user calls").clone();
443 let agent_calls = writer.agent_calls.lock().expect("agent calls").clone();
444 assert_eq!(user_calls.len(), 2);
445 assert_eq!(agent_calls.len(), 2);
446 assert!(user_calls.iter().any(|id| id == "fact-user-preference"));
447 assert!(user_calls.iter().any(|id| id == "fact-user-fact"));
448 assert!(agent_calls.iter().any(|id| id == "fact-agent-outcome"));
449 assert!(agent_calls.iter().any(|id| id == "fact-agent-learning"));
450 }
451
452 #[tokio::test]
453 async fn build_distillation_report_counts_deduped_writes_as_stored_status() {
454 let writer = RecordingWriter {
455 dedupe: true,
456 ..RecordingWriter::default()
457 };
458 let facts = vec![test_fact("fact-deduped", FactCategory::Fact)];
459 let fact_refs = facts.iter().collect::<Vec<_>>();
460
461 let report = SessionDistiller::build_distillation_report(
462 "distill-2".to_string(),
463 "session-2",
464 0.5,
465 &fact_refs,
466 &writer,
467 )
468 .await
469 .expect("distillation report");
470
471 assert_eq!(report.facts_extracted, 1);
472 assert_eq!(report.stored_count, 0);
473 assert_eq!(report.deduped_count, 1);
474 assert_eq!(report.status, "stored");
475 assert_eq!(report.memory_ids, vec!["memory-fact-deduped".to_string()]);
476 assert_eq!(
477 report.candidate_ids,
478 vec!["candidate-fact-deduped".to_string()]
479 );
480 }
481}