Skip to main content

semantic_memory/
conversation.rs

1//! Session and message CRUD for conversation storage.
2
3use crate::db::with_transaction;
4use crate::error::MemoryError;
5use crate::types::{Message, Role, Session};
6use rusqlite::{params, Connection};
7
8/// Create a new conversation session.
9///
10/// Returns the session ID (UUID v4).
11pub fn create_session(
12    conn: &Connection,
13    channel: &str,
14    metadata: Option<&serde_json::Value>,
15) -> Result<String, MemoryError> {
16    let id = uuid::Uuid::new_v4().to_string();
17    let metadata_str = metadata.map(|m| m.to_string());
18    conn.execute(
19        "INSERT INTO sessions (id, channel, metadata) VALUES (?1, ?2, ?3)",
20        params![id, channel, metadata_str],
21    )?;
22    Ok(id)
23}
24
25/// Append a message to a session.
26///
27/// Updates the session's `updated_at` timestamp. Returns the message's auto-increment ID.
28pub fn add_message(
29    conn: &Connection,
30    session_id: &str,
31    role: Role,
32    content: &str,
33    token_count: Option<u32>,
34    metadata: Option<&serde_json::Value>,
35) -> Result<i64, MemoryError> {
36    // Verify session exists
37    let exists: bool = conn.query_row(
38        "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
39        params![session_id],
40        |row| row.get(0),
41    )?;
42    if !exists {
43        return Err(MemoryError::SessionNotFound(session_id.to_string()));
44    }
45
46    let metadata_str = metadata.map(|m| m.to_string());
47    with_transaction(conn, |tx| {
48        tx.execute(
49            "INSERT INTO messages (session_id, role, content, token_count, metadata) VALUES (?1, ?2, ?3, ?4, ?5)",
50            params![session_id, role.as_str(), content, token_count, metadata_str],
51        )?;
52        let msg_id = tx.last_insert_rowid();
53
54        tx.execute(
55            "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
56            params![session_id],
57        )?;
58
59        Ok(msg_id)
60    })
61}
62
63/// Get the most recent N messages from a session, in chronological order.
64pub fn get_recent_messages(
65    conn: &Connection,
66    session_id: &str,
67    limit: usize,
68) -> Result<Vec<Message>, MemoryError> {
69    let mut stmt = conn.prepare(
70        "SELECT id, session_id, role, content, token_count, created_at, metadata
71         FROM messages
72         WHERE session_id = ?1
73         ORDER BY created_at DESC, id DESC
74         LIMIT ?2",
75    )?;
76
77    let mut messages: Vec<Message> = stmt
78        .query_map(params![session_id, limit as i64], |row| {
79            let role_str: String = row.get(2)?;
80            let metadata_str: Option<String> = row.get(6)?;
81            Ok(Message {
82                id: row.get(0)?,
83                session_id: row.get(1)?,
84                role: Role::from_str_value(&role_str).unwrap_or(Role::User),
85                content: row.get(3)?,
86                token_count: row.get(4)?,
87                created_at: row.get(5)?,
88                metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
89            })
90        })?
91        .collect::<Result<Vec<_>, _>>()?;
92
93    // Reverse to chronological order (we fetched newest-first)
94    messages.reverse();
95    Ok(messages)
96}
97
98/// Get messages from a session up to `max_tokens` total.
99///
100/// Walks backward from newest, accumulating token counts, stops when
101/// the budget is exceeded. Returns messages in chronological order.
102///
103/// **Edge case:** The first (most recent) message is always included even
104/// if it alone exceeds `max_tokens`. This ensures the method never returns
105/// an empty Vec for a non-empty session. Callers that need strict budget
106/// enforcement should check the total token count of returned messages.
107pub fn get_messages_within_budget(
108    conn: &Connection,
109    session_id: &str,
110    max_tokens: u32,
111) -> Result<Vec<Message>, MemoryError> {
112    let mut stmt = conn.prepare(
113        "SELECT id, session_id, role, content, token_count, created_at, metadata
114         FROM messages
115         WHERE session_id = ?1
116         ORDER BY created_at DESC, id DESC",
117    )?;
118
119    let all_messages: Vec<Message> = stmt
120        .query_map(params![session_id], |row| {
121            let role_str: String = row.get(2)?;
122            let metadata_str: Option<String> = row.get(6)?;
123            Ok(Message {
124                id: row.get(0)?,
125                session_id: row.get(1)?,
126                role: Role::from_str_value(&role_str).unwrap_or(Role::User),
127                content: row.get(3)?,
128                token_count: row.get(4)?,
129                created_at: row.get(5)?,
130                metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
131            })
132        })?
133        .collect::<Result<Vec<_>, _>>()?;
134
135    let mut collected = Vec::new();
136    let mut total_tokens: u32 = 0;
137
138    for msg in all_messages {
139        let msg_tokens = msg.token_count.unwrap_or(0);
140        if total_tokens + msg_tokens > max_tokens && !collected.is_empty() {
141            break;
142        }
143        total_tokens += msg_tokens;
144        collected.push(msg);
145    }
146
147    // Reverse to chronological order
148    collected.reverse();
149    Ok(collected)
150}
151
152/// Get total token count for a session.
153pub fn session_token_count(conn: &Connection, session_id: &str) -> Result<u64, MemoryError> {
154    let count: i64 = conn.query_row(
155        "SELECT COALESCE(SUM(token_count), 0) FROM messages WHERE session_id = ?1",
156        params![session_id],
157        |row| row.get(0),
158    )?;
159    Ok(count as u64)
160}
161
162/// Append a message to a session with a pre-computed embedding and FTS entry.
163///
164/// Same as `add_message` but also stores the embedding BLOB and inserts into
165/// the FTS bridge + FTS table. Wrap in a transaction.
166pub fn add_message_with_embedding(
167    conn: &Connection,
168    session_id: &str,
169    role: Role,
170    content: &str,
171    token_count: Option<u32>,
172    metadata: Option<&serde_json::Value>,
173    embedding_bytes: &[u8],
174) -> Result<i64, MemoryError> {
175    add_message_with_embedding_q8(conn, session_id, role, content, token_count, metadata, embedding_bytes, None)
176}
177
178/// Append an embedded message with optional quantized embedding.
179#[allow(clippy::too_many_arguments)]
180pub fn add_message_with_embedding_q8(
181    conn: &Connection,
182    session_id: &str,
183    role: Role,
184    content: &str,
185    token_count: Option<u32>,
186    metadata: Option<&serde_json::Value>,
187    embedding_bytes: &[u8],
188    q8_bytes: Option<&[u8]>,
189) -> Result<i64, MemoryError> {
190    // Verify session exists
191    let exists: bool = conn.query_row(
192        "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
193        params![session_id],
194        |row| row.get(0),
195    )?;
196    if !exists {
197        return Err(MemoryError::SessionNotFound(session_id.to_string()));
198    }
199
200    let metadata_str = metadata.map(|m| m.to_string());
201    with_transaction(conn, |tx| {
202        // INSERT into messages (with embedding + q8 BLOBs)
203        tx.execute(
204            "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
205            params![session_id, role.as_str(), content, token_count, metadata_str, embedding_bytes, q8_bytes],
206        )?;
207        let msg_id = tx.last_insert_rowid();
208
209        // INSERT into messages_rowid_map
210        tx.execute(
211            "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
212            params![msg_id],
213        )?;
214        let fts_rowid = tx.last_insert_rowid();
215
216        // INSERT into messages_fts
217        tx.execute(
218            "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
219            params![fts_rowid, content],
220        )?;
221
222        // UPDATE sessions SET updated_at
223        tx.execute(
224            "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
225            params![session_id],
226        )?;
227
228        Ok(msg_id)
229    })
230}
231
232/// Delete FTS entries for a single message. Needed for cleanup.
233pub fn delete_message_fts(conn: &Connection, message_id: i64) -> Result<(), MemoryError> {
234    // Get content and FTS rowid
235    let result: Result<(String, i64), _> = conn.query_row(
236        "SELECT m.content, mm.rowid
237         FROM messages m
238         JOIN messages_rowid_map mm ON mm.message_id = m.id
239         WHERE m.id = ?1",
240        params![message_id],
241        |row| Ok((row.get(0)?, row.get(1)?)),
242    );
243
244    if let Ok((content, fts_rowid)) = result {
245        conn.execute(
246            "INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', ?1, ?2)",
247            params![fts_rowid, content],
248        )?;
249        conn.execute(
250            "DELETE FROM messages_rowid_map WHERE message_id = ?1",
251            params![message_id],
252        )?;
253    }
254    // If no FTS entry exists (non-embedded message), that's fine — nothing to clean up.
255
256    Ok(())
257}
258
259/// Delete a session and all its messages (CASCADE).
260///
261/// Cleans up message FTS entries before CASCADE to avoid ghost entries.
262pub fn delete_session(conn: &Connection, session_id: &str) -> Result<(), MemoryError> {
263    with_transaction(conn, |tx| {
264        // Clean up message FTS entries before CASCADE
265        let fts_data: Vec<(i64, String, i64)> = {
266            let mut stmt = tx.prepare(
267                "SELECT m.id, m.content, mm.rowid
268                 FROM messages m
269                 JOIN messages_rowid_map mm ON mm.message_id = m.id
270                 WHERE m.session_id = ?1",
271            )?;
272            let result = stmt
273                .query_map(params![session_id], |row| {
274                    Ok((row.get(0)?, row.get(1)?, row.get(2)?))
275                })?
276                .collect::<Result<Vec<_>, _>>()?;
277            result
278        };
279
280        for (_msg_id, content, fts_rowid) in &fts_data {
281            tx.execute(
282                "INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', ?1, ?2)",
283                params![fts_rowid, content],
284            )?;
285        }
286
287        // Now delete session (CASCADE handles messages + messages_rowid_map)
288        let affected = tx.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
289        if affected == 0 {
290            return Err(MemoryError::SessionNotFound(session_id.to_string()));
291        }
292
293        Ok(())
294    })
295}
296
297/// List recent sessions with message counts, newest first.
298pub fn list_sessions(
299    conn: &Connection,
300    limit: usize,
301    offset: usize,
302) -> Result<Vec<Session>, MemoryError> {
303    let mut stmt = conn.prepare(
304        "SELECT s.id, s.channel, s.created_at, s.updated_at, s.metadata,
305                COUNT(m.id) AS message_count
306         FROM sessions s
307         LEFT JOIN messages m ON m.session_id = s.id
308         GROUP BY s.id
309         ORDER BY s.updated_at DESC
310         LIMIT ?1 OFFSET ?2",
311    )?;
312
313    let sessions = stmt
314        .query_map(params![limit as i64, offset as i64], |row| {
315            let metadata_str: Option<String> = row.get(4)?;
316            let message_count: i64 = row.get(5)?;
317            Ok(Session {
318                id: row.get(0)?,
319                channel: row.get(1)?,
320                created_at: row.get(2)?,
321                updated_at: row.get(3)?,
322                metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
323                message_count: message_count as u32,
324            })
325        })?
326        .collect::<Result<Vec<_>, _>>()?;
327
328    Ok(sessions)
329}