Skip to main content

tuillem_db/
messages.rs

1use chrono::{DateTime, Utc};
2use uuid::Uuid;
3
4use crate::{Db, DbError};
5
6#[derive(Debug, Clone, PartialEq)]
7pub enum Role {
8    User,
9    Assistant,
10    System,
11    Tool,
12}
13
14impl Role {
15    pub fn as_str(&self) -> &'static str {
16        match self {
17            Role::User => "user",
18            Role::Assistant => "assistant",
19            Role::System => "system",
20            Role::Tool => "tool",
21        }
22    }
23
24    pub fn parse(s: &str) -> Option<Self> {
25        match s {
26            "user" => Some(Role::User),
27            "assistant" => Some(Role::Assistant),
28            "system" => Some(Role::System),
29            "tool" => Some(Role::Tool),
30            _ => None,
31        }
32    }
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub enum BlockType {
37    Text,
38    Thinking,
39    ToolCall,
40    ToolResult,
41}
42
43impl BlockType {
44    pub fn as_str(&self) -> &'static str {
45        match self {
46            BlockType::Text => "text",
47            BlockType::Thinking => "thinking",
48            BlockType::ToolCall => "tool_call",
49            BlockType::ToolResult => "tool_result",
50        }
51    }
52
53    pub fn parse(s: &str) -> Option<Self> {
54        match s {
55            "text" => Some(BlockType::Text),
56            "thinking" => Some(BlockType::Thinking),
57            "tool_call" => Some(BlockType::ToolCall),
58            "tool_result" => Some(BlockType::ToolResult),
59            _ => None,
60        }
61    }
62}
63
64#[derive(Debug, Clone)]
65pub struct Message {
66    pub id: String,
67    pub session_id: String,
68    pub role: Role,
69    pub content: Option<String>,
70    pub model_id: Option<String>,
71    pub provider_name: Option<String>,
72    pub created_at: DateTime<Utc>,
73    pub token_usage_in: Option<i64>,
74    pub token_usage_out: Option<i64>,
75    pub latency_ms: Option<i64>,
76    pub parent_message_id: Option<String>,
77    pub blocks: Vec<MessageBlock>,
78}
79
80#[derive(Debug, Clone)]
81pub struct MessageBlock {
82    pub id: String,
83    pub message_id: String,
84    pub block_type: BlockType,
85    pub content: Option<String>,
86    pub sequence: i32,
87    pub compressed: bool,
88}
89
90pub struct NewMessage<'a> {
91    pub session_id: &'a str,
92    pub role: &'a str,
93    pub content: Option<&'a str>,
94    pub model_id: Option<&'a str>,
95    pub provider_name: Option<&'a str>,
96    pub parent_message_id: Option<&'a str>,
97}
98
99pub struct NewBlock<'a> {
100    pub block_type: &'a str,
101    pub content: &'a str,
102    pub sequence: i32,
103}
104
105impl Db {
106    pub fn create_message(
107        &self,
108        msg: &NewMessage,
109        blocks: &[NewBlock],
110    ) -> Result<Message, DbError> {
111        let id = Uuid::new_v4().to_string();
112        let now = Utc::now();
113        let now_str = now.to_rfc3339();
114
115        self.conn.execute(
116            "INSERT INTO messages (id, session_id, role, content, model_id, provider_name, created_at, parent_message_id)
117             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
118            rusqlite::params![
119                id,
120                msg.session_id,
121                msg.role,
122                msg.content,
123                msg.model_id,
124                msg.provider_name,
125                now_str,
126                msg.parent_message_id,
127            ],
128        )?;
129
130        // Update session updated_at
131        self.conn.execute(
132            "UPDATE sessions SET updated_at = ?1 WHERE id = ?2",
133            rusqlite::params![now_str, msg.session_id],
134        )?;
135
136        let mut message_blocks = Vec::new();
137        for block in blocks {
138            let block_id = Uuid::new_v4().to_string();
139            self.conn.execute(
140                "INSERT INTO message_blocks (id, message_id, block_type, content, sequence)
141                 VALUES (?1, ?2, ?3, ?4, ?5)",
142                rusqlite::params![
143                    block_id,
144                    id,
145                    block.block_type,
146                    block.content,
147                    block.sequence
148                ],
149            )?;
150            message_blocks.push(MessageBlock {
151                id: block_id,
152                message_id: id.clone(),
153                block_type: BlockType::parse(block.block_type).unwrap_or(BlockType::Text),
154                content: Some(block.content.to_string()),
155                sequence: block.sequence,
156                compressed: false,
157            });
158        }
159
160        Ok(Message {
161            id,
162            session_id: msg.session_id.to_string(),
163            role: Role::parse(msg.role).unwrap_or(Role::User),
164            content: msg.content.map(|s| s.to_string()),
165            model_id: msg.model_id.map(|s| s.to_string()),
166            provider_name: msg.provider_name.map(|s| s.to_string()),
167            created_at: now,
168            token_usage_in: None,
169            token_usage_out: None,
170            latency_ms: None,
171            parent_message_id: msg.parent_message_id.map(|s| s.to_string()),
172            blocks: message_blocks,
173        })
174    }
175
176    pub fn update_message_usage(
177        &self,
178        message_id: &str,
179        tokens_in: i64,
180        tokens_out: i64,
181        latency_ms: i64,
182    ) -> Result<(), DbError> {
183        let rows = self.conn.execute(
184            "UPDATE messages SET token_usage_in = ?1, token_usage_out = ?2, latency_ms = ?3 WHERE id = ?4",
185            rusqlite::params![tokens_in, tokens_out, latency_ms, message_id],
186        )?;
187        if rows == 0 {
188            return Err(DbError::NotFound(format!("message {message_id}")));
189        }
190        Ok(())
191    }
192
193    pub fn get_session_messages(&self, session_id: &str) -> Result<Vec<Message>, DbError> {
194        let mut stmt = self.conn.prepare(
195            "SELECT id, session_id, role, content, model_id, provider_name, created_at,
196                    token_usage_in, token_usage_out, latency_ms, parent_message_id
197             FROM messages WHERE session_id = ?1 ORDER BY created_at ASC",
198        )?;
199
200        let rows = stmt.query_map(rusqlite::params![session_id], |row| {
201            let created_str: String = row.get(6)?;
202            Ok((
203                row.get::<_, String>(0)?,
204                row.get::<_, String>(1)?,
205                row.get::<_, String>(2)?,
206                row.get::<_, Option<String>>(3)?,
207                row.get::<_, Option<String>>(4)?,
208                row.get::<_, Option<String>>(5)?,
209                created_str,
210                row.get::<_, Option<i64>>(7)?,
211                row.get::<_, Option<i64>>(8)?,
212                row.get::<_, Option<i64>>(9)?,
213                row.get::<_, Option<String>>(10)?,
214            ))
215        })?;
216
217        let mut messages = Vec::new();
218        for row in rows {
219            let r = row?;
220            let blocks = self.get_message_blocks(&r.0)?;
221            messages.push(Message {
222                id: r.0,
223                session_id: r.1,
224                role: Role::parse(&r.2).unwrap_or(Role::User),
225                content: r.3,
226                model_id: r.4,
227                provider_name: r.5,
228                created_at: DateTime::parse_from_rfc3339(&r.6)
229                    .unwrap_or_default()
230                    .with_timezone(&Utc),
231                token_usage_in: r.7,
232                token_usage_out: r.8,
233                latency_ms: r.9,
234                parent_message_id: r.10,
235                blocks,
236            });
237        }
238
239        Ok(messages)
240    }
241
242    fn get_message_blocks(&self, message_id: &str) -> Result<Vec<MessageBlock>, DbError> {
243        let mut stmt = self.conn.prepare(
244            "SELECT id, message_id, block_type, content, sequence, compressed
245             FROM message_blocks WHERE message_id = ?1 ORDER BY sequence ASC",
246        )?;
247
248        let rows = stmt.query_map(rusqlite::params![message_id], |row| {
249            Ok((
250                row.get::<_, String>(0)?,
251                row.get::<_, String>(1)?,
252                row.get::<_, String>(2)?,
253                row.get::<_, Option<String>>(3)?,
254                row.get::<_, i32>(4)?,
255                row.get::<_, bool>(5)?,
256            ))
257        })?;
258
259        let mut blocks = Vec::new();
260        for row in rows {
261            let r = row?;
262            blocks.push(MessageBlock {
263                id: r.0,
264                message_id: r.1,
265                block_type: BlockType::parse(&r.2).unwrap_or(BlockType::Text),
266                content: r.3,
267                sequence: r.4,
268                compressed: r.5,
269            });
270        }
271
272        Ok(blocks)
273    }
274
275    /// Delete a message and its blocks by ID.
276    pub fn delete_message(&self, message_id: &str) -> Result<(), DbError> {
277        self.conn.execute(
278            "DELETE FROM messages WHERE id = ?1",
279            rusqlite::params![message_id],
280        )?;
281        Ok(())
282    }
283
284    pub fn compress_thinking_blocks(&self, older_than_days: i64) -> Result<usize, DbError> {
285        let cutoff = Utc::now() - chrono::Duration::days(older_than_days);
286        let cutoff_str = cutoff.to_rfc3339();
287
288        let count = self.conn.execute(
289            "UPDATE message_blocks SET content = NULL, compressed = 1
290             WHERE block_type = 'thinking' AND compressed = 0
291             AND message_id IN (
292                 SELECT id FROM messages WHERE created_at < ?1
293             )",
294            rusqlite::params![cutoff_str],
295        )?;
296
297        Ok(count)
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use crate::Db;
304    use crate::messages::{NewBlock, NewMessage};
305
306    fn setup() -> Db {
307        Db::open_in_memory().unwrap()
308    }
309
310    #[test]
311    fn test_create_message_with_blocks() {
312        let db = setup();
313        let session = db.create_session("Test").unwrap();
314
315        let msg = NewMessage {
316            session_id: &session.id,
317            role: "user",
318            content: Some("Hello world"),
319            model_id: Some("claude-3"),
320            provider_name: Some("anthropic"),
321            parent_message_id: None,
322        };
323
324        let blocks = vec![
325            NewBlock {
326                block_type: "text",
327                content: "Hello world",
328                sequence: 0,
329            },
330            NewBlock {
331                block_type: "thinking",
332                content: "Let me think...",
333                sequence: 1,
334            },
335        ];
336
337        let message = db.create_message(&msg, &blocks).unwrap();
338        assert_eq!(message.content, Some("Hello world".to_string()));
339        assert_eq!(message.blocks.len(), 2);
340        assert_eq!(message.blocks[0].sequence, 0);
341        assert_eq!(message.blocks[1].sequence, 1);
342    }
343
344    #[test]
345    fn test_get_session_messages_ordered() {
346        let db = setup();
347        let session = db.create_session("Test").unwrap();
348
349        let msg1 = NewMessage {
350            session_id: &session.id,
351            role: "user",
352            content: Some("First"),
353            model_id: None,
354            provider_name: None,
355            parent_message_id: None,
356        };
357        db.create_message(&msg1, &[]).unwrap();
358
359        let msg2 = NewMessage {
360            session_id: &session.id,
361            role: "assistant",
362            content: Some("Second"),
363            model_id: None,
364            provider_name: None,
365            parent_message_id: None,
366        };
367        db.create_message(&msg2, &[]).unwrap();
368
369        let messages = db.get_session_messages(&session.id).unwrap();
370        assert_eq!(messages.len(), 2);
371        assert_eq!(messages[0].content, Some("First".to_string()));
372        assert_eq!(messages[1].content, Some("Second".to_string()));
373    }
374
375    #[test]
376    fn test_update_message_usage() {
377        let db = setup();
378        let session = db.create_session("Test").unwrap();
379
380        let msg = NewMessage {
381            session_id: &session.id,
382            role: "assistant",
383            content: Some("Response"),
384            model_id: None,
385            provider_name: None,
386            parent_message_id: None,
387        };
388        let message = db.create_message(&msg, &[]).unwrap();
389
390        db.update_message_usage(&message.id, 100, 200, 1500)
391            .unwrap();
392
393        let messages = db.get_session_messages(&session.id).unwrap();
394        assert_eq!(messages[0].token_usage_in, Some(100));
395        assert_eq!(messages[0].token_usage_out, Some(200));
396        assert_eq!(messages[0].latency_ms, Some(1500));
397    }
398
399    #[test]
400    fn test_compress_thinking_blocks() {
401        let db = setup();
402        let session = db.create_session("Test").unwrap();
403
404        let msg = NewMessage {
405            session_id: &session.id,
406            role: "assistant",
407            content: Some("Response"),
408            model_id: None,
409            provider_name: None,
410            parent_message_id: None,
411        };
412
413        let blocks = vec![
414            NewBlock {
415                block_type: "text",
416                content: "visible",
417                sequence: 0,
418            },
419            NewBlock {
420                block_type: "thinking",
421                content: "internal thoughts",
422                sequence: 1,
423            },
424        ];
425
426        db.create_message(&msg, &blocks).unwrap();
427
428        // Compress blocks older than 0 days (i.e., all)
429        let count = db.compress_thinking_blocks(0).unwrap();
430        assert_eq!(count, 1);
431
432        let messages = db.get_session_messages(&session.id).unwrap();
433        let thinking_block = messages[0]
434            .blocks
435            .iter()
436            .find(|b| b.block_type == crate::messages::BlockType::Thinking)
437            .unwrap();
438        assert!(thinking_block.compressed);
439        assert!(thinking_block.content.is_none());
440
441        // Text block should be unaffected
442        let text_block = messages[0]
443            .blocks
444            .iter()
445            .find(|b| b.block_type == crate::messages::BlockType::Text)
446            .unwrap();
447        assert!(!text_block.compressed);
448        assert_eq!(text_block.content, Some("visible".to_string()));
449    }
450
451    #[test]
452    fn test_cascade_delete() {
453        let db = setup();
454        let session = db.create_session("Test").unwrap();
455
456        let msg = NewMessage {
457            session_id: &session.id,
458            role: "user",
459            content: Some("Hello"),
460            model_id: None,
461            provider_name: None,
462            parent_message_id: None,
463        };
464        let blocks = vec![NewBlock {
465            block_type: "text",
466            content: "Hello",
467            sequence: 0,
468        }];
469        db.create_message(&msg, &blocks).unwrap();
470
471        // Delete session should cascade to messages and blocks
472        db.delete_session(&session.id).unwrap();
473
474        let messages = db.get_session_messages(&session.id).unwrap();
475        assert!(messages.is_empty());
476    }
477}