1use zeph_llm::provider::{LlmProvider as _, Message};
5
6use crate::embedding_store::{MessageKind, SearchFilter};
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10use super::SemanticMemory;
11use super::algorithms::{apply_mmr, apply_temporal_decay};
12
13#[derive(Debug)]
14pub struct RecalledMessage {
15 pub message: Message,
16 pub score: f32,
17}
18
19impl SemanticMemory {
20 pub async fn remember(
29 &self,
30 conversation_id: ConversationId,
31 role: &str,
32 content: &str,
33 ) -> Result<MessageId, MemoryError> {
34 let message_id = self
35 .sqlite
36 .save_message(conversation_id, role, content)
37 .await?;
38
39 if let Some(qdrant) = &self.qdrant
40 && self.provider.supports_embeddings()
41 {
42 match self.provider.embed(content).await {
43 Ok(vector) => {
44 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
45 if let Err(e) = qdrant.ensure_collection(vector_size).await {
46 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
47 } else if let Err(e) = qdrant
48 .store(
49 message_id,
50 conversation_id,
51 role,
52 vector,
53 MessageKind::Regular,
54 &self.embedding_model,
55 )
56 .await
57 {
58 tracing::warn!("Failed to store embedding: {e:#}");
59 }
60 }
61 Err(e) => {
62 tracing::warn!("Failed to generate embedding: {e:#}");
63 }
64 }
65 }
66
67 Ok(message_id)
68 }
69
70 pub async fn remember_with_parts(
79 &self,
80 conversation_id: ConversationId,
81 role: &str,
82 content: &str,
83 parts_json: &str,
84 ) -> Result<(MessageId, bool), MemoryError> {
85 let message_id = self
86 .sqlite
87 .save_message_with_parts(conversation_id, role, content, parts_json)
88 .await?;
89
90 let mut embedding_stored = false;
91
92 if let Some(qdrant) = &self.qdrant
93 && self.provider.supports_embeddings()
94 {
95 match self.provider.embed(content).await {
96 Ok(vector) => {
97 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
98 if let Err(e) = qdrant.ensure_collection(vector_size).await {
99 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
100 } else if let Err(e) = qdrant
101 .store(
102 message_id,
103 conversation_id,
104 role,
105 vector,
106 MessageKind::Regular,
107 &self.embedding_model,
108 )
109 .await
110 {
111 tracing::warn!("Failed to store embedding: {e:#}");
112 } else {
113 embedding_stored = true;
114 }
115 }
116 Err(e) => {
117 tracing::warn!("Failed to generate embedding: {e:#}");
118 }
119 }
120 }
121
122 Ok((message_id, embedding_stored))
123 }
124
125 pub async fn save_only(
133 &self,
134 conversation_id: ConversationId,
135 role: &str,
136 content: &str,
137 parts_json: &str,
138 ) -> Result<MessageId, MemoryError> {
139 self.sqlite
140 .save_message_with_parts(conversation_id, role, content, parts_json)
141 .await
142 }
143
144 pub async fn recall(
154 &self,
155 query: &str,
156 limit: usize,
157 filter: Option<SearchFilter>,
158 ) -> Result<Vec<RecalledMessage>, MemoryError> {
159 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
160
161 tracing::debug!(
162 query_len = query.len(),
163 limit,
164 has_filter = filter.is_some(),
165 conversation_id = conversation_id.map(|c| c.0),
166 has_qdrant = self.qdrant.is_some(),
167 "recall: starting hybrid search"
168 );
169
170 let keyword_results = match self
171 .sqlite
172 .keyword_search(query, limit * 2, conversation_id)
173 .await
174 {
175 Ok(results) => results,
176 Err(e) => {
177 tracing::warn!("FTS5 keyword search failed: {e:#}");
178 Vec::new()
179 }
180 };
181
182 let vector_results = if let Some(qdrant) = &self.qdrant
183 && self.provider.supports_embeddings()
184 {
185 let query_vector = self.provider.embed(query).await?;
186 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
187 qdrant.ensure_collection(vector_size).await?;
188 qdrant.search(&query_vector, limit * 2, filter).await?
189 } else {
190 Vec::new()
191 };
192
193 self.recall_merge_and_rank(keyword_results, vector_results, limit)
194 .await
195 }
196
197 pub(super) async fn recall_fts5_raw(
198 &self,
199 query: &str,
200 limit: usize,
201 conversation_id: Option<ConversationId>,
202 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
203 self.sqlite
204 .keyword_search(query, limit * 2, conversation_id)
205 .await
206 }
207
208 pub(super) async fn recall_vectors_raw(
209 &self,
210 query: &str,
211 limit: usize,
212 filter: Option<SearchFilter>,
213 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
214 let Some(qdrant) = &self.qdrant else {
215 return Ok(Vec::new());
216 };
217 if !self.provider.supports_embeddings() {
218 return Ok(Vec::new());
219 }
220 let query_vector = self.provider.embed(query).await?;
221 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
222 qdrant.ensure_collection(vector_size).await?;
223 qdrant.search(&query_vector, limit * 2, filter).await
224 }
225
226 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
235 pub(super) async fn recall_merge_and_rank(
236 &self,
237 keyword_results: Vec<(MessageId, f64)>,
238 vector_results: Vec<crate::embedding_store::SearchResult>,
239 limit: usize,
240 ) -> Result<Vec<RecalledMessage>, MemoryError> {
241 tracing::debug!(
242 vector_count = vector_results.len(),
243 keyword_count = keyword_results.len(),
244 limit,
245 "recall: merging search results"
246 );
247
248 let mut scores: std::collections::HashMap<MessageId, f64> =
249 std::collections::HashMap::new();
250
251 if !vector_results.is_empty() {
252 let max_vs = vector_results
253 .iter()
254 .map(|r| r.score)
255 .fold(f32::NEG_INFINITY, f32::max);
256 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
257 for r in &vector_results {
258 let normalized = f64::from(r.score / norm);
259 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
260 }
261 }
262
263 if !keyword_results.is_empty() {
264 let max_ks = keyword_results
265 .iter()
266 .map(|r| r.1)
267 .fold(f64::NEG_INFINITY, f64::max);
268 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
269 for &(msg_id, score) in &keyword_results {
270 let normalized = score / norm;
271 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
272 }
273 }
274
275 if scores.is_empty() {
276 tracing::debug!("recall: empty merge, no overlapping scores");
277 return Ok(Vec::new());
278 }
279
280 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
281 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282
283 tracing::debug!(
284 merged = ranked.len(),
285 top_score = ranked.first().map(|r| r.1),
286 bottom_score = ranked.last().map(|r| r.1),
287 vector_weight = %self.vector_weight,
288 keyword_weight = %self.keyword_weight,
289 "recall: weighted merge complete"
290 );
291
292 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
293 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
294 match self.sqlite.message_timestamps(&ids).await {
295 Ok(timestamps) => {
296 apply_temporal_decay(
297 &mut ranked,
298 ×tamps,
299 self.temporal_decay_half_life_days,
300 );
301 ranked
302 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
303 tracing::debug!(
304 half_life_days = self.temporal_decay_half_life_days,
305 top_score_after = ranked.first().map(|r| r.1),
306 "recall: temporal decay applied"
307 );
308 }
309 Err(e) => {
310 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
311 }
312 }
313 }
314
315 if self.mmr_enabled && !vector_results.is_empty() {
316 if let Some(qdrant) = &self.qdrant {
317 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
318 match qdrant.get_vectors(&ids).await {
319 Ok(vec_map) if !vec_map.is_empty() => {
320 let ranked_len_before = ranked.len();
321 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
322 tracing::debug!(
323 before = ranked_len_before,
324 after = ranked.len(),
325 lambda = %self.mmr_lambda,
326 "recall: mmr re-ranked"
327 );
328 }
329 Ok(_) => {
330 ranked.truncate(limit);
331 }
332 Err(e) => {
333 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
334 ranked.truncate(limit);
335 }
336 }
337 } else {
338 ranked.truncate(limit);
339 }
340 } else {
341 ranked.truncate(limit);
342 }
343
344 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
345 let messages = self.sqlite.messages_by_ids(&ids).await?;
346 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
347
348 let recalled: Vec<RecalledMessage> = ranked
349 .iter()
350 .filter_map(|(msg_id, score)| {
351 msg_map.get(msg_id).map(|msg| RecalledMessage {
352 message: msg.clone(),
353 #[expect(clippy::cast_possible_truncation)]
354 score: *score as f32,
355 })
356 })
357 .collect();
358
359 tracing::debug!(final_count = recalled.len(), "recall: final results");
360
361 Ok(recalled)
362 }
363
364 pub async fn recall_routed(
373 &self,
374 query: &str,
375 limit: usize,
376 filter: Option<SearchFilter>,
377 router: &dyn crate::router::MemoryRouter,
378 ) -> Result<Vec<RecalledMessage>, MemoryError> {
379 use crate::router::MemoryRoute;
380
381 let route = router.route(query);
382 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
383
384 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
385
386 let (keyword_results, vector_results): (
387 Vec<(MessageId, f64)>,
388 Vec<crate::embedding_store::SearchResult>,
389 ) = match route {
390 MemoryRoute::Keyword => {
391 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
392 (kw, Vec::new())
393 }
394 MemoryRoute::Semantic => {
395 let vr = self.recall_vectors_raw(query, limit, filter).await?;
396 (Vec::new(), vr)
397 }
398 MemoryRoute::Hybrid => {
399 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
400 Ok(r) => r,
401 Err(e) => {
402 tracing::warn!("FTS5 keyword search failed: {e:#}");
403 Vec::new()
404 }
405 };
406 let vr = self.recall_vectors_raw(query, limit, filter).await?;
407 (kw, vr)
408 }
409 MemoryRoute::Episodic => {
418 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
419 let cleaned = crate::router::strip_temporal_keywords(query);
420 let search_query = if cleaned.is_empty() { query } else { &cleaned };
421 let kw = if let Some(ref r) = range {
422 self.sqlite
423 .keyword_search_with_time_range(
424 search_query,
425 limit,
426 conversation_id,
427 r.after.as_deref(),
428 r.before.as_deref(),
429 )
430 .await?
431 } else {
432 self.recall_fts5_raw(search_query, limit, conversation_id)
433 .await?
434 };
435 tracing::debug!(
436 has_range = range.is_some(),
437 cleaned_query = %search_query,
438 keyword_count = kw.len(),
439 "recall: episodic path"
440 );
441 (kw, Vec::new())
442 }
443 MemoryRoute::Graph => {
446 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
447 Ok(r) => r,
448 Err(e) => {
449 tracing::warn!("FTS5 keyword search failed (graphâhybrid fallback): {e:#}");
450 Vec::new()
451 }
452 };
453 let vr = self.recall_vectors_raw(query, limit, filter).await?;
454 (kw, vr)
455 }
456 };
457
458 tracing::debug!(
459 keyword_count = keyword_results.len(),
460 vector_count = vector_results.len(),
461 "recall: routed search results"
462 );
463
464 self.recall_merge_and_rank(keyword_results, vector_results, limit)
465 .await
466 }
467
468 pub async fn recall_graph(
482 &self,
483 query: &str,
484 limit: usize,
485 max_hops: u32,
486 at_timestamp: Option<&str>,
487 temporal_decay_rate: f64,
488 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
489 let Some(store) = &self.graph_store else {
490 return Ok(Vec::new());
491 };
492
493 tracing::debug!(
494 query_len = query.len(),
495 limit,
496 max_hops,
497 "graph: starting recall"
498 );
499
500 let results = crate::graph::retrieval::graph_recall(
501 store,
502 self.qdrant.as_deref(),
503 &self.provider,
504 query,
505 limit,
506 max_hops,
507 at_timestamp,
508 temporal_decay_rate,
509 )
510 .await?;
511
512 tracing::debug!(result_count = results.len(), "graph: recall complete");
513
514 Ok(results)
515 }
516
517 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
523 match &self.qdrant {
524 Some(qdrant) => qdrant.has_embedding(message_id).await,
525 None => Ok(false),
526 }
527 }
528
529 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
538 let Some(qdrant) = &self.qdrant else {
539 return Ok(0);
540 };
541 if !self.provider.supports_embeddings() {
542 return Ok(0);
543 }
544
545 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
546
547 if unembedded.is_empty() {
548 return Ok(0);
549 }
550
551 let probe = self.provider.embed("probe").await?;
552 let vector_size = u64::try_from(probe.len())?;
553 qdrant.ensure_collection(vector_size).await?;
554
555 let mut count = 0;
556 for (msg_id, conversation_id, role, content) in &unembedded {
557 match self.provider.embed(content).await {
558 Ok(vector) => {
559 if let Err(e) = qdrant
560 .store(
561 *msg_id,
562 *conversation_id,
563 role,
564 vector,
565 MessageKind::Regular,
566 &self.embedding_model,
567 )
568 .await
569 {
570 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
571 continue;
572 }
573 count += 1;
574 }
575 Err(e) => {
576 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
577 }
578 }
579 }
580
581 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
582 Ok(count)
583 }
584}