steer_core/session/stores/
sqlite.rs

1use async_trait::async_trait;
2use chrono::Utc;
3use serde_json;
4use sqlx::{
5    Row,
6    sqlite::{
7        SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteSynchronous,
8    },
9};
10use std::collections::HashSet;
11use std::path::Path;
12use std::str::FromStr;
13use uuid::Uuid;
14
15use crate::app::Message;
16use crate::app::conversation::{AssistantContent, MessageData, UserContent};
17use crate::events::StreamEvent;
18// use crate::session::state::ToolVisibility;
19use crate::session::{
20    Session, SessionConfig, SessionFilter, SessionInfo, SessionOrderBy, SessionState,
21    SessionStatus, SessionStore, SessionStoreError, ToolApprovalPolicy, ToolCallState,
22    ToolCallStatus, ToolCallUpdate, ToolExecutionStats,
23};
24use steer_tools::ToolCall;
25use steer_tools::result::ToolResult;
26
27/// SQLite implementation of SessionStore
28pub struct SqliteSessionStore {
29    pool: SqlitePool,
30}
31
32impl SqliteSessionStore {
33    /// Create a new SQLite session store
34    pub async fn new(path: &Path) -> Result<Self, SessionStoreError> {
35        // Create parent directory if it doesn't exist
36        if let Some(parent) = path.parent() {
37            std::fs::create_dir_all(parent).map_err(|e| {
38                SessionStoreError::connection(format!("Failed to create directory: {e}"))
39            })?;
40        }
41
42        let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path.display()))
43            .map_err(|e| SessionStoreError::connection(format!("Invalid SQLite path: {e}")))?
44            .create_if_missing(true)
45            .journal_mode(SqliteJournalMode::Wal)
46            .synchronous(SqliteSynchronous::Normal)
47            .foreign_keys(true);
48
49        let pool = SqlitePoolOptions::new()
50            .max_connections(1) // Single connection for local use
51            .connect_with(options)
52            .await
53            .map_err(|e| {
54                SessionStoreError::connection(format!("Failed to connect to SQLite: {e}"))
55            })?;
56
57        // Run migrations
58        sqlx::migrate!("migrations/sqlite")
59            .run(&pool)
60            .await
61            .map_err(|e| SessionStoreError::Migration {
62                message: format!("Failed to run migrations: {e}"),
63            })?;
64
65        Ok(Self { pool })
66    }
67
68    /// Parse tool approval policy from database format
69    fn parse_tool_policy(
70        policy_type: &str,
71        pre_approved_json: &str,
72    ) -> Result<ToolApprovalPolicy, SessionStoreError> {
73        let pre_approved: Vec<String> = serde_json::from_str(pre_approved_json).map_err(|e| {
74            SessionStoreError::serialization(format!("Invalid pre_approved_tools: {e}"))
75        })?;
76
77        match policy_type {
78            "always_ask" => Ok(ToolApprovalPolicy::AlwaysAsk),
79            "pre_approved" => Ok(ToolApprovalPolicy::PreApproved {
80                tools: pre_approved.into_iter().collect(),
81            }),
82            "mixed" => Ok(ToolApprovalPolicy::Mixed {
83                pre_approved: pre_approved.into_iter().collect(),
84                ask_for_others: true,
85            }),
86            _ => Err(SessionStoreError::validation(format!(
87                "Invalid tool policy type: {policy_type}"
88            ))),
89        }
90    }
91
92    /// Convert tool approval policy to database format
93    fn serialize_tool_policy(policy: &ToolApprovalPolicy) -> (String, String) {
94        match policy {
95            ToolApprovalPolicy::AlwaysAsk => ("always_ask".to_string(), "[]".to_string()),
96            ToolApprovalPolicy::PreApproved { tools } => {
97                let tools_vec: Vec<String> = tools.iter().cloned().collect();
98                (
99                    "pre_approved".to_string(),
100                    serde_json::to_string(&tools_vec).unwrap(),
101                )
102            }
103            ToolApprovalPolicy::Mixed { pre_approved, .. } => {
104                let tools_vec: Vec<String> = pre_approved.iter().cloned().collect();
105                (
106                    "mixed".to_string(),
107                    serde_json::to_string(&tools_vec).unwrap(),
108                )
109            }
110        }
111    }
112}
113
114#[async_trait]
115impl SessionStore for SqliteSessionStore {
116    async fn create_session(&self, config: SessionConfig) -> Result<Session, SessionStoreError> {
117        let id = Uuid::new_v4().to_string();
118        let now = Utc::now();
119        let (policy_type, pre_approved_json) =
120            Self::serialize_tool_policy(&config.tool_config.approval_policy);
121        let metadata_json = serde_json::to_string(&config.metadata).map_err(|e| {
122            SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
123        })?;
124        let tool_config_json = serde_json::to_string(&config.tool_config).map_err(|e| {
125            SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
126        })?;
127        let workspace_config_json = serde_json::to_string(&config.workspace).map_err(|e| {
128            SessionStoreError::serialization(format!("Failed to serialize workspace_config: {e}"))
129        })?;
130
131        sqlx::query(
132            r#"
133            INSERT INTO sessions (id, created_at, updated_at, status, metadata,
134                                  tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt)
135            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
136            "#,
137        )
138        .bind(&id)
139        .bind(now)
140        .bind(now)
141        .bind("inactive") // New sessions start as inactive
142        .bind(&metadata_json)
143        .bind(&policy_type)
144        .bind(&pre_approved_json)
145        .bind(&tool_config_json)
146        .bind(&workspace_config_json)
147        .bind(&config.system_prompt)
148        .execute(&self.pool)
149        .await
150        .map_err(|e| SessionStoreError::database(format!("Failed to create session: {e}")))?;
151
152        Ok(Session {
153            id: id.clone(),
154            created_at: now,
155            updated_at: now,
156            config,
157            state: SessionState::default(),
158        })
159    }
160
161    async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionStoreError> {
162        let row = sqlx::query(
163            r#"
164            SELECT id, created_at, updated_at, metadata,
165                   tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt,
166                   active_message_id
167            FROM sessions
168            WHERE id = ?1
169            "#,
170        )
171        .bind(session_id)
172        .fetch_optional(&self.pool)
173        .await
174        .map_err(|e| SessionStoreError::database(format!("Failed to get session: {e}")))?;
175
176        let Some(row) = row else {
177            return Ok(None);
178        };
179
180        let approval_policy = Self::parse_tool_policy(
181            &row.get::<String, _>("tool_policy_type"),
182            &row.get::<String, _>("pre_approved_tools"),
183        )?;
184
185        let metadata: std::collections::HashMap<String, String> =
186            serde_json::from_str(&row.get::<String, _>("metadata"))
187                .map_err(|e| SessionStoreError::serialization(format!("Invalid metadata: {e}")))?;
188
189        let tool_config = serde_json::from_str(&row.get::<String, _>("tool_config"))
190            .map_err(|e| SessionStoreError::serialization(format!("Invalid tool_config: {e}")))?;
191        let workspace_config = serde_json::from_str(&row.get::<String, _>("workspace_config"))
192            .map_err(|e| {
193                SessionStoreError::serialization(format!("Invalid workspace_config: {e}"))
194            })?;
195
196        let mut tool_config: crate::session::SessionToolConfig = tool_config;
197        tool_config.approval_policy = approval_policy;
198
199        let system_prompt: Option<String> = row.get("system_prompt");
200
201        let config = SessionConfig {
202            workspace: workspace_config,
203            tool_config,
204            system_prompt,
205            metadata,
206        };
207
208        // Load messages
209        let messages = self.get_messages(session_id, None).await?;
210
211        // Load tool calls
212        let tool_calls_rows = sqlx::query(
213            r#"
214            SELECT id, tool_name, parameters, status, result, error, started_at, completed_at, kind, payload_json, error_json
215            FROM tool_calls
216            WHERE session_id = ?1
217            "#,
218        )
219        .bind(session_id)
220        .fetch_all(&self.pool)
221        .await
222        .map_err(|e| SessionStoreError::database(format!("Failed to load tool calls: {e}")))?;
223
224        let mut tool_calls = std::collections::HashMap::new();
225        for row in tool_calls_rows {
226            let id: String = row.get("id");
227            let status_str: String = row.get("status");
228            let error: Option<String> = row.get("error");
229
230            let status = match status_str.as_str() {
231                "pending" => ToolCallStatus::PendingApproval,
232                "approved" => ToolCallStatus::Approved,
233                "denied" => ToolCallStatus::Denied,
234                "executing" => ToolCallStatus::Executing,
235                "completed" => ToolCallStatus::Completed,
236                "failed" => ToolCallStatus::Failed {
237                    error: error.unwrap_or_else(|| "Unknown error".to_string()),
238                },
239                _ => {
240                    return Err(SessionStoreError::validation(format!(
241                        "Invalid tool call status: {status_str}"
242                    )));
243                }
244            };
245
246            let tool_call = ToolCall {
247                id: id.clone(),
248                name: row.get("tool_name"),
249                parameters: serde_json::from_str(&row.get::<String, _>("parameters")).map_err(
250                    |e| SessionStoreError::serialization(format!("Invalid tool parameters: {e}")),
251                )?,
252            };
253
254            let result: Option<String> = row.get("result");
255            let json_result: Option<String> = row.get("payload_json");
256            let result_type: Option<String> = row.get("kind");
257            let error_json: Option<String> = row.get("error_json");
258
259            let _tool_result = if let Some(kind) = result_type.as_ref() {
260                if kind == "error" {
261                    // Error result - use error_json
262                    let error_data = error_json.and_then(|json_str| {
263                        serde_json::from_str::<serde_json::Value>(&json_str).ok()
264                    });
265                    Some(ToolExecutionStats {
266                        output: result.clone(),
267                        json_output: error_data,
268                        result_type: Some("error".to_string()),
269                        success: false,
270                        execution_time_ms: 0,
271                        metadata: std::collections::HashMap::new(),
272                    })
273                } else if let Some(json_str) = json_result {
274                    // Success result with payload
275                    let json_value = serde_json::from_str(&json_str).map_err(|e| {
276                        SessionStoreError::serialization(format!("Invalid JSON result: {e}"))
277                    })?;
278                    Some(ToolExecutionStats {
279                        output: result.clone(),
280                        json_output: Some(json_value),
281                        result_type,
282                        success: true,
283                        execution_time_ms: 0,
284                        metadata: std::collections::HashMap::new(),
285                    })
286                } else {
287                    // Success without payload
288                    Some(ToolExecutionStats {
289                        output: result.clone(),
290                        json_output: None,
291                        result_type,
292                        success: true,
293                        execution_time_ms: 0,
294                        metadata: std::collections::HashMap::new(),
295                    })
296                }
297            } else {
298                result.map(|r| ToolExecutionStats {
299                    output: Some(r),
300                    json_output: None,
301                    result_type: None,
302                    success: true,
303                    execution_time_ms: 0,
304                    metadata: std::collections::HashMap::new(),
305                })
306            };
307
308            // Skip loading tool results for now - just set to None
309            let tool_result: Option<ToolResult> = None;
310
311            let state = ToolCallState {
312                tool_call,
313                status,
314                started_at: row.get("started_at"),
315                completed_at: row.get("completed_at"),
316                result: tool_result,
317            };
318
319            tool_calls.insert(id, state);
320        }
321
322        // Get the latest event sequence number
323        let last_sequence: Option<i64> =
324            sqlx::query_scalar("SELECT MAX(sequence_num) FROM events WHERE session_id = ?1")
325                .bind(session_id)
326                .fetch_one(&self.pool)
327                .await
328                .map_err(|e| {
329                    SessionStoreError::database(format!("Failed to get last event sequence: {e}"))
330                })?;
331
332        // Extract approved bash patterns from session config
333        let approved_bash_patterns: HashSet<String> =
334            if let Some(bash_config) = config.tool_config.tools.get("bash") {
335                let crate::session::state::ToolSpecificConfig::Bash(bash) = bash_config;
336                bash.approved_patterns.iter().cloned().collect()
337            } else {
338                HashSet::new()
339            };
340
341        // Get active_message_id from the row
342        let active_message_id: Option<String> = row.get("active_message_id");
343
344        let state = SessionState {
345            messages,
346            tool_calls,
347            approved_tools: Default::default(), // TODO: Track approved tools separately if needed
348            approved_bash_patterns,
349            last_event_sequence: last_sequence.unwrap_or(0) as u64,
350            metadata: Default::default(),
351            active_message_id,
352            mcp_servers: Default::default(), // Transient field, rebuilt on activation
353        };
354
355        Ok(Some(Session {
356            id: row.get("id"),
357            created_at: row.get("created_at"),
358            updated_at: row.get("updated_at"),
359            config,
360            state,
361        }))
362    }
363
364    async fn update_session(&self, session: &Session) -> Result<(), SessionStoreError> {
365        let metadata_json = serde_json::to_string(&session.config.metadata).map_err(|e| {
366            SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
367        })?;
368        let tool_config_json = serde_json::to_string(&session.config.tool_config).map_err(|e| {
369            SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
370        })?;
371        let workspace_config_json =
372            serde_json::to_string(&session.config.workspace).map_err(|e| {
373                SessionStoreError::serialization(format!(
374                    "Failed to serialize workspace_config: {e}"
375                ))
376            })?;
377        let (policy_type, pre_approved_json) =
378            Self::serialize_tool_policy(&session.config.tool_config.approval_policy);
379
380        sqlx::query(
381            r#"
382            UPDATE sessions
383            SET updated_at = ?2, metadata = ?3,
384                tool_policy_type = ?4, pre_approved_tools = ?5, tool_config = ?6, workspace_config = ?7
385            WHERE id = ?1
386            "#,
387        )
388        .bind(&session.id)
389        .bind(Utc::now())
390        .bind(&metadata_json)
391        .bind(&policy_type)
392        .bind(&pre_approved_json)
393        .bind(&tool_config_json)
394        .bind(&workspace_config_json)
395        .execute(&self.pool)
396        .await
397        .map_err(|e| SessionStoreError::database(format!("Failed to update session: {e}")))?;
398
399        Ok(())
400    }
401
402    async fn delete_session(&self, session_id: &str) -> Result<(), SessionStoreError> {
403        sqlx::query("DELETE FROM sessions WHERE id = ?1")
404            .bind(session_id)
405            .execute(&self.pool)
406            .await
407            .map_err(|e| SessionStoreError::database(format!("Failed to delete session: {e}")))?;
408
409        Ok(())
410    }
411
412    async fn list_sessions(
413        &self,
414        filter: SessionFilter,
415    ) -> Result<Vec<SessionInfo>, SessionStoreError> {
416        let mut query = String::from(
417            r#"
418            SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
419                   (SELECT e.event_data
420                    FROM events e
421                    WHERE e.session_id = s.id
422                      AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
423                    ORDER BY e.sequence_num DESC
424                    LIMIT 1) as last_event_data
425            FROM sessions s
426            WHERE 1=1
427            "#,
428        );
429        let mut bindings: Vec<String> = Vec::new();
430
431        // Apply filters
432        if let Some(created_after) = filter.created_after {
433            query.push_str(&format!(" AND s.created_at >= ?{}", bindings.len() + 1));
434            bindings.push(created_after.to_rfc3339());
435        }
436        if let Some(created_before) = filter.created_before {
437            query.push_str(&format!(" AND s.created_at <= ?{}", bindings.len() + 1));
438            bindings.push(created_before.to_rfc3339());
439        }
440        if let Some(updated_after) = filter.updated_after {
441            query.push_str(&format!(" AND s.updated_at >= ?{}", bindings.len() + 1));
442            bindings.push(updated_after.to_rfc3339());
443        }
444        if let Some(updated_before) = filter.updated_before {
445            query.push_str(&format!(" AND s.updated_at <= ?{}", bindings.len() + 1));
446            bindings.push(updated_before.to_rfc3339());
447        }
448        if let Some(status) = filter.status_filter {
449            let status_str = match status {
450                SessionStatus::Active => "active",
451                SessionStatus::Inactive => "inactive",
452            };
453            query.push_str(&format!(" AND s.status = ?{}", bindings.len() + 1));
454            bindings.push(status_str.to_string());
455        }
456
457        // Add ordering
458        let order_column = match filter.order_by {
459            SessionOrderBy::CreatedAt => "s.created_at",
460            SessionOrderBy::UpdatedAt => "s.updated_at",
461            SessionOrderBy::MessageCount => {
462                // For message count, we'll need a subquery
463                query = r#"
464                    SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
465                           (SELECT e.event_data
466                            FROM events e
467                            WHERE e.session_id = s.id
468                              AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
469                            ORDER BY e.sequence_num DESC
470                            LIMIT 1) as last_event_data,
471                           (SELECT COUNT(*) FROM messages WHERE session_id = s.id) as message_count
472                    FROM sessions s
473                    WHERE 1=1
474                    "#.to_string();
475                "message_count"
476            }
477        };
478
479        let order_direction = match filter.order_direction {
480            crate::session::OrderDirection::Ascending => "ASC",
481            crate::session::OrderDirection::Descending => "DESC",
482        };
483
484        query.push_str(&format!(" ORDER BY {order_column} {order_direction}"));
485
486        // Add pagination
487        if let Some(limit) = filter.limit {
488            query.push_str(&format!(" LIMIT {limit}"));
489        }
490        if let Some(offset) = filter.offset {
491            query.push_str(&format!(" OFFSET {offset}"));
492        }
493
494        // Execute query with dynamic bindings
495        let mut q = sqlx::query(&query);
496        for binding in bindings {
497            q = q.bind(binding);
498        }
499
500        let rows = q
501            .fetch_all(&self.pool)
502            .await
503            .map_err(|e| SessionStoreError::database(format!("Failed to list sessions: {e}")))?;
504
505        let mut sessions = Vec::new();
506        for row in rows {
507            let metadata: std::collections::HashMap<String, String> =
508                serde_json::from_str(&row.get::<String, _>("metadata")).map_err(|e| {
509                    SessionStoreError::serialization(format!("Invalid metadata: {e}"))
510                })?;
511
512            // Count messages for this session (if not already done in query)
513            let message_count: i64 = if matches!(filter.order_by, SessionOrderBy::MessageCount) {
514                row.get("message_count")
515            } else {
516                sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE session_id = ?1")
517                    .bind(row.get::<String, _>("id"))
518                    .fetch_one(&self.pool)
519                    .await
520                    .map_err(|e| {
521                        SessionStoreError::database(format!("Failed to count messages: {e}"))
522                    })?
523            };
524
525            // Extract last model from event data
526            let last_model =
527                if let Some(event_json) = row.get::<Option<String>, _>("last_event_data") {
528                    let event: StreamEvent = serde_json::from_str(&event_json).map_err(|e| {
529                        SessionStoreError::serialization(format!("Invalid event data: {e}"))
530                    })?;
531
532                    match event {
533                        StreamEvent::MessageComplete { model, .. } => Some(model),
534                        StreamEvent::ToolCallStarted { model, .. } => Some(model),
535                        StreamEvent::ToolCallCompleted { model, .. } => Some(model),
536                        StreamEvent::ToolCallFailed { model, .. } => Some(model),
537                        _ => None,
538                    }
539                } else {
540                    None
541                };
542
543            sessions.push(SessionInfo {
544                id: row.get("id"),
545                created_at: row.get("created_at"),
546                updated_at: row.get("updated_at"),
547                last_model,
548                message_count: message_count as usize,
549                metadata,
550            });
551        }
552
553        Ok(sessions)
554    }
555
556    async fn append_message(
557        &self,
558        session_id: &str,
559        message: &Message,
560    ) -> Result<(), SessionStoreError> {
561        let id = message.id();
562
563        // Get the next sequence number
564        let next_seq: i64 = sqlx::query_scalar(
565            "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM messages WHERE session_id = ?1",
566        )
567        .bind(session_id)
568        .fetch_one(&self.pool)
569        .await
570        .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
571
572        // Serialize the message based on its variant
573        let (role, content_json) = match &message.data {
574            MessageData::User { content, .. } => {
575                let json = serde_json::to_string(&content).map_err(|e| {
576                    SessionStoreError::serialization(format!(
577                        "Failed to serialize user content: {e}"
578                    ))
579                })?;
580                ("user", json)
581            }
582            MessageData::Assistant { content, .. } => {
583                let json = serde_json::to_string(&content).map_err(|e| {
584                    SessionStoreError::serialization(format!(
585                        "Failed to serialize assistant content: {e}"
586                    ))
587                })?;
588                ("assistant", json)
589            }
590            MessageData::Tool {
591                tool_use_id,
592                result,
593                ..
594            } => {
595                // Convert tool result to a format that can be stored
596                #[derive(serde::Serialize)]
597                struct StoredToolMessage {
598                    tool_use_id: String,
599                    result: crate::app::conversation::ToolResult,
600                }
601                let stored = StoredToolMessage {
602                    tool_use_id: tool_use_id.clone(),
603                    result: result.clone(),
604                };
605                let json = serde_json::to_string(&stored).map_err(|e| {
606                    SessionStoreError::serialization(format!(
607                        "Failed to serialize tool message: {e}"
608                    ))
609                })?;
610                ("tool", json)
611            }
612        };
613
614        // Extract parent_message_id from the message
615        let parent_message_id = message.parent_message_id();
616
617        sqlx::query(
618            r#"
619            INSERT INTO messages (id, session_id, sequence_num, role, content, created_at, parent_message_id)
620            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
621            "#,
622        )
623        .bind(id)
624        .bind(session_id)
625        .bind(next_seq)
626        .bind(role)
627        .bind(&content_json)
628        .bind(Utc::now())
629        .bind(parent_message_id)
630        .execute(&self.pool)
631        .await
632        .map_err(|e| SessionStoreError::database(format!("Failed to append message: {e}")))?;
633
634        Ok(())
635    }
636
637    async fn get_messages(
638        &self,
639        session_id: &str,
640        after_sequence: Option<u32>,
641    ) -> Result<Vec<Message>, SessionStoreError> {
642        let query = if let Some(seq) = after_sequence {
643            sqlx::query(
644                r#"
645                SELECT id, sequence_num, role, content, created_at, parent_message_id
646                FROM messages
647                WHERE session_id = ?1 AND sequence_num > ?2
648                ORDER BY sequence_num ASC
649                "#,
650            )
651            .bind(session_id)
652            .bind(seq as i64)
653        } else {
654            sqlx::query(
655                r#"
656                SELECT id, sequence_num, role, content, created_at, parent_message_id
657                FROM messages
658                WHERE session_id = ?1
659                ORDER BY sequence_num ASC
660                "#,
661            )
662            .bind(session_id)
663        };
664
665        let rows = query
666            .fetch_all(&self.pool)
667            .await
668            .map_err(|e| SessionStoreError::database(format!("Failed to get messages: {e}")))?;
669
670        let mut messages = Vec::new();
671        for row in rows {
672            let role = row.get::<String, _>("role");
673            let content = row.get::<String, _>("content");
674            let id: String = row.get("id");
675            let created_at = row.get::<chrono::DateTime<chrono::Utc>, _>("created_at");
676
677            let parent_message_id: Option<String> = row.get("parent_message_id");
678
679            // Deserialize based on role
680            let message = match role.as_str() {
681                "user" => {
682                    let content: Vec<UserContent> =
683                        serde_json::from_str(&content).map_err(|e| {
684                            SessionStoreError::serialization(format!(
685                                "Failed to deserialize user content: {e}"
686                            ))
687                        })?;
688                    Message {
689                        data: MessageData::User { content },
690                        timestamp: created_at.timestamp() as u64,
691                        id,
692                        parent_message_id,
693                    }
694                }
695                "assistant" => {
696                    let content: Vec<AssistantContent> =
697                        serde_json::from_str(&content).map_err(|e| {
698                            SessionStoreError::serialization(format!(
699                                "Failed to deserialize assistant content: {e}"
700                            ))
701                        })?;
702                    Message {
703                        data: MessageData::Assistant { content },
704                        timestamp: created_at.timestamp() as u64,
705                        id,
706                        parent_message_id,
707                    }
708                }
709                "tool" => {
710                    #[derive(serde::Deserialize)]
711                    struct StoredToolMessage {
712                        tool_use_id: String,
713                        result: crate::app::conversation::ToolResult,
714                    }
715                    let stored: StoredToolMessage =
716                        serde_json::from_str(&content).map_err(|e| {
717                            SessionStoreError::serialization(format!(
718                                "Failed to deserialize tool message: {e}"
719                            ))
720                        })?;
721                    Message {
722                        data: MessageData::Tool {
723                            tool_use_id: stored.tool_use_id,
724                            result: stored.result,
725                        },
726                        timestamp: created_at.timestamp() as u64,
727                        id: id.clone(),
728                        parent_message_id: parent_message_id.clone(),
729                    }
730                }
731                _ => {
732                    return Err(SessionStoreError::serialization(format!(
733                        "Unknown message role: {role}"
734                    )));
735                }
736            };
737
738            messages.push(message);
739        }
740
741        Ok(messages)
742    }
743
744    async fn create_tool_call(
745        &self,
746        session_id: &str,
747        tool_call: &ToolCall,
748    ) -> Result<(), SessionStoreError> {
749        let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
750            SessionStoreError::serialization(format!("Failed to serialize parameters: {e}"))
751        })?;
752
753        sqlx::query(
754            r#"
755            INSERT INTO tool_calls (id, session_id, tool_name, parameters, status, kind)
756            VALUES (?1, ?2, ?3, ?4, ?5, ?6)
757            "#,
758        )
759        .bind(&tool_call.id)
760        .bind(session_id)
761        .bind(&tool_call.name)
762        .bind(&parameters_json)
763        .bind("pending")
764        .bind("external") // Default to external, will be updated when result comes in
765        .execute(&self.pool)
766        .await
767        .map_err(|e| SessionStoreError::database(format!("Failed to create tool call: {e}")))?;
768
769        Ok(())
770    }
771
772    async fn update_tool_call(
773        &self,
774        tool_call_id: &str,
775        update: ToolCallUpdate,
776    ) -> Result<(), SessionStoreError> {
777        let mut query = String::from("UPDATE tool_calls SET ");
778        let mut updates = Vec::new();
779        let mut bindings: Vec<String> = Vec::new();
780
781        if let Some(status) = update.status {
782            let status_str = match &status {
783                ToolCallStatus::PendingApproval => "pending",
784                ToolCallStatus::Approved => "approved",
785                ToolCallStatus::Denied => "denied",
786                ToolCallStatus::Executing => "executing",
787                ToolCallStatus::Completed => "completed",
788                ToolCallStatus::Failed { .. } => "failed",
789            };
790            updates.push(format!("status = ?{}", bindings.len() + 1));
791            bindings.push(status_str.to_string());
792
793            // Update timestamps based on status
794            match status {
795                ToolCallStatus::Executing => {
796                    updates.push(format!("started_at = ?{}", bindings.len() + 1));
797                    bindings.push(Utc::now().to_rfc3339());
798                }
799                ToolCallStatus::Completed | ToolCallStatus::Failed { .. } => {
800                    updates.push(format!("completed_at = ?{}", bindings.len() + 1));
801                    bindings.push(Utc::now().to_rfc3339());
802                }
803                _ => {}
804            }
805        }
806
807        if let Some(result) = update.result {
808            // Determine kind from result_type
809            let kind = if let Some(rt) = &result.result_type {
810                rt.clone()
811            } else {
812                "external".to_string()
813            };
814
815            updates.push(format!("kind = ?{}", bindings.len() + 1));
816            bindings.push(kind);
817
818            // Always store the legacy output if available
819            if let Some(output) = &result.output {
820                updates.push(format!("result = ?{}", bindings.len() + 1));
821                bindings.push(output.clone());
822            }
823
824            // Store JSON result if available
825            if let Some(json_output) = &result.json_output {
826                updates.push(format!("payload_json = ?{}", bindings.len() + 1));
827                let json_str = serde_json::to_string(json_output).map_err(|e| {
828                    SessionStoreError::serialization(format!(
829                        "Failed to serialize JSON result: {e}"
830                    ))
831                })?;
832                bindings.push(json_str);
833            } else if let Some(output) = &result.output {
834                // Fallback to wrapping string output as External
835                updates.push(format!("payload_json = ?{}", bindings.len() + 1));
836                let external_json = serde_json::json!({
837                    "tool_name": "unknown",
838                    "payload": output
839                });
840                bindings.push(external_json.to_string());
841            }
842        }
843
844        if let Some(error) = update.error {
845            // Mark as error kind
846            updates.push(format!("kind = ?{}", bindings.len() + 1));
847            bindings.push("error".to_string());
848
849            // Store error in error_json
850            updates.push(format!("error_json = ?{}", bindings.len() + 1));
851            let error_json = serde_json::json!({
852                "tool_name": "unknown",
853                "message": &error
854            });
855            bindings.push(error_json.to_string());
856
857            // Also store in legacy error column
858            updates.push(format!("error = ?{}", bindings.len() + 1));
859            bindings.push(error);
860        }
861
862        if updates.is_empty() {
863            return Ok(());
864        }
865
866        query.push_str(&updates.join(", "));
867        query.push_str(&format!(" WHERE id = ?{}", bindings.len() + 1));
868        bindings.push(tool_call_id.to_string());
869
870        // Execute with dynamic bindings
871        let mut q = sqlx::query(&query);
872        for binding in bindings {
873            q = q.bind(binding);
874        }
875
876        q.execute(&self.pool)
877            .await
878            .map_err(|e| SessionStoreError::database(format!("Failed to update tool call: {e}")))?;
879
880        Ok(())
881    }
882
883    async fn get_pending_tool_calls(
884        &self,
885        session_id: &str,
886    ) -> Result<Vec<ToolCall>, SessionStoreError> {
887        let rows = sqlx::query(
888            r#"
889            SELECT id, tool_name, parameters
890            FROM tool_calls
891            WHERE session_id = ?1 AND status = 'pending'
892            ORDER BY id ASC
893            "#,
894        )
895        .bind(session_id)
896        .fetch_all(&self.pool)
897        .await
898        .map_err(|e| {
899            SessionStoreError::database(format!("Failed to get pending tool calls: {e}"))
900        })?;
901
902        let mut tool_calls = Vec::new();
903        for row in rows {
904            let parameters: serde_json::Value =
905                serde_json::from_str(&row.get::<String, _>("parameters")).map_err(|e| {
906                    SessionStoreError::serialization(format!("Invalid parameters: {e}"))
907                })?;
908
909            tool_calls.push(ToolCall {
910                id: row.get("id"),
911                name: row.get("tool_name"),
912                parameters,
913            });
914        }
915
916        Ok(tool_calls)
917    }
918
919    async fn append_event(
920        &self,
921        session_id: &str,
922        event: &StreamEvent,
923    ) -> Result<u64, SessionStoreError> {
924        let event_type = match event {
925            StreamEvent::MessagePart { .. } => "message_part",
926            StreamEvent::MessageComplete { .. } => "message_complete",
927            StreamEvent::ToolCallStarted { .. } => "tool_call_started",
928            StreamEvent::ToolCallCompleted { .. } => "tool_call_completed",
929            StreamEvent::ToolCallFailed { .. } => "tool_call_failed",
930            StreamEvent::ToolApprovalRequired { .. } => "tool_approval_required",
931            StreamEvent::SessionCreated { .. } => "session_created",
932            StreamEvent::SessionResumed { .. } => "session_resumed",
933            StreamEvent::SessionSaved { .. } => "session_saved",
934            StreamEvent::OperationStarted { .. } => "operation_started",
935            StreamEvent::OperationCompleted { .. } => "operation_completed",
936            StreamEvent::OperationCancelled { .. } => "operation_cancelled",
937            StreamEvent::Error { .. } => "error",
938            StreamEvent::WorkspaceChanged => "workspace_changed",
939            StreamEvent::WorkspaceFiles { .. } => "workspace_files",
940        };
941
942        let event_data = serde_json::to_string(event).map_err(|e| {
943            SessionStoreError::serialization(format!("Failed to serialize event: {e}"))
944        })?;
945
946        // Get the next sequence number
947        let next_seq: i64 = sqlx::query_scalar(
948            "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM events WHERE session_id = ?1",
949        )
950        .bind(session_id)
951        .fetch_one(&self.pool)
952        .await
953        .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
954
955        sqlx::query(
956            r#"
957            INSERT INTO events (session_id, sequence_num, event_type, event_data, created_at)
958            VALUES (?1, ?2, ?3, ?4, ?5)
959            "#,
960        )
961        .bind(session_id)
962        .bind(next_seq)
963        .bind(event_type)
964        .bind(&event_data)
965        .bind(Utc::now())
966        .execute(&self.pool)
967        .await
968        .map_err(|e| SessionStoreError::database(format!("Failed to append event: {e}")))?;
969
970        Ok(next_seq as u64)
971    }
972
973    async fn get_events(
974        &self,
975        session_id: &str,
976        after_sequence: u64,
977        limit: Option<u32>,
978    ) -> Result<Vec<(u64, StreamEvent)>, SessionStoreError> {
979        let query = if let Some(limit) = limit {
980            sqlx::query(
981                r#"
982                SELECT sequence_num, event_data
983                FROM events
984                WHERE session_id = ?1 AND sequence_num > ?2
985                ORDER BY sequence_num ASC
986                LIMIT ?3
987                "#,
988            )
989            .bind(session_id)
990            .bind(after_sequence as i64)
991            .bind(limit as i64)
992        } else {
993            sqlx::query(
994                r#"
995                SELECT sequence_num, event_data
996                FROM events
997                WHERE session_id = ?1 AND sequence_num > ?2
998                ORDER BY sequence_num ASC
999                "#,
1000            )
1001            .bind(session_id)
1002            .bind(after_sequence as i64)
1003        };
1004
1005        let rows = query
1006            .fetch_all(&self.pool)
1007            .await
1008            .map_err(|e| SessionStoreError::database(format!("Failed to get events: {e}")))?;
1009
1010        let mut events = Vec::new();
1011        for row in rows {
1012            let seq: i64 = row.get("sequence_num");
1013            let event: StreamEvent = serde_json::from_str(&row.get::<String, _>("event_data"))
1014                .map_err(|e| {
1015                    SessionStoreError::serialization(format!("Invalid event data: {e}"))
1016                })?;
1017
1018            events.push((seq as u64, event));
1019        }
1020
1021        Ok(events)
1022    }
1023
1024    async fn delete_events_before(
1025        &self,
1026        session_id: &str,
1027        before_sequence: u64,
1028    ) -> Result<u64, SessionStoreError> {
1029        let result = sqlx::query("DELETE FROM events WHERE session_id = ?1 AND sequence_num < ?2")
1030            .bind(session_id)
1031            .bind(before_sequence as i64)
1032            .execute(&self.pool)
1033            .await
1034            .map_err(|e| SessionStoreError::database(format!("Failed to delete events: {e}")))?;
1035
1036        Ok(result.rows_affected())
1037    }
1038
1039    async fn update_active_message_id(
1040        &self,
1041        session_id: &str,
1042        message_id: Option<&str>,
1043    ) -> Result<(), SessionStoreError> {
1044        sqlx::query("UPDATE sessions SET active_message_id = ?2, updated_at = ?3 WHERE id = ?1")
1045            .bind(session_id)
1046            .bind(message_id)
1047            .bind(Utc::now())
1048            .execute(&self.pool)
1049            .await
1050            .map_err(|e| {
1051                SessionStoreError::database(format!("Failed to update active_message_id: {e}"))
1052            })?;
1053
1054        Ok(())
1055    }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060    use crate::api::Model;
1061    use crate::app::conversation::{AssistantContent, Message, Role, UserContent};
1062    use crate::events::SessionMetadata;
1063    use crate::session::ToolVisibility;
1064    use crate::session::state::WorkspaceConfig;
1065
1066    use super::*;
1067    use tempfile::TempDir;
1068
1069    async fn create_test_store() -> (SqliteSessionStore, TempDir) {
1070        let temp_dir = TempDir::new().unwrap();
1071        let db_path = temp_dir.path().join("test.db");
1072        let store = SqliteSessionStore::new(&db_path).await.unwrap();
1073        (store, temp_dir)
1074    }
1075
1076    fn create_test_session_config() -> SessionConfig {
1077        let tool_config = crate::session::SessionToolConfig {
1078            approval_policy: ToolApprovalPolicy::AlwaysAsk,
1079            visibility: ToolVisibility::All,
1080            ..Default::default()
1081        };
1082
1083        SessionConfig {
1084            workspace: WorkspaceConfig::default(),
1085            tool_config,
1086            system_prompt: None,
1087            metadata: std::collections::HashMap::new(),
1088        }
1089    }
1090
1091    #[tokio::test]
1092    async fn test_create_and_get_session() {
1093        let (store, _temp) = create_test_store().await;
1094
1095        let tool_config = crate::session::SessionToolConfig {
1096            approval_policy: ToolApprovalPolicy::AlwaysAsk,
1097            ..Default::default()
1098        };
1099
1100        let config = SessionConfig {
1101            workspace: WorkspaceConfig::default(),
1102            tool_config,
1103            system_prompt: None,
1104            metadata: Default::default(),
1105        };
1106
1107        let session = store.create_session(config.clone()).await.unwrap();
1108        assert!(!session.id.is_empty());
1109
1110        let fetched_session = store.get_session(&session.id).await.unwrap().unwrap();
1111        assert_eq!(session.id, fetched_session.id);
1112        assert!(matches!(
1113            fetched_session.config.tool_config.approval_policy,
1114            ToolApprovalPolicy::AlwaysAsk
1115        ));
1116        assert!(matches!(
1117            fetched_session.config.workspace,
1118            WorkspaceConfig::Local { .. }
1119        ));
1120    }
1121
1122    #[tokio::test]
1123    async fn test_message_operations() {
1124        let (store, _temp) = create_test_store().await;
1125
1126        let config = create_test_session_config();
1127        let session = store.create_session(config).await.unwrap();
1128
1129        let message = Message {
1130            data: MessageData::User {
1131                content: vec![UserContent::Text {
1132                    text: "Hello".to_string(),
1133                }],
1134            },
1135            timestamp: 123456789,
1136            id: "msg1".to_string(),
1137            parent_message_id: None,
1138        };
1139
1140        store.append_message(&session.id, &message).await.unwrap();
1141
1142        let messages = store.get_messages(&session.id, None).await.unwrap();
1143        assert_eq!(messages.len(), 1);
1144        assert_eq!(messages[0].role(), Role::User);
1145    }
1146
1147    #[tokio::test]
1148    async fn test_tool_call_operations() {
1149        let (store, _temp) = create_test_store().await;
1150
1151        let config = create_test_session_config();
1152        let session = store.create_session(config).await.unwrap();
1153
1154        let tool_call = ToolCall {
1155            id: "tc1".to_string(),
1156            name: "test_tool".to_string(),
1157            parameters: serde_json::json!({"param": "value"}),
1158        };
1159
1160        store
1161            .create_tool_call(&session.id, &tool_call)
1162            .await
1163            .unwrap();
1164
1165        let pending = store.get_pending_tool_calls(&session.id).await.unwrap();
1166        assert_eq!(pending.len(), 1);
1167        assert_eq!(pending[0].name, "test_tool");
1168
1169        let update = ToolCallUpdate::set_status(ToolCallStatus::Completed);
1170        store.update_tool_call(&tool_call.id, update).await.unwrap();
1171
1172        let pending_after = store.get_pending_tool_calls(&session.id).await.unwrap();
1173        assert_eq!(pending_after.len(), 0);
1174    }
1175
1176    #[tokio::test]
1177    async fn test_event_streaming() {
1178        let (store, _temp) = create_test_store().await;
1179
1180        let config = create_test_session_config();
1181        let session = store.create_session(config).await.unwrap();
1182
1183        let event = StreamEvent::SessionCreated {
1184            session_id: session.id.clone(),
1185            metadata: SessionMetadata {
1186                model: Model::Claude3_5Sonnet20241022,
1187                created_at: session.created_at,
1188                metadata: session.config.metadata,
1189            },
1190        };
1191
1192        let seq = store.append_event(&session.id, &event).await.unwrap();
1193        assert_eq!(seq, 0);
1194
1195        // Get events after sequence 0 (should be empty since we only have sequence 0)
1196        let events = store.get_events(&session.id, 0, None).await.unwrap();
1197        assert_eq!(events.len(), 0);
1198
1199        // Get all events including sequence 0 by asking for events after -1
1200        let all_events = store.get_events(&session.id, u64::MAX, None).await.unwrap();
1201        assert_eq!(all_events.len(), 1);
1202        assert_eq!(all_events[0].0, 0);
1203    }
1204
1205    #[tokio::test]
1206    async fn test_session_listing() {
1207        let (store, _temp) = create_test_store().await;
1208
1209        // Create multiple sessions
1210        for i in 0..3 {
1211            let mut config = create_test_session_config();
1212            config.metadata.insert("index".to_string(), i.to_string());
1213            store.create_session(config).await.unwrap();
1214        }
1215
1216        let filter = SessionFilter {
1217            limit: Some(2),
1218            order_by: SessionOrderBy::CreatedAt,
1219            ..Default::default()
1220        };
1221
1222        let sessions = store.list_sessions(filter).await.unwrap();
1223        assert_eq!(sessions.len(), 2);
1224    }
1225
1226    #[tokio::test]
1227    async fn test_last_model_tracking() {
1228        let (store, _temp) = create_test_store().await;
1229
1230        let config = create_test_session_config();
1231        let session = store.create_session(config).await.unwrap();
1232
1233        // Initially, no events means no last model
1234        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1235        assert_eq!(sessions.len(), 1);
1236        assert_eq!(sessions[0].last_model, None);
1237
1238        // Add a MessageComplete event with Claude model
1239        let claude_model = Model::Claude3_5Sonnet20241022;
1240        let message_event = StreamEvent::MessageComplete {
1241            message: Message {
1242                data: MessageData::Assistant {
1243                    content: vec![AssistantContent::Text {
1244                        text: "Hello from Claude".to_string(),
1245                    }],
1246                },
1247                timestamp: 123456789,
1248                id: "msg1".to_string(),
1249                parent_message_id: None,
1250            },
1251            usage: None,
1252            metadata: std::collections::HashMap::new(),
1253            model: claude_model,
1254        };
1255        store
1256            .append_event(&session.id, &message_event)
1257            .await
1258            .unwrap();
1259
1260        // Check that last_model is now Claude
1261        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1262        assert_eq!(sessions[0].last_model, Some(claude_model));
1263
1264        // Add a ToolCallFailed event with GPT model (more recent)
1265        let gpt_model = Model::Gpt4_1_20250414;
1266        let tool_event = StreamEvent::ToolCallFailed {
1267            tool_call_id: "tool1".to_string(),
1268            error: "Test error".to_string(),
1269            metadata: std::collections::HashMap::new(),
1270            model: gpt_model,
1271        };
1272        store.append_event(&session.id, &tool_event).await.unwrap();
1273
1274        // Check that last_model is now GPT (the most recent)
1275        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1276        assert_eq!(sessions[0].last_model, Some(gpt_model));
1277
1278        // Add an event without a model field (shouldn't change last_model)
1279        let session_event = StreamEvent::SessionSaved {
1280            session_id: session.id.clone(),
1281        };
1282        store
1283            .append_event(&session.id, &session_event)
1284            .await
1285            .unwrap();
1286
1287        // Check that last_model is still GPT
1288        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1289        assert_eq!(sessions[0].last_model, Some(gpt_model));
1290    }
1291}