1use zeph_llm::provider::{Message, MessagePart, Role};
2
3use super::SqliteStore;
4use crate::error::MemoryError;
5use crate::types::{ConversationId, MessageId};
6
7fn parse_role(s: &str) -> Role {
8 match s {
9 "assistant" => Role::Assistant,
10 "system" => Role::System,
11 _ => Role::User,
12 }
13}
14
15#[must_use]
16pub fn role_str(role: Role) -> &'static str {
17 match role {
18 Role::System => "system",
19 Role::User => "user",
20 Role::Assistant => "assistant",
21 }
22}
23
24impl SqliteStore {
25 pub async fn create_conversation(&self) -> Result<ConversationId, MemoryError> {
31 let row: (ConversationId,) =
32 sqlx::query_as("INSERT INTO conversations DEFAULT VALUES RETURNING id")
33 .fetch_one(&self.pool)
34 .await?;
35 Ok(row.0)
36 }
37
38 pub async fn save_message(
44 &self,
45 conversation_id: ConversationId,
46 role: &str,
47 content: &str,
48 ) -> Result<MessageId, MemoryError> {
49 self.save_message_with_parts(conversation_id, role, content, "[]")
50 .await
51 }
52
53 pub async fn save_message_with_parts(
59 &self,
60 conversation_id: ConversationId,
61 role: &str,
62 content: &str,
63 parts_json: &str,
64 ) -> Result<MessageId, MemoryError> {
65 let row: (MessageId,) = sqlx::query_as(
66 "INSERT INTO messages (conversation_id, role, content, parts) VALUES (?, ?, ?, ?) RETURNING id",
67 )
68 .bind(conversation_id)
69 .bind(role)
70 .bind(content)
71 .bind(parts_json)
72 .fetch_one(&self.pool)
73 .await
74 ?;
75 Ok(row.0)
76 }
77
78 pub async fn load_history(
84 &self,
85 conversation_id: ConversationId,
86 limit: u32,
87 ) -> Result<Vec<Message>, MemoryError> {
88 let rows: Vec<(String, String, String)> = sqlx::query_as(
89 "SELECT role, content, parts FROM (\
90 SELECT role, content, parts, id FROM messages \
91 WHERE conversation_id = ? \
92 ORDER BY id DESC \
93 LIMIT ?\
94 ) ORDER BY id ASC",
95 )
96 .bind(conversation_id)
97 .bind(limit)
98 .fetch_all(&self.pool)
99 .await?;
100
101 let messages = rows
102 .into_iter()
103 .map(|(role_str, content, parts_json)| {
104 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
105 Message {
106 role: parse_role(&role_str),
107 content,
108 parts,
109 }
110 })
111 .collect();
112 Ok(messages)
113 }
114
115 pub async fn latest_conversation_id(&self) -> Result<Option<ConversationId>, MemoryError> {
121 let row: Option<(ConversationId,)> =
122 sqlx::query_as("SELECT id FROM conversations ORDER BY id DESC LIMIT 1")
123 .fetch_optional(&self.pool)
124 .await?;
125 Ok(row.map(|r| r.0))
126 }
127
128 pub async fn message_by_id(
134 &self,
135 message_id: MessageId,
136 ) -> Result<Option<Message>, MemoryError> {
137 let row: Option<(String, String, String)> =
138 sqlx::query_as("SELECT role, content, parts FROM messages WHERE id = ?")
139 .bind(message_id)
140 .fetch_optional(&self.pool)
141 .await?;
142
143 Ok(row.map(|(role_str, content, parts_json)| {
144 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
145 Message {
146 role: parse_role(&role_str),
147 content,
148 parts,
149 }
150 }))
151 }
152
153 pub async fn messages_by_ids(
159 &self,
160 ids: &[MessageId],
161 ) -> Result<Vec<(MessageId, Message)>, MemoryError> {
162 if ids.is_empty() {
163 return Ok(Vec::new());
164 }
165
166 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
167
168 let query =
169 format!("SELECT id, role, content, parts FROM messages WHERE id IN ({placeholders})");
170 let mut q = sqlx::query_as::<_, (MessageId, String, String, String)>(&query);
171 for &id in ids {
172 q = q.bind(id);
173 }
174
175 let rows = q.fetch_all(&self.pool).await?;
176
177 Ok(rows
178 .into_iter()
179 .map(|(id, role_str, content, parts_json)| {
180 let parts: Vec<MessagePart> = serde_json::from_str(&parts_json).unwrap_or_default();
181 (
182 id,
183 Message {
184 role: parse_role(&role_str),
185 content,
186 parts,
187 },
188 )
189 })
190 .collect())
191 }
192
193 pub async fn unembedded_message_ids(
199 &self,
200 limit: Option<usize>,
201 ) -> Result<Vec<(MessageId, ConversationId, String, String)>, MemoryError> {
202 let effective_limit = limit.map_or(i64::MAX, |l| i64::try_from(l).unwrap_or(i64::MAX));
203
204 let rows: Vec<(MessageId, ConversationId, String, String)> = sqlx::query_as(
205 "SELECT m.id, m.conversation_id, m.role, m.content \
206 FROM messages m \
207 LEFT JOIN embeddings_metadata em ON m.id = em.message_id \
208 WHERE em.id IS NULL \
209 ORDER BY m.id ASC \
210 LIMIT ?",
211 )
212 .bind(effective_limit)
213 .fetch_all(&self.pool)
214 .await?;
215
216 Ok(rows)
217 }
218
219 pub async fn count_messages(
225 &self,
226 conversation_id: ConversationId,
227 ) -> Result<i64, MemoryError> {
228 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ?")
229 .bind(conversation_id)
230 .fetch_one(&self.pool)
231 .await?;
232 Ok(row.0)
233 }
234
235 pub async fn count_messages_after(
241 &self,
242 conversation_id: ConversationId,
243 after_id: MessageId,
244 ) -> Result<i64, MemoryError> {
245 let row: (i64,) =
246 sqlx::query_as("SELECT COUNT(*) FROM messages WHERE conversation_id = ? AND id > ?")
247 .bind(conversation_id)
248 .bind(after_id)
249 .fetch_one(&self.pool)
250 .await?;
251 Ok(row.0)
252 }
253
254 pub async fn keyword_search(
263 &self,
264 query: &str,
265 limit: usize,
266 conversation_id: Option<ConversationId>,
267 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
268 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
269
270 let rows: Vec<(MessageId, f64)> = if let Some(cid) = conversation_id {
271 sqlx::query_as(
272 "SELECT m.id, -rank AS score \
273 FROM messages_fts f \
274 JOIN messages m ON m.id = f.rowid \
275 WHERE messages_fts MATCH ? AND m.conversation_id = ? \
276 ORDER BY rank \
277 LIMIT ?",
278 )
279 .bind(query)
280 .bind(cid)
281 .bind(effective_limit)
282 .fetch_all(&self.pool)
283 .await?
284 } else {
285 sqlx::query_as(
286 "SELECT f.rowid, -rank AS score \
287 FROM messages_fts f \
288 WHERE messages_fts MATCH ? \
289 ORDER BY rank \
290 LIMIT ?",
291 )
292 .bind(query)
293 .bind(effective_limit)
294 .fetch_all(&self.pool)
295 .await?
296 };
297
298 Ok(rows)
299 }
300
301 pub async fn load_messages_range(
307 &self,
308 conversation_id: ConversationId,
309 after_message_id: MessageId,
310 limit: usize,
311 ) -> Result<Vec<(MessageId, String, String)>, MemoryError> {
312 let effective_limit = i64::try_from(limit).unwrap_or(i64::MAX);
313
314 let rows: Vec<(MessageId, String, String)> = sqlx::query_as(
315 "SELECT id, role, content FROM messages \
316 WHERE conversation_id = ? AND id > ? \
317 ORDER BY id ASC LIMIT ?",
318 )
319 .bind(conversation_id)
320 .bind(after_message_id)
321 .bind(effective_limit)
322 .fetch_all(&self.pool)
323 .await?;
324
325 Ok(rows)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 async fn test_store() -> SqliteStore {
334 SqliteStore::new(":memory:").await.unwrap()
335 }
336
337 #[tokio::test]
338 async fn create_conversation_returns_id() {
339 let store = test_store().await;
340 let id1 = store.create_conversation().await.unwrap();
341 let id2 = store.create_conversation().await.unwrap();
342 assert_eq!(id1, ConversationId(1));
343 assert_eq!(id2, ConversationId(2));
344 }
345
346 #[tokio::test]
347 async fn save_and_load_messages() {
348 let store = test_store().await;
349 let cid = store.create_conversation().await.unwrap();
350
351 let msg_id1 = store.save_message(cid, "user", "hello").await.unwrap();
352 let msg_id2 = store
353 .save_message(cid, "assistant", "hi there")
354 .await
355 .unwrap();
356
357 assert_eq!(msg_id1, MessageId(1));
358 assert_eq!(msg_id2, MessageId(2));
359
360 let history = store.load_history(cid, 50).await.unwrap();
361 assert_eq!(history.len(), 2);
362 assert_eq!(history[0].role, Role::User);
363 assert_eq!(history[0].content, "hello");
364 assert_eq!(history[1].role, Role::Assistant);
365 assert_eq!(history[1].content, "hi there");
366 }
367
368 #[tokio::test]
369 async fn load_history_respects_limit() {
370 let store = test_store().await;
371 let cid = store.create_conversation().await.unwrap();
372
373 for i in 0..10 {
374 store
375 .save_message(cid, "user", &format!("msg {i}"))
376 .await
377 .unwrap();
378 }
379
380 let history = store.load_history(cid, 3).await.unwrap();
381 assert_eq!(history.len(), 3);
382 assert_eq!(history[0].content, "msg 7");
383 assert_eq!(history[1].content, "msg 8");
384 assert_eq!(history[2].content, "msg 9");
385 }
386
387 #[tokio::test]
388 async fn latest_conversation_id_empty() {
389 let store = test_store().await;
390 assert!(store.latest_conversation_id().await.unwrap().is_none());
391 }
392
393 #[tokio::test]
394 async fn latest_conversation_id_returns_newest() {
395 let store = test_store().await;
396 store.create_conversation().await.unwrap();
397 let id2 = store.create_conversation().await.unwrap();
398 assert_eq!(store.latest_conversation_id().await.unwrap(), Some(id2));
399 }
400
401 #[tokio::test]
402 async fn messages_isolated_per_conversation() {
403 let store = test_store().await;
404 let cid1 = store.create_conversation().await.unwrap();
405 let cid2 = store.create_conversation().await.unwrap();
406
407 store.save_message(cid1, "user", "conv1").await.unwrap();
408 store.save_message(cid2, "user", "conv2").await.unwrap();
409
410 let h1 = store.load_history(cid1, 50).await.unwrap();
411 let h2 = store.load_history(cid2, 50).await.unwrap();
412 assert_eq!(h1.len(), 1);
413 assert_eq!(h1[0].content, "conv1");
414 assert_eq!(h2.len(), 1);
415 assert_eq!(h2[0].content, "conv2");
416 }
417
418 #[tokio::test]
419 async fn pool_accessor_returns_valid_pool() {
420 let store = test_store().await;
421 let pool = store.pool();
422 let row: (i64,) = sqlx::query_as("SELECT 1").fetch_one(pool).await.unwrap();
423 assert_eq!(row.0, 1);
424 }
425
426 #[tokio::test]
427 async fn embeddings_metadata_table_exists() {
428 let store = test_store().await;
429 let result: (i64,) = sqlx::query_as(
430 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings_metadata'",
431 )
432 .fetch_one(store.pool())
433 .await
434 .unwrap();
435 assert_eq!(result.0, 1);
436 }
437
438 #[tokio::test]
439 async fn cascade_delete_removes_embeddings_metadata() {
440 let store = test_store().await;
441 let pool = store.pool();
442
443 let cid = store.create_conversation().await.unwrap();
444 let msg_id = store.save_message(cid, "user", "test").await.unwrap();
445
446 let point_id = uuid::Uuid::new_v4().to_string();
447 sqlx::query(
448 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
449 VALUES (?, ?, ?)",
450 )
451 .bind(msg_id)
452 .bind(&point_id)
453 .bind(768_i64)
454 .execute(pool)
455 .await
456 .unwrap();
457
458 let before: (i64,) =
459 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
460 .bind(msg_id)
461 .fetch_one(pool)
462 .await
463 .unwrap();
464 assert_eq!(before.0, 1);
465
466 sqlx::query("DELETE FROM messages WHERE id = ?")
467 .bind(msg_id)
468 .execute(pool)
469 .await
470 .unwrap();
471
472 let after: (i64,) =
473 sqlx::query_as("SELECT COUNT(*) FROM embeddings_metadata WHERE message_id = ?")
474 .bind(msg_id)
475 .fetch_one(pool)
476 .await
477 .unwrap();
478 assert_eq!(after.0, 0);
479 }
480
481 #[tokio::test]
482 async fn messages_by_ids_batch_fetch() {
483 let store = test_store().await;
484 let cid = store.create_conversation().await.unwrap();
485 let id1 = store.save_message(cid, "user", "hello").await.unwrap();
486 let id2 = store.save_message(cid, "assistant", "hi").await.unwrap();
487 let _id3 = store.save_message(cid, "user", "bye").await.unwrap();
488
489 let results = store.messages_by_ids(&[id1, id2]).await.unwrap();
490 assert_eq!(results.len(), 2);
491 assert_eq!(results[0].0, id1);
492 assert_eq!(results[0].1.content, "hello");
493 assert_eq!(results[1].0, id2);
494 assert_eq!(results[1].1.content, "hi");
495 }
496
497 #[tokio::test]
498 async fn messages_by_ids_empty_input() {
499 let store = test_store().await;
500 let results = store.messages_by_ids(&[]).await.unwrap();
501 assert!(results.is_empty());
502 }
503
504 #[tokio::test]
505 async fn messages_by_ids_nonexistent() {
506 let store = test_store().await;
507 let results = store
508 .messages_by_ids(&[MessageId(999), MessageId(1000)])
509 .await
510 .unwrap();
511 assert!(results.is_empty());
512 }
513
514 #[tokio::test]
515 async fn message_by_id_fetches_existing() {
516 let store = test_store().await;
517 let cid = store.create_conversation().await.unwrap();
518 let msg_id = store.save_message(cid, "user", "hello").await.unwrap();
519
520 let msg = store.message_by_id(msg_id).await.unwrap();
521 assert!(msg.is_some());
522 let msg = msg.unwrap();
523 assert_eq!(msg.role, Role::User);
524 assert_eq!(msg.content, "hello");
525 }
526
527 #[tokio::test]
528 async fn message_by_id_returns_none_for_nonexistent() {
529 let store = test_store().await;
530 let msg = store.message_by_id(MessageId(999)).await.unwrap();
531 assert!(msg.is_none());
532 }
533
534 #[tokio::test]
535 async fn unembedded_message_ids_returns_all_when_none_embedded() {
536 let store = test_store().await;
537 let cid = store.create_conversation().await.unwrap();
538
539 store.save_message(cid, "user", "msg1").await.unwrap();
540 store.save_message(cid, "assistant", "msg2").await.unwrap();
541
542 let unembedded = store.unembedded_message_ids(None).await.unwrap();
543 assert_eq!(unembedded.len(), 2);
544 assert_eq!(unembedded[0].3, "msg1");
545 assert_eq!(unembedded[1].3, "msg2");
546 }
547
548 #[tokio::test]
549 async fn unembedded_message_ids_excludes_embedded() {
550 let store = test_store().await;
551 let pool = store.pool();
552 let cid = store.create_conversation().await.unwrap();
553
554 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
555 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
556
557 let point_id = uuid::Uuid::new_v4().to_string();
558 sqlx::query(
559 "INSERT INTO embeddings_metadata (message_id, qdrant_point_id, dimensions) \
560 VALUES (?, ?, ?)",
561 )
562 .bind(msg_id1)
563 .bind(&point_id)
564 .bind(768_i64)
565 .execute(pool)
566 .await
567 .unwrap();
568
569 let unembedded = store.unembedded_message_ids(None).await.unwrap();
570 assert_eq!(unembedded.len(), 1);
571 assert_eq!(unembedded[0].0, msg_id2);
572 assert_eq!(unembedded[0].3, "msg2");
573 }
574
575 #[tokio::test]
576 async fn unembedded_message_ids_respects_limit() {
577 let store = test_store().await;
578 let cid = store.create_conversation().await.unwrap();
579
580 for i in 0..10 {
581 store
582 .save_message(cid, "user", &format!("msg{i}"))
583 .await
584 .unwrap();
585 }
586
587 let unembedded = store.unembedded_message_ids(Some(3)).await.unwrap();
588 assert_eq!(unembedded.len(), 3);
589 }
590
591 #[tokio::test]
592 async fn count_messages_returns_correct_count() {
593 let store = test_store().await;
594 let cid = store.create_conversation().await.unwrap();
595
596 assert_eq!(store.count_messages(cid).await.unwrap(), 0);
597
598 store.save_message(cid, "user", "msg1").await.unwrap();
599 store.save_message(cid, "assistant", "msg2").await.unwrap();
600
601 assert_eq!(store.count_messages(cid).await.unwrap(), 2);
602 }
603
604 #[tokio::test]
605 async fn count_messages_after_filters_correctly() {
606 let store = test_store().await;
607 let cid = store.create_conversation().await.unwrap();
608
609 let id1 = store.save_message(cid, "user", "msg1").await.unwrap();
610 let _id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
611 let id3 = store.save_message(cid, "user", "msg3").await.unwrap();
612
613 assert_eq!(
614 store.count_messages_after(cid, MessageId(0)).await.unwrap(),
615 3
616 );
617 assert_eq!(store.count_messages_after(cid, id1).await.unwrap(), 2);
618 assert_eq!(store.count_messages_after(cid, id3).await.unwrap(), 0);
619 }
620
621 #[tokio::test]
622 async fn load_messages_range_basic() {
623 let store = test_store().await;
624 let cid = store.create_conversation().await.unwrap();
625
626 let msg_id1 = store.save_message(cid, "user", "msg1").await.unwrap();
627 let msg_id2 = store.save_message(cid, "assistant", "msg2").await.unwrap();
628 let msg_id3 = store.save_message(cid, "user", "msg3").await.unwrap();
629
630 let msgs = store.load_messages_range(cid, msg_id1, 10).await.unwrap();
631 assert_eq!(msgs.len(), 2);
632 assert_eq!(msgs[0].0, msg_id2);
633 assert_eq!(msgs[0].2, "msg2");
634 assert_eq!(msgs[1].0, msg_id3);
635 assert_eq!(msgs[1].2, "msg3");
636 }
637
638 #[tokio::test]
639 async fn load_messages_range_respects_limit() {
640 let store = test_store().await;
641 let cid = store.create_conversation().await.unwrap();
642
643 store.save_message(cid, "user", "msg1").await.unwrap();
644 store.save_message(cid, "assistant", "msg2").await.unwrap();
645 store.save_message(cid, "user", "msg3").await.unwrap();
646
647 let msgs = store
648 .load_messages_range(cid, MessageId(0), 2)
649 .await
650 .unwrap();
651 assert_eq!(msgs.len(), 2);
652 }
653
654 #[tokio::test]
655 async fn keyword_search_basic() {
656 let store = test_store().await;
657 let cid = store.create_conversation().await.unwrap();
658
659 store
660 .save_message(cid, "user", "rust programming language")
661 .await
662 .unwrap();
663 store
664 .save_message(cid, "assistant", "python is great too")
665 .await
666 .unwrap();
667 store
668 .save_message(cid, "user", "I love rust and cargo")
669 .await
670 .unwrap();
671
672 let results = store.keyword_search("rust", 10, None).await.unwrap();
673 assert_eq!(results.len(), 2);
674 assert!(results.iter().all(|(_, score)| *score > 0.0));
675 }
676
677 #[tokio::test]
678 async fn keyword_search_with_conversation_filter() {
679 let store = test_store().await;
680 let cid1 = store.create_conversation().await.unwrap();
681 let cid2 = store.create_conversation().await.unwrap();
682
683 store
684 .save_message(cid1, "user", "hello world")
685 .await
686 .unwrap();
687 store
688 .save_message(cid2, "user", "hello universe")
689 .await
690 .unwrap();
691
692 let results = store.keyword_search("hello", 10, Some(cid1)).await.unwrap();
693 assert_eq!(results.len(), 1);
694 }
695
696 #[tokio::test]
697 async fn keyword_search_no_match() {
698 let store = test_store().await;
699 let cid = store.create_conversation().await.unwrap();
700
701 store
702 .save_message(cid, "user", "hello world")
703 .await
704 .unwrap();
705
706 let results = store.keyword_search("nonexistent", 10, None).await.unwrap();
707 assert!(results.is_empty());
708 }
709
710 #[tokio::test]
711 async fn keyword_search_respects_limit() {
712 let store = test_store().await;
713 let cid = store.create_conversation().await.unwrap();
714
715 for i in 0..10 {
716 store
717 .save_message(cid, "user", &format!("test message {i}"))
718 .await
719 .unwrap();
720 }
721
722 let results = store.keyword_search("test", 3, None).await.unwrap();
723 assert_eq!(results.len(), 3);
724 }
725}