1#[cfg(feature = "hnsw")]
4use crate::db::{enqueue_pending_index_op, PendingIndexOpKind};
5use crate::db::{parse_optional_json, parse_role, with_transaction};
6use crate::error::MemoryError;
7use crate::quantize::{self, Quantizer};
8use crate::search;
9use crate::types::{Message, Role, SearchResult, SearchSourceType, Session};
10use crate::{as_str_slice, merge_trace_ctx, to_owned_string_vec, MemoryStore};
11use rusqlite::{params, Connection};
12use stack_ids::TraceCtx;
13
14pub fn create_session(
16 conn: &Connection,
17 channel: &str,
18 metadata: Option<&serde_json::Value>,
19) -> Result<String, MemoryError> {
20 let id = uuid::Uuid::new_v4().to_string();
21 let metadata_str = metadata.map(|m| m.to_string());
22 conn.execute(
23 "INSERT INTO sessions (id, channel, metadata) VALUES (?1, ?2, ?3)",
24 params![id, channel, metadata_str],
25 )?;
26 Ok(id)
27}
28
29#[allow(dead_code)]
31pub fn add_message(
32 conn: &Connection,
33 session_id: &str,
34 role: Role,
35 content: &str,
36 token_count: Option<u32>,
37 metadata: Option<&serde_json::Value>,
38) -> Result<i64, MemoryError> {
39 let exists: bool = conn.query_row(
40 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
41 params![session_id],
42 |row| row.get(0),
43 )?;
44 if !exists {
45 return Err(MemoryError::SessionNotFound(session_id.to_string()));
46 }
47
48 let metadata_str = metadata.map(|m| m.to_string());
49 with_transaction(conn, |tx| {
50 tx.execute(
51 "INSERT INTO messages (session_id, role, content, token_count, metadata)
52 VALUES (?1, ?2, ?3, ?4, ?5)",
53 params![
54 session_id,
55 role.as_str(),
56 content,
57 token_count,
58 metadata_str
59 ],
60 )?;
61 let msg_id = tx.last_insert_rowid();
62 tx.execute(
63 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
64 params![session_id],
65 )?;
66 Ok(msg_id)
67 })
68}
69
70pub fn get_recent_messages(
72 conn: &Connection,
73 session_id: &str,
74 limit: usize,
75) -> Result<Vec<Message>, MemoryError> {
76 let mut stmt = conn.prepare(
77 "SELECT id, session_id, role, content, token_count, created_at, metadata
78 FROM messages
79 WHERE session_id = ?1
80 ORDER BY created_at DESC, id DESC
81 LIMIT ?2",
82 )?;
83
84 let mut messages: Vec<Message> = stmt
85 .query_map(params![session_id, limit as i64], |row| {
86 Ok((
87 row.get::<_, i64>(0)?,
88 row.get::<_, String>(1)?,
89 row.get::<_, String>(2)?,
90 row.get::<_, String>(3)?,
91 row.get::<_, Option<u32>>(4)?,
92 row.get::<_, String>(5)?,
93 row.get::<_, Option<String>>(6)?,
94 ))
95 })?
96 .collect::<Result<Vec<_>, _>>()?
97 .into_iter()
98 .map(
99 |(id, session_id, role_raw, content, token_count, created_at, metadata_raw)| {
100 Ok(Message {
101 role: parse_role("messages", &id.to_string(), &role_raw)?,
102 metadata: parse_optional_json(
103 "messages",
104 &id.to_string(),
105 "metadata",
106 metadata_raw.as_deref(),
107 )?,
108 id,
109 session_id,
110 content,
111 token_count,
112 created_at,
113 })
114 },
115 )
116 .collect::<Result<Vec<_>, MemoryError>>()?;
117
118 messages.reverse();
119 Ok(messages)
120}
121
122pub fn get_messages_within_budget(
124 conn: &Connection,
125 session_id: &str,
126 max_tokens: u32,
127) -> Result<Vec<Message>, MemoryError> {
128 let mut stmt = conn.prepare(
129 "SELECT id, session_id, role, content, token_count, created_at, metadata
130 FROM messages
131 WHERE session_id = ?1
132 ORDER BY created_at DESC, id DESC",
133 )?;
134
135 let all_messages: Vec<Message> = stmt
136 .query_map(params![session_id], |row| {
137 Ok((
138 row.get::<_, i64>(0)?,
139 row.get::<_, String>(1)?,
140 row.get::<_, String>(2)?,
141 row.get::<_, String>(3)?,
142 row.get::<_, Option<u32>>(4)?,
143 row.get::<_, String>(5)?,
144 row.get::<_, Option<String>>(6)?,
145 ))
146 })?
147 .collect::<Result<Vec<_>, _>>()?
148 .into_iter()
149 .map(
150 |(id, session_id, role_raw, content, token_count, created_at, metadata_raw)| {
151 Ok(Message {
152 role: parse_role("messages", &id.to_string(), &role_raw)?,
153 metadata: parse_optional_json(
154 "messages",
155 &id.to_string(),
156 "metadata",
157 metadata_raw.as_deref(),
158 )?,
159 id,
160 session_id,
161 content,
162 token_count,
163 created_at,
164 })
165 },
166 )
167 .collect::<Result<Vec<_>, MemoryError>>()?;
168
169 let mut collected = Vec::new();
170 let mut total_tokens = 0u32;
171 for msg in all_messages {
172 let msg_tokens = msg.token_count.unwrap_or(0);
173 if total_tokens + msg_tokens > max_tokens && !collected.is_empty() {
174 break;
175 }
176 total_tokens += msg_tokens;
177 collected.push(msg);
178 }
179
180 collected.reverse();
181 Ok(collected)
182}
183
184pub fn session_token_count(conn: &Connection, session_id: &str) -> Result<u64, MemoryError> {
186 let count: i64 = conn.query_row(
187 "SELECT COALESCE(SUM(token_count), 0) FROM messages WHERE session_id = ?1",
188 params![session_id],
189 |row| row.get(0),
190 )?;
191 Ok(count as u64)
192}
193
194#[allow(clippy::too_many_arguments)]
196pub fn add_message_with_embedding_q8(
197 conn: &Connection,
198 session_id: &str,
199 role: Role,
200 content: &str,
201 token_count: Option<u32>,
202 metadata: Option<&serde_json::Value>,
203 embedding_bytes: &[u8],
204 q8_bytes: Option<&[u8]>,
205) -> Result<i64, MemoryError> {
206 let exists: bool = conn.query_row(
207 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
208 params![session_id],
209 |row| row.get(0),
210 )?;
211 if !exists {
212 return Err(MemoryError::SessionNotFound(session_id.to_string()));
213 }
214
215 let metadata_str = metadata.map(|m| m.to_string());
216 with_transaction(conn, |tx| {
217 tx.execute(
218 "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8)
219 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
220 params![
221 session_id,
222 role.as_str(),
223 content,
224 token_count,
225 metadata_str,
226 embedding_bytes,
227 q8_bytes
228 ],
229 )?;
230 let msg_id = tx.last_insert_rowid();
231
232 tx.execute(
233 "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
234 params![msg_id],
235 )?;
236 let fts_rowid = tx.last_insert_rowid();
237 tx.execute(
238 "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
239 params![fts_rowid, content],
240 )?;
241
242 #[cfg(feature = "hnsw")]
243 enqueue_pending_index_op(
244 tx,
245 &format!("msg:{}", msg_id),
246 "message",
247 PendingIndexOpKind::Upsert,
248 )?;
249
250 tx.execute(
251 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
252 params![session_id],
253 )?;
254
255 Ok(msg_id)
256 })
257}
258
259#[allow(dead_code, clippy::too_many_arguments)]
261pub fn add_message_with_embedding(
262 conn: &Connection,
263 session_id: &str,
264 role: Role,
265 content: &str,
266 token_count: Option<u32>,
267 metadata: Option<&serde_json::Value>,
268 embedding_bytes: &[u8],
269) -> Result<i64, MemoryError> {
270 add_message_with_embedding_q8(
271 conn,
272 session_id,
273 role,
274 content,
275 token_count,
276 metadata,
277 embedding_bytes,
278 None,
279 )
280}
281
282pub fn add_message_with_fts(
284 conn: &Connection,
285 session_id: &str,
286 role: Role,
287 content: &str,
288 token_count: Option<u32>,
289 metadata: Option<&serde_json::Value>,
290) -> Result<i64, MemoryError> {
291 let exists: bool = conn.query_row(
292 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
293 params![session_id],
294 |row| row.get(0),
295 )?;
296 if !exists {
297 return Err(MemoryError::SessionNotFound(session_id.to_string()));
298 }
299
300 let metadata_str = metadata.map(|m| m.to_string());
301 with_transaction(conn, |tx| {
302 tx.execute(
303 "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8)
304 VALUES (?1, ?2, ?3, ?4, ?5, NULL, NULL)",
305 params![session_id, role.as_str(), content, token_count, metadata_str],
306 )?;
307 let msg_id = tx.last_insert_rowid();
308
309 tx.execute(
310 "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
311 params![msg_id],
312 )?;
313 let fts_rowid = tx.last_insert_rowid();
314 tx.execute(
315 "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
316 params![fts_rowid, content],
317 )?;
318 tx.execute(
319 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
320 params![session_id],
321 )?;
322
323 Ok(msg_id)
324 })
325}
326
327pub fn delete_session(conn: &Connection, session_id: &str) -> Result<(), MemoryError> {
329 with_transaction(conn, |tx| {
330 let fts_data: Vec<(i64, String, i64, bool)> = {
331 let mut stmt = tx.prepare(
332 "SELECT m.id, m.content, mm.rowid, m.embedding IS NOT NULL
333 FROM messages m
334 JOIN messages_rowid_map mm ON mm.message_id = m.id
335 WHERE m.session_id = ?1",
336 )?;
337 let rows = stmt.query_map(params![session_id], |row| {
338 Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?))
339 })?;
340 rows.collect::<Result<Vec<_>, _>>()?
341 };
342
343 for (msg_id, content, fts_rowid, has_embedding) in &fts_data {
344 tx.execute(
345 "INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', ?1, ?2)",
346 params![fts_rowid, content],
347 )?;
348
349 #[cfg(feature = "hnsw")]
350 if *has_embedding {
351 enqueue_pending_index_op(
352 tx,
353 &format!("msg:{}", msg_id),
354 "message",
355 PendingIndexOpKind::Delete,
356 )?;
357 }
358
359 #[cfg(not(feature = "hnsw"))]
360 {
361 let _ = msg_id;
362 let _ = has_embedding;
363 }
364 }
365
366 let affected = tx.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
367 if affected == 0 {
368 return Err(MemoryError::SessionNotFound(session_id.to_string()));
369 }
370
371 Ok(())
372 })
373}
374
375pub fn list_sessions(
377 conn: &Connection,
378 limit: usize,
379 offset: usize,
380) -> Result<Vec<Session>, MemoryError> {
381 let mut stmt = conn.prepare(
382 "SELECT s.id, s.channel, s.created_at, s.updated_at, s.metadata,
383 COUNT(m.id) AS message_count
384 FROM sessions s
385 LEFT JOIN messages m ON m.session_id = s.id
386 GROUP BY s.id
387 ORDER BY s.updated_at DESC
388 LIMIT ?1 OFFSET ?2",
389 )?;
390
391 let sessions = stmt
392 .query_map(params![limit as i64, offset as i64], |row| {
393 Ok((
394 row.get::<_, String>(0)?,
395 row.get::<_, String>(1)?,
396 row.get::<_, String>(2)?,
397 row.get::<_, String>(3)?,
398 row.get::<_, Option<String>>(4)?,
399 row.get::<_, i64>(5)? as u32,
400 ))
401 })?
402 .collect::<Result<Vec<_>, _>>()?
403 .into_iter()
404 .map(
405 |(id, channel, created_at, updated_at, metadata_raw, message_count)| {
406 Ok(Session {
407 metadata: parse_optional_json(
408 "sessions",
409 &id,
410 "metadata",
411 metadata_raw.as_deref(),
412 )?,
413 id,
414 channel,
415 created_at,
416 updated_at,
417 message_count,
418 })
419 },
420 )
421 .collect::<Result<Vec<_>, MemoryError>>()?;
422
423 Ok(sessions)
424}
425
426pub fn rename_session(
428 conn: &Connection,
429 session_id: &str,
430 new_channel: &str,
431) -> Result<(), MemoryError> {
432 let affected = conn.execute(
433 "UPDATE sessions SET channel = ?1, updated_at = datetime('now') WHERE id = ?2",
434 params![new_channel, session_id],
435 )?;
436 if affected == 0 {
437 return Err(MemoryError::SessionNotFound(session_id.to_string()));
438 }
439 Ok(())
440}
441
442impl MemoryStore {
443 pub async fn create_session(&self, channel: &str) -> Result<String, MemoryError> {
445 let channel = channel.to_string();
446 self.with_write_conn(move |conn| create_session(conn, &channel, None))
447 .await
448 }
449
450 pub async fn create_session_with_metadata(
455 &self,
456 channel: &str,
457 metadata: Option<serde_json::Value>,
458 ) -> Result<String, MemoryError> {
459 let channel = channel.to_string();
460 self.with_write_conn(move |conn| create_session(conn, &channel, metadata.as_ref()))
461 .await
462 }
463
464 pub async fn rename_session(
466 &self,
467 session_id: &str,
468 new_channel: &str,
469 ) -> Result<(), MemoryError> {
470 let sid = session_id.to_string();
471 let ch = new_channel.to_string();
472 self.with_write_conn(move |conn| rename_session(conn, &sid, &ch))
473 .await
474 }
475
476 pub async fn list_sessions(
478 &self,
479 limit: usize,
480 offset: usize,
481 ) -> Result<Vec<Session>, MemoryError> {
482 self.with_read_conn(move |conn| list_sessions(conn, limit, offset))
483 .await
484 }
485
486 pub async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
490 let sid = session_id.to_string();
491 self.with_write_conn(move |conn| delete_session(conn, &sid))
492 .await?;
493
494 #[cfg(feature = "hnsw")]
495 self.sync_pending_hnsw_ops_best_effort("delete_session")
496 .await;
497
498 Ok(())
499 }
500
501 pub async fn add_message(
503 &self,
504 session_id: &str,
505 role: Role,
506 content: &str,
507 token_count: Option<u32>,
508 metadata: Option<serde_json::Value>,
509 ) -> Result<i64, MemoryError> {
510 self.add_message_with_trace(session_id, role, content, token_count, metadata, None)
511 .await
512 }
513
514 pub async fn add_message_with_trace(
516 &self,
517 session_id: &str,
518 role: Role,
519 content: &str,
520 token_count: Option<u32>,
521 metadata: Option<serde_json::Value>,
522 trace_ctx: Option<&TraceCtx>,
523 ) -> Result<i64, MemoryError> {
524 self.add_message_embedded_with_trace(
525 session_id,
526 role,
527 content,
528 token_count,
529 metadata,
530 trace_ctx,
531 )
532 .await
533 }
534
535 pub async fn add_message_fts(
540 &self,
541 session_id: &str,
542 role: Role,
543 content: &str,
544 token_count: Option<u32>,
545 metadata: Option<serde_json::Value>,
546 ) -> Result<i64, MemoryError> {
547 self.add_message_fts_with_trace(session_id, role, content, token_count, metadata, None)
548 .await
549 }
550
551 pub async fn add_message_fts_with_trace(
553 &self,
554 session_id: &str,
555 role: Role,
556 content: &str,
557 token_count: Option<u32>,
558 metadata: Option<serde_json::Value>,
559 trace_ctx: Option<&TraceCtx>,
560 ) -> Result<i64, MemoryError> {
561 self.validate_content("message.content", content)?;
562
563 let effective_token_count =
564 token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
565 let sid = session_id.to_string();
566 let ct = content.to_string();
567 let meta = merge_trace_ctx(metadata, trace_ctx);
568 self.with_write_conn(move |conn| {
569 add_message_with_fts(conn, &sid, role, &ct, effective_token_count, meta.as_ref())
570 })
571 .await
572 }
573
574 pub async fn get_recent_messages(
576 &self,
577 session_id: &str,
578 limit: usize,
579 ) -> Result<Vec<Message>, MemoryError> {
580 let sid = session_id.to_string();
581 self.with_read_conn(move |conn| get_recent_messages(conn, &sid, limit))
582 .await
583 }
584
585 pub async fn get_messages_within_budget(
587 &self,
588 session_id: &str,
589 max_tokens: u32,
590 ) -> Result<Vec<Message>, MemoryError> {
591 let sid = session_id.to_string();
592 self.with_read_conn(move |conn| get_messages_within_budget(conn, &sid, max_tokens))
593 .await
594 }
595
596 pub async fn session_token_count(&self, session_id: &str) -> Result<u64, MemoryError> {
598 let sid = session_id.to_string();
599 self.with_read_conn(move |conn| session_token_count(conn, &sid))
600 .await
601 }
602
603 pub async fn add_message_embedded(
605 &self,
606 session_id: &str,
607 role: Role,
608 content: &str,
609 token_count: Option<u32>,
610 metadata: Option<serde_json::Value>,
611 ) -> Result<i64, MemoryError> {
612 self.add_message_embedded_with_trace(session_id, role, content, token_count, metadata, None)
613 .await
614 }
615
616 pub async fn add_message_embedded_with_trace(
618 &self,
619 session_id: &str,
620 role: Role,
621 content: &str,
622 token_count: Option<u32>,
623 metadata: Option<serde_json::Value>,
624 trace_ctx: Option<&TraceCtx>,
625 ) -> Result<i64, MemoryError> {
626 self.validate_content("message.content", content)?;
627
628 let effective_token_count =
629 token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
630
631 let embedding = self.embed_text_internal(content).await?;
632 self.validate_embedding_dimensions(&embedding)?;
633 let embedding_bytes = crate::db::embedding_to_bytes(&embedding);
634 let q8_bytes = Quantizer::new(self.inner.config.embedding.dimensions)
636 .quantize(&embedding)
637 .map(|qv| quantize::pack_quantized(&qv))
638 .ok();
639
640 let sid = session_id.to_string();
641 let ct = content.to_string();
642 let meta = merge_trace_ctx(metadata, trace_ctx);
643 let msg_id = self
644 .with_write_conn(move |conn| {
645 add_message_with_embedding_q8(
646 conn,
647 &sid,
648 role,
649 &ct,
650 effective_token_count,
651 meta.as_ref(),
652 &embedding_bytes,
653 q8_bytes.as_deref(),
654 )
655 })
656 .await?;
657
658 #[cfg(feature = "hnsw")]
659 self.sync_pending_hnsw_ops_best_effort("add_message_embedded")
660 .await;
661
662 Ok(msg_id)
663 }
664
665 pub async fn search_conversations(
667 &self,
668 query: &str,
669 top_k: Option<usize>,
670 session_ids: Option<&[&str]>,
671 ) -> Result<Vec<SearchResult>, MemoryError> {
672 let k = top_k.unwrap_or(self.inner.config.search.default_top_k);
673
674 let query_embedding = self.embed_text_internal(query).await?;
675
676 #[cfg(feature = "hnsw")]
677 let hnsw_hits = {
678 let guard = self
679 .inner
680 .hnsw_index
681 .read()
682 .unwrap_or_else(|e| e.into_inner());
683 let candidates = self.inner.config.search.candidate_pool_size.max(k * 3);
684 match guard.search(&query_embedding, candidates) {
685 Ok(hits) => hits,
686 Err(err) => {
687 tracing::error!(
688 "HNSW conversation search failed, falling back to brute-force message search: {}",
689 err
690 );
691 Vec::new()
692 }
693 }
694 };
695
696 let q = query.to_string();
697 let config = self.inner.config.search.clone();
698 let sids_owned = to_owned_string_vec(session_ids);
699
700 #[cfg(feature = "hnsw")]
701 let hnsw_hits_owned = hnsw_hits;
702
703 self.with_read_conn(move |conn| {
704 let sids_refs = as_str_slice(&sids_owned);
705 let sids_slice: Option<&[&str]> = sids_refs.as_deref();
706 #[cfg(feature = "hnsw")]
707 {
708 if hnsw_hits_owned.is_empty() {
709 search::hybrid_search(
710 conn,
711 &q,
712 &query_embedding,
713 &config,
714 k,
715 None,
716 Some(&[SearchSourceType::Messages]),
717 sids_slice,
718 )
719 } else {
720 search::hybrid_search_with_hnsw(
721 conn,
722 &q,
723 &query_embedding,
724 &config,
725 k,
726 None,
727 Some(&[SearchSourceType::Messages]),
728 sids_slice,
729 &hnsw_hits_owned,
730 )
731 }
732 }
733 #[cfg(not(feature = "hnsw"))]
734 {
735 search::hybrid_search(
736 conn,
737 &q,
738 &query_embedding,
739 &config,
740 k,
741 None,
742 Some(&[SearchSourceType::Messages]),
743 sids_slice,
744 )
745 }
746 })
747 .await
748 }
749}