1use zeph_llm::provider::{Message, MessagePart, Role};
2
3use super::SqliteStore;
4use crate::error::MemoryError;
5use crate::types::{ConversationId, MessageId};
6
7fn parse_role(s: &str) -> Role {
8 match s {
9 "assistant" => Role::Assistant,
10 "system" => Role::System,
11 _ => Role::User,
12 }
13}
14
15#[must_use]
16pub fn role_str(role: Role) -> &'static str {
17 match role {
18 Role::System => "system",
19 Role::User => "user",
20 Role::Assistant => "assistant",
21 }
22}
23
24impl SqliteStore {
25 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
31 let row: (ConversationId,) =
32 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
33 .fetch_one(&self.pool)
34 .await?;
35 Ok(row.0)
36 }
37
38 pub async fn save_message(
44 &self,
45 conversation_id: ConversationId,
46 role: &str,
47 content: &str,
48 ) -> Result<MessageId, MemoryError> {
49 self.save_message_with_parts(conversation_id, role, content, "[]")
50 .await
51 }
52
53 pub async fn save_message_with_parts(
59 &self,
60 conversation_id: ConversationId,
61 role: &str,
62 content: &str,
63 parts_json: &str,
64 ) -> Result<MessageId, MemoryError> {
65 let row: (MessageId,) = sqlx::query_as(
66 "INSERT INTO messages (conversation_id, role, content, parts) VALUES (?, ?, ?, ?) RETURNING id",
67 )
68 .bind(conversation_id)
69 .bind(role)
70 .bind(content)
71 .bind(parts_json)
72 .fetch_one(&self.pool)
73 .await
74 ?;
75 Ok(row.0)
76 }
77
78 pub async fn load_history(
84 &self,
85 conversation_id: ConversationId,
86 limit: u32,
87 ) -> Result<Vec<Message>, MemoryError> {
88 let rows: Vec<(String, String, String)> = sqlx::query_as(
89 "SELECT role, content, parts FROM (\
90 SELECT role, content, parts, id FROM messages \
91 WHERE conversation_id = ? \
92 ORDER BY id DESC \
93 LIMIT ?\
94 ) ORDER BY id ASC",
95 )
96 .bind(conversation_id)
97 .bind(limit)
98 .fetch_all(&self.pool)
99 .await?;
100
101 let messages = rows
102 .into_iter()
103 .map(|(role_str, content, parts_json)| {
104 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
105 Message {
106 role: parse_role(&role_str),
107 content,
108 parts,
109 }
110 })
111 .collect();
112 Ok(messages)
113 }
114
115 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
121 let row: Option<(ConversationId,)> =
122 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
123 .fetch_optional(&self.pool)
124 .await?;
125 Ok(row.map(|r| r.0))
126 }
127
128 pub async fn message_by_id(
134 &self,
135 message_id: MessageId,
136 ) -> Result<Option<Message>, MemoryError> {
137 let row: Option<(String, String, String)> =
138 sqlx::query_as("SELECT role, content, parts FROM messages WHERE id = ?")
139 .bind(message_id)
140 .fetch_optional(&self.pool)
141 .await?;
142
143 Ok(row.map(|(role_str, content, parts_json)| {
144 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
145 Message {
146 role: parse_role(&role_str),
147 content,
148 parts,
149 }
150 }))
151 }
152
153 pub async fn messages_by_ids(
159 &self,
160 ids: &[MessageId],
161 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
162 if ids.is_empty() {
163 return Ok(Vec::new());
164 }
165
166 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
167
168 let query =
169 format!("SELECT id, role, content, parts FROM messages WHERE id IN ({placeholders})");
170 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
171 for &id in ids {
172 q = q.bind(id);
173 }
174
175 let rows = q.fetch_all(&self.pool).await?;
176
177 Ok(rows
178 .into_iter()
179 .map(|(id, role_str, content, parts_json)| {
180 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
181 (
182 id,
183 Message {
184 role: parse_role(&role_str),
185 content,
186 parts,
187 },
188 )
189 })
190 .collect())
191 }
192
193 pub async fn unembedded_message_ids(
199 &self,
200 limit: Option<usize>,
201 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
202 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
203
204 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
205 "SELECT m.id, m.conversation_id, m.role, m.content \
206 FROM messages m \
207 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
208 WHERE em.id IS NULL \
209 ORDER BY m.id ASC \
210 LIMIT ?",
211 )
212 .bind(effective_limit)
213 .fetch_all(&self.pool)
214 .await?;
215
216 Ok(rows)
217 }
218
219 pub async fn count_messages(
225 &self,
226 conversation_id: ConversationId,
227 ) -> Result<i64, MemoryError> {
228 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
229 .bind(conversation_id)
230 .fetch_one(&self.pool)
231 .await?;
232 Ok(row.0)
233 }
234
235 pub async fn count_messages_after(
241 &self,
242 conversation_id: ConversationId,
243 after_id: MessageId,
244 ) -> Result<i64, MemoryError> {
245 let row: (i64,) =
246 sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
247 .bind(conversation_id)
248 .bind(after_id)
249 .fetch_one(&self.pool)
250 .await?;
251 Ok(row.0)
252 }
253
254 pub async fn keyword_search(
263 &self,
264 query: &str,
265 limit: usize,
266 conversation_id: Option<ConversationId>,
267 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
268 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
269
270 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
271 sqlx::query_as(
272 "SELECT m.id, -rank AS score \
273 FROM messages_fts f \
274 JOIN messages m ON m.id = f.rowid \
275 WHERE messages_fts MATCH ? AND m.conversation_id = ? \
276 ORDER BY rank \
277 LIMIT ?",
278 )
279 .bind(query)
280 .bind(cid)
281 .bind(effective_limit)
282 .fetch_all(&self.pool)
283 .await?
284 } else {
285 sqlx::query_as(
286 "SELECT f.rowid, -rank AS score \
287 FROM messages_fts f \
288 WHERE messages_fts MATCH ? \
289 ORDER BY rank \
290 LIMIT ?",
291 )
292 .bind(query)
293 .bind(effective_limit)
294 .fetch_all(&self.pool)
295 .await?
296 };
297
298 Ok(rows)
299 }
300
301 pub async fn message_timestamps(
309 &self,
310 ids: &[MessageId],
311 ) -> Result<std::collections::HashMap<MessageId, i64>, MemoryError> {
312 if ids.is_empty() {
313 return Ok(std::collections::HashMap::new());
314 }
315
316 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
317 let query = format!(
318 "SELECT id, COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) \
319 FROM messages WHERE id IN ({placeholders})"
320 );
321 let mut q = sqlx::query_as::<_, (MessageId, i64)>(&query);
322 for &id in ids {
323 q = q.bind(id);
324 }
325
326 let rows = q.fetch_all(&self.pool).await?;
327 Ok(rows.into_iter().collect())
328 }
329
330 pub async fn load_messages_range(
336 &self,
337 conversation_id: ConversationId,
338 after_message_id: MessageId,
339 limit: usize,
340 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
341 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
342
343 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
344 "SELECT id, role, content FROM messages \
345 WHERE conversation_id = ? AND id > ? \
346 ORDER BY id ASC LIMIT ?",
347 )
348 .bind(conversation_id)
349 .bind(after_message_id)
350 .bind(effective_limit)
351 .fetch_all(&self.pool)
352 .await?;
353
354 Ok(rows)
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 async fn test_store() -> SqliteStore {
363 SqliteStore::new(":memory:").await.unwrap()
364 }
365
366 #[tokio::test]
367 async fn create_conversation_returns_id() {
368 let store = test_store().await;
369 let id1 = store.create_conversation().await.unwrap();
370 let id2 = store.create_conversation().await.unwrap();
371 assert_eq!(id1, ConversationId(1));
372 assert_eq!(id2, ConversationId(2));
373 }
374
375 #[tokio::test]
376 async fn save_and_load_messages() {
377 let store = test_store().await;
378 let cid = store.create_conversation().await.unwrap();
379
380 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
381 let msg_id2 = store
382 .save_message(cid, "assistant", "hi there")
383 .await
384 .unwrap();
385
386 assert_eq!(msg_id1, MessageId(1));
387 assert_eq!(msg_id2, MessageId(2));
388
389 let history = store.load_history(cid, 50).await.unwrap();
390 assert_eq!(history.len(), 2);
391 assert_eq!(history[0].role, Role::User);
392 assert_eq!(history[0].content, "hello");
393 assert_eq!(history[1].role, Role::Assistant);
394 assert_eq!(history[1].content, "hi there");
395 }
396
397 #[tokio::test]
398 async fn load_history_respects_limit() {
399 let store = test_store().await;
400 let cid = store.create_conversation().await.unwrap();
401
402 for i in 0..10 {
403 store
404 .save_message(cid, "user", &format!("msg {i}"))
405 .await
406 .unwrap();
407 }
408
409 let history = store.load_history(cid, 3).await.unwrap();
410 assert_eq!(history.len(), 3);
411 assert_eq!(history[0].content, "msg 7");
412 assert_eq!(history[1].content, "msg 8");
413 assert_eq!(history[2].content, "msg 9");
414 }
415
416 #[tokio::test]
417 async fn latest_conversation_id_empty() {
418 let store = test_store().await;
419 assert!(store.latest_conversation_id().await.unwrap().is_none());
420 }
421
422 #[tokio::test]
423 async fn latest_conversation_id_returns_newest() {
424 let store = test_store().await;
425 store.create_conversation().await.unwrap();
426 let id2 = store.create_conversation().await.unwrap();
427 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
428 }
429
430 #[tokio::test]
431 async fn messages_isolated_per_conversation() {
432 let store = test_store().await;
433 let cid1 = store.create_conversation().await.unwrap();
434 let cid2 = store.create_conversation().await.unwrap();
435
436 store.save_message(cid1, "user", "conv1").await.unwrap();
437 store.save_message(cid2, "user", "conv2").await.unwrap();
438
439 let h1 = store.load_history(cid1, 50).await.unwrap();
440 let h2 = store.load_history(cid2, 50).await.unwrap();
441 assert_eq!(h1.len(), 1);
442 assert_eq!(h1[0].content, "conv1");
443 assert_eq!(h2.len(), 1);
444 assert_eq!(h2[0].content, "conv2");
445 }
446
447 #[tokio::test]
448 async fn pool_accessor_returns_valid_pool() {
449 let store = test_store().await;
450 let pool = store.pool();
451 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
452 assert_eq!(row.0, 1);
453 }
454
455 #[tokio::test]
456 async fn embeddings_metadata_table_exists() {
457 let store = test_store().await;
458 let result: (i64,) = sqlx::query_as(
459 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
460 )
461 .fetch_one(store.pool())
462 .await
463 .unwrap();
464 assert_eq!(result.0, 1);
465 }
466
467 #[tokio::test]
468 async fn cascade_delete_removes_embeddings_metadata() {
469 let store = test_store().await;
470 let pool = store.pool();
471
472 let cid = store.create_conversation().await.unwrap();
473 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
474
475 let point_id = uuid::Uuid::new_v4().to_string();
476 sqlx::query(
477 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
478 VALUES (?, ?, ?)",
479 )
480 .bind(msg_id)
481 .bind(&point_id)
482 .bind(768_i64)
483 .execute(pool)
484 .await
485 .unwrap();
486
487 let before: (i64,) =
488 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
489 .bind(msg_id)
490 .fetch_one(pool)
491 .await
492 .unwrap();
493 assert_eq!(before.0, 1);
494
495 sqlx::query("DELETE FROM messages WHERE id = ?")
496 .bind(msg_id)
497 .execute(pool)
498 .await
499 .unwrap();
500
501 let after: (i64,) =
502 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
503 .bind(msg_id)
504 .fetch_one(pool)
505 .await
506 .unwrap();
507 assert_eq!(after.0, 0);
508 }
509
510 #[tokio::test]
511 async fn messages_by_ids_batch_fetch() {
512 let store = test_store().await;
513 let cid = store.create_conversation().await.unwrap();
514 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
515 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
516 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
517
518 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
519 assert_eq!(results.len(), 2);
520 assert_eq!(results[0].0, id1);
521 assert_eq!(results[0].1.content, "hello");
522 assert_eq!(results[1].0, id2);
523 assert_eq!(results[1].1.content, "hi");
524 }
525
526 #[tokio::test]
527 async fn messages_by_ids_empty_input() {
528 let store = test_store().await;
529 let results = store.messages_by_ids(&[]).await.unwrap();
530 assert!(results.is_empty());
531 }
532
533 #[tokio::test]
534 async fn messages_by_ids_nonexistent() {
535 let store = test_store().await;
536 let results = store
537 .messages_by_ids(&[MessageId(999), MessageId(1000)])
538 .await
539 .unwrap();
540 assert!(results.is_empty());
541 }
542
543 #[tokio::test]
544 async fn message_by_id_fetches_existing() {
545 let store = test_store().await;
546 let cid = store.create_conversation().await.unwrap();
547 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
548
549 let msg = store.message_by_id(msg_id).await.unwrap();
550 assert!(msg.is_some());
551 let msg = msg.unwrap();
552 assert_eq!(msg.role, Role::User);
553 assert_eq!(msg.content, "hello");
554 }
555
556 #[tokio::test]
557 async fn message_by_id_returns_none_for_nonexistent() {
558 let store = test_store().await;
559 let msg = store.message_by_id(MessageId(999)).await.unwrap();
560 assert!(msg.is_none());
561 }
562
563 #[tokio::test]
564 async fn unembedded_message_ids_returns_all_when_none_embedded() {
565 let store = test_store().await;
566 let cid = store.create_conversation().await.unwrap();
567
568 store.save_message(cid, "user", "msg1").await.unwrap();
569 store.save_message(cid, "assistant", "msg2").await.unwrap();
570
571 let unembedded = store.unembedded_message_ids(None).await.unwrap();
572 assert_eq!(unembedded.len(), 2);
573 assert_eq!(unembedded[0].3, "msg1");
574 assert_eq!(unembedded[1].3, "msg2");
575 }
576
577 #[tokio::test]
578 async fn unembedded_message_ids_excludes_embedded() {
579 let store = test_store().await;
580 let pool = store.pool();
581 let cid = store.create_conversation().await.unwrap();
582
583 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
584 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
585
586 let point_id = uuid::Uuid::new_v4().to_string();
587 sqlx::query(
588 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
589 VALUES (?, ?, ?)",
590 )
591 .bind(msg_id1)
592 .bind(&point_id)
593 .bind(768_i64)
594 .execute(pool)
595 .await
596 .unwrap();
597
598 let unembedded = store.unembedded_message_ids(None).await.unwrap();
599 assert_eq!(unembedded.len(), 1);
600 assert_eq!(unembedded[0].0, msg_id2);
601 assert_eq!(unembedded[0].3, "msg2");
602 }
603
604 #[tokio::test]
605 async fn unembedded_message_ids_respects_limit() {
606 let store = test_store().await;
607 let cid = store.create_conversation().await.unwrap();
608
609 for i in 0..10 {
610 store
611 .save_message(cid, "user", &format!("msg{i}"))
612 .await
613 .unwrap();
614 }
615
616 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
617 assert_eq!(unembedded.len(), 3);
618 }
619
620 #[tokio::test]
621 async fn count_messages_returns_correct_count() {
622 let store = test_store().await;
623 let cid = store.create_conversation().await.unwrap();
624
625 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
626
627 store.save_message(cid, "user", "msg1").await.unwrap();
628 store.save_message(cid, "assistant", "msg2").await.unwrap();
629
630 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
631 }
632
633 #[tokio::test]
634 async fn count_messages_after_filters_correctly() {
635 let store = test_store().await;
636 let cid = store.create_conversation().await.unwrap();
637
638 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
639 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
640 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
641
642 assert_eq!(
643 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
644 3
645 );
646 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
647 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
648 }
649
650 #[tokio::test]
651 async fn load_messages_range_basic() {
652 let store = test_store().await;
653 let cid = store.create_conversation().await.unwrap();
654
655 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
656 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
657 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
658
659 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
660 assert_eq!(msgs.len(), 2);
661 assert_eq!(msgs[0].0, msg_id2);
662 assert_eq!(msgs[0].2, "msg2");
663 assert_eq!(msgs[1].0, msg_id3);
664 assert_eq!(msgs[1].2, "msg3");
665 }
666
667 #[tokio::test]
668 async fn load_messages_range_respects_limit() {
669 let store = test_store().await;
670 let cid = store.create_conversation().await.unwrap();
671
672 store.save_message(cid, "user", "msg1").await.unwrap();
673 store.save_message(cid, "assistant", "msg2").await.unwrap();
674 store.save_message(cid, "user", "msg3").await.unwrap();
675
676 let msgs = store
677 .load_messages_range(cid, MessageId(0), 2)
678 .await
679 .unwrap();
680 assert_eq!(msgs.len(), 2);
681 }
682
683 #[tokio::test]
684 async fn keyword_search_basic() {
685 let store = test_store().await;
686 let cid = store.create_conversation().await.unwrap();
687
688 store
689 .save_message(cid, "user", "rust programming language")
690 .await
691 .unwrap();
692 store
693 .save_message(cid, "assistant", "python is great too")
694 .await
695 .unwrap();
696 store
697 .save_message(cid, "user", "I love rust and cargo")
698 .await
699 .unwrap();
700
701 let results = store.keyword_search("rust", 10, None).await.unwrap();
702 assert_eq!(results.len(), 2);
703 assert!(results.iter().all(|(_, score)| *score > 0.0));
704 }
705
706 #[tokio::test]
707 async fn keyword_search_with_conversation_filter() {
708 let store = test_store().await;
709 let cid1 = store.create_conversation().await.unwrap();
710 let cid2 = store.create_conversation().await.unwrap();
711
712 store
713 .save_message(cid1, "user", "hello world")
714 .await
715 .unwrap();
716 store
717 .save_message(cid2, "user", "hello universe")
718 .await
719 .unwrap();
720
721 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
722 assert_eq!(results.len(), 1);
723 }
724
725 #[tokio::test]
726 async fn keyword_search_no_match() {
727 let store = test_store().await;
728 let cid = store.create_conversation().await.unwrap();
729
730 store
731 .save_message(cid, "user", "hello world")
732 .await
733 .unwrap();
734
735 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
736 assert!(results.is_empty());
737 }
738
739 #[tokio::test]
740 async fn keyword_search_respects_limit() {
741 let store = test_store().await;
742 let cid = store.create_conversation().await.unwrap();
743
744 for i in 0..10 {
745 store
746 .save_message(cid, "user", &format!("test message {i}"))
747 .await
748 .unwrap();
749 }
750
751 let results = store.keyword_search("test", 3, None).await.unwrap();
752 assert_eq!(results.len(), 3);
753 }
754}