turul_mcp_session_storage/
traits.rs1use async_trait::async_trait;
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::time::SystemTime;
16use uuid::Uuid;
17
18use turul_mcp_protocol::{ClientCapabilities, ServerCapabilities};
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SessionInfo {
25 pub session_id: String,
27 pub client_capabilities: Option<ClientCapabilities>,
29 pub server_capabilities: Option<ServerCapabilities>,
31 pub state: HashMap<String, Value>,
33 pub created_at: u64,
35 pub last_activity: u64,
37 pub is_initialized: bool,
39 pub metadata: HashMap<String, Value>,
41}
42
43impl Default for SessionInfo {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl SessionInfo {
50 pub fn new() -> Self {
52 let now = chrono::Utc::now().timestamp_millis() as u64;
53 Self {
54 session_id: Uuid::now_v7().to_string(),
55 client_capabilities: None,
56 server_capabilities: None,
57 state: HashMap::new(),
58 created_at: now,
59 last_activity: now,
60 is_initialized: false,
61 metadata: HashMap::new(),
62 }
63 }
64
65 pub fn with_id(session_id: String) -> Self {
67 let now = chrono::Utc::now().timestamp_millis() as u64;
68 Self {
69 session_id,
70 client_capabilities: None,
71 server_capabilities: None,
72 state: HashMap::new(),
73 created_at: now,
74 last_activity: now,
75 is_initialized: false,
76 metadata: HashMap::new(),
77 }
78 }
79
80 pub fn touch(&mut self) {
82 self.last_activity = chrono::Utc::now().timestamp_millis() as u64;
83 }
84
85 pub fn is_expired(&self, timeout_minutes: u64) -> bool {
87 let now = chrono::Utc::now().timestamp_millis() as u64;
88 let timeout_millis = timeout_minutes * 60 * 1000;
89 now - self.last_activity > timeout_millis
90 }
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct SseEvent {
96 pub id: u64,
98 pub timestamp: u64,
100 pub event_type: String,
102 pub data: Value,
104 pub retry: Option<u32>,
106}
107
108impl SseEvent {
109 pub fn new(event_type: String, data: Value) -> Self {
111 Self {
112 id: 0, timestamp: chrono::Utc::now().timestamp_millis() as u64,
114 event_type,
115 data,
116 retry: None,
117 }
118 }
119
120 pub fn format(&self) -> String {
130 if self.event_type == "ping" || self.event_type == "keepalive" {
133 return ": keepalive\n\n".to_string();
134 }
135
136 let mut result = String::new();
137
138 result.push_str(&format!("id: {}\n", self.id));
140
141 result.push_str("event: message\n");
143
144 if let Ok(data_str) = serde_json::to_string(&self.data) {
146 result.push_str(&format!("data: {}\n", data_str));
147 } else {
148 result.push_str("data: {}\n");
149 }
150
151 if let Some(retry) = self.retry {
153 result.push_str(&format!("retry: {}\n", retry));
154 }
155
156 result.push('\n');
158
159 result
160 }
161}
162
163#[async_trait]
165pub trait SessionStorage: Send + Sync {
166 type Error: std::error::Error + Send + Sync + 'static;
168
169 fn backend_name(&self) -> &'static str;
171
172 async fn create_session(
188 &self,
189 capabilities: ServerCapabilities,
190 ) -> Result<SessionInfo, Self::Error>;
191
192 async fn create_session_with_id(
224 &self,
225 session_id: String,
226 capabilities: ServerCapabilities,
227 ) -> Result<SessionInfo, Self::Error>;
228
229 async fn get_session(&self, session_id: &str) -> Result<Option<SessionInfo>, Self::Error>;
231
232 async fn update_session(&self, session_info: SessionInfo) -> Result<(), Self::Error>;
234
235 async fn set_session_state(
237 &self,
238 session_id: &str,
239 key: &str,
240 value: Value,
241 ) -> Result<(), Self::Error>;
242
243 async fn get_session_state(
245 &self,
246 session_id: &str,
247 key: &str,
248 ) -> Result<Option<Value>, Self::Error>;
249
250 async fn remove_session_state(
252 &self,
253 session_id: &str,
254 key: &str,
255 ) -> Result<Option<Value>, Self::Error>;
256
257 async fn delete_session(&self, session_id: &str) -> Result<bool, Self::Error>;
259
260 async fn list_sessions(&self) -> Result<Vec<String>, Self::Error>;
262
263 async fn store_event(&self, session_id: &str, event: SseEvent)
269 -> Result<SseEvent, Self::Error>;
270
271 async fn get_events_after(
273 &self,
274 session_id: &str,
275 after_event_id: u64,
276 ) -> Result<Vec<SseEvent>, Self::Error>;
277
278 async fn get_recent_events(
280 &self,
281 session_id: &str,
282 limit: usize,
283 ) -> Result<Vec<SseEvent>, Self::Error>;
284
285 async fn delete_events_before(
287 &self,
288 session_id: &str,
289 before_event_id: u64,
290 ) -> Result<u64, Self::Error>;
291
292 async fn expire_sessions(&self, older_than: SystemTime) -> Result<Vec<String>, Self::Error>;
298
299 async fn session_count(&self) -> Result<usize, Self::Error>;
301
302 async fn event_count(&self) -> Result<usize, Self::Error>;
304
305 async fn maintenance(&self) -> Result<(), Self::Error>;
307}
308
309pub type SessionResult<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
311
312#[derive(Debug, thiserror::Error)]
314pub enum SessionStorageError {
315 #[error("Session not found: {0}")]
316 SessionNotFound(String),
317
318 #[error("Maximum sessions limit reached: {0}")]
319 MaxSessionsReached(usize),
320
321 #[error("Maximum events limit reached: {0}")]
322 MaxEventsReached(usize),
323
324 #[error("Database error: {0}")]
325 DatabaseError(String),
326
327 #[error("Serialization error: {0}")]
328 SerializationError(String),
329
330 #[error("Connection error: {0}")]
331 ConnectionError(String),
332
333 #[error("Migration error: {0}")]
334 MigrationError(String),
335
336 #[error("AWS SDK error: {0}")]
337 AwsError(String),
338
339 #[error("AWS configuration error: {0}")]
340 AwsConfigurationError(String),
341
342 #[error("DynamoDB table does not exist: {0}")]
343 TableNotFound(String),
344
345 #[error("Invalid session data: {0}")]
346 InvalidData(String),
347
348 #[error("Concurrent modification error: {0}")]
349 ConcurrentModification(String),
350
351 #[error("Generic storage error: {0}")]
352 Generic(String),
353}
354
355impl From<serde_json::Error> for SessionStorageError {
357 fn from(err: serde_json::Error) -> Self {
358 SessionStorageError::SerializationError(err.to_string())
359 }
360}
361
362#[cfg(feature = "sqlite")]
363impl From<sqlx::Error> for SessionStorageError {
364 fn from(err: sqlx::Error) -> Self {
365 SessionStorageError::DatabaseError(err.to_string())
366 }
367}
368
369impl From<crate::in_memory::InMemoryError> for SessionStorageError {
371 fn from(err: crate::in_memory::InMemoryError) -> Self {
372 match err {
373 crate::in_memory::InMemoryError::SessionNotFound(id) => {
374 SessionStorageError::SessionNotFound(id)
375 }
376 crate::in_memory::InMemoryError::MaxSessionsReached(limit) => {
377 SessionStorageError::MaxSessionsReached(limit)
378 }
379 crate::in_memory::InMemoryError::MaxEventsReached(limit) => {
380 SessionStorageError::MaxEventsReached(limit)
381 }
382 crate::in_memory::InMemoryError::SerializationError(e) => {
383 SessionStorageError::SerializationError(e.to_string())
384 }
385 }
386 }
387}
388
389#[cfg(feature = "sqlite")]
390impl From<crate::sqlite::SqliteError> for SessionStorageError {
391 fn from(err: crate::sqlite::SqliteError) -> Self {
392 match err {
393 crate::sqlite::SqliteError::Database(e) => {
394 SessionStorageError::DatabaseError(e.to_string())
395 }
396 crate::sqlite::SqliteError::Serialization(e) => {
397 SessionStorageError::SerializationError(e.to_string())
398 }
399 crate::sqlite::SqliteError::SessionNotFound(id) => {
400 SessionStorageError::SessionNotFound(id)
401 }
402 crate::sqlite::SqliteError::Connection(e) => SessionStorageError::ConnectionError(e),
403 crate::sqlite::SqliteError::Migration(e) => SessionStorageError::MigrationError(e),
404 }
405 }
406}
407
408#[cfg(feature = "postgres")]
409impl From<crate::postgres::PostgresError> for SessionStorageError {
410 fn from(err: crate::postgres::PostgresError) -> Self {
411 match err {
412 crate::postgres::PostgresError::Database(e) => {
413 SessionStorageError::DatabaseError(e.to_string())
414 }
415 crate::postgres::PostgresError::Serialization(e) => {
416 SessionStorageError::SerializationError(e.to_string())
417 }
418 crate::postgres::PostgresError::SessionNotFound(id) => {
419 SessionStorageError::SessionNotFound(id)
420 }
421 crate::postgres::PostgresError::Connection(e) => {
422 SessionStorageError::ConnectionError(e)
423 }
424 crate::postgres::PostgresError::Migration(e) => SessionStorageError::MigrationError(e),
425 crate::postgres::PostgresError::ConcurrentModification(e) => {
426 SessionStorageError::ConcurrentModification(e)
427 }
428 }
429 }
430}
431
432#[cfg(feature = "dynamodb")]
433impl From<crate::dynamodb::DynamoDbError> for SessionStorageError {
434 fn from(err: crate::dynamodb::DynamoDbError) -> Self {
435 match err {
436 crate::dynamodb::DynamoDbError::AwsError(e) => SessionStorageError::AwsError(e),
437 crate::dynamodb::DynamoDbError::SerializationError(e) => {
438 SessionStorageError::SerializationError(e.to_string())
439 }
440 crate::dynamodb::DynamoDbError::SessionNotFound(id) => {
441 SessionStorageError::SessionNotFound(id)
442 }
443 crate::dynamodb::DynamoDbError::InvalidSessionData(e) => {
444 SessionStorageError::InvalidData(e)
445 }
446 crate::dynamodb::DynamoDbError::TableNotFound(table) => {
447 SessionStorageError::TableNotFound(table)
448 }
449 crate::dynamodb::DynamoDbError::ConfigError(e) => {
450 SessionStorageError::AwsConfigurationError(e)
451 }
452 }
453 }
454}
455
456pub type BoxedSessionStorage = dyn SessionStorage<Error = SessionStorageError>;
458
459pub trait SessionStorageBuilder {
461 type Storage: SessionStorage;
462 type Config;
463 type Error: std::error::Error + Send + Sync + 'static;
464
465 fn build(config: Self::Config) -> Result<Self::Storage, Self::Error>;
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471
472 #[test]
473 fn test_session_info_creation() {
474 let session = SessionInfo::new();
475 assert!(!session.session_id.is_empty());
476 assert!(!session.is_initialized);
477 assert!(session.state.is_empty());
478 }
479
480 #[test]
481 fn test_session_expiration() {
482 let mut session = SessionInfo::new();
483 assert!(!session.is_expired(30)); session.last_activity = chrono::Utc::now().timestamp_millis() as u64 - (31 * 60 * 1000);
487 assert!(session.is_expired(30));
488 }
489
490 #[test]
491 fn test_sse_event_formatting() {
492 let mut event = SseEvent {
494 id: 123,
495 timestamp: 1234567890,
496 event_type: "notifications/progress".to_string(),
497 data: serde_json::json!({"message": "test"}),
498 retry: Some(1000),
499 };
500 event.id = 123; let formatted = event.format();
503 assert!(formatted.contains("id: 123"));
504 assert!(formatted.contains("event: message")); assert!(formatted.contains("retry: 1000"));
506 assert!(formatted.contains("data: {\"message\":\"test\"}"));
507
508 let keepalive = SseEvent {
510 id: 0,
511 timestamp: 1234567890,
512 event_type: "ping".to_string(),
513 data: serde_json::json!({"type": "keepalive"}),
514 retry: None,
515 };
516
517 let keepalive_formatted = keepalive.format();
518 assert_eq!(keepalive_formatted, ": keepalive\n\n");
520 assert!(keepalive_formatted.starts_with(":")); assert!(!keepalive_formatted.contains("id:")); assert!(!keepalive_formatted.contains("event:")); assert!(!keepalive_formatted.contains("data:")); }
525}