1use crate::db::with_transaction;
4use crate::error::MemoryError;
5use crate::types::{Message, Role, Session};
6use rusqlite::{params, Connection};
7
8pub fn create_session(
12 conn: &Connection,
13 channel: &str,
14 metadata: Option<&serde_json::Value>,
15) -> Result<String, MemoryError> {
16 let id = uuid::Uuid::new_v4().to_string();
17 let metadata_str = metadata.map(|m| m.to_string());
18 conn.execute(
19 "INSERT INTO sessions (id, channel, metadata) VALUES (?1, ?2, ?3)",
20 params![id, channel, metadata_str],
21 )?;
22 Ok(id)
23}
24
25pub fn add_message(
29 conn: &Connection,
30 session_id: &str,
31 role: Role,
32 content: &str,
33 token_count: Option<u32>,
34 metadata: Option<&serde_json::Value>,
35) -> Result<i64, MemoryError> {
36 let exists: bool = conn.query_row(
38 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
39 params![session_id],
40 |row| row.get(0),
41 )?;
42 if !exists {
43 return Err(MemoryError::SessionNotFound(session_id.to_string()));
44 }
45
46 let metadata_str = metadata.map(|m| m.to_string());
47 with_transaction(conn, |tx| {
48 tx.execute(
49 "INSERT INTO messages (session_id, role, content, token_count, metadata) VALUES (?1, ?2, ?3, ?4, ?5)",
50 params![session_id, role.as_str(), content, token_count, metadata_str],
51 )?;
52 let msg_id = tx.last_insert_rowid();
53
54 tx.execute(
55 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
56 params![session_id],
57 )?;
58
59 Ok(msg_id)
60 })
61}
62
63pub fn get_recent_messages(
65 conn: &Connection,
66 session_id: &str,
67 limit: usize,
68) -> Result<Vec<Message>, MemoryError> {
69 let mut stmt = conn.prepare(
70 "SELECT id, session_id, role, content, token_count, created_at, metadata
71 FROM messages
72 WHERE session_id = ?1
73 ORDER BY created_at DESC, id DESC
74 LIMIT ?2",
75 )?;
76
77 let mut messages: Vec<Message> = stmt
78 .query_map(params![session_id, limit as i64], |row| {
79 let role_str: String = row.get(2)?;
80 let metadata_str: Option<String> = row.get(6)?;
81 Ok(Message {
82 id: row.get(0)?,
83 session_id: row.get(1)?,
84 role: Role::from_str_value(&role_str).unwrap_or(Role::User),
85 content: row.get(3)?,
86 token_count: row.get(4)?,
87 created_at: row.get(5)?,
88 metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
89 })
90 })?
91 .collect::<Result<Vec<_>, _>>()?;
92
93 messages.reverse();
95 Ok(messages)
96}
97
98pub fn get_messages_within_budget(
108 conn: &Connection,
109 session_id: &str,
110 max_tokens: u32,
111) -> Result<Vec<Message>, MemoryError> {
112 let mut stmt = conn.prepare(
113 "SELECT id, session_id, role, content, token_count, created_at, metadata
114 FROM messages
115 WHERE session_id = ?1
116 ORDER BY created_at DESC, id DESC",
117 )?;
118
119 let all_messages: Vec<Message> = stmt
120 .query_map(params![session_id], |row| {
121 let role_str: String = row.get(2)?;
122 let metadata_str: Option<String> = row.get(6)?;
123 Ok(Message {
124 id: row.get(0)?,
125 session_id: row.get(1)?,
126 role: Role::from_str_value(&role_str).unwrap_or(Role::User),
127 content: row.get(3)?,
128 token_count: row.get(4)?,
129 created_at: row.get(5)?,
130 metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
131 })
132 })?
133 .collect::<Result<Vec<_>, _>>()?;
134
135 let mut collected = Vec::new();
136 let mut total_tokens: u32 = 0;
137
138 for msg in all_messages {
139 let msg_tokens = msg.token_count.unwrap_or(0);
140 if total_tokens + msg_tokens > max_tokens && !collected.is_empty() {
141 break;
142 }
143 total_tokens += msg_tokens;
144 collected.push(msg);
145 }
146
147 collected.reverse();
149 Ok(collected)
150}
151
152pub fn session_token_count(conn: &Connection, session_id: &str) -> Result<u64, MemoryError> {
154 let count: i64 = conn.query_row(
155 "SELECT COALESCE(SUM(token_count), 0) FROM messages WHERE session_id = ?1",
156 params![session_id],
157 |row| row.get(0),
158 )?;
159 Ok(count as u64)
160}
161
162pub fn add_message_with_embedding(
167 conn: &Connection,
168 session_id: &str,
169 role: Role,
170 content: &str,
171 token_count: Option<u32>,
172 metadata: Option<&serde_json::Value>,
173 embedding_bytes: &[u8],
174) -> Result<i64, MemoryError> {
175 add_message_with_embedding_q8(conn, session_id, role, content, token_count, metadata, embedding_bytes, None)
176}
177
178#[allow(clippy::too_many_arguments)]
180pub fn add_message_with_embedding_q8(
181 conn: &Connection,
182 session_id: &str,
183 role: Role,
184 content: &str,
185 token_count: Option<u32>,
186 metadata: Option<&serde_json::Value>,
187 embedding_bytes: &[u8],
188 q8_bytes: Option<&[u8]>,
189) -> Result<i64, MemoryError> {
190 let exists: bool = conn.query_row(
192 "SELECT EXISTS(SELECT 1 FROM sessions WHERE id = ?1)",
193 params![session_id],
194 |row| row.get(0),
195 )?;
196 if !exists {
197 return Err(MemoryError::SessionNotFound(session_id.to_string()));
198 }
199
200 let metadata_str = metadata.map(|m| m.to_string());
201 with_transaction(conn, |tx| {
202 tx.execute(
204 "INSERT INTO messages (session_id, role, content, token_count, metadata, embedding, embedding_q8) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
205 params![session_id, role.as_str(), content, token_count, metadata_str, embedding_bytes, q8_bytes],
206 )?;
207 let msg_id = tx.last_insert_rowid();
208
209 tx.execute(
211 "INSERT INTO messages_rowid_map (message_id) VALUES (?1)",
212 params![msg_id],
213 )?;
214 let fts_rowid = tx.last_insert_rowid();
215
216 tx.execute(
218 "INSERT INTO messages_fts(rowid, content) VALUES (?1, ?2)",
219 params![fts_rowid, content],
220 )?;
221
222 tx.execute(
224 "UPDATE sessions SET updated_at = datetime('now') WHERE id = ?1",
225 params![session_id],
226 )?;
227
228 Ok(msg_id)
229 })
230}
231
232pub fn delete_message_fts(conn: &Connection, message_id: i64) -> Result<(), MemoryError> {
234 let result: Result<(String, i64), _> = conn.query_row(
236 "SELECT m.content, mm.rowid
237 FROM messages m
238 JOIN messages_rowid_map mm ON mm.message_id = m.id
239 WHERE m.id = ?1",
240 params![message_id],
241 |row| Ok((row.get(0)?, row.get(1)?)),
242 );
243
244 if let Ok((content, fts_rowid)) = result {
245 conn.execute(
246 "INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', ?1, ?2)",
247 params![fts_rowid, content],
248 )?;
249 conn.execute(
250 "DELETE FROM messages_rowid_map WHERE message_id = ?1",
251 params![message_id],
252 )?;
253 }
254 Ok(())
257}
258
259pub fn delete_session(conn: &Connection, session_id: &str) -> Result<(), MemoryError> {
263 with_transaction(conn, |tx| {
264 let fts_data: Vec<(i64, String, i64)> = {
266 let mut stmt = tx.prepare(
267 "SELECT m.id, m.content, mm.rowid
268 FROM messages m
269 JOIN messages_rowid_map mm ON mm.message_id = m.id
270 WHERE m.session_id = ?1",
271 )?;
272 let result = stmt
273 .query_map(params![session_id], |row| {
274 Ok((row.get(0)?, row.get(1)?, row.get(2)?))
275 })?
276 .collect::<Result<Vec<_>, _>>()?;
277 result
278 };
279
280 for (_msg_id, content, fts_rowid) in &fts_data {
281 tx.execute(
282 "INSERT INTO messages_fts(messages_fts, rowid, content) VALUES('delete', ?1, ?2)",
283 params![fts_rowid, content],
284 )?;
285 }
286
287 let affected = tx.execute("DELETE FROM sessions WHERE id = ?1", params![session_id])?;
289 if affected == 0 {
290 return Err(MemoryError::SessionNotFound(session_id.to_string()));
291 }
292
293 Ok(())
294 })
295}
296
297pub fn list_sessions(
299 conn: &Connection,
300 limit: usize,
301 offset: usize,
302) -> Result<Vec<Session>, MemoryError> {
303 let mut stmt = conn.prepare(
304 "SELECT s.id, s.channel, s.created_at, s.updated_at, s.metadata,
305 COUNT(m.id) AS message_count
306 FROM sessions s
307 LEFT JOIN messages m ON m.session_id = s.id
308 GROUP BY s.id
309 ORDER BY s.updated_at DESC
310 LIMIT ?1 OFFSET ?2",
311 )?;
312
313 let sessions = stmt
314 .query_map(params![limit as i64, offset as i64], |row| {
315 let metadata_str: Option<String> = row.get(4)?;
316 let message_count: i64 = row.get(5)?;
317 Ok(Session {
318 id: row.get(0)?,
319 channel: row.get(1)?,
320 created_at: row.get(2)?,
321 updated_at: row.get(3)?,
322 metadata: metadata_str.and_then(|s| serde_json::from_str(&s).ok()),
323 message_count: message_count as u32,
324 })
325 })?
326 .collect::<Result<Vec<_>, _>>()?;
327
328 Ok(sessions)
329}