1use async_trait::async_trait;
2use chrono::Utc;
3use serde_json;
4use sqlx::{
5 Row,
6 sqlite::{
7 SqliteConnectOptions, SqliteJournalMode, SqlitePool, SqlitePoolOptions, SqliteSynchronous,
8 },
9};
10use std::collections::HashSet;
11use std::path::Path;
12use std::str::FromStr;
13use uuid::Uuid;
14
15use crate::app::Message;
16use crate::app::conversation::{AssistantContent, MessageData, UserContent};
17use crate::events::StreamEvent;
18use crate::session::{
20 Session, SessionConfig, SessionFilter, SessionInfo, SessionOrderBy, SessionState,
21 SessionStatus, SessionStore, SessionStoreError, ToolApprovalPolicy, ToolCallState,
22 ToolCallStatus, ToolCallUpdate, ToolExecutionStats,
23};
24use steer_tools::ToolCall;
25use steer_tools::result::ToolResult;
26
27pub struct SqliteSessionStore {
29 pool: SqlitePool,
30}
31
32impl SqliteSessionStore {
33 pub async fn new(path: &Path) -> Result<Self, SessionStoreError> {
35 if let Some(parent) = path.parent() {
37 std::fs::create_dir_all(parent).map_err(|e| {
38 SessionStoreError::connection(format!("Failed to create directory: {e}"))
39 })?;
40 }
41
42 let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path.display()))
43 .map_err(|e| SessionStoreError::connection(format!("Invalid SQLite path: {e}")))?
44 .create_if_missing(true)
45 .journal_mode(SqliteJournalMode::Wal)
46 .synchronous(SqliteSynchronous::Normal)
47 .foreign_keys(true);
48
49 let pool = SqlitePoolOptions::new()
50 .max_connections(1) .connect_with(options)
52 .await
53 .map_err(|e| {
54 SessionStoreError::connection(format!("Failed to connect to SQLite: {e}"))
55 })?;
56
57 sqlx::migrate!("migrations/sqlite")
59 .run(&pool)
60 .await
61 .map_err(|e| SessionStoreError::Migration {
62 message: format!("Failed to run migrations: {e}"),
63 })?;
64
65 Ok(Self { pool })
66 }
67
68 fn parse_tool_policy(
70 policy_type: &str,
71 pre_approved_json: &str,
72 ) -> Result<ToolApprovalPolicy, SessionStoreError> {
73 let pre_approved: Vec<String> = serde_json::from_str(pre_approved_json).map_err(|e| {
74 SessionStoreError::serialization(format!("Invalid pre_approved_tools: {e}"))
75 })?;
76
77 match policy_type {
78 "always_ask" => Ok(ToolApprovalPolicy::AlwaysAsk),
79 "pre_approved" => Ok(ToolApprovalPolicy::PreApproved {
80 tools: pre_approved.into_iter().collect(),
81 }),
82 "mixed" => Ok(ToolApprovalPolicy::Mixed {
83 pre_approved: pre_approved.into_iter().collect(),
84 ask_for_others: true,
85 }),
86 _ => Err(SessionStoreError::validation(format!(
87 "Invalid tool policy type: {policy_type}"
88 ))),
89 }
90 }
91
92 fn serialize_tool_policy(policy: &ToolApprovalPolicy) -> (String, String) {
94 match policy {
95 ToolApprovalPolicy::AlwaysAsk => ("always_ask".to_string(), "[]".to_string()),
96 ToolApprovalPolicy::PreApproved { tools } => {
97 let tools_vec: Vec<String> = tools.iter().cloned().collect();
98 (
99 "pre_approved".to_string(),
100 serde_json::to_string(&tools_vec).unwrap(),
101 )
102 }
103 ToolApprovalPolicy::Mixed { pre_approved, .. } => {
104 let tools_vec: Vec<String> = pre_approved.iter().cloned().collect();
105 (
106 "mixed".to_string(),
107 serde_json::to_string(&tools_vec).unwrap(),
108 )
109 }
110 }
111 }
112}
113
114#[async_trait]
115impl SessionStore for SqliteSessionStore {
116 async fn create_session(&self, config: SessionConfig) -> Result<Session, SessionStoreError> {
117 let id = Uuid::new_v4().to_string();
118 let now = Utc::now();
119 let (policy_type, pre_approved_json) =
120 Self::serialize_tool_policy(&config.tool_config.approval_policy);
121 let metadata_json = serde_json::to_string(&config.metadata).map_err(|e| {
122 SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
123 })?;
124 let tool_config_json = serde_json::to_string(&config.tool_config).map_err(|e| {
125 SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
126 })?;
127 let workspace_config_json = serde_json::to_string(&config.workspace).map_err(|e| {
128 SessionStoreError::serialization(format!("Failed to serialize workspace_config: {e}"))
129 })?;
130
131 sqlx::query(
132 r#"
133 INSERT INTO sessions (id, created_at, updated_at, status, metadata,
134 tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt)
135 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
136 "#,
137 )
138 .bind(&id)
139 .bind(now)
140 .bind(now)
141 .bind("inactive") .bind(&metadata_json)
143 .bind(&policy_type)
144 .bind(&pre_approved_json)
145 .bind(&tool_config_json)
146 .bind(&workspace_config_json)
147 .bind(&config.system_prompt)
148 .execute(&self.pool)
149 .await
150 .map_err(|e| SessionStoreError::database(format!("Failed to create session: {e}")))?;
151
152 Ok(Session {
153 id: id.clone(),
154 created_at: now,
155 updated_at: now,
156 config,
157 state: SessionState::default(),
158 })
159 }
160
161 async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionStoreError> {
162 let row = sqlx::query(
163 r#"
164 SELECT id, created_at, updated_at, metadata,
165 tool_policy_type, pre_approved_tools, tool_config, workspace_config, system_prompt,
166 active_message_id
167 FROM sessions
168 WHERE id = ?1
169 "#,
170 )
171 .bind(session_id)
172 .fetch_optional(&self.pool)
173 .await
174 .map_err(|e| SessionStoreError::database(format!("Failed to get session: {e}")))?;
175
176 let Some(row) = row else {
177 return Ok(None);
178 };
179
180 let approval_policy = Self::parse_tool_policy(
181 &row.get::<String, _>("tool_policy_type"),
182 &row.get::<String, _>("pre_approved_tools"),
183 )?;
184
185 let metadata: std::collections::HashMap<String, String> =
186 serde_json::from_str(&row.get::<String, _>("metadata"))
187 .map_err(|e| SessionStoreError::serialization(format!("Invalid metadata: {e}")))?;
188
189 let tool_config = serde_json::from_str(&row.get::<String, _>("tool_config"))
190 .map_err(|e| SessionStoreError::serialization(format!("Invalid tool_config: {e}")))?;
191 let workspace_config = serde_json::from_str(&row.get::<String, _>("workspace_config"))
192 .map_err(|e| {
193 SessionStoreError::serialization(format!("Invalid workspace_config: {e}"))
194 })?;
195
196 let mut tool_config: crate::session::SessionToolConfig = tool_config;
197 tool_config.approval_policy = approval_policy;
198
199 let system_prompt: Option<String> = row.get("system_prompt");
200
201 let config = SessionConfig {
202 workspace: workspace_config,
203 tool_config,
204 system_prompt,
205 metadata,
206 };
207
208 let messages = self.get_messages(session_id, None).await?;
210
211 let tool_calls_rows = sqlx::query(
213 r#"
214 SELECT id, tool_name, parameters, status, result, error, started_at, completed_at, kind, payload_json, error_json
215 FROM tool_calls
216 WHERE session_id = ?1
217 "#,
218 )
219 .bind(session_id)
220 .fetch_all(&self.pool)
221 .await
222 .map_err(|e| SessionStoreError::database(format!("Failed to load tool calls: {e}")))?;
223
224 let mut tool_calls = std::collections::HashMap::new();
225 for row in tool_calls_rows {
226 let id: String = row.get("id");
227 let status_str: String = row.get("status");
228 let error: Option<String> = row.get("error");
229
230 let status = match status_str.as_str() {
231 "pending" => ToolCallStatus::PendingApproval,
232 "approved" => ToolCallStatus::Approved,
233 "denied" => ToolCallStatus::Denied,
234 "executing" => ToolCallStatus::Executing,
235 "completed" => ToolCallStatus::Completed,
236 "failed" => ToolCallStatus::Failed {
237 error: error.unwrap_or_else(|| "Unknown error".to_string()),
238 },
239 _ => {
240 return Err(SessionStoreError::validation(format!(
241 "Invalid tool call status: {status_str}"
242 )));
243 }
244 };
245
246 let tool_call = ToolCall {
247 id: id.clone(),
248 name: row.get("tool_name"),
249 parameters: serde_json::from_str(&row.get::<String, _>("parameters")).map_err(
250 |e| SessionStoreError::serialization(format!("Invalid tool parameters: {e}")),
251 )?,
252 };
253
254 let result: Option<String> = row.get("result");
255 let json_result: Option<String> = row.get("payload_json");
256 let result_type: Option<String> = row.get("kind");
257 let error_json: Option<String> = row.get("error_json");
258
259 let _tool_result = if let Some(kind) = result_type.as_ref() {
260 if kind == "error" {
261 let error_data = error_json.and_then(|json_str| {
263 serde_json::from_str::<serde_json::Value>(&json_str).ok()
264 });
265 Some(ToolExecutionStats {
266 output: result.clone(),
267 json_output: error_data,
268 result_type: Some("error".to_string()),
269 success: false,
270 execution_time_ms: 0,
271 metadata: std::collections::HashMap::new(),
272 })
273 } else if let Some(json_str) = json_result {
274 let json_value = serde_json::from_str(&json_str).map_err(|e| {
276 SessionStoreError::serialization(format!("Invalid JSON result: {e}"))
277 })?;
278 Some(ToolExecutionStats {
279 output: result.clone(),
280 json_output: Some(json_value),
281 result_type,
282 success: true,
283 execution_time_ms: 0,
284 metadata: std::collections::HashMap::new(),
285 })
286 } else {
287 Some(ToolExecutionStats {
289 output: result.clone(),
290 json_output: None,
291 result_type,
292 success: true,
293 execution_time_ms: 0,
294 metadata: std::collections::HashMap::new(),
295 })
296 }
297 } else {
298 result.map(|r| ToolExecutionStats {
299 output: Some(r),
300 json_output: None,
301 result_type: None,
302 success: true,
303 execution_time_ms: 0,
304 metadata: std::collections::HashMap::new(),
305 })
306 };
307
308 let tool_result: Option<ToolResult> = None;
310
311 let state = ToolCallState {
312 tool_call,
313 status,
314 started_at: row.get("started_at"),
315 completed_at: row.get("completed_at"),
316 result: tool_result,
317 };
318
319 tool_calls.insert(id, state);
320 }
321
322 let last_sequence: Option<i64> =
324 sqlx::query_scalar("SELECT MAX(sequence_num) FROM events WHERE session_id = ?1")
325 .bind(session_id)
326 .fetch_one(&self.pool)
327 .await
328 .map_err(|e| {
329 SessionStoreError::database(format!("Failed to get last event sequence: {e}"))
330 })?;
331
332 let approved_bash_patterns: HashSet<String> =
334 if let Some(bash_config) = config.tool_config.tools.get("bash") {
335 let crate::session::state::ToolSpecificConfig::Bash(bash) = bash_config;
336 bash.approved_patterns.iter().cloned().collect()
337 } else {
338 HashSet::new()
339 };
340
341 let active_message_id: Option<String> = row.get("active_message_id");
343
344 let state = SessionState {
345 messages,
346 tool_calls,
347 approved_tools: Default::default(), approved_bash_patterns,
349 last_event_sequence: last_sequence.unwrap_or(0) as u64,
350 metadata: Default::default(),
351 active_message_id,
352 mcp_servers: Default::default(), };
354
355 Ok(Some(Session {
356 id: row.get("id"),
357 created_at: row.get("created_at"),
358 updated_at: row.get("updated_at"),
359 config,
360 state,
361 }))
362 }
363
364 async fn update_session(&self, session: &Session) -> Result<(), SessionStoreError> {
365 let metadata_json = serde_json::to_string(&session.config.metadata).map_err(|e| {
366 SessionStoreError::serialization(format!("Failed to serialize metadata: {e}"))
367 })?;
368 let tool_config_json = serde_json::to_string(&session.config.tool_config).map_err(|e| {
369 SessionStoreError::serialization(format!("Failed to serialize tool_config: {e}"))
370 })?;
371 let workspace_config_json =
372 serde_json::to_string(&session.config.workspace).map_err(|e| {
373 SessionStoreError::serialization(format!(
374 "Failed to serialize workspace_config: {e}"
375 ))
376 })?;
377 let (policy_type, pre_approved_json) =
378 Self::serialize_tool_policy(&session.config.tool_config.approval_policy);
379
380 sqlx::query(
381 r#"
382 UPDATE sessions
383 SET updated_at = ?2, metadata = ?3,
384 tool_policy_type = ?4, pre_approved_tools = ?5, tool_config = ?6, workspace_config = ?7
385 WHERE id = ?1
386 "#,
387 )
388 .bind(&session.id)
389 .bind(Utc::now())
390 .bind(&metadata_json)
391 .bind(&policy_type)
392 .bind(&pre_approved_json)
393 .bind(&tool_config_json)
394 .bind(&workspace_config_json)
395 .execute(&self.pool)
396 .await
397 .map_err(|e| SessionStoreError::database(format!("Failed to update session: {e}")))?;
398
399 Ok(())
400 }
401
402 async fn delete_session(&self, session_id: &str) -> Result<(), SessionStoreError> {
403 sqlx::query("DELETE FROM sessions WHERE id = ?1")
404 .bind(session_id)
405 .execute(&self.pool)
406 .await
407 .map_err(|e| SessionStoreError::database(format!("Failed to delete session: {e}")))?;
408
409 Ok(())
410 }
411
412 async fn list_sessions(
413 &self,
414 filter: SessionFilter,
415 ) -> Result<Vec<SessionInfo>, SessionStoreError> {
416 let mut query = String::from(
417 r#"
418 SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
419 (SELECT e.event_data
420 FROM events e
421 WHERE e.session_id = s.id
422 AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
423 ORDER BY e.sequence_num DESC
424 LIMIT 1) as last_event_data
425 FROM sessions s
426 WHERE 1=1
427 "#,
428 );
429 let mut bindings: Vec<String> = Vec::new();
430
431 if let Some(created_after) = filter.created_after {
433 query.push_str(&format!(" AND s.created_at >= ?{}", bindings.len() + 1));
434 bindings.push(created_after.to_rfc3339());
435 }
436 if let Some(created_before) = filter.created_before {
437 query.push_str(&format!(" AND s.created_at <= ?{}", bindings.len() + 1));
438 bindings.push(created_before.to_rfc3339());
439 }
440 if let Some(updated_after) = filter.updated_after {
441 query.push_str(&format!(" AND s.updated_at >= ?{}", bindings.len() + 1));
442 bindings.push(updated_after.to_rfc3339());
443 }
444 if let Some(updated_before) = filter.updated_before {
445 query.push_str(&format!(" AND s.updated_at <= ?{}", bindings.len() + 1));
446 bindings.push(updated_before.to_rfc3339());
447 }
448 if let Some(status) = filter.status_filter {
449 let status_str = match status {
450 SessionStatus::Active => "active",
451 SessionStatus::Inactive => "inactive",
452 };
453 query.push_str(&format!(" AND s.status = ?{}", bindings.len() + 1));
454 bindings.push(status_str.to_string());
455 }
456
457 let order_column = match filter.order_by {
459 SessionOrderBy::CreatedAt => "s.created_at",
460 SessionOrderBy::UpdatedAt => "s.updated_at",
461 SessionOrderBy::MessageCount => {
462 query = r#"
464 SELECT s.id, s.created_at, s.updated_at, s.status, s.metadata,
465 (SELECT e.event_data
466 FROM events e
467 WHERE e.session_id = s.id
468 AND e.event_type IN ('message_complete', 'tool_call_started', 'tool_call_completed', 'tool_call_failed')
469 ORDER BY e.sequence_num DESC
470 LIMIT 1) as last_event_data,
471 (SELECT COUNT(*) FROM messages WHERE session_id = s.id) as message_count
472 FROM sessions s
473 WHERE 1=1
474 "#.to_string();
475 "message_count"
476 }
477 };
478
479 let order_direction = match filter.order_direction {
480 crate::session::OrderDirection::Ascending => "ASC",
481 crate::session::OrderDirection::Descending => "DESC",
482 };
483
484 query.push_str(&format!(" ORDER BY {order_column} {order_direction}"));
485
486 if let Some(limit) = filter.limit {
488 query.push_str(&format!(" LIMIT {limit}"));
489 }
490 if let Some(offset) = filter.offset {
491 query.push_str(&format!(" OFFSET {offset}"));
492 }
493
494 let mut q = sqlx::query(&query);
496 for binding in bindings {
497 q = q.bind(binding);
498 }
499
500 let rows = q
501 .fetch_all(&self.pool)
502 .await
503 .map_err(|e| SessionStoreError::database(format!("Failed to list sessions: {e}")))?;
504
505 let mut sessions = Vec::new();
506 for row in rows {
507 let metadata: std::collections::HashMap<String, String> =
508 serde_json::from_str(&row.get::<String, _>("metadata")).map_err(|e| {
509 SessionStoreError::serialization(format!("Invalid metadata: {e}"))
510 })?;
511
512 let message_count: i64 = if matches!(filter.order_by, SessionOrderBy::MessageCount) {
514 row.get("message_count")
515 } else {
516 sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE session_id = ?1")
517 .bind(row.get::<String, _>("id"))
518 .fetch_one(&self.pool)
519 .await
520 .map_err(|e| {
521 SessionStoreError::database(format!("Failed to count messages: {e}"))
522 })?
523 };
524
525 let last_model =
527 if let Some(event_json) = row.get::<Option<String>, _>("last_event_data") {
528 let event: StreamEvent = serde_json::from_str(&event_json).map_err(|e| {
529 SessionStoreError::serialization(format!("Invalid event data: {e}"))
530 })?;
531
532 match event {
533 StreamEvent::MessageComplete { model, .. } => Some(model),
534 StreamEvent::ToolCallStarted { model, .. } => Some(model),
535 StreamEvent::ToolCallCompleted { model, .. } => Some(model),
536 StreamEvent::ToolCallFailed { model, .. } => Some(model),
537 _ => None,
538 }
539 } else {
540 None
541 };
542
543 sessions.push(SessionInfo {
544 id: row.get("id"),
545 created_at: row.get("created_at"),
546 updated_at: row.get("updated_at"),
547 last_model,
548 message_count: message_count as usize,
549 metadata,
550 });
551 }
552
553 Ok(sessions)
554 }
555
556 async fn append_message(
557 &self,
558 session_id: &str,
559 message: &Message,
560 ) -> Result<(), SessionStoreError> {
561 let id = message.id();
562
563 let next_seq: i64 = sqlx::query_scalar(
565 "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM messages WHERE session_id = ?1",
566 )
567 .bind(session_id)
568 .fetch_one(&self.pool)
569 .await
570 .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
571
572 let (role, content_json) = match &message.data {
574 MessageData::User { content, .. } => {
575 let json = serde_json::to_string(&content).map_err(|e| {
576 SessionStoreError::serialization(format!(
577 "Failed to serialize user content: {e}"
578 ))
579 })?;
580 ("user", json)
581 }
582 MessageData::Assistant { content, .. } => {
583 let json = serde_json::to_string(&content).map_err(|e| {
584 SessionStoreError::serialization(format!(
585 "Failed to serialize assistant content: {e}"
586 ))
587 })?;
588 ("assistant", json)
589 }
590 MessageData::Tool {
591 tool_use_id,
592 result,
593 ..
594 } => {
595 #[derive(serde::Serialize)]
597 struct StoredToolMessage {
598 tool_use_id: String,
599 result: crate::app::conversation::ToolResult,
600 }
601 let stored = StoredToolMessage {
602 tool_use_id: tool_use_id.clone(),
603 result: result.clone(),
604 };
605 let json = serde_json::to_string(&stored).map_err(|e| {
606 SessionStoreError::serialization(format!(
607 "Failed to serialize tool message: {e}"
608 ))
609 })?;
610 ("tool", json)
611 }
612 };
613
614 let parent_message_id = message.parent_message_id();
616
617 sqlx::query(
618 r#"
619 INSERT INTO messages (id, session_id, sequence_num, role, content, created_at, parent_message_id)
620 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
621 "#,
622 )
623 .bind(id)
624 .bind(session_id)
625 .bind(next_seq)
626 .bind(role)
627 .bind(&content_json)
628 .bind(Utc::now())
629 .bind(parent_message_id)
630 .execute(&self.pool)
631 .await
632 .map_err(|e| SessionStoreError::database(format!("Failed to append message: {e}")))?;
633
634 Ok(())
635 }
636
637 async fn get_messages(
638 &self,
639 session_id: &str,
640 after_sequence: Option<u32>,
641 ) -> Result<Vec<Message>, SessionStoreError> {
642 let query = if let Some(seq) = after_sequence {
643 sqlx::query(
644 r#"
645 SELECT id, sequence_num, role, content, created_at, parent_message_id
646 FROM messages
647 WHERE session_id = ?1 AND sequence_num > ?2
648 ORDER BY sequence_num ASC
649 "#,
650 )
651 .bind(session_id)
652 .bind(seq as i64)
653 } else {
654 sqlx::query(
655 r#"
656 SELECT id, sequence_num, role, content, created_at, parent_message_id
657 FROM messages
658 WHERE session_id = ?1
659 ORDER BY sequence_num ASC
660 "#,
661 )
662 .bind(session_id)
663 };
664
665 let rows = query
666 .fetch_all(&self.pool)
667 .await
668 .map_err(|e| SessionStoreError::database(format!("Failed to get messages: {e}")))?;
669
670 let mut messages = Vec::new();
671 for row in rows {
672 let role = row.get::<String, _>("role");
673 let content = row.get::<String, _>("content");
674 let id: String = row.get("id");
675 let created_at = row.get::<chrono::DateTime<chrono::Utc>, _>("created_at");
676
677 let parent_message_id: Option<String> = row.get("parent_message_id");
678
679 let message = match role.as_str() {
681 "user" => {
682 let content: Vec<UserContent> =
683 serde_json::from_str(&content).map_err(|e| {
684 SessionStoreError::serialization(format!(
685 "Failed to deserialize user content: {e}"
686 ))
687 })?;
688 Message {
689 data: MessageData::User { content },
690 timestamp: created_at.timestamp() as u64,
691 id,
692 parent_message_id,
693 }
694 }
695 "assistant" => {
696 let content: Vec<AssistantContent> =
697 serde_json::from_str(&content).map_err(|e| {
698 SessionStoreError::serialization(format!(
699 "Failed to deserialize assistant content: {e}"
700 ))
701 })?;
702 Message {
703 data: MessageData::Assistant { content },
704 timestamp: created_at.timestamp() as u64,
705 id,
706 parent_message_id,
707 }
708 }
709 "tool" => {
710 #[derive(serde::Deserialize)]
711 struct StoredToolMessage {
712 tool_use_id: String,
713 result: crate::app::conversation::ToolResult,
714 }
715 let stored: StoredToolMessage =
716 serde_json::from_str(&content).map_err(|e| {
717 SessionStoreError::serialization(format!(
718 "Failed to deserialize tool message: {e}"
719 ))
720 })?;
721 Message {
722 data: MessageData::Tool {
723 tool_use_id: stored.tool_use_id,
724 result: stored.result,
725 },
726 timestamp: created_at.timestamp() as u64,
727 id: id.clone(),
728 parent_message_id: parent_message_id.clone(),
729 }
730 }
731 _ => {
732 return Err(SessionStoreError::serialization(format!(
733 "Unknown message role: {role}"
734 )));
735 }
736 };
737
738 messages.push(message);
739 }
740
741 Ok(messages)
742 }
743
744 async fn create_tool_call(
745 &self,
746 session_id: &str,
747 tool_call: &ToolCall,
748 ) -> Result<(), SessionStoreError> {
749 let parameters_json = serde_json::to_string(&tool_call.parameters).map_err(|e| {
750 SessionStoreError::serialization(format!("Failed to serialize parameters: {e}"))
751 })?;
752
753 sqlx::query(
754 r#"
755 INSERT INTO tool_calls (id, session_id, tool_name, parameters, status, kind)
756 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
757 "#,
758 )
759 .bind(&tool_call.id)
760 .bind(session_id)
761 .bind(&tool_call.name)
762 .bind(¶meters_json)
763 .bind("pending")
764 .bind("external") .execute(&self.pool)
766 .await
767 .map_err(|e| SessionStoreError::database(format!("Failed to create tool call: {e}")))?;
768
769 Ok(())
770 }
771
772 async fn update_tool_call(
773 &self,
774 tool_call_id: &str,
775 update: ToolCallUpdate,
776 ) -> Result<(), SessionStoreError> {
777 let mut query = String::from("UPDATE tool_calls SET ");
778 let mut updates = Vec::new();
779 let mut bindings: Vec<String> = Vec::new();
780
781 if let Some(status) = update.status {
782 let status_str = match &status {
783 ToolCallStatus::PendingApproval => "pending",
784 ToolCallStatus::Approved => "approved",
785 ToolCallStatus::Denied => "denied",
786 ToolCallStatus::Executing => "executing",
787 ToolCallStatus::Completed => "completed",
788 ToolCallStatus::Failed { .. } => "failed",
789 };
790 updates.push(format!("status = ?{}", bindings.len() + 1));
791 bindings.push(status_str.to_string());
792
793 match status {
795 ToolCallStatus::Executing => {
796 updates.push(format!("started_at = ?{}", bindings.len() + 1));
797 bindings.push(Utc::now().to_rfc3339());
798 }
799 ToolCallStatus::Completed | ToolCallStatus::Failed { .. } => {
800 updates.push(format!("completed_at = ?{}", bindings.len() + 1));
801 bindings.push(Utc::now().to_rfc3339());
802 }
803 _ => {}
804 }
805 }
806
807 if let Some(result) = update.result {
808 let kind = if let Some(rt) = &result.result_type {
810 rt.clone()
811 } else {
812 "external".to_string()
813 };
814
815 updates.push(format!("kind = ?{}", bindings.len() + 1));
816 bindings.push(kind);
817
818 if let Some(output) = &result.output {
820 updates.push(format!("result = ?{}", bindings.len() + 1));
821 bindings.push(output.clone());
822 }
823
824 if let Some(json_output) = &result.json_output {
826 updates.push(format!("payload_json = ?{}", bindings.len() + 1));
827 let json_str = serde_json::to_string(json_output).map_err(|e| {
828 SessionStoreError::serialization(format!(
829 "Failed to serialize JSON result: {e}"
830 ))
831 })?;
832 bindings.push(json_str);
833 } else if let Some(output) = &result.output {
834 updates.push(format!("payload_json = ?{}", bindings.len() + 1));
836 let external_json = serde_json::json!({
837 "tool_name": "unknown",
838 "payload": output
839 });
840 bindings.push(external_json.to_string());
841 }
842 }
843
844 if let Some(error) = update.error {
845 updates.push(format!("kind = ?{}", bindings.len() + 1));
847 bindings.push("error".to_string());
848
849 updates.push(format!("error_json = ?{}", bindings.len() + 1));
851 let error_json = serde_json::json!({
852 "tool_name": "unknown",
853 "message": &error
854 });
855 bindings.push(error_json.to_string());
856
857 updates.push(format!("error = ?{}", bindings.len() + 1));
859 bindings.push(error);
860 }
861
862 if updates.is_empty() {
863 return Ok(());
864 }
865
866 query.push_str(&updates.join(", "));
867 query.push_str(&format!(" WHERE id = ?{}", bindings.len() + 1));
868 bindings.push(tool_call_id.to_string());
869
870 let mut q = sqlx::query(&query);
872 for binding in bindings {
873 q = q.bind(binding);
874 }
875
876 q.execute(&self.pool)
877 .await
878 .map_err(|e| SessionStoreError::database(format!("Failed to update tool call: {e}")))?;
879
880 Ok(())
881 }
882
883 async fn get_pending_tool_calls(
884 &self,
885 session_id: &str,
886 ) -> Result<Vec<ToolCall>, SessionStoreError> {
887 let rows = sqlx::query(
888 r#"
889 SELECT id, tool_name, parameters
890 FROM tool_calls
891 WHERE session_id = ?1 AND status = 'pending'
892 ORDER BY id ASC
893 "#,
894 )
895 .bind(session_id)
896 .fetch_all(&self.pool)
897 .await
898 .map_err(|e| {
899 SessionStoreError::database(format!("Failed to get pending tool calls: {e}"))
900 })?;
901
902 let mut tool_calls = Vec::new();
903 for row in rows {
904 let parameters: serde_json::Value =
905 serde_json::from_str(&row.get::<String, _>("parameters")).map_err(|e| {
906 SessionStoreError::serialization(format!("Invalid parameters: {e}"))
907 })?;
908
909 tool_calls.push(ToolCall {
910 id: row.get("id"),
911 name: row.get("tool_name"),
912 parameters,
913 });
914 }
915
916 Ok(tool_calls)
917 }
918
919 async fn append_event(
920 &self,
921 session_id: &str,
922 event: &StreamEvent,
923 ) -> Result<u64, SessionStoreError> {
924 let event_type = match event {
925 StreamEvent::MessagePart { .. } => "message_part",
926 StreamEvent::MessageComplete { .. } => "message_complete",
927 StreamEvent::ToolCallStarted { .. } => "tool_call_started",
928 StreamEvent::ToolCallCompleted { .. } => "tool_call_completed",
929 StreamEvent::ToolCallFailed { .. } => "tool_call_failed",
930 StreamEvent::ToolApprovalRequired { .. } => "tool_approval_required",
931 StreamEvent::SessionCreated { .. } => "session_created",
932 StreamEvent::SessionResumed { .. } => "session_resumed",
933 StreamEvent::SessionSaved { .. } => "session_saved",
934 StreamEvent::OperationStarted { .. } => "operation_started",
935 StreamEvent::OperationCompleted { .. } => "operation_completed",
936 StreamEvent::OperationCancelled { .. } => "operation_cancelled",
937 StreamEvent::Error { .. } => "error",
938 StreamEvent::WorkspaceChanged => "workspace_changed",
939 StreamEvent::WorkspaceFiles { .. } => "workspace_files",
940 };
941
942 let event_data = serde_json::to_string(event).map_err(|e| {
943 SessionStoreError::serialization(format!("Failed to serialize event: {e}"))
944 })?;
945
946 let next_seq: i64 = sqlx::query_scalar(
948 "SELECT COALESCE(MAX(sequence_num), -1) + 1 FROM events WHERE session_id = ?1",
949 )
950 .bind(session_id)
951 .fetch_one(&self.pool)
952 .await
953 .map_err(|e| SessionStoreError::database(format!("Failed to get next sequence: {e}")))?;
954
955 sqlx::query(
956 r#"
957 INSERT INTO events (session_id, sequence_num, event_type, event_data, created_at)
958 VALUES (?1, ?2, ?3, ?4, ?5)
959 "#,
960 )
961 .bind(session_id)
962 .bind(next_seq)
963 .bind(event_type)
964 .bind(&event_data)
965 .bind(Utc::now())
966 .execute(&self.pool)
967 .await
968 .map_err(|e| SessionStoreError::database(format!("Failed to append event: {e}")))?;
969
970 Ok(next_seq as u64)
971 }
972
973 async fn get_events(
974 &self,
975 session_id: &str,
976 after_sequence: u64,
977 limit: Option<u32>,
978 ) -> Result<Vec<(u64, StreamEvent)>, SessionStoreError> {
979 let query = if let Some(limit) = limit {
980 sqlx::query(
981 r#"
982 SELECT sequence_num, event_data
983 FROM events
984 WHERE session_id = ?1 AND sequence_num > ?2
985 ORDER BY sequence_num ASC
986 LIMIT ?3
987 "#,
988 )
989 .bind(session_id)
990 .bind(after_sequence as i64)
991 .bind(limit as i64)
992 } else {
993 sqlx::query(
994 r#"
995 SELECT sequence_num, event_data
996 FROM events
997 WHERE session_id = ?1 AND sequence_num > ?2
998 ORDER BY sequence_num ASC
999 "#,
1000 )
1001 .bind(session_id)
1002 .bind(after_sequence as i64)
1003 };
1004
1005 let rows = query
1006 .fetch_all(&self.pool)
1007 .await
1008 .map_err(|e| SessionStoreError::database(format!("Failed to get events: {e}")))?;
1009
1010 let mut events = Vec::new();
1011 for row in rows {
1012 let seq: i64 = row.get("sequence_num");
1013 let event: StreamEvent = serde_json::from_str(&row.get::<String, _>("event_data"))
1014 .map_err(|e| {
1015 SessionStoreError::serialization(format!("Invalid event data: {e}"))
1016 })?;
1017
1018 events.push((seq as u64, event));
1019 }
1020
1021 Ok(events)
1022 }
1023
1024 async fn delete_events_before(
1025 &self,
1026 session_id: &str,
1027 before_sequence: u64,
1028 ) -> Result<u64, SessionStoreError> {
1029 let result = sqlx::query("DELETE FROM events WHERE session_id = ?1 AND sequence_num < ?2")
1030 .bind(session_id)
1031 .bind(before_sequence as i64)
1032 .execute(&self.pool)
1033 .await
1034 .map_err(|e| SessionStoreError::database(format!("Failed to delete events: {e}")))?;
1035
1036 Ok(result.rows_affected())
1037 }
1038
1039 async fn update_active_message_id(
1040 &self,
1041 session_id: &str,
1042 message_id: Option<&str>,
1043 ) -> Result<(), SessionStoreError> {
1044 sqlx::query("UPDATE sessions SET active_message_id = ?2, updated_at = ?3 WHERE id = ?1")
1045 .bind(session_id)
1046 .bind(message_id)
1047 .bind(Utc::now())
1048 .execute(&self.pool)
1049 .await
1050 .map_err(|e| {
1051 SessionStoreError::database(format!("Failed to update active_message_id: {e}"))
1052 })?;
1053
1054 Ok(())
1055 }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060 use crate::api::Model;
1061 use crate::app::conversation::{AssistantContent, Message, Role, UserContent};
1062 use crate::events::SessionMetadata;
1063 use crate::session::ToolVisibility;
1064 use crate::session::state::WorkspaceConfig;
1065
1066 use super::*;
1067 use tempfile::TempDir;
1068
1069 async fn create_test_store() -> (SqliteSessionStore, TempDir) {
1070 let temp_dir = TempDir::new().unwrap();
1071 let db_path = temp_dir.path().join("test.db");
1072 let store = SqliteSessionStore::new(&db_path).await.unwrap();
1073 (store, temp_dir)
1074 }
1075
1076 fn create_test_session_config() -> SessionConfig {
1077 let tool_config = crate::session::SessionToolConfig {
1078 approval_policy: ToolApprovalPolicy::AlwaysAsk,
1079 visibility: ToolVisibility::All,
1080 ..Default::default()
1081 };
1082
1083 SessionConfig {
1084 workspace: WorkspaceConfig::default(),
1085 tool_config,
1086 system_prompt: None,
1087 metadata: std::collections::HashMap::new(),
1088 }
1089 }
1090
1091 #[tokio::test]
1092 async fn test_create_and_get_session() {
1093 let (store, _temp) = create_test_store().await;
1094
1095 let tool_config = crate::session::SessionToolConfig {
1096 approval_policy: ToolApprovalPolicy::AlwaysAsk,
1097 ..Default::default()
1098 };
1099
1100 let config = SessionConfig {
1101 workspace: WorkspaceConfig::default(),
1102 tool_config,
1103 system_prompt: None,
1104 metadata: Default::default(),
1105 };
1106
1107 let session = store.create_session(config.clone()).await.unwrap();
1108 assert!(!session.id.is_empty());
1109
1110 let fetched_session = store.get_session(&session.id).await.unwrap().unwrap();
1111 assert_eq!(session.id, fetched_session.id);
1112 assert!(matches!(
1113 fetched_session.config.tool_config.approval_policy,
1114 ToolApprovalPolicy::AlwaysAsk
1115 ));
1116 assert!(matches!(
1117 fetched_session.config.workspace,
1118 WorkspaceConfig::Local { .. }
1119 ));
1120 }
1121
1122 #[tokio::test]
1123 async fn test_message_operations() {
1124 let (store, _temp) = create_test_store().await;
1125
1126 let config = create_test_session_config();
1127 let session = store.create_session(config).await.unwrap();
1128
1129 let message = Message {
1130 data: MessageData::User {
1131 content: vec![UserContent::Text {
1132 text: "Hello".to_string(),
1133 }],
1134 },
1135 timestamp: 123456789,
1136 id: "msg1".to_string(),
1137 parent_message_id: None,
1138 };
1139
1140 store.append_message(&session.id, &message).await.unwrap();
1141
1142 let messages = store.get_messages(&session.id, None).await.unwrap();
1143 assert_eq!(messages.len(), 1);
1144 assert_eq!(messages[0].role(), Role::User);
1145 }
1146
1147 #[tokio::test]
1148 async fn test_tool_call_operations() {
1149 let (store, _temp) = create_test_store().await;
1150
1151 let config = create_test_session_config();
1152 let session = store.create_session(config).await.unwrap();
1153
1154 let tool_call = ToolCall {
1155 id: "tc1".to_string(),
1156 name: "test_tool".to_string(),
1157 parameters: serde_json::json!({"param": "value"}),
1158 };
1159
1160 store
1161 .create_tool_call(&session.id, &tool_call)
1162 .await
1163 .unwrap();
1164
1165 let pending = store.get_pending_tool_calls(&session.id).await.unwrap();
1166 assert_eq!(pending.len(), 1);
1167 assert_eq!(pending[0].name, "test_tool");
1168
1169 let update = ToolCallUpdate::set_status(ToolCallStatus::Completed);
1170 store.update_tool_call(&tool_call.id, update).await.unwrap();
1171
1172 let pending_after = store.get_pending_tool_calls(&session.id).await.unwrap();
1173 assert_eq!(pending_after.len(), 0);
1174 }
1175
1176 #[tokio::test]
1177 async fn test_event_streaming() {
1178 let (store, _temp) = create_test_store().await;
1179
1180 let config = create_test_session_config();
1181 let session = store.create_session(config).await.unwrap();
1182
1183 let event = StreamEvent::SessionCreated {
1184 session_id: session.id.clone(),
1185 metadata: SessionMetadata {
1186 model: Model::Claude3_5Sonnet20241022,
1187 created_at: session.created_at,
1188 metadata: session.config.metadata,
1189 },
1190 };
1191
1192 let seq = store.append_event(&session.id, &event).await.unwrap();
1193 assert_eq!(seq, 0);
1194
1195 let events = store.get_events(&session.id, 0, None).await.unwrap();
1197 assert_eq!(events.len(), 0);
1198
1199 let all_events = store.get_events(&session.id, u64::MAX, None).await.unwrap();
1201 assert_eq!(all_events.len(), 1);
1202 assert_eq!(all_events[0].0, 0);
1203 }
1204
1205 #[tokio::test]
1206 async fn test_session_listing() {
1207 let (store, _temp) = create_test_store().await;
1208
1209 for i in 0..3 {
1211 let mut config = create_test_session_config();
1212 config.metadata.insert("index".to_string(), i.to_string());
1213 store.create_session(config).await.unwrap();
1214 }
1215
1216 let filter = SessionFilter {
1217 limit: Some(2),
1218 order_by: SessionOrderBy::CreatedAt,
1219 ..Default::default()
1220 };
1221
1222 let sessions = store.list_sessions(filter).await.unwrap();
1223 assert_eq!(sessions.len(), 2);
1224 }
1225
1226 #[tokio::test]
1227 async fn test_last_model_tracking() {
1228 let (store, _temp) = create_test_store().await;
1229
1230 let config = create_test_session_config();
1231 let session = store.create_session(config).await.unwrap();
1232
1233 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1235 assert_eq!(sessions.len(), 1);
1236 assert_eq!(sessions[0].last_model, None);
1237
1238 let claude_model = Model::Claude3_5Sonnet20241022;
1240 let message_event = StreamEvent::MessageComplete {
1241 message: Message {
1242 data: MessageData::Assistant {
1243 content: vec![AssistantContent::Text {
1244 text: "Hello from Claude".to_string(),
1245 }],
1246 },
1247 timestamp: 123456789,
1248 id: "msg1".to_string(),
1249 parent_message_id: None,
1250 },
1251 usage: None,
1252 metadata: std::collections::HashMap::new(),
1253 model: claude_model,
1254 };
1255 store
1256 .append_event(&session.id, &message_event)
1257 .await
1258 .unwrap();
1259
1260 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1262 assert_eq!(sessions[0].last_model, Some(claude_model));
1263
1264 let gpt_model = Model::Gpt4_1_20250414;
1266 let tool_event = StreamEvent::ToolCallFailed {
1267 tool_call_id: "tool1".to_string(),
1268 error: "Test error".to_string(),
1269 metadata: std::collections::HashMap::new(),
1270 model: gpt_model,
1271 };
1272 store.append_event(&session.id, &tool_event).await.unwrap();
1273
1274 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1276 assert_eq!(sessions[0].last_model, Some(gpt_model));
1277
1278 let session_event = StreamEvent::SessionSaved {
1280 session_id: session.id.clone(),
1281 };
1282 store
1283 .append_event(&session.id, &session_event)
1284 .await
1285 .unwrap();
1286
1287 let sessions = store.list_sessions(SessionFilter::default()).await.unwrap();
1289 assert_eq!(sessions[0].last_model, Some(gpt_model));
1290 }
1291}