1#[cfg(feature = "hnsw")]
4use crate::db::{enqueue_pending_index_op, PendingIndexOpKind};
5use crate::db::{parse_optional_json, parse_role, with_transaction};
6use crate::error::MemoryError;
7use crate::quantize::{self, Quantizer};
8use crate::search;
9use crate::types::{Message, Role, SearchResult, SearchSourceType, Session};
10use crate::{as_str_slice, merge_trace_ctx, to_owned_string_vec, MemoryStore};
11use rusqlite::{params, Connection};
12use stack_ids::TraceCtx;
13
14pub fn create_session(
16 conn: &Connection,
17 channel: &str,
18 metadata: Option<&serde_json::Value>,
19) -> Result<String, MemoryError> {
20 let id = uuid::Uuid::new_v4().to_string();
21 let metadata_str = metadata.map(|m| m.to_string());
22 conn.execute(
23 "INSERT INTO sessions (id, channel, metadata) VALUES (?1, ?2, ?3)",
24 params![id, channel, metadata_str],
25 )?;
26 Ok(id)
27}
28
29#[allow(dead_code)]
31pub fn add_message(
32 conn: &Connection,
33 session_id: &str,
34 role: Role,
35 content: &str,
36 token_count: Option<u32>,
37 metadata: Option<&serde_json::Value>,
38) -> Result<i64, MemoryError> {
39 let exists: bool = conn.query_row(
40 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
41 params![session_id],
42 |row| row.get(0),
43 )?;
44 if !exists {
45 return Err(MemoryError::SessionNotFound(session_id.to_string()));
46 }
47
48 let metadata_str = metadata.map(|m| m.to_string());
49 with_transaction(conn, |tx| {
50 tx.execute(
51 "INSERT INTO messages (session_id, role, content, token_count, metadata)
52 VALUES (?1, ?2, ?3, ?4, ?5)",
53 params![
54 session_id,
55 role.as_str(),
56 content,
57 token_count,
58 metadata_str
59 ],
60 )?;
61 let msg_id = tx.last_insert_rowid();
62 tx.execute(
63 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
64 params![session_id],
65 )?;
66 Ok(msg_id)
67 })
68}
69
70pub fn get_recent_messages(
72 conn: &Connection,
73 session_id: &str,
74 limit: usize,
75) -> Result<Vec<Message>, MemoryError> {
76 let mut stmt = conn.prepare(
77 "SELECT id, session_id, role, content, token_count, created_at, metadata
78 FROM messages
79 WHERE session_id = ?1
80 ORDER BY created_at DESC, id DESC
81 LIMIT ?2",
82 )?;
83
84 let mut messages: Vec<Message> = stmt
85 .query_map(params![session_id, limit as i64], |row| {
86 Ok((
87 row.get::<_, i64>(0)?,
88 row.get::<_, String>(1)?,
89 row.get::<_, String>(2)?,
90 row.get::<_, String>(3)?,
91 row.get::<_, Option<u32>>(4)?,
92 row.get::<_, String>(5)?,
93 row.get::<_, Option<String>>(6)?,
94 ))
95 })?
96 .collect::<Result<Vec<_>, _>>()?
97 .into_iter()
98 .map(
99 |(id, session_id, role_raw, content, token_count, created_at, metadata_raw)| {
100 Ok(Message {
101 role: parse_role("messages", &id.to_string(), &role_raw)?,
102 metadata: parse_optional_json(
103 "messages",
104 &id.to_string(),
105 "metadata",
106 metadata_raw.as_deref(),
107 )?,
108 id,
109 session_id,
110 content,
111 token_count,
112 created_at,
113 })
114 },
115 )
116 .collect::<Result<Vec<_>, MemoryError>>()?;
117
118 messages.reverse();
119 Ok(messages)
120}
121
122pub fn get_messages_within_budget(
124 conn: &Connection,
125 session_id: &str,
126 max_tokens: u32,
127) -> Result<Vec<Message>, MemoryError> {
128 let mut stmt = conn.prepare(
129 "SELECT id, session_id, role, content, token_count, created_at, metadata
130 FROM messages
131 WHERE session_id = ?1
132 ORDER BY created_at DESC, id DESC",
133 )?;
134
135 let all_messages: Vec<Message> = stmt
136 .query_map(params![session_id], |row| {
137 Ok((
138 row.get::<_, i64>(0)?,
139 row.get::<_, String>(1)?,
140 row.get::<_, String>(2)?,
141 row.get::<_, String>(3)?,
142 row.get::<_, Option<u32>>(4)?,
143 row.get::<_, String>(5)?,
144 row.get::<_, Option<String>>(6)?,
145 ))
146 })?
147 .collect::<Result<Vec<_>, _>>()?
148 .into_iter()
149 .map(
150 |(id, session_id, role_raw, content, token_count, created_at, metadata_raw)| {
151 Ok(Message {
152 role: parse_role("messages", &id.to_string(), &role_raw)?,
153 metadata: parse_optional_json(
154 "messages",
155 &id.to_string(),
156 "metadata",
157 metadata_raw.as_deref(),
158 )?,
159 id,
160 session_id,
161 content,
162 token_count,
163 created_at,
164 })
165 },
166 )
167 .collect::<Result<Vec<_>, MemoryError>>()?;
168
169 let mut collected = Vec::new();
170 let mut total_tokens = 0u32;
171 for msg in all_messages {
172 let msg_tokens = msg
173 .token_count
174 .unwrap_or_else(|| (msg.content.len() / 4).max(1) as u32);
175 let next_total = total_tokens.saturating_add(msg_tokens);
176 if next_total > max_tokens && !collected.is_empty() {
177 break;
178 }
179 total_tokens = next_total;
180 collected.push(msg);
181 }
182
183 collected.reverse();
184 Ok(collected)
185}
186
187pub fn session_token_count(conn: &Connection, session_id: &str) -> Result<u64, MemoryError> {
189 let count: i64 = conn.query_row(
190 "SELECT COALESCE(SUM(token_count), 0) FROM messages WHERE session_id = ?1",
191 params![session_id],
192 |row| row.get(0),
193 )?;
194 if count < 0 {
195 return Err(MemoryError::CorruptData {
196 table: "messages",
197 row_id: session_id.to_string(),
198 detail: format!("negative token_count aggregate: {count}"),
199 });
200 }
201 Ok(count as u64)
202}
203
204#[allow(clippy::too_many_arguments)]
206pub fn add_message_with_embedding_q8(
207 conn: &Connection,
208 session_id: &str,
209 role: Role,
210 content: &str,
211 token_count: Option<u32>,
212 metadata: Option<&serde_json::Value>,
213 embedding_bytes: &[u8],
214 q8_bytes: Option<&[u8]>,
215) -> Result<i64, MemoryError> {
216 let exists: bool = conn.query_row(
217 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
218 params![session_id],
219 |row| row.get(0),
220 )?;
221 if !exists {
222 return Err(MemoryError::SessionNotFound(session_id.to_string()));
223 }
224
225 let metadata_str = metadata.map(|m| m.to_string());
226 with_transaction(conn, |tx| {
227 tx.execute(
228 "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8)
229 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
230 params![
231 session_id,
232 role.as_str(),
233 content,
234 token_count,
235 metadata_str,
236 embedding_bytes,
237 q8_bytes
238 ],
239 )?;
240 let msg_id = tx.last_insert_rowid();
241
242 tx.execute(
243 "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
244 params![msg_id],
245 )?;
246 let fts_rowid = tx.last_insert_rowid();
247 tx.execute(
248 "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
249 params![fts_rowid, content],
250 )?;
251
252 #[cfg(feature = "hnsw")]
253 enqueue_pending_index_op(
254 tx,
255 &format!("msg:{}", msg_id),
256 "message",
257 PendingIndexOpKind::Upsert,
258 )?;
259 crate::db::invalidate_derived_vector_artifact(tx, &format!("msg:{msg_id}"))?;
260
261 tx.execute(
262 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
263 params![session_id],
264 )?;
265
266 Ok(msg_id)
267 })
268}
269
270#[allow(dead_code, clippy::too_many_arguments)]
272pub fn add_message_with_embedding(
273 conn: &Connection,
274 session_id: &str,
275 role: Role,
276 content: &str,
277 token_count: Option<u32>,
278 metadata: Option<&serde_json::Value>,
279 embedding_bytes: &[u8],
280) -> Result<i64, MemoryError> {
281 add_message_with_embedding_q8(
282 conn,
283 session_id,
284 role,
285 content,
286 token_count,
287 metadata,
288 embedding_bytes,
289 None,
290 )
291}
292
293pub fn add_message_with_fts(
295 conn: &Connection,
296 session_id: &str,
297 role: Role,
298 content: &str,
299 token_count: Option<u32>,
300 metadata: Option<&serde_json::Value>,
301) -> Result<i64, MemoryError> {
302 let exists: bool = conn.query_row(
303 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
304 params![session_id],
305 |row| row.get(0),
306 )?;
307 if !exists {
308 return Err(MemoryError::SessionNotFound(session_id.to_string()));
309 }
310
311 let metadata_str = metadata.map(|m| m.to_string());
312 with_transaction(conn, |tx| {
313 tx.execute(
314 "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8)
315 VALUES (?1, ?2, ?3, ?4, ?5, NULL, NULL)",
316 params![session_id, role.as_str(), content, token_count, metadata_str],
317 )?;
318 let msg_id = tx.last_insert_rowid();
319
320 tx.execute(
321 "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
322 params![msg_id],
323 )?;
324 let fts_rowid = tx.last_insert_rowid();
325 tx.execute(
326 "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
327 params![fts_rowid, content],
328 )?;
329 tx.execute(
330 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
331 params![session_id],
332 )?;
333
334 Ok(msg_id)
335 })
336}
337
338pub fn delete_session(conn: &Connection, session_id: &str) -> Result<(), MemoryError> {
340 with_transaction(conn, |tx| {
341 let fts_data: Vec<(i64, String, i64, bool)> = {
342 let mut stmt = tx.prepare(
343 "SELECT m.id, m.content, mm.rowid, m.embedding IS NOT NULL
344 FROM messages m
345 JOIN messages_rowid_map mm ON mm.message_id = m.id
346 WHERE m.session_id = ?1",
347 )?;
348 let rows = stmt.query_map(params![session_id], |row| {
349 Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?))
350 })?;
351 rows.collect::<Result<Vec<_>, _>>()?
352 };
353
354 for (msg_id, content, fts_rowid, has_embedding) in &fts_data {
355 tx.execute(
356 "INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', ?1, ?2)",
357 params![fts_rowid, content],
358 )?;
359
360 #[cfg(feature = "hnsw")]
361 if *has_embedding {
362 enqueue_pending_index_op(
363 tx,
364 &format!("msg:{}", msg_id),
365 "message",
366 PendingIndexOpKind::Delete,
367 )?;
368 }
369
370 #[cfg(not(feature = "hnsw"))]
371 {
372 let _ = msg_id;
373 let _ = has_embedding;
374 }
375 if *has_embedding {
376 crate::db::invalidate_derived_vector_artifact(tx, &format!("msg:{msg_id}"))?;
377 }
378 }
379
380 let affected = tx.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
381 if affected == 0 {
382 return Err(MemoryError::SessionNotFound(session_id.to_string()));
383 }
384
385 Ok(())
386 })
387}
388
389pub fn list_sessions(
391 conn: &Connection,
392 limit: usize,
393 offset: usize,
394) -> Result<Vec<Session>, MemoryError> {
395 let mut stmt = conn.prepare(
396 "SELECT s.id, s.channel, s.created_at, s.updated_at, s.metadata,
397 COUNT(m.id) AS message_count
398 FROM sessions s
399 LEFT JOIN messages m ON m.session_id = s.id
400 GROUP BY s.id
401 ORDER BY s.updated_at DESC
402 LIMIT ?1 OFFSET ?2",
403 )?;
404
405 let sessions = stmt
406 .query_map(params![limit as i64, offset as i64], |row| {
407 Ok((
408 row.get::<_, String>(0)?,
409 row.get::<_, String>(1)?,
410 row.get::<_, String>(2)?,
411 row.get::<_, String>(3)?,
412 row.get::<_, Option<String>>(4)?,
413 row.get::<_, i64>(5)? as u32,
414 ))
415 })?
416 .collect::<Result<Vec<_>, _>>()?
417 .into_iter()
418 .map(
419 |(id, channel, created_at, updated_at, metadata_raw, message_count)| {
420 Ok(Session {
421 metadata: parse_optional_json(
422 "sessions",
423 &id,
424 "metadata",
425 metadata_raw.as_deref(),
426 )?,
427 id,
428 channel,
429 created_at,
430 updated_at,
431 message_count,
432 })
433 },
434 )
435 .collect::<Result<Vec<_>, MemoryError>>()?;
436
437 Ok(sessions)
438}
439
440pub fn rename_session(
442 conn: &Connection,
443 session_id: &str,
444 new_channel: &str,
445) -> Result<(), MemoryError> {
446 let affected = conn.execute(
447 "UPDATE sessions SET channel = ?1, updated_at = datetime('now') WHERE id = ?2",
448 params![new_channel, session_id],
449 )?;
450 if affected == 0 {
451 return Err(MemoryError::SessionNotFound(session_id.to_string()));
452 }
453 Ok(())
454}
455
456impl MemoryStore {
457 pub async fn create_session(&self, channel: &str) -> Result<String, MemoryError> {
459 let channel = channel.to_string();
460 self.with_write_conn(move |conn| create_session(conn, &channel, None))
461 .await
462 }
463
464 pub async fn create_session_with_metadata(
469 &self,
470 channel: &str,
471 metadata: Option<serde_json::Value>,
472 ) -> Result<String, MemoryError> {
473 let channel = channel.to_string();
474 self.with_write_conn(move |conn| create_session(conn, &channel, metadata.as_ref()))
475 .await
476 }
477
478 pub async fn rename_session(
480 &self,
481 session_id: &str,
482 new_channel: &str,
483 ) -> Result<(), MemoryError> {
484 let sid = session_id.to_string();
485 let ch = new_channel.to_string();
486 self.with_write_conn(move |conn| rename_session(conn, &sid, &ch))
487 .await
488 }
489
490 pub async fn list_sessions(
492 &self,
493 limit: usize,
494 offset: usize,
495 ) -> Result<Vec<Session>, MemoryError> {
496 self.with_read_conn(move |conn| list_sessions(conn, limit, offset))
497 .await
498 }
499
500 pub async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
504 let sid = session_id.to_string();
505 self.with_write_conn(move |conn| delete_session(conn, &sid))
506 .await?;
507
508 #[cfg(feature = "hnsw")]
509 self.sync_pending_hnsw_ops_best_effort("delete_session")
510 .await;
511
512 Ok(())
513 }
514
515 pub async fn add_message(
517 &self,
518 session_id: &str,
519 role: Role,
520 content: &str,
521 token_count: Option<u32>,
522 metadata: Option<serde_json::Value>,
523 ) -> Result<i64, MemoryError> {
524 self.add_message_with_trace(session_id, role, content, token_count, metadata, None)
525 .await
526 }
527
528 pub async fn add_message_with_trace(
530 &self,
531 session_id: &str,
532 role: Role,
533 content: &str,
534 token_count: Option<u32>,
535 metadata: Option<serde_json::Value>,
536 trace_ctx: Option<&TraceCtx>,
537 ) -> Result<i64, MemoryError> {
538 self.add_message_embedded_with_trace(
539 session_id,
540 role,
541 content,
542 token_count,
543 metadata,
544 trace_ctx,
545 )
546 .await
547 }
548
549 pub async fn add_message_fts(
554 &self,
555 session_id: &str,
556 role: Role,
557 content: &str,
558 token_count: Option<u32>,
559 metadata: Option<serde_json::Value>,
560 ) -> Result<i64, MemoryError> {
561 self.add_message_fts_with_trace(session_id, role, content, token_count, metadata, None)
562 .await
563 }
564
565 pub async fn add_message_fts_with_trace(
567 &self,
568 session_id: &str,
569 role: Role,
570 content: &str,
571 token_count: Option<u32>,
572 metadata: Option<serde_json::Value>,
573 trace_ctx: Option<&TraceCtx>,
574 ) -> Result<i64, MemoryError> {
575 self.validate_content("message.content", content)?;
576
577 let effective_token_count =
578 token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
579 let sid = session_id.to_string();
580 let ct = content.to_string();
581 let meta = merge_trace_ctx(metadata, trace_ctx);
582 self.with_write_conn(move |conn| {
583 add_message_with_fts(conn, &sid, role, &ct, effective_token_count, meta.as_ref())
584 })
585 .await
586 }
587
588 pub async fn get_recent_messages(
590 &self,
591 session_id: &str,
592 limit: usize,
593 ) -> Result<Vec<Message>, MemoryError> {
594 let sid = session_id.to_string();
595 self.with_read_conn(move |conn| get_recent_messages(conn, &sid, limit))
596 .await
597 }
598
599 pub async fn get_messages_within_budget(
601 &self,
602 session_id: &str,
603 max_tokens: u32,
604 ) -> Result<Vec<Message>, MemoryError> {
605 let sid = session_id.to_string();
606 self.with_read_conn(move |conn| get_messages_within_budget(conn, &sid, max_tokens))
607 .await
608 }
609
610 pub async fn session_token_count(&self, session_id: &str) -> Result<u64, MemoryError> {
612 let sid = session_id.to_string();
613 self.with_read_conn(move |conn| session_token_count(conn, &sid))
614 .await
615 }
616
617 pub async fn add_message_embedded(
619 &self,
620 session_id: &str,
621 role: Role,
622 content: &str,
623 token_count: Option<u32>,
624 metadata: Option<serde_json::Value>,
625 ) -> Result<i64, MemoryError> {
626 self.add_message_embedded_with_trace(session_id, role, content, token_count, metadata, None)
627 .await
628 }
629
630 pub async fn add_message_embedded_with_trace(
632 &self,
633 session_id: &str,
634 role: Role,
635 content: &str,
636 token_count: Option<u32>,
637 metadata: Option<serde_json::Value>,
638 trace_ctx: Option<&TraceCtx>,
639 ) -> Result<i64, MemoryError> {
640 self.validate_content("message.content", content)?;
641
642 let effective_token_count =
643 token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
644
645 let embedding = self.embed_text_internal(content).await?;
646 self.validate_embedding_dimensions(&embedding)?;
647 let embedding_bytes = crate::db::embedding_to_bytes(&embedding);
648 let q8_bytes = Quantizer::new(self.inner.config.embedding.dimensions)
650 .quantize(&embedding)
651 .map(|qv| quantize::pack_quantized(&qv))
652 .ok();
653
654 let sid = session_id.to_string();
655 let ct = content.to_string();
656 let meta = merge_trace_ctx(metadata, trace_ctx);
657 let msg_id = self
658 .with_write_conn(move |conn| {
659 add_message_with_embedding_q8(
660 conn,
661 &sid,
662 role,
663 &ct,
664 effective_token_count,
665 meta.as_ref(),
666 &embedding_bytes,
667 q8_bytes.as_deref(),
668 )
669 })
670 .await?;
671
672 #[cfg(feature = "hnsw")]
673 self.sync_pending_hnsw_ops_best_effort("add_message_embedded")
674 .await;
675
676 Ok(msg_id)
677 }
678
679 pub async fn search_conversations(
681 &self,
682 query: &str,
683 top_k: Option<usize>,
684 session_ids: Option<&[&str]>,
685 ) -> Result<Vec<SearchResult>, MemoryError> {
686 const MAX_TOP_K: usize = 1_000;
687 let k = top_k
688 .unwrap_or(self.inner.config.search.default_top_k)
689 .min(MAX_TOP_K);
690
691 let query_embedding = self.embed_text_internal(query).await?;
692
693 #[cfg(feature = "hnsw")]
694 let hnsw_hits = {
695 let index = self
696 .inner
697 .hnsw_index
698 .read()
699 .unwrap_or_else(|e| e.into_inner())
700 .clone();
701 let candidates = self
702 .inner
703 .config
704 .search
705 .candidate_pool_size
706 .max(k.saturating_mul(3))
707 .min(MAX_TOP_K.saturating_mul(10));
708 let query_embedding_for_hnsw = query_embedding.clone();
709 match tokio::task::spawn_blocking(move || {
710 index.search(&query_embedding_for_hnsw, candidates)
711 })
712 .await
713 {
714 Ok(Ok(hits)) => hits,
715 Ok(Err(err)) => {
716 tracing::error!(
717 "HNSW conversation search failed, falling back to brute-force message search: {}",
718 err
719 );
720 Vec::new()
721 }
722 Err(err) => {
723 tracing::error!(
724 "HNSW conversation search task failed, falling back to brute-force message search: {}",
725 err
726 );
727 Vec::new()
728 }
729 }
730 };
731
732 let q = query.to_string();
733 let config = self.inner.config.search.clone();
734 let sids_owned = to_owned_string_vec(session_ids);
735
736 #[cfg(feature = "hnsw")]
737 let hnsw_hits_owned = hnsw_hits;
738
739 self.with_read_conn(move |conn| {
740 let sids_refs = as_str_slice(&sids_owned);
741 let sids_slice: Option<&[&str]> = sids_refs.as_deref();
742 #[cfg(feature = "hnsw")]
743 {
744 if hnsw_hits_owned.is_empty() {
745 search::hybrid_search(
746 conn,
747 &q,
748 &query_embedding,
749 &config,
750 k,
751 None,
752 Some(&[SearchSourceType::Messages]),
753 sids_slice,
754 )
755 } else {
756 search::hybrid_search_with_hnsw(
757 conn,
758 &q,
759 &query_embedding,
760 &config,
761 k,
762 None,
763 Some(&[SearchSourceType::Messages]),
764 sids_slice,
765 &hnsw_hits_owned,
766 )
767 }
768 }
769 #[cfg(not(feature = "hnsw"))]
770 {
771 search::hybrid_search(
772 conn,
773 &q,
774 &query_embedding,
775 &config,
776 k,
777 None,
778 Some(&[SearchSourceType::Messages]),
779 sids_slice,
780 )
781 }
782 })
783 .await
784 }
785}