Skip to main content

semantic_memory/
conversation.rs

1//! Session and message CRUD for conversation storage.
2
3#[cfg(feature = "hnsw")]
4use crate::db::{enqueue_pending_index_op, PendingIndexOpKind};
5use crate::db::{parse_optional_json, parse_role, with_transaction};
6use crate::error::MemoryError;
7use crate::quantize::{self, Quantizer};
8use crate::search;
9use crate::types::{Message, Role, SearchResult, SearchSourceType, Session};
10use crate::{as_str_slice, merge_trace_ctx, to_owned_string_vec, MemoryStore};
11use rusqlite::{params, Connection};
12use stack_ids::TraceCtx;
13
14/// Create a new conversation session and return its UUID.
15pub fn create_session(
16    conn: &Connection,
17    channel: &str,
18    metadata: Option<&serde_json::Value>,
19) -> Result<String, MemoryError> {
20    let id = uuid::Uuid::new_v4().to_string();
21    let metadata_str = metadata.map(|m| m.to_string());
22    conn.execute(
23        "INSERT INTO sessions (id, channel, metadata) VALUES (?1, ?2, ?3)",
24        params![id, channel, metadata_str],
25    )?;
26    Ok(id)
27}
28
29/// Append a message to a session without search indexes.
30#[allow(dead_code)]
31pub fn add_message(
32    conn: &Connection,
33    session_id: &str,
34    role: Role,
35    content: &str,
36    token_count: Option<u32>,
37    metadata: Option<&serde_json::Value>,
38) -> Result<i64, MemoryError> {
39    let exists: bool = conn.query_row(
40        "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
41        params![session_id],
42        |row| row.get(0),
43    )?;
44    if !exists {
45        return Err(MemoryError::SessionNotFound(session_id.to_string()));
46    }
47
48    let metadata_str = metadata.map(|m| m.to_string());
49    with_transaction(conn, |tx| {
50        tx.execute(
51            "INSERT INTO messages (session_id, role, content, token_count, metadata)
52             VALUES (?1, ?2, ?3, ?4, ?5)",
53            params![
54                session_id,
55                role.as_str(),
56                content,
57                token_count,
58                metadata_str
59            ],
60        )?;
61        let msg_id = tx.last_insert_rowid();
62        tx.execute(
63            "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
64            params![session_id],
65        )?;
66        Ok(msg_id)
67    })
68}
69
70/// Get the most recent N messages from a session in chronological order.
71pub fn get_recent_messages(
72    conn: &Connection,
73    session_id: &str,
74    limit: usize,
75) -> Result<Vec<Message>, MemoryError> {
76    let mut stmt = conn.prepare(
77        "SELECT id, session_id, role, content, token_count, created_at, metadata
78         FROM messages
79         WHERE session_id = ?1
80         ORDER BY created_at DESC, id DESC
81         LIMIT ?2",
82    )?;
83
84    let mut messages: Vec<Message> = stmt
85        .query_map(params![session_id, limit as i64], |row| {
86            Ok((
87                row.get::<_, i64>(0)?,
88                row.get::<_, String>(1)?,
89                row.get::<_, String>(2)?,
90                row.get::<_, String>(3)?,
91                row.get::<_, Option<u32>>(4)?,
92                row.get::<_, String>(5)?,
93                row.get::<_, Option<String>>(6)?,
94            ))
95        })?
96        .collect::<Result<Vec<_>, _>>()?
97        .into_iter()
98        .map(
99            |(id, session_id, role_raw, content, token_count, created_at, metadata_raw)| {
100                Ok(Message {
101                    role: parse_role("messages", &id.to_string(), &role_raw)?,
102                    metadata: parse_optional_json(
103                        "messages",
104                        &id.to_string(),
105                        "metadata",
106                        metadata_raw.as_deref(),
107                    )?,
108                    id,
109                    session_id,
110                    content,
111                    token_count,
112                    created_at,
113                })
114            },
115        )
116        .collect::<Result<Vec<_>, MemoryError>>()?;
117
118    messages.reverse();
119    Ok(messages)
120}
121
122/// Get messages from a session while staying under the token budget.
123pub fn get_messages_within_budget(
124    conn: &Connection,
125    session_id: &str,
126    max_tokens: u32,
127) -> Result<Vec<Message>, MemoryError> {
128    let mut stmt = conn.prepare(
129        "SELECT id, session_id, role, content, token_count, created_at, metadata
130         FROM messages
131         WHERE session_id = ?1
132         ORDER BY created_at DESC, id DESC",
133    )?;
134
135    let all_messages: Vec<Message> = stmt
136        .query_map(params![session_id], |row| {
137            Ok((
138                row.get::<_, i64>(0)?,
139                row.get::<_, String>(1)?,
140                row.get::<_, String>(2)?,
141                row.get::<_, String>(3)?,
142                row.get::<_, Option<u32>>(4)?,
143                row.get::<_, String>(5)?,
144                row.get::<_, Option<String>>(6)?,
145            ))
146        })?
147        .collect::<Result<Vec<_>, _>>()?
148        .into_iter()
149        .map(
150            |(id, session_id, role_raw, content, token_count, created_at, metadata_raw)| {
151                Ok(Message {
152                    role: parse_role("messages", &id.to_string(), &role_raw)?,
153                    metadata: parse_optional_json(
154                        "messages",
155                        &id.to_string(),
156                        "metadata",
157                        metadata_raw.as_deref(),
158                    )?,
159                    id,
160                    session_id,
161                    content,
162                    token_count,
163                    created_at,
164                })
165            },
166        )
167        .collect::<Result<Vec<_>, MemoryError>>()?;
168
169    let mut collected = Vec::new();
170    let mut total_tokens = 0u32;
171    for msg in all_messages {
172        let msg_tokens = msg.token_count.unwrap_or(0);
173        if total_tokens + msg_tokens > max_tokens && !collected.is_empty() {
174            break;
175        }
176        total_tokens += msg_tokens;
177        collected.push(msg);
178    }
179
180    collected.reverse();
181    Ok(collected)
182}
183
184/// Get the total token count for a session.
185pub fn session_token_count(conn: &Connection, session_id: &str) -> Result<u64, MemoryError> {
186    let count: i64 = conn.query_row(
187        "SELECT COALESCE(SUM(token_count), 0) FROM messages WHERE session_id = ?1",
188        params![session_id],
189        |row| row.get(0),
190    )?;
191    Ok(count as u64)
192}
193
194/// Append a message with embedding + q8 + FTS entries.
195#[allow(clippy::too_many_arguments)]
196pub fn add_message_with_embedding_q8(
197    conn: &Connection,
198    session_id: &str,
199    role: Role,
200    content: &str,
201    token_count: Option<u32>,
202    metadata: Option<&serde_json::Value>,
203    embedding_bytes: &[u8],
204    q8_bytes: Option<&[u8]>,
205) -> Result<i64, MemoryError> {
206    let exists: bool = conn.query_row(
207        "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
208        params![session_id],
209        |row| row.get(0),
210    )?;
211    if !exists {
212        return Err(MemoryError::SessionNotFound(session_id.to_string()));
213    }
214
215    let metadata_str = metadata.map(|m| m.to_string());
216    with_transaction(conn, |tx| {
217        tx.execute(
218            "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8)
219             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
220            params![
221                session_id,
222                role.as_str(),
223                content,
224                token_count,
225                metadata_str,
226                embedding_bytes,
227                q8_bytes
228            ],
229        )?;
230        let msg_id = tx.last_insert_rowid();
231
232        tx.execute(
233            "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
234            params![msg_id],
235        )?;
236        let fts_rowid = tx.last_insert_rowid();
237        tx.execute(
238            "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
239            params![fts_rowid, content],
240        )?;
241
242        #[cfg(feature = "hnsw")]
243        enqueue_pending_index_op(
244            tx,
245            &format!("msg:{}", msg_id),
246            "message",
247            PendingIndexOpKind::Upsert,
248        )?;
249
250        tx.execute(
251            "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
252            params![session_id],
253        )?;
254
255        Ok(msg_id)
256    })
257}
258
259/// Backward-compatible wrapper for embedded messages without q8 input.
260#[allow(dead_code, clippy::too_many_arguments)]
261pub fn add_message_with_embedding(
262    conn: &Connection,
263    session_id: &str,
264    role: Role,
265    content: &str,
266    token_count: Option<u32>,
267    metadata: Option<&serde_json::Value>,
268    embedding_bytes: &[u8],
269) -> Result<i64, MemoryError> {
270    add_message_with_embedding_q8(
271        conn,
272        session_id,
273        role,
274        content,
275        token_count,
276        metadata,
277        embedding_bytes,
278        None,
279    )
280}
281
282/// Append a message with FTS indexing but no embedding.
283pub fn add_message_with_fts(
284    conn: &Connection,
285    session_id: &str,
286    role: Role,
287    content: &str,
288    token_count: Option<u32>,
289    metadata: Option<&serde_json::Value>,
290) -> Result<i64, MemoryError> {
291    let exists: bool = conn.query_row(
292        "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
293        params![session_id],
294        |row| row.get(0),
295    )?;
296    if !exists {
297        return Err(MemoryError::SessionNotFound(session_id.to_string()));
298    }
299
300    let metadata_str = metadata.map(|m| m.to_string());
301    with_transaction(conn, |tx| {
302        tx.execute(
303            "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8)
304             VALUES (?1, ?2, ?3, ?4, ?5, NULL, NULL)",
305            params![session_id, role.as_str(), content, token_count, metadata_str],
306        )?;
307        let msg_id = tx.last_insert_rowid();
308
309        tx.execute(
310            "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
311            params![msg_id],
312        )?;
313        let fts_rowid = tx.last_insert_rowid();
314        tx.execute(
315            "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
316            params![fts_rowid, content],
317        )?;
318        tx.execute(
319            "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
320            params![session_id],
321        )?;
322
323        Ok(msg_id)
324    })
325}
326
327/// Delete a session and all its messages.
328pub fn delete_session(conn: &Connection, session_id: &str) -> Result<(), MemoryError> {
329    with_transaction(conn, |tx| {
330        let fts_data: Vec<(i64, String, i64, bool)> = {
331            let mut stmt = tx.prepare(
332                "SELECT m.id, m.content, mm.rowid, m.embedding IS NOT NULL
333                 FROM messages m
334                 JOIN messages_rowid_map mm ON mm.message_id = m.id
335                 WHERE m.session_id = ?1",
336            )?;
337            let rows = stmt.query_map(params![session_id], |row| {
338                Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?))
339            })?;
340            rows.collect::<Result<Vec<_>, _>>()?
341        };
342
343        for (msg_id, content, fts_rowid, has_embedding) in &fts_data {
344            tx.execute(
345                "INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', ?1, ?2)",
346                params![fts_rowid, content],
347            )?;
348
349            #[cfg(feature = "hnsw")]
350            if *has_embedding {
351                enqueue_pending_index_op(
352                    tx,
353                    &format!("msg:{}", msg_id),
354                    "message",
355                    PendingIndexOpKind::Delete,
356                )?;
357            }
358
359            #[cfg(not(feature = "hnsw"))]
360            {
361                let _ = msg_id;
362                let _ = has_embedding;
363            }
364        }
365
366        let affected = tx.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
367        if affected == 0 {
368            return Err(MemoryError::SessionNotFound(session_id.to_string()));
369        }
370
371        Ok(())
372    })
373}
374
375/// List recent sessions with message counts.
376pub fn list_sessions(
377    conn: &Connection,
378    limit: usize,
379    offset: usize,
380) -> Result<Vec<Session>, MemoryError> {
381    let mut stmt = conn.prepare(
382        "SELECT s.id, s.channel, s.created_at, s.updated_at, s.metadata,
383                COUNT(m.id) AS message_count
384         FROM sessions s
385         LEFT JOIN messages m ON m.session_id = s.id
386         GROUP BY s.id
387         ORDER BY s.updated_at DESC
388         LIMIT ?1 OFFSET ?2",
389    )?;
390
391    let sessions = stmt
392        .query_map(params![limit as i64, offset as i64], |row| {
393            Ok((
394                row.get::<_, String>(0)?,
395                row.get::<_, String>(1)?,
396                row.get::<_, String>(2)?,
397                row.get::<_, String>(3)?,
398                row.get::<_, Option<String>>(4)?,
399                row.get::<_, i64>(5)? as u32,
400            ))
401        })?
402        .collect::<Result<Vec<_>, _>>()?
403        .into_iter()
404        .map(
405            |(id, channel, created_at, updated_at, metadata_raw, message_count)| {
406                Ok(Session {
407                    metadata: parse_optional_json(
408                        "sessions",
409                        &id,
410                        "metadata",
411                        metadata_raw.as_deref(),
412                    )?,
413                    id,
414                    channel,
415                    created_at,
416                    updated_at,
417                    message_count,
418                })
419            },
420        )
421        .collect::<Result<Vec<_>, MemoryError>>()?;
422
423    Ok(sessions)
424}
425
426/// Update a session channel.
427pub fn rename_session(
428    conn: &Connection,
429    session_id: &str,
430    new_channel: &str,
431) -> Result<(), MemoryError> {
432    let affected = conn.execute(
433        "UPDATE sessions SET channel = ?1, updated_at = datetime('now') WHERE id = ?2",
434        params![new_channel, session_id],
435    )?;
436    if affected == 0 {
437        return Err(MemoryError::SessionNotFound(session_id.to_string()));
438    }
439    Ok(())
440}
441
442impl MemoryStore {
443    /// Create a new conversation session. Returns the session ID (UUID v4).
444    pub async fn create_session(&self, channel: &str) -> Result<String, MemoryError> {
445        let channel = channel.to_string();
446        self.with_write_conn(move |conn| create_session(conn, &channel, None))
447            .await
448    }
449
450    /// Create a new conversation session with metadata.
451    ///
452    /// Metadata can be used to carry namespace tags and trace data for retention
453    /// and deletion policy decisions.
454    pub async fn create_session_with_metadata(
455        &self,
456        channel: &str,
457        metadata: Option<serde_json::Value>,
458    ) -> Result<String, MemoryError> {
459        let channel = channel.to_string();
460        self.with_write_conn(move |conn| create_session(conn, &channel, metadata.as_ref()))
461            .await
462    }
463
464    /// Rename a session's channel (display name).
465    pub async fn rename_session(
466        &self,
467        session_id: &str,
468        new_channel: &str,
469    ) -> Result<(), MemoryError> {
470        let sid = session_id.to_string();
471        let ch = new_channel.to_string();
472        self.with_write_conn(move |conn| rename_session(conn, &sid, &ch))
473            .await
474    }
475
476    /// List recent sessions, newest first.
477    pub async fn list_sessions(
478        &self,
479        limit: usize,
480        offset: usize,
481    ) -> Result<Vec<Session>, MemoryError> {
482        self.with_read_conn(move |conn| list_sessions(conn, limit, offset))
483            .await
484    }
485
486    /// Delete a session and all its messages.
487    ///
488    /// Cleans up HNSW entries for embedded messages before CASCADE delete.
489    pub async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
490        let sid = session_id.to_string();
491        self.with_write_conn(move |conn| delete_session(conn, &sid))
492            .await?;
493
494        #[cfg(feature = "hnsw")]
495        self.sync_pending_hnsw_ops_best_effort("delete_session")
496            .await;
497
498        Ok(())
499    }
500
501    /// Append a message to a session. Returns the message's auto-increment ID.
502    pub async fn add_message(
503        &self,
504        session_id: &str,
505        role: Role,
506        content: &str,
507        token_count: Option<u32>,
508        metadata: Option<serde_json::Value>,
509    ) -> Result<i64, MemoryError> {
510        self.add_message_with_trace(session_id, role, content, token_count, metadata, None)
511            .await
512    }
513
514    /// Append a message to a session with optional trace metadata.
515    pub async fn add_message_with_trace(
516        &self,
517        session_id: &str,
518        role: Role,
519        content: &str,
520        token_count: Option<u32>,
521        metadata: Option<serde_json::Value>,
522        trace_ctx: Option<&TraceCtx>,
523    ) -> Result<i64, MemoryError> {
524        self.add_message_embedded_with_trace(
525            session_id,
526            role,
527            content,
528            token_count,
529            metadata,
530            trace_ctx,
531        )
532        .await
533    }
534
535    /// Append a message to a session with FTS indexing but no embedding.
536    ///
537    /// Fallback path when embedding fails: messages still appear in conversation
538    /// history and are findable via BM25 search, just not via vector search.
539    pub async fn add_message_fts(
540        &self,
541        session_id: &str,
542        role: Role,
543        content: &str,
544        token_count: Option<u32>,
545        metadata: Option<serde_json::Value>,
546    ) -> Result<i64, MemoryError> {
547        self.add_message_fts_with_trace(session_id, role, content, token_count, metadata, None)
548            .await
549    }
550
551    /// Append a message with FTS indexing and optional trace metadata.
552    pub async fn add_message_fts_with_trace(
553        &self,
554        session_id: &str,
555        role: Role,
556        content: &str,
557        token_count: Option<u32>,
558        metadata: Option<serde_json::Value>,
559        trace_ctx: Option<&TraceCtx>,
560    ) -> Result<i64, MemoryError> {
561        self.validate_content("message.content", content)?;
562
563        let effective_token_count =
564            token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
565        let sid = session_id.to_string();
566        let ct = content.to_string();
567        let meta = merge_trace_ctx(metadata, trace_ctx);
568        self.with_write_conn(move |conn| {
569            add_message_with_fts(conn, &sid, role, &ct, effective_token_count, meta.as_ref())
570        })
571        .await
572    }
573
574    /// Get the most recent N messages from a session, in chronological order.
575    pub async fn get_recent_messages(
576        &self,
577        session_id: &str,
578        limit: usize,
579    ) -> Result<Vec<Message>, MemoryError> {
580        let sid = session_id.to_string();
581        self.with_read_conn(move |conn| get_recent_messages(conn, &sid, limit))
582            .await
583    }
584
585    /// Get messages from a session up to `max_tokens` total.
586    pub async fn get_messages_within_budget(
587        &self,
588        session_id: &str,
589        max_tokens: u32,
590    ) -> Result<Vec<Message>, MemoryError> {
591        let sid = session_id.to_string();
592        self.with_read_conn(move |conn| get_messages_within_budget(conn, &sid, max_tokens))
593            .await
594    }
595
596    /// Get total token count for a session.
597    pub async fn session_token_count(&self, session_id: &str) -> Result<u64, MemoryError> {
598        let sid = session_id.to_string();
599        self.with_read_conn(move |conn| session_token_count(conn, &sid))
600            .await
601    }
602
603    /// Append a message to a session with automatic embedding and FTS indexing.
604    pub async fn add_message_embedded(
605        &self,
606        session_id: &str,
607        role: Role,
608        content: &str,
609        token_count: Option<u32>,
610        metadata: Option<serde_json::Value>,
611    ) -> Result<i64, MemoryError> {
612        self.add_message_embedded_with_trace(session_id, role, content, token_count, metadata, None)
613            .await
614    }
615
616    /// Append an embedded message with optional trace metadata.
617    pub async fn add_message_embedded_with_trace(
618        &self,
619        session_id: &str,
620        role: Role,
621        content: &str,
622        token_count: Option<u32>,
623        metadata: Option<serde_json::Value>,
624        trace_ctx: Option<&TraceCtx>,
625    ) -> Result<i64, MemoryError> {
626        self.validate_content("message.content", content)?;
627
628        let effective_token_count =
629            token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
630
631        let embedding = self.embed_text_internal(content).await?;
632        self.validate_embedding_dimensions(&embedding)?;
633        let embedding_bytes = crate::db::embedding_to_bytes(&embedding);
634        // INTENTIONAL: q8 quantization is an optional search optimization; missing q8 is non-fatal
635        let q8_bytes = Quantizer::new(self.inner.config.embedding.dimensions)
636            .quantize(&embedding)
637            .map(|qv| quantize::pack_quantized(&qv))
638            .ok();
639
640        let sid = session_id.to_string();
641        let ct = content.to_string();
642        let meta = merge_trace_ctx(metadata, trace_ctx);
643        let msg_id = self
644            .with_write_conn(move |conn| {
645                add_message_with_embedding_q8(
646                    conn,
647                    &sid,
648                    role,
649                    &ct,
650                    effective_token_count,
651                    meta.as_ref(),
652                    &embedding_bytes,
653                    q8_bytes.as_deref(),
654                )
655            })
656            .await?;
657
658        #[cfg(feature = "hnsw")]
659        self.sync_pending_hnsw_ops_best_effort("add_message_embedded")
660            .await;
661
662        Ok(msg_id)
663    }
664
665    /// Hybrid search over conversation messages only.
666    pub async fn search_conversations(
667        &self,
668        query: &str,
669        top_k: Option<usize>,
670        session_ids: Option<&[&str]>,
671    ) -> Result<Vec<SearchResult>, MemoryError> {
672        let k = top_k.unwrap_or(self.inner.config.search.default_top_k);
673
674        let query_embedding = self.embed_text_internal(query).await?;
675
676        #[cfg(feature = "hnsw")]
677        let hnsw_hits = {
678            let guard = self
679                .inner
680                .hnsw_index
681                .read()
682                .unwrap_or_else(|e| e.into_inner());
683            let candidates = self.inner.config.search.candidate_pool_size.max(k * 3);
684            match guard.search(&query_embedding, candidates) {
685                Ok(hits) => hits,
686                Err(err) => {
687                    tracing::error!(
688                        "HNSW conversation search failed, falling back to brute-force message search: {}",
689                        err
690                    );
691                    Vec::new()
692                }
693            }
694        };
695
696        let q = query.to_string();
697        let config = self.inner.config.search.clone();
698        let sids_owned = to_owned_string_vec(session_ids);
699
700        #[cfg(feature = "hnsw")]
701        let hnsw_hits_owned = hnsw_hits;
702
703        self.with_read_conn(move |conn| {
704            let sids_refs = as_str_slice(&sids_owned);
705            let sids_slice: Option<&[&str]> = sids_refs.as_deref();
706            #[cfg(feature = "hnsw")]
707            {
708                if hnsw_hits_owned.is_empty() {
709                    search::hybrid_search(
710                        conn,
711                        &q,
712                        &query_embedding,
713                        &config,
714                        k,
715                        None,
716                        Some(&[SearchSourceType::Messages]),
717                        sids_slice,
718                    )
719                } else {
720                    search::hybrid_search_with_hnsw(
721                        conn,
722                        &q,
723                        &query_embedding,
724                        &config,
725                        k,
726                        None,
727                        Some(&[SearchSourceType::Messages]),
728                        sids_slice,
729                        &hnsw_hits_owned,
730                    )
731                }
732            }
733            #[cfg(not(feature = "hnsw"))]
734            {
735                search::hybrid_search(
736                    conn,
737                    &q,
738                    &query_embedding,
739                    &config,
740                    k,
741                    None,
742                    Some(&[SearchSourceType::Messages]),
743                    sids_slice,
744                )
745            }
746        })
747        .await
748    }
749}