steer_core/session/
store.rs

1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use super::{Session, SessionConfig, SessionInfo};
7use crate::app::Message;
8use crate::events::StreamEvent;
9use steer_tools::ToolCall;
10
11/// Database-agnostic session store trait
12#[async_trait]
13pub trait SessionStore: Send + Sync {
14    // Session lifecycle
15    async fn create_session(&self, config: SessionConfig) -> Result<Session, SessionStoreError>;
16    async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionStoreError>;
17    async fn update_session(&self, session: &Session) -> Result<(), SessionStoreError>;
18    async fn delete_session(&self, session_id: &str) -> Result<(), SessionStoreError>;
19    async fn list_sessions(
20        &self,
21        filter: SessionFilter,
22    ) -> Result<Vec<SessionInfo>, SessionStoreError>;
23
24    // Message operations
25    async fn append_message(
26        &self,
27        session_id: &str,
28        message: &Message,
29    ) -> Result<(), SessionStoreError>;
30    async fn get_messages(
31        &self,
32        session_id: &str,
33        after_sequence: Option<u32>,
34    ) -> Result<Vec<Message>, SessionStoreError>;
35
36    // Tool operations
37    async fn create_tool_call(
38        &self,
39        session_id: &str,
40        tool_call: &ToolCall,
41    ) -> Result<(), SessionStoreError>;
42    async fn update_tool_call(
43        &self,
44        tool_call_id: &str,
45        update: ToolCallUpdate,
46    ) -> Result<(), SessionStoreError>;
47    async fn get_pending_tool_calls(
48        &self,
49        session_id: &str,
50    ) -> Result<Vec<ToolCall>, SessionStoreError>;
51
52    // Event streaming
53    async fn append_event(
54        &self,
55        session_id: &str,
56        event: &StreamEvent,
57    ) -> Result<u64, SessionStoreError>;
58    async fn get_events(
59        &self,
60        session_id: &str,
61        after_sequence: u64,
62        limit: Option<u32>,
63    ) -> Result<Vec<(u64, StreamEvent)>, SessionStoreError>;
64    async fn delete_events_before(
65        &self,
66        session_id: &str,
67        before_sequence: u64,
68    ) -> Result<u64, SessionStoreError>;
69
70    // Active message tracking
71    async fn update_active_message_id(
72        &self,
73        session_id: &str,
74        message_id: Option<&str>,
75    ) -> Result<(), SessionStoreError>;
76}
77
78/// Filter for listing sessions
79#[derive(Debug, Clone, Default)]
80pub struct SessionFilter {
81    /// Filter by creation date range
82    pub created_after: Option<DateTime<Utc>>,
83    pub created_before: Option<DateTime<Utc>>,
84
85    /// Filter by last update date range
86    pub updated_after: Option<DateTime<Utc>>,
87    pub updated_before: Option<DateTime<Utc>>,
88
89    /// Filter by metadata key-value pairs
90    pub metadata_filters: HashMap<String, String>,
91
92    /// Filter by session status
93    pub status_filter: Option<SessionStatus>,
94
95    /// Pagination
96    pub limit: Option<u32>,
97    pub offset: Option<u32>,
98
99    /// Ordering
100    pub order_by: SessionOrderBy,
101    pub order_direction: OrderDirection,
102}
103
104/// Session status for filtering
105#[derive(Debug, Clone, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub enum SessionStatus {
108    /// Session has an active app instance running in a live environment
109    Active,
110    /// Session has no active app instance (environment down or app not running)
111    Inactive,
112}
113
114/// Ordering options for session listing
115#[derive(Debug, Clone, Default)]
116pub enum SessionOrderBy {
117    #[default]
118    CreatedAt,
119    UpdatedAt,
120    MessageCount,
121}
122
123/// Order direction
124#[derive(Debug, Clone, Default)]
125pub enum OrderDirection {
126    #[default]
127    Descending,
128    Ascending,
129}
130
131/// Tool call update operations
132#[derive(Debug, Clone)]
133pub struct ToolCallUpdate {
134    pub status: Option<super::ToolCallStatus>,
135    pub result: Option<super::ToolExecutionStats>,
136    pub error: Option<String>,
137}
138
139impl ToolCallUpdate {
140    pub fn set_status(status: super::ToolCallStatus) -> Self {
141        Self {
142            status: Some(status),
143            result: None,
144            error: None,
145        }
146    }
147
148    pub fn set_result(result: super::ToolExecutionStats) -> Self {
149        Self {
150            status: Some(super::ToolCallStatus::Completed),
151            result: Some(result),
152            error: None,
153        }
154    }
155
156    pub fn set_error(error: String) -> Self {
157        Self {
158            status: Some(super::ToolCallStatus::Failed {
159                error: error.clone(),
160            }),
161            result: None,
162            error: Some(error),
163        }
164    }
165}
166
167/// Pagination support for messages
168#[derive(Debug, Clone)]
169pub struct MessagePage {
170    pub messages: Vec<Message>,
171    pub has_more: bool,
172    pub next_cursor: Option<MessageCursor>,
173}
174
175/// Cursor for stable message pagination
176#[derive(Debug, Clone, Serialize, Deserialize)]
177pub struct MessageCursor {
178    pub sequence_num: u32,
179    pub message_id: Option<String>,
180}
181
182/// Archive store trait for cold storage
183#[async_trait]
184pub trait ArchiveStore: Send + Sync {
185    async fn archive_session(&self, session: &Session) -> Result<String, SessionStoreError>;
186    async fn restore_session(&self, archive_id: &str) -> Result<Session, SessionStoreError>;
187    async fn delete_archive(&self, archive_id: &str) -> Result<(), SessionStoreError>;
188}
189
190/// Session store error types
191#[derive(Debug, thiserror::Error)]
192pub enum SessionStoreError {
193    #[error("Session not found: {session_id}")]
194    SessionNotFound { session_id: String },
195
196    #[error("Tool call not found: {tool_call_id}")]
197    ToolCallNotFound { tool_call_id: String },
198
199    #[error("Database error: {message}")]
200    Database { message: String },
201
202    #[error("Serialization error: {message}")]
203    Serialization { message: String },
204
205    #[error("Transaction error: {message}")]
206    Transaction { message: String },
207
208    #[error("Validation error: {message}")]
209    Validation { message: String },
210
211    #[error("Connection error: {message}")]
212    Connection { message: String },
213
214    #[error("Migration error: {message}")]
215    Migration { message: String },
216
217    #[error("Constraint violation: {message}")]
218    ConstraintViolation { message: String },
219
220    #[error("Internal error: {message}")]
221    Internal { message: String },
222
223    #[error("{entity} not found: {id}")]
224    NotFound { entity: String, id: String },
225}
226
227impl SessionStoreError {
228    pub fn database<S: Into<String>>(message: S) -> Self {
229        Self::Database {
230            message: message.into(),
231        }
232    }
233
234    pub fn serialization<S: Into<String>>(message: S) -> Self {
235        Self::Serialization {
236            message: message.into(),
237        }
238    }
239
240    pub fn transaction<S: Into<String>>(message: S) -> Self {
241        Self::Transaction {
242            message: message.into(),
243        }
244    }
245
246    pub fn validation<S: Into<String>>(message: S) -> Self {
247        Self::Validation {
248            message: message.into(),
249        }
250    }
251
252    pub fn connection<S: Into<String>>(message: S) -> Self {
253        Self::Connection {
254            message: message.into(),
255        }
256    }
257
258    pub fn internal<S: Into<String>>(message: S) -> Self {
259        Self::Internal {
260            message: message.into(),
261        }
262    }
263}
264
265/// Extension trait for SessionStore with additional convenience methods
266#[async_trait]
267pub trait SessionStoreExt: SessionStore {
268    /// Get messages with pagination support
269    async fn get_messages_paginated(
270        &self,
271        session_id: &str,
272        page_size: u32,
273        cursor: Option<MessageCursor>,
274    ) -> Result<MessagePage, SessionStoreError> {
275        let after_sequence = cursor.map(|c| c.sequence_num);
276        let messages = self.get_messages(session_id, after_sequence).await?;
277
278        let has_more = messages.len() > page_size as usize;
279        let messages = if has_more {
280            messages.into_iter().take(page_size as usize).collect()
281        } else {
282            messages
283        };
284
285        let next_cursor = if has_more && !messages.is_empty() {
286            Some(MessageCursor {
287                sequence_num: messages.len() as u32,
288                message_id: messages.last().map(|m| m.id().to_string()),
289            })
290        } else {
291            None
292        };
293
294        Ok(MessagePage {
295            messages,
296            has_more,
297            next_cursor,
298        })
299    }
300
301    /// Archive a completed session
302    async fn archive_session(
303        &self,
304        session_id: &str,
305        archive_store: &dyn ArchiveStore,
306    ) -> Result<String, SessionStoreError> {
307        let session = self.get_session(session_id).await?.ok_or_else(|| {
308            SessionStoreError::SessionNotFound {
309                session_id: session_id.to_string(),
310            }
311        })?;
312
313        let archive_id = archive_store.archive_session(&session).await?;
314        self.delete_session(session_id).await?;
315
316        Ok(archive_id)
317    }
318}
319
320// Blanket implementation for all SessionStore implementors
321impl<T: SessionStore + ?Sized> SessionStoreExt for T {}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_session_filter_creation() {
329        let filter = SessionFilter {
330            limit: Some(10),
331            order_by: SessionOrderBy::UpdatedAt,
332            ..Default::default()
333        };
334
335        assert_eq!(filter.limit, Some(10));
336        assert!(matches!(filter.order_by, SessionOrderBy::UpdatedAt));
337    }
338
339    #[test]
340    fn test_tool_call_update() {
341        let update = ToolCallUpdate::set_error("Test error".to_string());
342
343        assert!(update.status.is_some());
344        assert!(update.error.is_some());
345        assert_eq!(update.error.unwrap(), "Test error");
346    }
347
348    #[test]
349    fn test_message_cursor() {
350        let cursor = MessageCursor {
351            sequence_num: 5,
352            message_id: Some("msg-123".to_string()),
353        };
354
355        assert_eq!(cursor.sequence_num, 5);
356        assert_eq!(cursor.message_id.unwrap(), "msg-123");
357    }
358}