1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use super::{Session, SessionConfig, SessionInfo};
7use crate::app::Message;
8use crate::events::StreamEvent;
9use steer_tools::ToolCall;
10
11#[async_trait]
13pub trait SessionStore: Send + Sync {
14 async fn create_session(&self, config: SessionConfig) -> Result<Session, SessionStoreError>;
16 async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionStoreError>;
17 async fn update_session(&self, session: &Session) -> Result<(), SessionStoreError>;
18 async fn delete_session(&self, session_id: &str) -> Result<(), SessionStoreError>;
19 async fn list_sessions(
20 &self,
21 filter: SessionFilter,
22 ) -> Result<Vec<SessionInfo>, SessionStoreError>;
23
24 async fn append_message(
26 &self,
27 session_id: &str,
28 message: &Message,
29 ) -> Result<(), SessionStoreError>;
30 async fn get_messages(
31 &self,
32 session_id: &str,
33 after_sequence: Option<u32>,
34 ) -> Result<Vec<Message>, SessionStoreError>;
35
36 async fn create_tool_call(
38 &self,
39 session_id: &str,
40 tool_call: &ToolCall,
41 ) -> Result<(), SessionStoreError>;
42 async fn update_tool_call(
43 &self,
44 tool_call_id: &str,
45 update: ToolCallUpdate,
46 ) -> Result<(), SessionStoreError>;
47 async fn get_pending_tool_calls(
48 &self,
49 session_id: &str,
50 ) -> Result<Vec<ToolCall>, SessionStoreError>;
51
52 async fn append_event(
54 &self,
55 session_id: &str,
56 event: &StreamEvent,
57 ) -> Result<u64, SessionStoreError>;
58 async fn get_events(
59 &self,
60 session_id: &str,
61 after_sequence: u64,
62 limit: Option<u32>,
63 ) -> Result<Vec<(u64, StreamEvent)>, SessionStoreError>;
64 async fn delete_events_before(
65 &self,
66 session_id: &str,
67 before_sequence: u64,
68 ) -> Result<u64, SessionStoreError>;
69
70 async fn update_active_message_id(
72 &self,
73 session_id: &str,
74 message_id: Option<&str>,
75 ) -> Result<(), SessionStoreError>;
76}
77
78#[derive(Debug, Clone, Default)]
80pub struct SessionFilter {
81 pub created_after: Option<DateTime<Utc>>,
83 pub created_before: Option<DateTime<Utc>>,
84
85 pub updated_after: Option<DateTime<Utc>>,
87 pub updated_before: Option<DateTime<Utc>>,
88
89 pub metadata_filters: HashMap<String, String>,
91
92 pub status_filter: Option<SessionStatus>,
94
95 pub limit: Option<u32>,
97 pub offset: Option<u32>,
98
99 pub order_by: SessionOrderBy,
101 pub order_direction: OrderDirection,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub enum SessionStatus {
108 Active,
110 Inactive,
112}
113
114#[derive(Debug, Clone, Default)]
116pub enum SessionOrderBy {
117 #[default]
118 CreatedAt,
119 UpdatedAt,
120 MessageCount,
121}
122
123#[derive(Debug, Clone, Default)]
125pub enum OrderDirection {
126 #[default]
127 Descending,
128 Ascending,
129}
130
131#[derive(Debug, Clone)]
133pub struct ToolCallUpdate {
134 pub status: Option<super::ToolCallStatus>,
135 pub result: Option<super::ToolExecutionStats>,
136 pub error: Option<String>,
137}
138
139impl ToolCallUpdate {
140 pub fn set_status(status: super::ToolCallStatus) -> Self {
141 Self {
142 status: Some(status),
143 result: None,
144 error: None,
145 }
146 }
147
148 pub fn set_result(result: super::ToolExecutionStats) -> Self {
149 Self {
150 status: Some(super::ToolCallStatus::Completed),
151 result: Some(result),
152 error: None,
153 }
154 }
155
156 pub fn set_error(error: String) -> Self {
157 Self {
158 status: Some(super::ToolCallStatus::Failed {
159 error: error.clone(),
160 }),
161 result: None,
162 error: Some(error),
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct MessagePage {
170 pub messages: Vec<Message>,
171 pub has_more: bool,
172 pub next_cursor: Option<MessageCursor>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct MessageCursor {
178 pub sequence_num: u32,
179 pub message_id: Option<String>,
180}
181
182#[async_trait]
184pub trait ArchiveStore: Send + Sync {
185 async fn archive_session(&self, session: &Session) -> Result<String, SessionStoreError>;
186 async fn restore_session(&self, archive_id: &str) -> Result<Session, SessionStoreError>;
187 async fn delete_archive(&self, archive_id: &str) -> Result<(), SessionStoreError>;
188}
189
190#[derive(Debug, thiserror::Error)]
192pub enum SessionStoreError {
193 #[error("Session not found: {session_id}")]
194 SessionNotFound { session_id: String },
195
196 #[error("Tool call not found: {tool_call_id}")]
197 ToolCallNotFound { tool_call_id: String },
198
199 #[error("Database error: {message}")]
200 Database { message: String },
201
202 #[error("Serialization error: {message}")]
203 Serialization { message: String },
204
205 #[error("Transaction error: {message}")]
206 Transaction { message: String },
207
208 #[error("Validation error: {message}")]
209 Validation { message: String },
210
211 #[error("Connection error: {message}")]
212 Connection { message: String },
213
214 #[error("Migration error: {message}")]
215 Migration { message: String },
216
217 #[error("Constraint violation: {message}")]
218 ConstraintViolation { message: String },
219
220 #[error("Internal error: {message}")]
221 Internal { message: String },
222
223 #[error("{entity} not found: {id}")]
224 NotFound { entity: String, id: String },
225}
226
227impl SessionStoreError {
228 pub fn database<S: Into<String>>(message: S) -> Self {
229 Self::Database {
230 message: message.into(),
231 }
232 }
233
234 pub fn serialization<S: Into<String>>(message: S) -> Self {
235 Self::Serialization {
236 message: message.into(),
237 }
238 }
239
240 pub fn transaction<S: Into<String>>(message: S) -> Self {
241 Self::Transaction {
242 message: message.into(),
243 }
244 }
245
246 pub fn validation<S: Into<String>>(message: S) -> Self {
247 Self::Validation {
248 message: message.into(),
249 }
250 }
251
252 pub fn connection<S: Into<String>>(message: S) -> Self {
253 Self::Connection {
254 message: message.into(),
255 }
256 }
257
258 pub fn internal<S: Into<String>>(message: S) -> Self {
259 Self::Internal {
260 message: message.into(),
261 }
262 }
263}
264
265#[async_trait]
267pub trait SessionStoreExt: SessionStore {
268 async fn get_messages_paginated(
270 &self,
271 session_id: &str,
272 page_size: u32,
273 cursor: Option<MessageCursor>,
274 ) -> Result<MessagePage, SessionStoreError> {
275 let after_sequence = cursor.map(|c| c.sequence_num);
276 let messages = self.get_messages(session_id, after_sequence).await?;
277
278 let has_more = messages.len() > page_size as usize;
279 let messages = if has_more {
280 messages.into_iter().take(page_size as usize).collect()
281 } else {
282 messages
283 };
284
285 let next_cursor = if has_more && !messages.is_empty() {
286 Some(MessageCursor {
287 sequence_num: messages.len() as u32,
288 message_id: messages.last().map(|m| m.id().to_string()),
289 })
290 } else {
291 None
292 };
293
294 Ok(MessagePage {
295 messages,
296 has_more,
297 next_cursor,
298 })
299 }
300
301 async fn archive_session(
303 &self,
304 session_id: &str,
305 archive_store: &dyn ArchiveStore,
306 ) -> Result<String, SessionStoreError> {
307 let session = self.get_session(session_id).await?.ok_or_else(|| {
308 SessionStoreError::SessionNotFound {
309 session_id: session_id.to_string(),
310 }
311 })?;
312
313 let archive_id = archive_store.archive_session(&session).await?;
314 self.delete_session(session_id).await?;
315
316 Ok(archive_id)
317 }
318}
319
320impl<T: SessionStore + ?Sized> SessionStoreExt for T {}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_session_filter_creation() {
329 let filter = SessionFilter {
330 limit: Some(10),
331 order_by: SessionOrderBy::UpdatedAt,
332 ..Default::default()
333 };
334
335 assert_eq!(filter.limit, Some(10));
336 assert!(matches!(filter.order_by, SessionOrderBy::UpdatedAt));
337 }
338
339 #[test]
340 fn test_tool_call_update() {
341 let update = ToolCallUpdate::set_error("Test error".to_string());
342
343 assert!(update.status.is_some());
344 assert!(update.error.is_some());
345 assert_eq!(update.error.unwrap(), "Test error");
346 }
347
348 #[test]
349 fn test_message_cursor() {
350 let cursor = MessageCursor {
351 sequence_num: 5,
352 message_id: Some("msg-123".to_string()),
353 };
354
355 assert_eq!(cursor.sequence_num, 5);
356 assert_eq!(cursor.message_id.unwrap(), "msg-123");
357 }
358}