1use async_trait::async_trait;
2use chrono::Utc;
3use serde_json;
4use sqlx::{
5 Row,
6 sqlite::{
7 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteSynchronous,
8 },
9};
10use std::collections::HashSet;
11use std::path::Path;
12use std::str::FromStr;
13use uuid::Uuid;
14
15use crate::app::Message;
16use crate::app::conversation::{AssistantContent, MessageData, UserContent};
17use crate::events::StreamEvent;
18use crate::session::{
20 Session, SessionConfig, SessionFilter, SessionInfo, SessionOrderBy, SessionState,
21 SessionStatus, SessionStore, SessionStoreError, ToolApprovalPolicy, ToolCallState,
22 ToolCallStatus, ToolCallUpdate, ToolExecutionStats,
23};
24use steer_tools::ToolCall;
25use steer_tools::result::ToolResult;
26
27pub struct SqliteSessionStore {
29 pool: SqlitePool,
30}
31
32impl SqliteSessionStore {
33 pub async fn new(path: &Path) -> Result<Self, SessionStoreError> {
35 if let Some(parent) = path.parent() {
37 std::fs::create_dir_all(parent).map_err(|e| {
38 SessionStoreError::connection(format!("Failed to create directory: {e}"))
39 })?;
40 }
41
42 let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path.display()))
43 .map_err(|e| SessionStoreError::connection(format!("Invalid SQLite path: {e}")))?
44 .create_if_missing(true)
45 .journal_mode(SqliteJournalMode::Wal)
46 .synchronous(SqliteSynchronous::Normal)
47 .foreign_keys(true);
48
49 let pool = SqlitePoolOptions::new()
50 .max_connections(1) .connect_with(options)
52 .await
53 .map_err(|e| {
54 SessionStoreError::connection(format!("Failed to connect to SQLite: {e}"))
55 })?;
56
57 sqlx::migrate!("migrations/sqlite")
59 .run(&pool)
60 .await
61 .map_err(|e| SessionStoreError::Migration {
62 message: format!("Failed to run migrations: {e}"),
63 })?;
64
65 Ok(Self { pool })
66 }
67
68 fn parse_tool_policy(
70 policy_type: &str,
71 pre_approved_json: &str,
72 ) -> Result<ToolApprovalPolicy, SessionStoreError> {
73 let pre_approved: Vec<String> = serde_json::from_str(pre_approved_json).map_err(|e| {
74 SessionStoreError::serialization(format!("Invalid pre_approved_tools: {e}"))
75 })?;
76
77 match policy_type {
78 "always_ask" => Ok(ToolApprovalPolicy::AlwaysAsk),
79 "pre_approved" => Ok(ToolApprovalPolicy::PreApproved {
80 tools: pre_approved.into_iter().collect(),
81 }),
82 "mixed" => Ok(ToolApprovalPolicy::Mixed {
83 pre_approved: pre_approved.into_iter().collect(),
84 ask_for_others: true,
85 }),
86 _ => Err(SessionStoreError::validation(format!(
87 "Invalid tool policy type: {policy_type}"
88 ))),
89 }
90 }
91
92 fn serialize_tool_policy(policy: &ToolApprovalPolicy) -> (String, String) {
94 match policy {
95 ToolApprovalPolicy::AlwaysAsk => ("always_ask".to_string(), "[]".to_string()),
96 ToolApprovalPolicy::PreApproved { tools } => {
97 let tools_vec: Vec<String> = tools.iter().cloned().collect();
98 (
99 "pre_approved".to_string(),
100 serde_json::to_string(&tools_vec).unwrap(),
101 )
102 }
103 ToolApprovalPolicy::Mixed { pre_approved, .. } => {
104 let tools_vec: Vec<String> = pre_approved.iter().cloned().collect();
105 (
106 "mixed".to_string(),
107 serde_json::to_string(&tools_vec).unwrap(),
108 )
109 }
110 }
111 }
112}
113
114#[async_trait]
115impl SessionStore for SqliteSessionStore {
116 async fn create_session(&self, config: SessionConfig) -> Result<Session, SessionStoreError> {
117 let id = Uuid::new_v4().to_string();
118 let now = Utc::now();
119 let (policy_type, pre_approved_json) =
120 Self::serialize_tool_policy(&config.tool_config.approval_policy);
121 let metadata_json = serde_json::to_string(&config.metadata).map_err(|e| {
122 SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
123 })?;
124 let tool_config_json = serde_json::to_string(&config.tool_config).map_err(|e| {
125 SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
126 })?;
127 let workspace_config_json = serde_json::to_string(&config.workspace).map_err(|e| {
128 SessionStoreError::serialization(format!("Failed to serialize workspace_config: {e}"))
129 })?;
130
131 sqlx::query(
132 r#"
133 INSERT INTO sessions (id, created_at, updated_at, status, metadata,
134 tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt)
135 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
136 "#,
137 )
138 .bind(&id)
139 .bind(now)
140 .bind(now)
141 .bind("inactive") .bind(&metadata_json)
143 .bind(&policy_type)
144 .bind(&pre_approved_json)
145 .bind(&tool_config_json)
146 .bind(&workspace_config_json)
147 .bind(&config.system_prompt)
148 .execute(&self.pool)
149 .await
150 .map_err(|e| SessionStoreError::database(format!("Failed to create session: {e}")))?;
151
152 Ok(Session {
153 id: id.clone(),
154 created_at: now,
155 updated_at: now,
156 config,
157 state: SessionState::default(),
158 })
159 }
160
161 async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionStoreError> {
162 let row = sqlx::query(
163 r#"
164 SELECT id, created_at, updated_at, metadata,
165 tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt,
166 active_message_id
167 FROM sessions
168 WHERE id = ?1
169 "#,
170 )
171 .bind(session_id)
172 .fetch_optional(&self.pool)
173 .await
174 .map_err(|e| SessionStoreError::database(format!("Failed to get session: {e}")))?;
175
176 let Some(row) = row else {
177 return Ok(None);
178 };
179
180 let approval_policy = Self::parse_tool_policy(
181 &row.get::<String, _>("tool_policy_type"),
182 &row.get::<String, _>("pre_approved_tools"),
183 )?;
184
185 let metadata: std::collections::HashMap<String, String> =
186 serde_json::from_str(&row.get::<String, _>("metadata"))
187 .map_err(|e| SessionStoreError::serialization(format!("Invalid metadata: {e}")))?;
188
189 let tool_config = serde_json::from_str(&row.get::<String, _>("tool_config"))
190 .map_err(|e| SessionStoreError::serialization(format!("Invalid tool_config: {e}")))?;
191 let workspace_config = serde_json::from_str(&row.get::<String, _>("workspace_config"))
192 .map_err(|e| {
193 SessionStoreError::serialization(format!("Invalid workspace_config: {e}"))
194 })?;
195
196 let mut tool_config: crate::session::SessionToolConfig = tool_config;
197 tool_config.approval_policy = approval_policy;
198
199 let system_prompt: Option<String> = row.get("system_prompt");
200
201 let config = SessionConfig {
202 workspace: workspace_config,
203 tool_config,
204 system_prompt,
205 metadata,
206 };
207
208 let messages = self.get_messages(session_id, None).await?;
210
211 let tool_calls_rows = sqlx::query(
213 r#"
214 SELECT id, tool_name, parameters, status, result, error, started_at, completed_at, kind, payload_json, error_json
215 FROM tool_calls
216 WHERE session_id = ?1
217 "#,
218 )
219 .bind(session_id)
220 .fetch_all(&self.pool)
221 .await
222 .map_err(|e| SessionStoreError::database(format!("Failed to load tool calls: {e}")))?;
223
224 let mut tool_calls = std::collections::HashMap::new();
225 for row in tool_calls_rows {
226 let id: String = row.get("id");
227 let status_str: String = row.get("status");
228 let error: Option<String> = row.get("error");
229
230 let status = match status_str.as_str() {
231 "pending" => ToolCallStatus::PendingApproval,
232 "approved" => ToolCallStatus::Approved,
233 "denied" => ToolCallStatus::Denied,
234 "executing" => ToolCallStatus::Executing,
235 "completed" => ToolCallStatus::Completed,
236 "failed" => ToolCallStatus::Failed {
237 error: error.unwrap_or_else(|| "Unknown error".to_string()),
238 },
239 _ => {
240 return Err(SessionStoreError::validation(format!(
241 "Invalid tool call status: {status_str}"
242 )));
243 }
244 };
245
246 let tool_call = ToolCall {
247 id: id.clone(),
248 name: row.get("tool_name"),
249 parameters: serde_json::from_str(&row.get::<String, _>("parameters")).map_err(
250 |e| SessionStoreError::serialization(format!("Invalid tool parameters: {e}")),
251 )?,
252 };
253
254 let result: Option<String> = row.get("result");
255 let json_result: Option<String> = row.get("payload_json");
256 let result_type: Option<String> = row.get("kind");
257 let error_json: Option<String> = row.get("error_json");
258
259 let _tool_result = if let Some(kind) = result_type.as_ref() {
260 if kind == "error" {
261 let error_data = error_json.and_then(|json_str| {
263 serde_json::from_str::<serde_json::Value>(&json_str).ok()
264 });
265 Some(ToolExecutionStats {
266 output: result.clone(),
267 json_output: error_data,
268 result_type: Some("error".to_string()),
269 success: false,
270 execution_time_ms: 0,
271 metadata: std::collections::HashMap::new(),
272 })
273 } else if let Some(json_str) = json_result {
274 let json_value = serde_json::from_str(&json_str).map_err(|e| {
276 SessionStoreError::serialization(format!("Invalid JSON result: {e}"))
277 })?;
278 Some(ToolExecutionStats {
279 output: result.clone(),
280 json_output: Some(json_value),
281 result_type,
282 success: true,
283 execution_time_ms: 0,
284 metadata: std::collections::HashMap::new(),
285 })
286 } else {
287 Some(ToolExecutionStats {
289 output: result.clone(),
290 json_output: None,
291 result_type,
292 success: true,
293 execution_time_ms: 0,
294 metadata: std::collections::HashMap::new(),
295 })
296 }
297 } else {
298 result.map(|r| ToolExecutionStats {
299 output: Some(r),
300 json_output: None,
301 result_type: None,
302 success: true,
303 execution_time_ms: 0,
304 metadata: std::collections::HashMap::new(),
305 })
306 };
307
308 let tool_result: Option<ToolResult> = None;
310
311 let state = ToolCallState {
312 tool_call,
313 status,
314 started_at: row.get("started_at"),
315 completed_at: row.get("completed_at"),
316 result: tool_result,
317 };
318
319 tool_calls.insert(id, state);
320 }
321
322 let last_sequence: Option<i64> =
324 sqlx::query_scalar("SELECT MAX(sequence_num) FROM events WHERE session_id = ?1")
325 .bind(session_id)
326 .fetch_one(&self.pool)
327 .await
328 .map_err(|e| {
329 SessionStoreError::database(format!("Failed to get last event sequence: {e}"))
330 })?;
331
332 let approved_bash_patterns: HashSet<String> =
334 if let Some(bash_config) = config.tool_config.tools.get("bash") {
335 let crate::session::state::ToolSpecificConfig::Bash(bash) = bash_config;
336 bash.approved_patterns.iter().cloned().collect()
337 } else {
338 HashSet::new()
339 };
340
341 let active_message_id: Option<String> = row.get("active_message_id");
343
344 let state = SessionState {
345 messages,
346 tool_calls,
347 approved_tools: Default::default(), approved_bash_patterns,
349 last_event_sequence: last_sequence.unwrap_or(0) as u64,
350 metadata: Default::default(),
351 active_message_id,
352 };
353
354 Ok(Some(Session {
355 id: row.get("id"),
356 created_at: row.get("created_at"),
357 updated_at: row.get("updated_at"),
358 config,
359 state,
360 }))
361 }
362
363 async fn update_session(&self, session: &Session) -> Result<(), SessionStoreError> {
364 let metadata_json = serde_json::to_string(&session.config.metadata).map_err(|e| {
365 SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
366 })?;
367 let tool_config_json = serde_json::to_string(&session.config.tool_config).map_err(|e| {
368 SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
369 })?;
370 let workspace_config_json =
371 serde_json::to_string(&session.config.workspace).map_err(|e| {
372 SessionStoreError::serialization(format!(
373 "Failed to serialize workspace_config: {e}"
374 ))
375 })?;
376 let (policy_type, pre_approved_json) =
377 Self::serialize_tool_policy(&session.config.tool_config.approval_policy);
378
379 sqlx::query(
380 r#"
381 UPDATE sessions
382 SET updated_at = ?2, metadata = ?3,
383 tool_policy_type = ?4, pre_approved_tools = ?5, tool_config = ?6, workspace_config = ?7
384 WHERE id = ?1
385 "#,
386 )
387 .bind(&session.id)
388 .bind(Utc::now())
389 .bind(&metadata_json)
390 .bind(&policy_type)
391 .bind(&pre_approved_json)
392 .bind(&tool_config_json)
393 .bind(&workspace_config_json)
394 .execute(&self.pool)
395 .await
396 .map_err(|e| SessionStoreError::database(format!("Failed to update session: {e}")))?;
397
398 Ok(())
399 }
400
401 async fn delete_session(&self, session_id: &str) -> Result<(), SessionStoreError> {
402 sqlx::query("DELETE FROM sessions WHERE id = ?1")
403 .bind(session_id)
404 .execute(&self.pool)
405 .await
406 .map_err(|e| SessionStoreError::database(format!("Failed to delete session: {e}")))?;
407
408 Ok(())
409 }
410
411 async fn list_sessions(
412 &self,
413 filter: SessionFilter,
414 ) -> Result<Vec<SessionInfo>, SessionStoreError> {
415 let mut query = String::from(
416 r#"
417 SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
418 (SELECT e.event_data
419 FROM events e
420 WHERE e.session_id = s.id
421 AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
422 ORDER BY e.sequence_num DESC
423 LIMIT 1) as last_event_data
424 FROM sessions s
425 WHERE 1=1
426 "#,
427 );
428 let mut bindings: Vec<String> = Vec::new();
429
430 if let Some(created_after) = filter.created_after {
432 query.push_str(&format!(" AND s.created_at >= ?{}", bindings.len() + 1));
433 bindings.push(created_after.to_rfc3339());
434 }
435 if let Some(created_before) = filter.created_before {
436 query.push_str(&format!(" AND s.created_at <= ?{}", bindings.len() + 1));
437 bindings.push(created_before.to_rfc3339());
438 }
439 if let Some(updated_after) = filter.updated_after {
440 query.push_str(&format!(" AND s.updated_at >= ?{}", bindings.len() + 1));
441 bindings.push(updated_after.to_rfc3339());
442 }
443 if let Some(updated_before) = filter.updated_before {
444 query.push_str(&format!(" AND s.updated_at <= ?{}", bindings.len() + 1));
445 bindings.push(updated_before.to_rfc3339());
446 }
447 if let Some(status) = filter.status_filter {
448 let status_str = match status {
449 SessionStatus::Active => "active",
450 SessionStatus::Inactive => "inactive",
451 };
452 query.push_str(&format!(" AND s.status = ?{}", bindings.len() + 1));
453 bindings.push(status_str.to_string());
454 }
455
456 let order_column = match filter.order_by {
458 SessionOrderBy::CreatedAt => "s.created_at",
459 SessionOrderBy::UpdatedAt => "s.updated_at",
460 SessionOrderBy::MessageCount => {
461 query = r#"
463 SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
464 (SELECT e.event_data
465 FROM events e
466 WHERE e.session_id = s.id
467 AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
468 ORDER BY e.sequence_num DESC
469 LIMIT 1) as last_event_data,
470 (SELECT COUNT(*) FROM messages WHERE session_id = s.id) as message_count
471 FROM sessions s
472 WHERE 1=1
473 "#.to_string();
474 "message_count"
475 }
476 };
477
478 let order_direction = match filter.order_direction {
479 crate::session::OrderDirection::Ascending => "ASC",
480 crate::session::OrderDirection::Descending => "DESC",
481 };
482
483 query.push_str(&format!(" ORDER BY {order_column} {order_direction}"));
484
485 if let Some(limit) = filter.limit {
487 query.push_str(&format!(" LIMIT {limit}"));
488 }
489 if let Some(offset) = filter.offset {
490 query.push_str(&format!(" OFFSET {offset}"));
491 }
492
493 let mut q = sqlx::query(&query);
495 for binding in bindings {
496 q = q.bind(binding);
497 }
498
499 let rows = q
500 .fetch_all(&self.pool)
501 .await
502 .map_err(|e| SessionStoreError::database(format!("Failed to list sessions: {e}")))?;
503
504 let mut sessions = Vec::new();
505 for row in rows {
506 let metadata: std::collections::HashMap<String, String> =
507 serde_json::from_str(&row.get::<String, _>("metadata")).map_err(|e| {
508 SessionStoreError::serialization(format!("Invalid metadata: {e}"))
509 })?;
510
511 let message_count: i64 = if matches!(filter.order_by, SessionOrderBy::MessageCount) {
513 row.get("message_count")
514 } else {
515 sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE session_id = ?1")
516 .bind(row.get::<String, _>("id"))
517 .fetch_one(&self.pool)
518 .await
519 .map_err(|e| {
520 SessionStoreError::database(format!("Failed to count messages: {e}"))
521 })?
522 };
523
524 let last_model =
526 if let Some(event_json) = row.get::<Option<String>, _>("last_event_data") {
527 let event: StreamEvent = serde_json::from_str(&event_json).map_err(|e| {
528 SessionStoreError::serialization(format!("Invalid event data: {e}"))
529 })?;
530
531 match event {
532 StreamEvent::MessageComplete { model, .. } => Some(model),
533 StreamEvent::ToolCallStarted { model, .. } => Some(model),
534 StreamEvent::ToolCallCompleted { model, .. } => Some(model),
535 StreamEvent::ToolCallFailed { model, .. } => Some(model),
536 _ => None,
537 }
538 } else {
539 None
540 };
541
542 sessions.push(SessionInfo {
543 id: row.get("id"),
544 created_at: row.get("created_at"),
545 updated_at: row.get("updated_at"),
546 last_model,
547 message_count: message_count as usize,
548 metadata,
549 });
550 }
551
552 Ok(sessions)
553 }
554
555 async fn append_message(
556 &self,
557 session_id: &str,
558 message: &Message,
559 ) -> Result<(), SessionStoreError> {
560 let id = message.id();
561
562 let next_seq: i64 = sqlx::query_scalar(
564 "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM messages WHERE session_id = ?1",
565 )
566 .bind(session_id)
567 .fetch_one(&self.pool)
568 .await
569 .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
570
571 let (role, content_json) = match &message.data {
573 MessageData::User { content, .. } => {
574 let json = serde_json::to_string(&content).map_err(|e| {
575 SessionStoreError::serialization(format!(
576 "Failed to serialize user content: {e}"
577 ))
578 })?;
579 ("user", json)
580 }
581 MessageData::Assistant { content, .. } => {
582 let json = serde_json::to_string(&content).map_err(|e| {
583 SessionStoreError::serialization(format!(
584 "Failed to serialize assistant content: {e}"
585 ))
586 })?;
587 ("assistant", json)
588 }
589 MessageData::Tool {
590 tool_use_id,
591 result,
592 ..
593 } => {
594 #[derive(serde::Serialize)]
596 struct StoredToolMessage {
597 tool_use_id: String,
598 result: crate::app::conversation::ToolResult,
599 }
600 let stored = StoredToolMessage {
601 tool_use_id: tool_use_id.clone(),
602 result: result.clone(),
603 };
604 let json = serde_json::to_string(&stored).map_err(|e| {
605 SessionStoreError::serialization(format!(
606 "Failed to serialize tool message: {e}"
607 ))
608 })?;
609 ("tool", json)
610 }
611 };
612
613 let parent_message_id = message.parent_message_id();
615
616 sqlx::query(
617 r#"
618 INSERT INTO messages (id, session_id, sequence_num, role, content, created_at, parent_message_id)
619 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
620 "#,
621 )
622 .bind(id)
623 .bind(session_id)
624 .bind(next_seq)
625 .bind(role)
626 .bind(&content_json)
627 .bind(Utc::now())
628 .bind(parent_message_id)
629 .execute(&self.pool)
630 .await
631 .map_err(|e| SessionStoreError::database(format!("Failed to append message: {e}")))?;
632
633 Ok(())
634 }
635
636 async fn get_messages(
637 &self,
638 session_id: &str,
639 after_sequence: Option<u32>,
640 ) -> Result<Vec<Message>, SessionStoreError> {
641 let query = if let Some(seq) = after_sequence {
642 sqlx::query(
643 r#"
644 SELECT id, sequence_num, role, content, created_at, parent_message_id
645 FROM messages
646 WHERE session_id = ?1 AND sequence_num > ?2
647 ORDER BY sequence_num ASC
648 "#,
649 )
650 .bind(session_id)
651 .bind(seq as i64)
652 } else {
653 sqlx::query(
654 r#"
655 SELECT id, sequence_num, role, content, created_at, parent_message_id
656 FROM messages
657 WHERE session_id = ?1
658 ORDER BY sequence_num ASC
659 "#,
660 )
661 .bind(session_id)
662 };
663
664 let rows = query
665 .fetch_all(&self.pool)
666 .await
667 .map_err(|e| SessionStoreError::database(format!("Failed to get messages: {e}")))?;
668
669 let mut messages = Vec::new();
670 for row in rows {
671 let role = row.get::<String, _>("role");
672 let content = row.get::<String, _>("content");
673 let id: String = row.get("id");
674 let created_at = row.get::<chrono::DateTime<chrono::Utc>, _>("created_at");
675
676 let parent_message_id: Option<String> = row.get("parent_message_id");
677
678 let message = match role.as_str() {
680 "user" => {
681 let content: Vec<UserContent> =
682 serde_json::from_str(&content).map_err(|e| {
683 SessionStoreError::serialization(format!(
684 "Failed to deserialize user content: {e}"
685 ))
686 })?;
687 Message {
688 data: MessageData::User { content },
689 timestamp: created_at.timestamp() as u64,
690 id,
691 parent_message_id,
692 }
693 }
694 "assistant" => {
695 let content: Vec<AssistantContent> =
696 serde_json::from_str(&content).map_err(|e| {
697 SessionStoreError::serialization(format!(
698 "Failed to deserialize assistant content: {e}"
699 ))
700 })?;
701 Message {
702 data: MessageData::Assistant { content },
703 timestamp: created_at.timestamp() as u64,
704 id,
705 parent_message_id,
706 }
707 }
708 "tool" => {
709 #[derive(serde::Deserialize)]
710 struct StoredToolMessage {
711 tool_use_id: String,
712 result: crate::app::conversation::ToolResult,
713 }
714 let stored: StoredToolMessage =
715 serde_json::from_str(&content).map_err(|e| {
716 SessionStoreError::serialization(format!(
717 "Failed to deserialize tool message: {e}"
718 ))
719 })?;
720 Message {
721 data: MessageData::Tool {
722 tool_use_id: stored.tool_use_id,
723 result: stored.result,
724 },
725 timestamp: created_at.timestamp() as u64,
726 id: id.clone(),
727 parent_message_id: parent_message_id.clone(),
728 }
729 }
730 _ => {
731 return Err(SessionStoreError::serialization(format!(
732 "Unknown message role: {role}"
733 )));
734 }
735 };
736
737 messages.push(message);
738 }
739
740 Ok(messages)
741 }
742
743 async fn create_tool_call(
744 &self,
745 session_id: &str,
746 tool_call: &ToolCall,
747 ) -> Result<(), SessionStoreError> {
748 let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
749 SessionStoreError::serialization(format!("Failed to serialize parameters: {e}"))
750 })?;
751
752 sqlx::query(
753 r#"
754 INSERT INTO tool_calls (id, session_id, tool_name, parameters, status, kind)
755 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
756 "#,
757 )
758 .bind(&tool_call.id)
759 .bind(session_id)
760 .bind(&tool_call.name)
761 .bind(¶meters_json)
762 .bind("pending")
763 .bind("external") .execute(&self.pool)
765 .await
766 .map_err(|e| SessionStoreError::database(format!("Failed to create tool call: {e}")))?;
767
768 Ok(())
769 }
770
771 async fn update_tool_call(
772 &self,
773 tool_call_id: &str,
774 update: ToolCallUpdate,
775 ) -> Result<(), SessionStoreError> {
776 let mut query = String::from("UPDATE tool_calls SET ");
777 let mut updates = Vec::new();
778 let mut bindings: Vec<String> = Vec::new();
779
780 if let Some(status) = update.status {
781 let status_str = match &status {
782 ToolCallStatus::PendingApproval => "pending",
783 ToolCallStatus::Approved => "approved",
784 ToolCallStatus::Denied => "denied",
785 ToolCallStatus::Executing => "executing",
786 ToolCallStatus::Completed => "completed",
787 ToolCallStatus::Failed { .. } => "failed",
788 };
789 updates.push(format!("status = ?{}", bindings.len() + 1));
790 bindings.push(status_str.to_string());
791
792 match status {
794 ToolCallStatus::Executing => {
795 updates.push(format!("started_at = ?{}", bindings.len() + 1));
796 bindings.push(Utc::now().to_rfc3339());
797 }
798 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } => {
799 updates.push(format!("completed_at = ?{}", bindings.len() + 1));
800 bindings.push(Utc::now().to_rfc3339());
801 }
802 _ => {}
803 }
804 }
805
806 if let Some(result) = update.result {
807 let kind = if let Some(rt) = &result.result_type {
809 rt.clone()
810 } else {
811 "external".to_string()
812 };
813
814 updates.push(format!("kind = ?{}", bindings.len() + 1));
815 bindings.push(kind);
816
817 if let Some(output) = &result.output {
819 updates.push(format!("result = ?{}", bindings.len() + 1));
820 bindings.push(output.clone());
821 }
822
823 if let Some(json_output) = &result.json_output {
825 updates.push(format!("payload_json = ?{}", bindings.len() + 1));
826 let json_str = serde_json::to_string(json_output).map_err(|e| {
827 SessionStoreError::serialization(format!(
828 "Failed to serialize JSON result: {e}"
829 ))
830 })?;
831 bindings.push(json_str);
832 } else if let Some(output) = &result.output {
833 updates.push(format!("payload_json = ?{}", bindings.len() + 1));
835 let external_json = serde_json::json!({
836 "tool_name": "unknown",
837 "payload": output
838 });
839 bindings.push(external_json.to_string());
840 }
841 }
842
843 if let Some(error) = update.error {
844 updates.push(format!("kind = ?{}", bindings.len() + 1));
846 bindings.push("error".to_string());
847
848 updates.push(format!("error_json = ?{}", bindings.len() + 1));
850 let error_json = serde_json::json!({
851 "tool_name": "unknown",
852 "message": &error
853 });
854 bindings.push(error_json.to_string());
855
856 updates.push(format!("error = ?{}", bindings.len() + 1));
858 bindings.push(error);
859 }
860
861 if updates.is_empty() {
862 return Ok(());
863 }
864
865 query.push_str(&updates.join(", "));
866 query.push_str(&format!(" WHERE id = ?{}", bindings.len() + 1));
867 bindings.push(tool_call_id.to_string());
868
869 let mut q = sqlx::query(&query);
871 for binding in bindings {
872 q = q.bind(binding);
873 }
874
875 q.execute(&self.pool)
876 .await
877 .map_err(|e| SessionStoreError::database(format!("Failed to update tool call: {e}")))?;
878
879 Ok(())
880 }
881
882 async fn get_pending_tool_calls(
883 &self,
884 session_id: &str,
885 ) -> Result<Vec<ToolCall>, SessionStoreError> {
886 let rows = sqlx::query(
887 r#"
888 SELECT id, tool_name, parameters
889 FROM tool_calls
890 WHERE session_id = ?1 AND status = 'pending'
891 ORDER BY id ASC
892 "#,
893 )
894 .bind(session_id)
895 .fetch_all(&self.pool)
896 .await
897 .map_err(|e| {
898 SessionStoreError::database(format!("Failed to get pending tool calls: {e}"))
899 })?;
900
901 let mut tool_calls = Vec::new();
902 for row in rows {
903 let parameters: serde_json::Value =
904 serde_json::from_str(&row.get::<String, _>("parameters")).map_err(|e| {
905 SessionStoreError::serialization(format!("Invalid parameters: {e}"))
906 })?;
907
908 tool_calls.push(ToolCall {
909 id: row.get("id"),
910 name: row.get("tool_name"),
911 parameters,
912 });
913 }
914
915 Ok(tool_calls)
916 }
917
918 async fn append_event(
919 &self,
920 session_id: &str,
921 event: &StreamEvent,
922 ) -> Result<u64, SessionStoreError> {
923 let event_type = match event {
924 StreamEvent::MessagePart { .. } => "message_part",
925 StreamEvent::MessageComplete { .. } => "message_complete",
926 StreamEvent::ToolCallStarted { .. } => "tool_call_started",
927 StreamEvent::ToolCallCompleted { .. } => "tool_call_completed",
928 StreamEvent::ToolCallFailed { .. } => "tool_call_failed",
929 StreamEvent::ToolApprovalRequired { .. } => "tool_approval_required",
930 StreamEvent::SessionCreated { .. } => "session_created",
931 StreamEvent::SessionResumed { .. } => "session_resumed",
932 StreamEvent::SessionSaved { .. } => "session_saved",
933 StreamEvent::OperationStarted { .. } => "operation_started",
934 StreamEvent::OperationCompleted { .. } => "operation_completed",
935 StreamEvent::OperationCancelled { .. } => "operation_cancelled",
936 StreamEvent::Error { .. } => "error",
937 StreamEvent::WorkspaceChanged => "workspace_changed",
938 StreamEvent::WorkspaceFiles { .. } => "workspace_files",
939 };
940
941 let event_data = serde_json::to_string(event).map_err(|e| {
942 SessionStoreError::serialization(format!("Failed to serialize event: {e}"))
943 })?;
944
945 let next_seq: i64 = sqlx::query_scalar(
947 "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM events WHERE session_id = ?1",
948 )
949 .bind(session_id)
950 .fetch_one(&self.pool)
951 .await
952 .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
953
954 sqlx::query(
955 r#"
956 INSERT INTO events (session_id, sequence_num, event_type, event_data, created_at)
957 VALUES (?1, ?2, ?3, ?4, ?5)
958 "#,
959 )
960 .bind(session_id)
961 .bind(next_seq)
962 .bind(event_type)
963 .bind(&event_data)
964 .bind(Utc::now())
965 .execute(&self.pool)
966 .await
967 .map_err(|e| SessionStoreError::database(format!("Failed to append event: {e}")))?;
968
969 Ok(next_seq as u64)
970 }
971
972 async fn get_events(
973 &self,
974 session_id: &str,
975 after_sequence: u64,
976 limit: Option<u32>,
977 ) -> Result<Vec<(u64, StreamEvent)>, SessionStoreError> {
978 let query = if let Some(limit) = limit {
979 sqlx::query(
980 r#"
981 SELECT sequence_num, event_data
982 FROM events
983 WHERE session_id = ?1 AND sequence_num > ?2
984 ORDER BY sequence_num ASC
985 LIMIT ?3
986 "#,
987 )
988 .bind(session_id)
989 .bind(after_sequence as i64)
990 .bind(limit as i64)
991 } else {
992 sqlx::query(
993 r#"
994 SELECT sequence_num, event_data
995 FROM events
996 WHERE session_id = ?1 AND sequence_num > ?2
997 ORDER BY sequence_num ASC
998 "#,
999 )
1000 .bind(session_id)
1001 .bind(after_sequence as i64)
1002 };
1003
1004 let rows = query
1005 .fetch_all(&self.pool)
1006 .await
1007 .map_err(|e| SessionStoreError::database(format!("Failed to get events: {e}")))?;
1008
1009 let mut events = Vec::new();
1010 for row in rows {
1011 let seq: i64 = row.get("sequence_num");
1012 let event: StreamEvent = serde_json::from_str(&row.get::<String, _>("event_data"))
1013 .map_err(|e| {
1014 SessionStoreError::serialization(format!("Invalid event data: {e}"))
1015 })?;
1016
1017 events.push((seq as u64, event));
1018 }
1019
1020 Ok(events)
1021 }
1022
1023 async fn delete_events_before(
1024 &self,
1025 session_id: &str,
1026 before_sequence: u64,
1027 ) -> Result<u64, SessionStoreError> {
1028 let result = sqlx::query("DELETE FROM events WHERE session_id = ?1 AND sequence_num < ?2")
1029 .bind(session_id)
1030 .bind(before_sequence as i64)
1031 .execute(&self.pool)
1032 .await
1033 .map_err(|e| SessionStoreError::database(format!("Failed to delete events: {e}")))?;
1034
1035 Ok(result.rows_affected())
1036 }
1037
1038 async fn update_active_message_id(
1039 &self,
1040 session_id: &str,
1041 message_id: Option<&str>,
1042 ) -> Result<(), SessionStoreError> {
1043 sqlx::query("UPDATE sessions SET active_message_id = ?2, updated_at = ?3 WHERE id = ?1")
1044 .bind(session_id)
1045 .bind(message_id)
1046 .bind(Utc::now())
1047 .execute(&self.pool)
1048 .await
1049 .map_err(|e| {
1050 SessionStoreError::database(format!("Failed to update active_message_id: {e}"))
1051 })?;
1052
1053 Ok(())
1054 }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use crate::api::Model;
1060 use crate::app::conversation::{AssistantContent, Message, Role, UserContent};
1061 use crate::events::SessionMetadata;
1062 use crate::session::ToolVisibility;
1063 use crate::session::state::WorkspaceConfig;
1064
1065 use super::*;
1066 use tempfile::TempDir;
1067
1068 async fn create_test_store() -> (SqliteSessionStore, TempDir) {
1069 let temp_dir = TempDir::new().unwrap();
1070 let db_path = temp_dir.path().join("test.db");
1071 let store = SqliteSessionStore::new(&db_path).await.unwrap();
1072 (store, temp_dir)
1073 }
1074
1075 fn create_test_session_config() -> SessionConfig {
1076 let tool_config = crate::session::SessionToolConfig {
1077 approval_policy: ToolApprovalPolicy::AlwaysAsk,
1078 visibility: ToolVisibility::All,
1079 ..Default::default()
1080 };
1081
1082 SessionConfig {
1083 workspace: WorkspaceConfig::default(),
1084 tool_config,
1085 system_prompt: None,
1086 metadata: std::collections::HashMap::new(),
1087 }
1088 }
1089
1090 #[tokio::test]
1091 async fn test_create_and_get_session() {
1092 let (store, _temp) = create_test_store().await;
1093
1094 let tool_config = crate::session::SessionToolConfig {
1095 approval_policy: ToolApprovalPolicy::AlwaysAsk,
1096 ..Default::default()
1097 };
1098
1099 let config = SessionConfig {
1100 workspace: WorkspaceConfig::default(),
1101 tool_config,
1102 system_prompt: None,
1103 metadata: Default::default(),
1104 };
1105
1106 let session = store.create_session(config.clone()).await.unwrap();
1107 assert!(!session.id.is_empty());
1108
1109 let fetched_session = store.get_session(&session.id).await.unwrap().unwrap();
1110 assert_eq!(session.id, fetched_session.id);
1111 assert!(matches!(
1112 fetched_session.config.tool_config.approval_policy,
1113 ToolApprovalPolicy::AlwaysAsk
1114 ));
1115 assert!(matches!(
1116 fetched_session.config.workspace,
1117 WorkspaceConfig::Local { .. }
1118 ));
1119 }
1120
1121 #[tokio::test]
1122 async fn test_message_operations() {
1123 let (store, _temp) = create_test_store().await;
1124
1125 let config = create_test_session_config();
1126 let session = store.create_session(config).await.unwrap();
1127
1128 let message = Message {
1129 data: MessageData::User {
1130 content: vec![UserContent::Text {
1131 text: "Hello".to_string(),
1132 }],
1133 },
1134 timestamp: 123456789,
1135 id: "msg1".to_string(),
1136 parent_message_id: None,
1137 };
1138
1139 store.append_message(&session.id, &message).await.unwrap();
1140
1141 let messages = store.get_messages(&session.id, None).await.unwrap();
1142 assert_eq!(messages.len(), 1);
1143 assert_eq!(messages[0].role(), Role::User);
1144 }
1145
1146 #[tokio::test]
1147 async fn test_tool_call_operations() {
1148 let (store, _temp) = create_test_store().await;
1149
1150 let config = create_test_session_config();
1151 let session = store.create_session(config).await.unwrap();
1152
1153 let tool_call = ToolCall {
1154 id: "tc1".to_string(),
1155 name: "test_tool".to_string(),
1156 parameters: serde_json::json!({"param": "value"}),
1157 };
1158
1159 store
1160 .create_tool_call(&session.id, &tool_call)
1161 .await
1162 .unwrap();
1163
1164 let pending = store.get_pending_tool_calls(&session.id).await.unwrap();
1165 assert_eq!(pending.len(), 1);
1166 assert_eq!(pending[0].name, "test_tool");
1167
1168 let update = ToolCallUpdate::set_status(ToolCallStatus::Completed);
1169 store.update_tool_call(&tool_call.id, update).await.unwrap();
1170
1171 let pending_after = store.get_pending_tool_calls(&session.id).await.unwrap();
1172 assert_eq!(pending_after.len(), 0);
1173 }
1174
1175 #[tokio::test]
1176 async fn test_event_streaming() {
1177 let (store, _temp) = create_test_store().await;
1178
1179 let config = create_test_session_config();
1180 let session = store.create_session(config).await.unwrap();
1181
1182 let event = StreamEvent::SessionCreated {
1183 session_id: session.id.clone(),
1184 metadata: SessionMetadata {
1185 model: Model::Claude3_5Sonnet20241022,
1186 created_at: session.created_at,
1187 metadata: session.config.metadata,
1188 },
1189 };
1190
1191 let seq = store.append_event(&session.id, &event).await.unwrap();
1192 assert_eq!(seq, 0);
1193
1194 let events = store.get_events(&session.id, 0, None).await.unwrap();
1196 assert_eq!(events.len(), 0);
1197
1198 let all_events = store.get_events(&session.id, u64::MAX, None).await.unwrap();
1200 assert_eq!(all_events.len(), 1);
1201 assert_eq!(all_events[0].0, 0);
1202 }
1203
1204 #[tokio::test]
1205 async fn test_session_listing() {
1206 let (store, _temp) = create_test_store().await;
1207
1208 for i in 0..3 {
1210 let mut config = create_test_session_config();
1211 config.metadata.insert("index".to_string(), i.to_string());
1212 store.create_session(config).await.unwrap();
1213 }
1214
1215 let filter = SessionFilter {
1216 limit: Some(2),
1217 order_by: SessionOrderBy::CreatedAt,
1218 ..Default::default()
1219 };
1220
1221 let sessions = store.list_sessions(filter).await.unwrap();
1222 assert_eq!(sessions.len(), 2);
1223 }
1224
1225 #[tokio::test]
1226 async fn test_last_model_tracking() {
1227 let (store, _temp) = create_test_store().await;
1228
1229 let config = create_test_session_config();
1230 let session = store.create_session(config).await.unwrap();
1231
1232 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1234 assert_eq!(sessions.len(), 1);
1235 assert_eq!(sessions[0].last_model, None);
1236
1237 let claude_model = Model::Claude3_5Sonnet20241022;
1239 let message_event = StreamEvent::MessageComplete {
1240 message: Message {
1241 data: MessageData::Assistant {
1242 content: vec![AssistantContent::Text {
1243 text: "Hello from Claude".to_string(),
1244 }],
1245 },
1246 timestamp: 123456789,
1247 id: "msg1".to_string(),
1248 parent_message_id: None,
1249 },
1250 usage: None,
1251 metadata: std::collections::HashMap::new(),
1252 model: claude_model,
1253 };
1254 store
1255 .append_event(&session.id, &message_event)
1256 .await
1257 .unwrap();
1258
1259 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1261 assert_eq!(sessions[0].last_model, Some(claude_model));
1262
1263 let gpt_model = Model::Gpt4_1_20250414;
1265 let tool_event = StreamEvent::ToolCallFailed {
1266 tool_call_id: "tool1".to_string(),
1267 error: "Test error".to_string(),
1268 metadata: std::collections::HashMap::new(),
1269 model: gpt_model,
1270 };
1271 store.append_event(&session.id, &tool_event).await.unwrap();
1272
1273 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1275 assert_eq!(sessions[0].last_model, Some(gpt_model));
1276
1277 let session_event = StreamEvent::SessionSaved {
1279 session_id: session.id.clone(),
1280 };
1281 store
1282 .append_event(&session.id, &session_event)
1283 .await
1284 .unwrap();
1285
1286 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1288 assert_eq!(sessions[0].last_model, Some(gpt_model));
1289 }
1290}