1use chrono::{DateTime, Utc};
2use uuid::Uuid;
3
4use crate::{Db, DbError};
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum Role {
8 User,
9 Assistant,
10 System,
11 Tool,
12}
13
14impl Role {
15 pub fn as_str(&self) -> &'static str {
16 match self {
17 Role::User => "user",
18 Role::Assistant => "assistant",
19 Role::System => "system",
20 Role::Tool => "tool",
21 }
22 }
23
24 pub fn parse(s: &str) -> Option<Self> {
25 match s {
26 "user" => Some(Role::User),
27 "assistant" => Some(Role::Assistant),
28 "system" => Some(Role::System),
29 "tool" => Some(Role::Tool),
30 _ => None,
31 }
32 }
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub enum BlockType {
37 Text,
38 Thinking,
39 ToolCall,
40 ToolResult,
41}
42
43impl BlockType {
44 pub fn as_str(&self) -> &'static str {
45 match self {
46 BlockType::Text => "text",
47 BlockType::Thinking => "thinking",
48 BlockType::ToolCall => "tool_call",
49 BlockType::ToolResult => "tool_result",
50 }
51 }
52
53 pub fn parse(s: &str) -> Option<Self> {
54 match s {
55 "text" => Some(BlockType::Text),
56 "thinking" => Some(BlockType::Thinking),
57 "tool_call" => Some(BlockType::ToolCall),
58 "tool_result" => Some(BlockType::ToolResult),
59 _ => None,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
65pub struct Message {
66 pub id: String,
67 pub session_id: String,
68 pub role: Role,
69 pub content: Option<String>,
70 pub model_id: Option<String>,
71 pub provider_name: Option<String>,
72 pub created_at: DateTime<Utc>,
73 pub token_usage_in: Option<i64>,
74 pub token_usage_out: Option<i64>,
75 pub latency_ms: Option<i64>,
76 pub parent_message_id: Option<String>,
77 pub blocks: Vec<MessageBlock>,
78}
79
80#[derive(Debug, Clone)]
81pub struct MessageBlock {
82 pub id: String,
83 pub message_id: String,
84 pub block_type: BlockType,
85 pub content: Option<String>,
86 pub sequence: i32,
87 pub compressed: bool,
88}
89
90pub struct NewMessage<'a> {
91 pub session_id: &'a str,
92 pub role: &'a str,
93 pub content: Option<&'a str>,
94 pub model_id: Option<&'a str>,
95 pub provider_name: Option<&'a str>,
96 pub parent_message_id: Option<&'a str>,
97}
98
99pub struct NewBlock<'a> {
100 pub block_type: &'a str,
101 pub content: &'a str,
102 pub sequence: i32,
103}
104
105impl Db {
106 pub fn create_message(
107 &self,
108 msg: &NewMessage,
109 blocks: &[NewBlock],
110 ) -> Result<Message, DbError> {
111 let id = Uuid::new_v4().to_string();
112 let now = Utc::now();
113 let now_str = now.to_rfc3339();
114
115 self.conn.execute(
116 "INSERT INTO messages (id, session_id, role, content, model_id, provider_name, created_at, parent_message_id)
117 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
118 rusqlite::params![
119 id,
120 msg.session_id,
121 msg.role,
122 msg.content,
123 msg.model_id,
124 msg.provider_name,
125 now_str,
126 msg.parent_message_id,
127 ],
128 )?;
129
130 self.conn.execute(
132 "UPDATE sessions SET updated_at = ?1 WHERE id = ?2",
133 rusqlite::params![now_str, msg.session_id],
134 )?;
135
136 let mut message_blocks = Vec::new();
137 for block in blocks {
138 let block_id = Uuid::new_v4().to_string();
139 self.conn.execute(
140 "INSERT INTO message_blocks (id, message_id, block_type, content, sequence)
141 VALUES (?1, ?2, ?3, ?4, ?5)",
142 rusqlite::params![
143 block_id,
144 id,
145 block.block_type,
146 block.content,
147 block.sequence
148 ],
149 )?;
150 message_blocks.push(MessageBlock {
151 id: block_id,
152 message_id: id.clone(),
153 block_type: BlockType::parse(block.block_type).unwrap_or(BlockType::Text),
154 content: Some(block.content.to_string()),
155 sequence: block.sequence,
156 compressed: false,
157 });
158 }
159
160 Ok(Message {
161 id,
162 session_id: msg.session_id.to_string(),
163 role: Role::parse(msg.role).unwrap_or(Role::User),
164 content: msg.content.map(|s| s.to_string()),
165 model_id: msg.model_id.map(|s| s.to_string()),
166 provider_name: msg.provider_name.map(|s| s.to_string()),
167 created_at: now,
168 token_usage_in: None,
169 token_usage_out: None,
170 latency_ms: None,
171 parent_message_id: msg.parent_message_id.map(|s| s.to_string()),
172 blocks: message_blocks,
173 })
174 }
175
176 pub fn update_message_usage(
177 &self,
178 message_id: &str,
179 tokens_in: i64,
180 tokens_out: i64,
181 latency_ms: i64,
182 ) -> Result<(), DbError> {
183 let rows = self.conn.execute(
184 "UPDATE messages SET token_usage_in = ?1, token_usage_out = ?2, latency_ms = ?3 WHERE id = ?4",
185 rusqlite::params![tokens_in, tokens_out, latency_ms, message_id],
186 )?;
187 if rows == 0 {
188 return Err(DbError::NotFound(format!("message {message_id}")));
189 }
190 Ok(())
191 }
192
193 pub fn get_session_messages(&self, session_id: &str) -> Result<Vec<Message>, DbError> {
194 let mut stmt = self.conn.prepare(
195 "SELECT id, session_id, role, content, model_id, provider_name, created_at,
196 token_usage_in, token_usage_out, latency_ms, parent_message_id
197 FROM messages WHERE session_id = ?1 ORDER BY created_at ASC",
198 )?;
199
200 let rows = stmt.query_map(rusqlite::params![session_id], |row| {
201 let created_str: String = row.get(6)?;
202 Ok((
203 row.get::<_, String>(0)?,
204 row.get::<_, String>(1)?,
205 row.get::<_, String>(2)?,
206 row.get::<_, Option<String>>(3)?,
207 row.get::<_, Option<String>>(4)?,
208 row.get::<_, Option<String>>(5)?,
209 created_str,
210 row.get::<_, Option<i64>>(7)?,
211 row.get::<_, Option<i64>>(8)?,
212 row.get::<_, Option<i64>>(9)?,
213 row.get::<_, Option<String>>(10)?,
214 ))
215 })?;
216
217 let mut messages = Vec::new();
218 for row in rows {
219 let r = row?;
220 let blocks = self.get_message_blocks(&r.0)?;
221 messages.push(Message {
222 id: r.0,
223 session_id: r.1,
224 role: Role::parse(&r.2).unwrap_or(Role::User),
225 content: r.3,
226 model_id: r.4,
227 provider_name: r.5,
228 created_at: DateTime::parse_from_rfc3339(&r.6)
229 .unwrap_or_default()
230 .with_timezone(&Utc),
231 token_usage_in: r.7,
232 token_usage_out: r.8,
233 latency_ms: r.9,
234 parent_message_id: r.10,
235 blocks,
236 });
237 }
238
239 Ok(messages)
240 }
241
242 fn get_message_blocks(&self, message_id: &str) -> Result<Vec<MessageBlock>, DbError> {
243 let mut stmt = self.conn.prepare(
244 "SELECT id, message_id, block_type, content, sequence, compressed
245 FROM message_blocks WHERE message_id = ?1 ORDER BY sequence ASC",
246 )?;
247
248 let rows = stmt.query_map(rusqlite::params![message_id], |row| {
249 Ok((
250 row.get::<_, String>(0)?,
251 row.get::<_, String>(1)?,
252 row.get::<_, String>(2)?,
253 row.get::<_, Option<String>>(3)?,
254 row.get::<_, i32>(4)?,
255 row.get::<_, bool>(5)?,
256 ))
257 })?;
258
259 let mut blocks = Vec::new();
260 for row in rows {
261 let r = row?;
262 blocks.push(MessageBlock {
263 id: r.0,
264 message_id: r.1,
265 block_type: BlockType::parse(&r.2).unwrap_or(BlockType::Text),
266 content: r.3,
267 sequence: r.4,
268 compressed: r.5,
269 });
270 }
271
272 Ok(blocks)
273 }
274
275 pub fn delete_message(&self, message_id: &str) -> Result<(), DbError> {
277 self.conn.execute(
278 "DELETE FROM messages WHERE id = ?1",
279 rusqlite::params![message_id],
280 )?;
281 Ok(())
282 }
283
284 pub fn compress_thinking_blocks(&self, older_than_days: i64) -> Result<usize, DbError> {
285 let cutoff = Utc::now() - chrono::Duration::days(older_than_days);
286 let cutoff_str = cutoff.to_rfc3339();
287
288 let count = self.conn.execute(
289 "UPDATE message_blocks SET content = NULL, compressed = 1
290 WHERE block_type = 'thinking' AND compressed = 0
291 AND message_id IN (
292 SELECT id FROM messages WHERE created_at < ?1
293 )",
294 rusqlite::params![cutoff_str],
295 )?;
296
297 Ok(count)
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use crate::Db;
304 use crate::messages::{NewBlock, NewMessage};
305
306 fn setup() -> Db {
307 Db::open_in_memory().unwrap()
308 }
309
310 #[test]
311 fn test_create_message_with_blocks() {
312 let db = setup();
313 let session = db.create_session("Test").unwrap();
314
315 let msg = NewMessage {
316 session_id: &session.id,
317 role: "user",
318 content: Some("Hello world"),
319 model_id: Some("claude-3"),
320 provider_name: Some("anthropic"),
321 parent_message_id: None,
322 };
323
324 let blocks = vec![
325 NewBlock {
326 block_type: "text",
327 content: "Hello world",
328 sequence: 0,
329 },
330 NewBlock {
331 block_type: "thinking",
332 content: "Let me think...",
333 sequence: 1,
334 },
335 ];
336
337 let message = db.create_message(&msg, &blocks).unwrap();
338 assert_eq!(message.content, Some("Hello world".to_string()));
339 assert_eq!(message.blocks.len(), 2);
340 assert_eq!(message.blocks[0].sequence, 0);
341 assert_eq!(message.blocks[1].sequence, 1);
342 }
343
344 #[test]
345 fn test_get_session_messages_ordered() {
346 let db = setup();
347 let session = db.create_session("Test").unwrap();
348
349 let msg1 = NewMessage {
350 session_id: &session.id,
351 role: "user",
352 content: Some("First"),
353 model_id: None,
354 provider_name: None,
355 parent_message_id: None,
356 };
357 db.create_message(&msg1, &[]).unwrap();
358
359 let msg2 = NewMessage {
360 session_id: &session.id,
361 role: "assistant",
362 content: Some("Second"),
363 model_id: None,
364 provider_name: None,
365 parent_message_id: None,
366 };
367 db.create_message(&msg2, &[]).unwrap();
368
369 let messages = db.get_session_messages(&session.id).unwrap();
370 assert_eq!(messages.len(), 2);
371 assert_eq!(messages[0].content, Some("First".to_string()));
372 assert_eq!(messages[1].content, Some("Second".to_string()));
373 }
374
375 #[test]
376 fn test_update_message_usage() {
377 let db = setup();
378 let session = db.create_session("Test").unwrap();
379
380 let msg = NewMessage {
381 session_id: &session.id,
382 role: "assistant",
383 content: Some("Response"),
384 model_id: None,
385 provider_name: None,
386 parent_message_id: None,
387 };
388 let message = db.create_message(&msg, &[]).unwrap();
389
390 db.update_message_usage(&message.id, 100, 200, 1500)
391 .unwrap();
392
393 let messages = db.get_session_messages(&session.id).unwrap();
394 assert_eq!(messages[0].token_usage_in, Some(100));
395 assert_eq!(messages[0].token_usage_out, Some(200));
396 assert_eq!(messages[0].latency_ms, Some(1500));
397 }
398
399 #[test]
400 fn test_compress_thinking_blocks() {
401 let db = setup();
402 let session = db.create_session("Test").unwrap();
403
404 let msg = NewMessage {
405 session_id: &session.id,
406 role: "assistant",
407 content: Some("Response"),
408 model_id: None,
409 provider_name: None,
410 parent_message_id: None,
411 };
412
413 let blocks = vec![
414 NewBlock {
415 block_type: "text",
416 content: "visible",
417 sequence: 0,
418 },
419 NewBlock {
420 block_type: "thinking",
421 content: "internal thoughts",
422 sequence: 1,
423 },
424 ];
425
426 db.create_message(&msg, &blocks).unwrap();
427
428 let count = db.compress_thinking_blocks(0).unwrap();
430 assert_eq!(count, 1);
431
432 let messages = db.get_session_messages(&session.id).unwrap();
433 let thinking_block = messages[0]
434 .blocks
435 .iter()
436 .find(|b| b.block_type == crate::messages::BlockType::Thinking)
437 .unwrap();
438 assert!(thinking_block.compressed);
439 assert!(thinking_block.content.is_none());
440
441 let text_block = messages[0]
443 .blocks
444 .iter()
445 .find(|b| b.block_type == crate::messages::BlockType::Text)
446 .unwrap();
447 assert!(!text_block.compressed);
448 assert_eq!(text_block.content, Some("visible".to_string()));
449 }
450
451 #[test]
452 fn test_cascade_delete() {
453 let db = setup();
454 let session = db.create_session("Test").unwrap();
455
456 let msg = NewMessage {
457 session_id: &session.id,
458 role: "user",
459 content: Some("Hello"),
460 model_id: None,
461 provider_name: None,
462 parent_message_id: None,
463 };
464 let blocks = vec![NewBlock {
465 block_type: "text",
466 content: "Hello",
467 sequence: 0,
468 }];
469 db.create_message(&msg, &blocks).unwrap();
470
471 db.delete_session(&session.id).unwrap();
473
474 let messages = db.get_session_messages(&session.id).unwrap();
475 assert!(messages.is_empty());
476 }
477}