steer_core/session/stores/
sqlite.rs

1use async_trait::async_trait;
2use chrono::Utc;
3use serde_json;
4use sqlx::{
5    Row,
6    sqlite::{
7        SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteSynchronous,
8    },
9};
10use std::collections::HashSet;
11use std::path::Path;
12use std::str::FromStr;
13use uuid::Uuid;
14
15use crate::app::Message;
16use crate::app::conversation::{AssistantContent, MessageData, UserContent};
17use crate::events::StreamEvent;
18// use crate::session::state::ToolVisibility;
19use crate::session::{
20    Session, SessionConfig, SessionFilter, SessionInfo, SessionOrderBy, SessionState,
21    SessionStatus, SessionStore, SessionStoreError, ToolApprovalPolicy, ToolCallState,
22    ToolCallStatus, ToolCallUpdate, ToolExecutionStats,
23};
24use steer_tools::ToolCall;
25use steer_tools::result::ToolResult;
26
27/// SQLite implementation of SessionStore
28pub struct SqliteSessionStore {
29    pool: SqlitePool,
30}
31
32impl SqliteSessionStore {
33    /// Create a new SQLite session store
34    pub async fn new(path: &Path) -> Result<Self, SessionStoreError> {
35        // Create parent directory if it doesn't exist
36        if let Some(parent) = path.parent() {
37            std::fs::create_dir_all(parent).map_err(|e| {
38                SessionStoreError::connection(format!("Failed to create directory: {e}"))
39            })?;
40        }
41
42        let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path.display()))
43            .map_err(|e| SessionStoreError::connection(format!("Invalid SQLite path: {e}")))?
44            .create_if_missing(true)
45            .journal_mode(SqliteJournalMode::Wal)
46            .synchronous(SqliteSynchronous::Normal)
47            .foreign_keys(true);
48
49        let pool = SqlitePoolOptions::new()
50            .max_connections(1) // Single connection for local use
51            .connect_with(options)
52            .await
53            .map_err(|e| {
54                SessionStoreError::connection(format!("Failed to connect to SQLite: {e}"))
55            })?;
56
57        // Run migrations
58        sqlx::migrate!("migrations/sqlite")
59            .run(&pool)
60            .await
61            .map_err(|e| SessionStoreError::Migration {
62                message: format!("Failed to run migrations: {e}"),
63            })?;
64
65        Ok(Self { pool })
66    }
67
68    /// Parse tool approval policy from database format
69    fn parse_tool_policy(
70        policy_type: &str,
71        pre_approved_json: &str,
72    ) -> Result<ToolApprovalPolicy, SessionStoreError> {
73        let pre_approved: Vec<String> = serde_json::from_str(pre_approved_json).map_err(|e| {
74            SessionStoreError::serialization(format!("Invalid pre_approved_tools: {e}"))
75        })?;
76
77        match policy_type {
78            "always_ask" => Ok(ToolApprovalPolicy::AlwaysAsk),
79            "pre_approved" => Ok(ToolApprovalPolicy::PreApproved {
80                tools: pre_approved.into_iter().collect(),
81            }),
82            "mixed" => Ok(ToolApprovalPolicy::Mixed {
83                pre_approved: pre_approved.into_iter().collect(),
84                ask_for_others: true,
85            }),
86            _ => Err(SessionStoreError::validation(format!(
87                "Invalid tool policy type: {policy_type}"
88            ))),
89        }
90    }
91
92    /// Convert tool approval policy to database format
93    fn serialize_tool_policy(policy: &ToolApprovalPolicy) -> (String, String) {
94        match policy {
95            ToolApprovalPolicy::AlwaysAsk => ("always_ask".to_string(), "[]".to_string()),
96            ToolApprovalPolicy::PreApproved { tools } => {
97                let tools_vec: Vec<String> = tools.iter().cloned().collect();
98                (
99                    "pre_approved".to_string(),
100                    serde_json::to_string(&tools_vec).unwrap(),
101                )
102            }
103            ToolApprovalPolicy::Mixed { pre_approved, .. } => {
104                let tools_vec: Vec<String> = pre_approved.iter().cloned().collect();
105                (
106                    "mixed".to_string(),
107                    serde_json::to_string(&tools_vec).unwrap(),
108                )
109            }
110        }
111    }
112}
113
114#[async_trait]
115impl SessionStore for SqliteSessionStore {
116    async fn create_session(&self, config: SessionConfig) -> Result<Session, SessionStoreError> {
117        let id = Uuid::new_v4().to_string();
118        let now = Utc::now();
119        let (policy_type, pre_approved_json) =
120            Self::serialize_tool_policy(&config.tool_config.approval_policy);
121        let metadata_json = serde_json::to_string(&config.metadata).map_err(|e| {
122            SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
123        })?;
124        let tool_config_json = serde_json::to_string(&config.tool_config).map_err(|e| {
125            SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
126        })?;
127        let workspace_config_json = serde_json::to_string(&config.workspace).map_err(|e| {
128            SessionStoreError::serialization(format!("Failed to serialize workspace_config: {e}"))
129        })?;
130
131        sqlx::query(
132            r#"
133            INSERT INTO sessions (id, created_at, updated_at, status, metadata,
134                                  tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt)
135            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
136            "#,
137        )
138        .bind(&id)
139        .bind(now)
140        .bind(now)
141        .bind("inactive") // New sessions start as inactive
142        .bind(&metadata_json)
143        .bind(&policy_type)
144        .bind(&pre_approved_json)
145        .bind(&tool_config_json)
146        .bind(&workspace_config_json)
147        .bind(&config.system_prompt)
148        .execute(&self.pool)
149        .await
150        .map_err(|e| SessionStoreError::database(format!("Failed to create session: {e}")))?;
151
152        Ok(Session {
153            id: id.clone(),
154            created_at: now,
155            updated_at: now,
156            config,
157            state: SessionState::default(),
158        })
159    }
160
161    async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionStoreError> {
162        let row = sqlx::query(
163            r#"
164            SELECT id, created_at, updated_at, metadata,
165                   tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt,
166                   active_message_id
167            FROM sessions
168            WHERE id = ?1
169            "#,
170        )
171        .bind(session_id)
172        .fetch_optional(&self.pool)
173        .await
174        .map_err(|e| SessionStoreError::database(format!("Failed to get session: {e}")))?;
175
176        let Some(row) = row else {
177            return Ok(None);
178        };
179
180        let approval_policy = Self::parse_tool_policy(
181            &row.get::<String, _>("tool_policy_type"),
182            &row.get::<String, _>("pre_approved_tools"),
183        )?;
184
185        let metadata: std::collections::HashMap<String, String> =
186            serde_json::from_str(&row.get::<String, _>("metadata"))
187                .map_err(|e| SessionStoreError::serialization(format!("Invalid metadata: {e}")))?;
188
189        let tool_config = serde_json::from_str(&row.get::<String, _>("tool_config"))
190            .map_err(|e| SessionStoreError::serialization(format!("Invalid tool_config: {e}")))?;
191        let workspace_config = serde_json::from_str(&row.get::<String, _>("workspace_config"))
192            .map_err(|e| {
193                SessionStoreError::serialization(format!("Invalid workspace_config: {e}"))
194            })?;
195
196        let mut tool_config: crate::session::SessionToolConfig = tool_config;
197        tool_config.approval_policy = approval_policy;
198
199        let system_prompt: Option<String> = row.get("system_prompt");
200
201        let config = SessionConfig {
202            workspace: workspace_config,
203            tool_config,
204            system_prompt,
205            metadata,
206        };
207
208        // Load messages
209        let messages = self.get_messages(session_id, None).await?;
210
211        // Load tool calls
212        let tool_calls_rows = sqlx::query(
213            r#"
214            SELECT id, tool_name, parameters, status, result, error, started_at, completed_at, kind, payload_json, error_json
215            FROM tool_calls
216            WHERE session_id = ?1
217            "#,
218        )
219        .bind(session_id)
220        .fetch_all(&self.pool)
221        .await
222        .map_err(|e| SessionStoreError::database(format!("Failed to load tool calls: {e}")))?;
223
224        let mut tool_calls = std::collections::HashMap::new();
225        for row in tool_calls_rows {
226            let id: String = row.get("id");
227            let status_str: String = row.get("status");
228            let error: Option<String> = row.get("error");
229
230            let status = match status_str.as_str() {
231                "pending" => ToolCallStatus::PendingApproval,
232                "approved" => ToolCallStatus::Approved,
233                "denied" => ToolCallStatus::Denied,
234                "executing" => ToolCallStatus::Executing,
235                "completed" => ToolCallStatus::Completed,
236                "failed" => ToolCallStatus::Failed {
237                    error: error.unwrap_or_else(|| "Unknown error".to_string()),
238                },
239                _ => {
240                    return Err(SessionStoreError::validation(format!(
241                        "Invalid tool call status: {status_str}"
242                    )));
243                }
244            };
245
246            let tool_call = ToolCall {
247                id: id.clone(),
248                name: row.get("tool_name"),
249                parameters: serde_json::from_str(&row.get::<String, _>("parameters")).map_err(
250                    |e| SessionStoreError::serialization(format!("Invalid tool parameters: {e}")),
251                )?,
252            };
253
254            let result: Option<String> = row.get("result");
255            let json_result: Option<String> = row.get("payload_json");
256            let result_type: Option<String> = row.get("kind");
257            let error_json: Option<String> = row.get("error_json");
258
259            let _tool_result = if let Some(kind) = result_type.as_ref() {
260                if kind == "error" {
261                    // Error result - use error_json
262                    let error_data = error_json.and_then(|json_str| {
263                        serde_json::from_str::<serde_json::Value>(&json_str).ok()
264                    });
265                    Some(ToolExecutionStats {
266                        output: result.clone(),
267                        json_output: error_data,
268                        result_type: Some("error".to_string()),
269                        success: false,
270                        execution_time_ms: 0,
271                        metadata: std::collections::HashMap::new(),
272                    })
273                } else if let Some(json_str) = json_result {
274                    // Success result with payload
275                    let json_value = serde_json::from_str(&json_str).map_err(|e| {
276                        SessionStoreError::serialization(format!("Invalid JSON result: {e}"))
277                    })?;
278                    Some(ToolExecutionStats {
279                        output: result.clone(),
280                        json_output: Some(json_value),
281                        result_type,
282                        success: true,
283                        execution_time_ms: 0,
284                        metadata: std::collections::HashMap::new(),
285                    })
286                } else {
287                    // Success without payload
288                    Some(ToolExecutionStats {
289                        output: result.clone(),
290                        json_output: None,
291                        result_type,
292                        success: true,
293                        execution_time_ms: 0,
294                        metadata: std::collections::HashMap::new(),
295                    })
296                }
297            } else {
298                result.map(|r| ToolExecutionStats {
299                    output: Some(r),
300                    json_output: None,
301                    result_type: None,
302                    success: true,
303                    execution_time_ms: 0,
304                    metadata: std::collections::HashMap::new(),
305                })
306            };
307
308            // Skip loading tool results for now - just set to None
309            let tool_result: Option<ToolResult> = None;
310
311            let state = ToolCallState {
312                tool_call,
313                status,
314                started_at: row.get("started_at"),
315                completed_at: row.get("completed_at"),
316                result: tool_result,
317            };
318
319            tool_calls.insert(id, state);
320        }
321
322        // Get the latest event sequence number
323        let last_sequence: Option<i64> =
324            sqlx::query_scalar("SELECT MAX(sequence_num) FROM events WHERE session_id = ?1")
325                .bind(session_id)
326                .fetch_one(&self.pool)
327                .await
328                .map_err(|e| {
329                    SessionStoreError::database(format!("Failed to get last event sequence: {e}"))
330                })?;
331
332        // Extract approved bash patterns from session config
333        let approved_bash_patterns: HashSet<String> =
334            if let Some(bash_config) = config.tool_config.tools.get("bash") {
335                let crate::session::state::ToolSpecificConfig::Bash(bash) = bash_config;
336                bash.approved_patterns.iter().cloned().collect()
337            } else {
338                HashSet::new()
339            };
340
341        // Get active_message_id from the row
342        let active_message_id: Option<String> = row.get("active_message_id");
343
344        let state = SessionState {
345            messages,
346            tool_calls,
347            approved_tools: Default::default(), // TODO: Track approved tools separately if needed
348            approved_bash_patterns,
349            last_event_sequence: last_sequence.unwrap_or(0) as u64,
350            metadata: Default::default(),
351            active_message_id,
352        };
353
354        Ok(Some(Session {
355            id: row.get("id"),
356            created_at: row.get("created_at"),
357            updated_at: row.get("updated_at"),
358            config,
359            state,
360        }))
361    }
362
363    async fn update_session(&self, session: &Session) -> Result<(), SessionStoreError> {
364        let metadata_json = serde_json::to_string(&session.config.metadata).map_err(|e| {
365            SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
366        })?;
367        let tool_config_json = serde_json::to_string(&session.config.tool_config).map_err(|e| {
368            SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
369        })?;
370        let workspace_config_json =
371            serde_json::to_string(&session.config.workspace).map_err(|e| {
372                SessionStoreError::serialization(format!(
373                    "Failed to serialize workspace_config: {e}"
374                ))
375            })?;
376        let (policy_type, pre_approved_json) =
377            Self::serialize_tool_policy(&session.config.tool_config.approval_policy);
378
379        sqlx::query(
380            r#"
381            UPDATE sessions
382            SET updated_at = ?2, metadata = ?3,
383                tool_policy_type = ?4, pre_approved_tools = ?5, tool_config = ?6, workspace_config = ?7
384            WHERE id = ?1
385            "#,
386        )
387        .bind(&session.id)
388        .bind(Utc::now())
389        .bind(&metadata_json)
390        .bind(&policy_type)
391        .bind(&pre_approved_json)
392        .bind(&tool_config_json)
393        .bind(&workspace_config_json)
394        .execute(&self.pool)
395        .await
396        .map_err(|e| SessionStoreError::database(format!("Failed to update session: {e}")))?;
397
398        Ok(())
399    }
400
401    async fn delete_session(&self, session_id: &str) -> Result<(), SessionStoreError> {
402        sqlx::query("DELETE FROM sessions WHERE id = ?1")
403            .bind(session_id)
404            .execute(&self.pool)
405            .await
406            .map_err(|e| SessionStoreError::database(format!("Failed to delete session: {e}")))?;
407
408        Ok(())
409    }
410
411    async fn list_sessions(
412        &self,
413        filter: SessionFilter,
414    ) -> Result<Vec<SessionInfo>, SessionStoreError> {
415        let mut query = String::from(
416            r#"
417            SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
418                   (SELECT e.event_data
419                    FROM events e
420                    WHERE e.session_id = s.id
421                      AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
422                    ORDER BY e.sequence_num DESC
423                    LIMIT 1) as last_event_data
424            FROM sessions s
425            WHERE 1=1
426            "#,
427        );
428        let mut bindings: Vec<String> = Vec::new();
429
430        // Apply filters
431        if let Some(created_after) = filter.created_after {
432            query.push_str(&format!(" AND s.created_at >= ?{}", bindings.len() + 1));
433            bindings.push(created_after.to_rfc3339());
434        }
435        if let Some(created_before) = filter.created_before {
436            query.push_str(&format!(" AND s.created_at <= ?{}", bindings.len() + 1));
437            bindings.push(created_before.to_rfc3339());
438        }
439        if let Some(updated_after) = filter.updated_after {
440            query.push_str(&format!(" AND s.updated_at >= ?{}", bindings.len() + 1));
441            bindings.push(updated_after.to_rfc3339());
442        }
443        if let Some(updated_before) = filter.updated_before {
444            query.push_str(&format!(" AND s.updated_at <= ?{}", bindings.len() + 1));
445            bindings.push(updated_before.to_rfc3339());
446        }
447        if let Some(status) = filter.status_filter {
448            let status_str = match status {
449                SessionStatus::Active => "active",
450                SessionStatus::Inactive => "inactive",
451            };
452            query.push_str(&format!(" AND s.status = ?{}", bindings.len() + 1));
453            bindings.push(status_str.to_string());
454        }
455
456        // Add ordering
457        let order_column = match filter.order_by {
458            SessionOrderBy::CreatedAt => "s.created_at",
459            SessionOrderBy::UpdatedAt => "s.updated_at",
460            SessionOrderBy::MessageCount => {
461                // For message count, we'll need a subquery
462                query = r#"
463                    SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
464                           (SELECT e.event_data
465                            FROM events e
466                            WHERE e.session_id = s.id
467                              AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
468                            ORDER BY e.sequence_num DESC
469                            LIMIT 1) as last_event_data,
470                           (SELECT COUNT(*) FROM messages WHERE session_id = s.id) as message_count
471                    FROM sessions s
472                    WHERE 1=1
473                    "#.to_string();
474                "message_count"
475            }
476        };
477
478        let order_direction = match filter.order_direction {
479            crate::session::OrderDirection::Ascending => "ASC",
480            crate::session::OrderDirection::Descending => "DESC",
481        };
482
483        query.push_str(&format!(" ORDER BY {order_column} {order_direction}"));
484
485        // Add pagination
486        if let Some(limit) = filter.limit {
487            query.push_str(&format!(" LIMIT {limit}"));
488        }
489        if let Some(offset) = filter.offset {
490            query.push_str(&format!(" OFFSET {offset}"));
491        }
492
493        // Execute query with dynamic bindings
494        let mut q = sqlx::query(&query);
495        for binding in bindings {
496            q = q.bind(binding);
497        }
498
499        let rows = q
500            .fetch_all(&self.pool)
501            .await
502            .map_err(|e| SessionStoreError::database(format!("Failed to list sessions: {e}")))?;
503
504        let mut sessions = Vec::new();
505        for row in rows {
506            let metadata: std::collections::HashMap<String, String> =
507                serde_json::from_str(&row.get::<String, _>("metadata")).map_err(|e| {
508                    SessionStoreError::serialization(format!("Invalid metadata: {e}"))
509                })?;
510
511            // Count messages for this session (if not already done in query)
512            let message_count: i64 = if matches!(filter.order_by, SessionOrderBy::MessageCount) {
513                row.get("message_count")
514            } else {
515                sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE session_id = ?1")
516                    .bind(row.get::<String, _>("id"))
517                    .fetch_one(&self.pool)
518                    .await
519                    .map_err(|e| {
520                        SessionStoreError::database(format!("Failed to count messages: {e}"))
521                    })?
522            };
523
524            // Extract last model from event data
525            let last_model =
526                if let Some(event_json) = row.get::<Option<String>, _>("last_event_data") {
527                    let event: StreamEvent = serde_json::from_str(&event_json).map_err(|e| {
528                        SessionStoreError::serialization(format!("Invalid event data: {e}"))
529                    })?;
530
531                    match event {
532                        StreamEvent::MessageComplete { model, .. } => Some(model),
533                        StreamEvent::ToolCallStarted { model, .. } => Some(model),
534                        StreamEvent::ToolCallCompleted { model, .. } => Some(model),
535                        StreamEvent::ToolCallFailed { model, .. } => Some(model),
536                        _ => None,
537                    }
538                } else {
539                    None
540                };
541
542            sessions.push(SessionInfo {
543                id: row.get("id"),
544                created_at: row.get("created_at"),
545                updated_at: row.get("updated_at"),
546                last_model,
547                message_count: message_count as usize,
548                metadata,
549            });
550        }
551
552        Ok(sessions)
553    }
554
555    async fn append_message(
556        &self,
557        session_id: &str,
558        message: &Message,
559    ) -> Result<(), SessionStoreError> {
560        let id = message.id();
561
562        // Get the next sequence number
563        let next_seq: i64 = sqlx::query_scalar(
564            "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM messages WHERE session_id = ?1",
565        )
566        .bind(session_id)
567        .fetch_one(&self.pool)
568        .await
569        .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
570
571        // Serialize the message based on its variant
572        let (role, content_json) = match &message.data {
573            MessageData::User { content, .. } => {
574                let json = serde_json::to_string(&content).map_err(|e| {
575                    SessionStoreError::serialization(format!(
576                        "Failed to serialize user content: {e}"
577                    ))
578                })?;
579                ("user", json)
580            }
581            MessageData::Assistant { content, .. } => {
582                let json = serde_json::to_string(&content).map_err(|e| {
583                    SessionStoreError::serialization(format!(
584                        "Failed to serialize assistant content: {e}"
585                    ))
586                })?;
587                ("assistant", json)
588            }
589            MessageData::Tool {
590                tool_use_id,
591                result,
592                ..
593            } => {
594                // Convert tool result to a format that can be stored
595                #[derive(serde::Serialize)]
596                struct StoredToolMessage {
597                    tool_use_id: String,
598                    result: crate::app::conversation::ToolResult,
599                }
600                let stored = StoredToolMessage {
601                    tool_use_id: tool_use_id.clone(),
602                    result: result.clone(),
603                };
604                let json = serde_json::to_string(&stored).map_err(|e| {
605                    SessionStoreError::serialization(format!(
606                        "Failed to serialize tool message: {e}"
607                    ))
608                })?;
609                ("tool", json)
610            }
611        };
612
613        // Extract parent_message_id from the message
614        let parent_message_id = message.parent_message_id();
615
616        sqlx::query(
617            r#"
618            INSERT INTO messages (id, session_id, sequence_num, role, content, created_at, parent_message_id)
619            VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
620            "#,
621        )
622        .bind(id)
623        .bind(session_id)
624        .bind(next_seq)
625        .bind(role)
626        .bind(&content_json)
627        .bind(Utc::now())
628        .bind(parent_message_id)
629        .execute(&self.pool)
630        .await
631        .map_err(|e| SessionStoreError::database(format!("Failed to append message: {e}")))?;
632
633        Ok(())
634    }
635
636    async fn get_messages(
637        &self,
638        session_id: &str,
639        after_sequence: Option<u32>,
640    ) -> Result<Vec<Message>, SessionStoreError> {
641        let query = if let Some(seq) = after_sequence {
642            sqlx::query(
643                r#"
644                SELECT id, sequence_num, role, content, created_at, parent_message_id
645                FROM messages
646                WHERE session_id = ?1 AND sequence_num > ?2
647                ORDER BY sequence_num ASC
648                "#,
649            )
650            .bind(session_id)
651            .bind(seq as i64)
652        } else {
653            sqlx::query(
654                r#"
655                SELECT id, sequence_num, role, content, created_at, parent_message_id
656                FROM messages
657                WHERE session_id = ?1
658                ORDER BY sequence_num ASC
659                "#,
660            )
661            .bind(session_id)
662        };
663
664        let rows = query
665            .fetch_all(&self.pool)
666            .await
667            .map_err(|e| SessionStoreError::database(format!("Failed to get messages: {e}")))?;
668
669        let mut messages = Vec::new();
670        for row in rows {
671            let role = row.get::<String, _>("role");
672            let content = row.get::<String, _>("content");
673            let id: String = row.get("id");
674            let created_at = row.get::<chrono::DateTime<chrono::Utc>, _>("created_at");
675
676            let parent_message_id: Option<String> = row.get("parent_message_id");
677
678            // Deserialize based on role
679            let message = match role.as_str() {
680                "user" => {
681                    let content: Vec<UserContent> =
682                        serde_json::from_str(&content).map_err(|e| {
683                            SessionStoreError::serialization(format!(
684                                "Failed to deserialize user content: {e}"
685                            ))
686                        })?;
687                    Message {
688                        data: MessageData::User { content },
689                        timestamp: created_at.timestamp() as u64,
690                        id,
691                        parent_message_id,
692                    }
693                }
694                "assistant" => {
695                    let content: Vec<AssistantContent> =
696                        serde_json::from_str(&content).map_err(|e| {
697                            SessionStoreError::serialization(format!(
698                                "Failed to deserialize assistant content: {e}"
699                            ))
700                        })?;
701                    Message {
702                        data: MessageData::Assistant { content },
703                        timestamp: created_at.timestamp() as u64,
704                        id,
705                        parent_message_id,
706                    }
707                }
708                "tool" => {
709                    #[derive(serde::Deserialize)]
710                    struct StoredToolMessage {
711                        tool_use_id: String,
712                        result: crate::app::conversation::ToolResult,
713                    }
714                    let stored: StoredToolMessage =
715                        serde_json::from_str(&content).map_err(|e| {
716                            SessionStoreError::serialization(format!(
717                                "Failed to deserialize tool message: {e}"
718                            ))
719                        })?;
720                    Message {
721                        data: MessageData::Tool {
722                            tool_use_id: stored.tool_use_id,
723                            result: stored.result,
724                        },
725                        timestamp: created_at.timestamp() as u64,
726                        id: id.clone(),
727                        parent_message_id: parent_message_id.clone(),
728                    }
729                }
730                _ => {
731                    return Err(SessionStoreError::serialization(format!(
732                        "Unknown message role: {role}"
733                    )));
734                }
735            };
736
737            messages.push(message);
738        }
739
740        Ok(messages)
741    }
742
743    async fn create_tool_call(
744        &self,
745        session_id: &str,
746        tool_call: &ToolCall,
747    ) -> Result<(), SessionStoreError> {
748        let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
749            SessionStoreError::serialization(format!("Failed to serialize parameters: {e}"))
750        })?;
751
752        sqlx::query(
753            r#"
754            INSERT INTO tool_calls (id, session_id, tool_name, parameters, status, kind)
755            VALUES (?1, ?2, ?3, ?4, ?5, ?6)
756            "#,
757        )
758        .bind(&tool_call.id)
759        .bind(session_id)
760        .bind(&tool_call.name)
761        .bind(&parameters_json)
762        .bind("pending")
763        .bind("external") // Default to external, will be updated when result comes in
764        .execute(&self.pool)
765        .await
766        .map_err(|e| SessionStoreError::database(format!("Failed to create tool call: {e}")))?;
767
768        Ok(())
769    }
770
771    async fn update_tool_call(
772        &self,
773        tool_call_id: &str,
774        update: ToolCallUpdate,
775    ) -> Result<(), SessionStoreError> {
776        let mut query = String::from("UPDATE tool_calls SET ");
777        let mut updates = Vec::new();
778        let mut bindings: Vec<String> = Vec::new();
779
780        if let Some(status) = update.status {
781            let status_str = match &status {
782                ToolCallStatus::PendingApproval => "pending",
783                ToolCallStatus::Approved => "approved",
784                ToolCallStatus::Denied => "denied",
785                ToolCallStatus::Executing => "executing",
786                ToolCallStatus::Completed => "completed",
787                ToolCallStatus::Failed { .. } => "failed",
788            };
789            updates.push(format!("status = ?{}", bindings.len() + 1));
790            bindings.push(status_str.to_string());
791
792            // Update timestamps based on status
793            match status {
794                ToolCallStatus::Executing => {
795                    updates.push(format!("started_at = ?{}", bindings.len() + 1));
796                    bindings.push(Utc::now().to_rfc3339());
797                }
798                ToolCallStatus::Completed | ToolCallStatus::Failed { .. } => {
799                    updates.push(format!("completed_at = ?{}", bindings.len() + 1));
800                    bindings.push(Utc::now().to_rfc3339());
801                }
802                _ => {}
803            }
804        }
805
806        if let Some(result) = update.result {
807            // Determine kind from result_type
808            let kind = if let Some(rt) = &result.result_type {
809                rt.clone()
810            } else {
811                "external".to_string()
812            };
813
814            updates.push(format!("kind = ?{}", bindings.len() + 1));
815            bindings.push(kind);
816
817            // Always store the legacy output if available
818            if let Some(output) = &result.output {
819                updates.push(format!("result = ?{}", bindings.len() + 1));
820                bindings.push(output.clone());
821            }
822
823            // Store JSON result if available
824            if let Some(json_output) = &result.json_output {
825                updates.push(format!("payload_json = ?{}", bindings.len() + 1));
826                let json_str = serde_json::to_string(json_output).map_err(|e| {
827                    SessionStoreError::serialization(format!(
828                        "Failed to serialize JSON result: {e}"
829                    ))
830                })?;
831                bindings.push(json_str);
832            } else if let Some(output) = &result.output {
833                // Fallback to wrapping string output as External
834                updates.push(format!("payload_json = ?{}", bindings.len() + 1));
835                let external_json = serde_json::json!({
836                    "tool_name": "unknown",
837                    "payload": output
838                });
839                bindings.push(external_json.to_string());
840            }
841        }
842
843        if let Some(error) = update.error {
844            // Mark as error kind
845            updates.push(format!("kind = ?{}", bindings.len() + 1));
846            bindings.push("error".to_string());
847
848            // Store error in error_json
849            updates.push(format!("error_json = ?{}", bindings.len() + 1));
850            let error_json = serde_json::json!({
851                "tool_name": "unknown",
852                "message": &error
853            });
854            bindings.push(error_json.to_string());
855
856            // Also store in legacy error column
857            updates.push(format!("error = ?{}", bindings.len() + 1));
858            bindings.push(error);
859        }
860
861        if updates.is_empty() {
862            return Ok(());
863        }
864
865        query.push_str(&updates.join(", "));
866        query.push_str(&format!(" WHERE id = ?{}", bindings.len() + 1));
867        bindings.push(tool_call_id.to_string());
868
869        // Execute with dynamic bindings
870        let mut q = sqlx::query(&query);
871        for binding in bindings {
872            q = q.bind(binding);
873        }
874
875        q.execute(&self.pool)
876            .await
877            .map_err(|e| SessionStoreError::database(format!("Failed to update tool call: {e}")))?;
878
879        Ok(())
880    }
881
882    async fn get_pending_tool_calls(
883        &self,
884        session_id: &str,
885    ) -> Result<Vec<ToolCall>, SessionStoreError> {
886        let rows = sqlx::query(
887            r#"
888            SELECT id, tool_name, parameters
889            FROM tool_calls
890            WHERE session_id = ?1 AND status = 'pending'
891            ORDER BY id ASC
892            "#,
893        )
894        .bind(session_id)
895        .fetch_all(&self.pool)
896        .await
897        .map_err(|e| {
898            SessionStoreError::database(format!("Failed to get pending tool calls: {e}"))
899        })?;
900
901        let mut tool_calls = Vec::new();
902        for row in rows {
903            let parameters: serde_json::Value =
904                serde_json::from_str(&row.get::<String, _>("parameters")).map_err(|e| {
905                    SessionStoreError::serialization(format!("Invalid parameters: {e}"))
906                })?;
907
908            tool_calls.push(ToolCall {
909                id: row.get("id"),
910                name: row.get("tool_name"),
911                parameters,
912            });
913        }
914
915        Ok(tool_calls)
916    }
917
918    async fn append_event(
919        &self,
920        session_id: &str,
921        event: &StreamEvent,
922    ) -> Result<u64, SessionStoreError> {
923        let event_type = match event {
924            StreamEvent::MessagePart { .. } => "message_part",
925            StreamEvent::MessageComplete { .. } => "message_complete",
926            StreamEvent::ToolCallStarted { .. } => "tool_call_started",
927            StreamEvent::ToolCallCompleted { .. } => "tool_call_completed",
928            StreamEvent::ToolCallFailed { .. } => "tool_call_failed",
929            StreamEvent::ToolApprovalRequired { .. } => "tool_approval_required",
930            StreamEvent::SessionCreated { .. } => "session_created",
931            StreamEvent::SessionResumed { .. } => "session_resumed",
932            StreamEvent::SessionSaved { .. } => "session_saved",
933            StreamEvent::OperationStarted { .. } => "operation_started",
934            StreamEvent::OperationCompleted { .. } => "operation_completed",
935            StreamEvent::OperationCancelled { .. } => "operation_cancelled",
936            StreamEvent::Error { .. } => "error",
937            StreamEvent::WorkspaceChanged => "workspace_changed",
938            StreamEvent::WorkspaceFiles { .. } => "workspace_files",
939        };
940
941        let event_data = serde_json::to_string(event).map_err(|e| {
942            SessionStoreError::serialization(format!("Failed to serialize event: {e}"))
943        })?;
944
945        // Get the next sequence number
946        let next_seq: i64 = sqlx::query_scalar(
947            "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM events WHERE session_id = ?1",
948        )
949        .bind(session_id)
950        .fetch_one(&self.pool)
951        .await
952        .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
953
954        sqlx::query(
955            r#"
956            INSERT INTO events (session_id, sequence_num, event_type, event_data, created_at)
957            VALUES (?1, ?2, ?3, ?4, ?5)
958            "#,
959        )
960        .bind(session_id)
961        .bind(next_seq)
962        .bind(event_type)
963        .bind(&event_data)
964        .bind(Utc::now())
965        .execute(&self.pool)
966        .await
967        .map_err(|e| SessionStoreError::database(format!("Failed to append event: {e}")))?;
968
969        Ok(next_seq as u64)
970    }
971
972    async fn get_events(
973        &self,
974        session_id: &str,
975        after_sequence: u64,
976        limit: Option<u32>,
977    ) -> Result<Vec<(u64, StreamEvent)>, SessionStoreError> {
978        let query = if let Some(limit) = limit {
979            sqlx::query(
980                r#"
981                SELECT sequence_num, event_data
982                FROM events
983                WHERE session_id = ?1 AND sequence_num > ?2
984                ORDER BY sequence_num ASC
985                LIMIT ?3
986                "#,
987            )
988            .bind(session_id)
989            .bind(after_sequence as i64)
990            .bind(limit as i64)
991        } else {
992            sqlx::query(
993                r#"
994                SELECT sequence_num, event_data
995                FROM events
996                WHERE session_id = ?1 AND sequence_num > ?2
997                ORDER BY sequence_num ASC
998                "#,
999            )
1000            .bind(session_id)
1001            .bind(after_sequence as i64)
1002        };
1003
1004        let rows = query
1005            .fetch_all(&self.pool)
1006            .await
1007            .map_err(|e| SessionStoreError::database(format!("Failed to get events: {e}")))?;
1008
1009        let mut events = Vec::new();
1010        for row in rows {
1011            let seq: i64 = row.get("sequence_num");
1012            let event: StreamEvent = serde_json::from_str(&row.get::<String, _>("event_data"))
1013                .map_err(|e| {
1014                    SessionStoreError::serialization(format!("Invalid event data: {e}"))
1015                })?;
1016
1017            events.push((seq as u64, event));
1018        }
1019
1020        Ok(events)
1021    }
1022
1023    async fn delete_events_before(
1024        &self,
1025        session_id: &str,
1026        before_sequence: u64,
1027    ) -> Result<u64, SessionStoreError> {
1028        let result = sqlx::query("DELETE FROM events WHERE session_id = ?1 AND sequence_num < ?2")
1029            .bind(session_id)
1030            .bind(before_sequence as i64)
1031            .execute(&self.pool)
1032            .await
1033            .map_err(|e| SessionStoreError::database(format!("Failed to delete events: {e}")))?;
1034
1035        Ok(result.rows_affected())
1036    }
1037
1038    async fn update_active_message_id(
1039        &self,
1040        session_id: &str,
1041        message_id: Option<&str>,
1042    ) -> Result<(), SessionStoreError> {
1043        sqlx::query("UPDATE sessions SET active_message_id = ?2, updated_at = ?3 WHERE id = ?1")
1044            .bind(session_id)
1045            .bind(message_id)
1046            .bind(Utc::now())
1047            .execute(&self.pool)
1048            .await
1049            .map_err(|e| {
1050                SessionStoreError::database(format!("Failed to update active_message_id: {e}"))
1051            })?;
1052
1053        Ok(())
1054    }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use crate::api::Model;
1060    use crate::app::conversation::{AssistantContent, Message, Role, UserContent};
1061    use crate::events::SessionMetadata;
1062    use crate::session::ToolVisibility;
1063    use crate::session::state::WorkspaceConfig;
1064
1065    use super::*;
1066    use tempfile::TempDir;
1067
1068    async fn create_test_store() -> (SqliteSessionStore, TempDir) {
1069        let temp_dir = TempDir::new().unwrap();
1070        let db_path = temp_dir.path().join("test.db");
1071        let store = SqliteSessionStore::new(&db_path).await.unwrap();
1072        (store, temp_dir)
1073    }
1074
1075    fn create_test_session_config() -> SessionConfig {
1076        let tool_config = crate::session::SessionToolConfig {
1077            approval_policy: ToolApprovalPolicy::AlwaysAsk,
1078            visibility: ToolVisibility::All,
1079            ..Default::default()
1080        };
1081
1082        SessionConfig {
1083            workspace: WorkspaceConfig::default(),
1084            tool_config,
1085            system_prompt: None,
1086            metadata: std::collections::HashMap::new(),
1087        }
1088    }
1089
1090    #[tokio::test]
1091    async fn test_create_and_get_session() {
1092        let (store, _temp) = create_test_store().await;
1093
1094        let tool_config = crate::session::SessionToolConfig {
1095            approval_policy: ToolApprovalPolicy::AlwaysAsk,
1096            ..Default::default()
1097        };
1098
1099        let config = SessionConfig {
1100            workspace: WorkspaceConfig::default(),
1101            tool_config,
1102            system_prompt: None,
1103            metadata: Default::default(),
1104        };
1105
1106        let session = store.create_session(config.clone()).await.unwrap();
1107        assert!(!session.id.is_empty());
1108
1109        let fetched_session = store.get_session(&session.id).await.unwrap().unwrap();
1110        assert_eq!(session.id, fetched_session.id);
1111        assert!(matches!(
1112            fetched_session.config.tool_config.approval_policy,
1113            ToolApprovalPolicy::AlwaysAsk
1114        ));
1115        assert!(matches!(
1116            fetched_session.config.workspace,
1117            WorkspaceConfig::Local { .. }
1118        ));
1119    }
1120
1121    #[tokio::test]
1122    async fn test_message_operations() {
1123        let (store, _temp) = create_test_store().await;
1124
1125        let config = create_test_session_config();
1126        let session = store.create_session(config).await.unwrap();
1127
1128        let message = Message {
1129            data: MessageData::User {
1130                content: vec![UserContent::Text {
1131                    text: "Hello".to_string(),
1132                }],
1133            },
1134            timestamp: 123456789,
1135            id: "msg1".to_string(),
1136            parent_message_id: None,
1137        };
1138
1139        store.append_message(&session.id, &message).await.unwrap();
1140
1141        let messages = store.get_messages(&session.id, None).await.unwrap();
1142        assert_eq!(messages.len(), 1);
1143        assert_eq!(messages[0].role(), Role::User);
1144    }
1145
1146    #[tokio::test]
1147    async fn test_tool_call_operations() {
1148        let (store, _temp) = create_test_store().await;
1149
1150        let config = create_test_session_config();
1151        let session = store.create_session(config).await.unwrap();
1152
1153        let tool_call = ToolCall {
1154            id: "tc1".to_string(),
1155            name: "test_tool".to_string(),
1156            parameters: serde_json::json!({"param": "value"}),
1157        };
1158
1159        store
1160            .create_tool_call(&session.id, &tool_call)
1161            .await
1162            .unwrap();
1163
1164        let pending = store.get_pending_tool_calls(&session.id).await.unwrap();
1165        assert_eq!(pending.len(), 1);
1166        assert_eq!(pending[0].name, "test_tool");
1167
1168        let update = ToolCallUpdate::set_status(ToolCallStatus::Completed);
1169        store.update_tool_call(&tool_call.id, update).await.unwrap();
1170
1171        let pending_after = store.get_pending_tool_calls(&session.id).await.unwrap();
1172        assert_eq!(pending_after.len(), 0);
1173    }
1174
1175    #[tokio::test]
1176    async fn test_event_streaming() {
1177        let (store, _temp) = create_test_store().await;
1178
1179        let config = create_test_session_config();
1180        let session = store.create_session(config).await.unwrap();
1181
1182        let event = StreamEvent::SessionCreated {
1183            session_id: session.id.clone(),
1184            metadata: SessionMetadata {
1185                model: Model::Claude3_5Sonnet20241022,
1186                created_at: session.created_at,
1187                metadata: session.config.metadata,
1188            },
1189        };
1190
1191        let seq = store.append_event(&session.id, &event).await.unwrap();
1192        assert_eq!(seq, 0);
1193
1194        // Get events after sequence 0 (should be empty since we only have sequence 0)
1195        let events = store.get_events(&session.id, 0, None).await.unwrap();
1196        assert_eq!(events.len(), 0);
1197
1198        // Get all events including sequence 0 by asking for events after -1
1199        let all_events = store.get_events(&session.id, u64::MAX, None).await.unwrap();
1200        assert_eq!(all_events.len(), 1);
1201        assert_eq!(all_events[0].0, 0);
1202    }
1203
1204    #[tokio::test]
1205    async fn test_session_listing() {
1206        let (store, _temp) = create_test_store().await;
1207
1208        // Create multiple sessions
1209        for i in 0..3 {
1210            let mut config = create_test_session_config();
1211            config.metadata.insert("index".to_string(), i.to_string());
1212            store.create_session(config).await.unwrap();
1213        }
1214
1215        let filter = SessionFilter {
1216            limit: Some(2),
1217            order_by: SessionOrderBy::CreatedAt,
1218            ..Default::default()
1219        };
1220
1221        let sessions = store.list_sessions(filter).await.unwrap();
1222        assert_eq!(sessions.len(), 2);
1223    }
1224
1225    #[tokio::test]
1226    async fn test_last_model_tracking() {
1227        let (store, _temp) = create_test_store().await;
1228
1229        let config = create_test_session_config();
1230        let session = store.create_session(config).await.unwrap();
1231
1232        // Initially, no events means no last model
1233        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1234        assert_eq!(sessions.len(), 1);
1235        assert_eq!(sessions[0].last_model, None);
1236
1237        // Add a MessageComplete event with Claude model
1238        let claude_model = Model::Claude3_5Sonnet20241022;
1239        let message_event = StreamEvent::MessageComplete {
1240            message: Message {
1241                data: MessageData::Assistant {
1242                    content: vec![AssistantContent::Text {
1243                        text: "Hello from Claude".to_string(),
1244                    }],
1245                },
1246                timestamp: 123456789,
1247                id: "msg1".to_string(),
1248                parent_message_id: None,
1249            },
1250            usage: None,
1251            metadata: std::collections::HashMap::new(),
1252            model: claude_model,
1253        };
1254        store
1255            .append_event(&session.id, &message_event)
1256            .await
1257            .unwrap();
1258
1259        // Check that last_model is now Claude
1260        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1261        assert_eq!(sessions[0].last_model, Some(claude_model));
1262
1263        // Add a ToolCallFailed event with GPT model (more recent)
1264        let gpt_model = Model::Gpt4_1_20250414;
1265        let tool_event = StreamEvent::ToolCallFailed {
1266            tool_call_id: "tool1".to_string(),
1267            error: "Test error".to_string(),
1268            metadata: std::collections::HashMap::new(),
1269            model: gpt_model,
1270        };
1271        store.append_event(&session.id, &tool_event).await.unwrap();
1272
1273        // Check that last_model is now GPT (the most recent)
1274        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1275        assert_eq!(sessions[0].last_model, Some(gpt_model));
1276
1277        // Add an event without a model field (shouldn't change last_model)
1278        let session_event = StreamEvent::SessionSaved {
1279            session_id: session.id.clone(),
1280        };
1281        store
1282            .append_event(&session.id, &session_event)
1283            .await
1284            .unwrap();
1285
1286        // Check that last_model is still GPT
1287        let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1288        assert_eq!(sessions[0].last_model, Some(gpt_model));
1289    }
1290}