1use async_trait::async_trait;
65use chrono::Utc;
66use sqlx::{Pool, Row, Sqlite, SqlitePool};
67use std::path::Path;
68
69use crate::codec::{items_to_messages, messages_to_items};
70use crate::error::Result;
71use crate::items::RunItem;
72use crate::memory::Session;
73use crate::sessions::{History, LoadSession, SaveSession};
74use std::future::Future;
75use std::pin::Pin;
76use tower::{BoxError, Service};
77
78pub struct SqliteSession {
84 session_id: String,
85 pool: Pool<Sqlite>,
86}
87
88impl SqliteSession {
89 pub async fn new(session_id: impl Into<String>, db_path: impl AsRef<Path>) -> Result<Self> {
99 let session_id = session_id.into();
100 let db_url = format!("sqlite:{}", db_path.as_ref().display());
101
102 let pool = SqlitePool::connect(&db_url).await?;
104
105 Self::run_migrations(&pool).await?;
107
108 Ok(Self { session_id, pool })
109 }
110
111 pub async fn new_default(session_id: impl Into<String>) -> Result<Self> {
113 Self::new(session_id, "sessions.db").await
114 }
115
116 pub async fn new_in_memory(session_id: impl Into<String>) -> Result<Self> {
121 let session_id = session_id.into();
122 let pool = SqlitePool::connect("sqlite::memory:").await?;
123
124 Self::run_migrations(&pool).await?;
126
127 Ok(Self { session_id, pool })
128 }
129
130 async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
132 sqlx::query(
134 r#"
135 CREATE TABLE IF NOT EXISTS sessions (
136 id INTEGER PRIMARY KEY AUTOINCREMENT,
137 session_id TEXT NOT NULL,
138 item_type TEXT NOT NULL,
139 item_data TEXT NOT NULL,
140 created_at TEXT NOT NULL,
141 sequence_num INTEGER NOT NULL,
142 UNIQUE(session_id, sequence_num)
143 )
144 "#,
145 )
146 .execute(pool)
147 .await?;
148
149 sqlx::query(
151 r#"
152 CREATE INDEX IF NOT EXISTS idx_session_id
153 ON sessions(session_id, sequence_num)
154 "#,
155 )
156 .execute(pool)
157 .await?;
158
159 Ok(())
160 }
161
162 fn serialize_item(item: &RunItem) -> Result<String> {
164 Ok(serde_json::to_string(item)?)
165 }
166
167 fn deserialize_item(data: &str) -> Result<RunItem> {
169 Ok(serde_json::from_str(data)?)
170 }
171
172 fn get_item_type(item: &RunItem) -> &'static str {
174 match item {
175 RunItem::Message(_) => "message",
176 RunItem::ToolCall(_) => "tool_call",
177 RunItem::ToolOutput(_) => "tool_output",
178 RunItem::Handoff(_) => "handoff",
179 }
180 }
181}
182
183#[async_trait]
184impl Session for SqliteSession {
185 fn session_id(&self) -> &str {
186 &self.session_id
187 }
188
189 async fn get_items(&self, limit: Option<usize>) -> Result<Vec<RunItem>> {
190 let query = if let Some(limit) = limit {
191 sqlx::query(
192 r#"
193 SELECT item_data
194 FROM sessions
195 WHERE session_id = ?
196 ORDER BY sequence_num DESC
197 LIMIT ?
198 "#,
199 )
200 .bind(&self.session_id)
201 .bind(limit as i64)
202 } else {
203 sqlx::query(
204 r#"
205 SELECT item_data
206 FROM sessions
207 WHERE session_id = ?
208 ORDER BY sequence_num ASC
209 "#,
210 )
211 .bind(&self.session_id)
212 };
213
214 let rows = query.fetch_all(&self.pool).await?;
215
216 let mut items = Vec::new();
217 for row in rows {
218 let data: String = row.get("item_data");
219 items.push(Self::deserialize_item(&data)?);
220 }
221
222 if limit.is_some() {
224 items.reverse();
225 }
226
227 Ok(items)
228 }
229
230 async fn add_items(&self, items: Vec<RunItem>) -> Result<()> {
231 let max_seq: Option<i64> = sqlx::query_scalar(
233 r#"
234 SELECT MAX(sequence_num)
235 FROM sessions
236 WHERE session_id = ?
237 "#,
238 )
239 .bind(&self.session_id)
240 .fetch_one(&self.pool)
241 .await?;
242
243 let mut sequence_num = max_seq.unwrap_or(0) + 1;
244
245 for item in items {
247 let item_type = Self::get_item_type(&item);
248 let item_data = Self::serialize_item(&item)?;
249 let created_at = Utc::now().to_rfc3339();
250
251 sqlx::query(
252 r#"
253 INSERT INTO sessions (session_id, item_type, item_data, created_at, sequence_num)
254 VALUES (?, ?, ?, ?, ?)
255 "#,
256 )
257 .bind(&self.session_id)
258 .bind(item_type)
259 .bind(item_data)
260 .bind(created_at)
261 .bind(sequence_num)
262 .execute(&self.pool)
263 .await?;
264
265 sequence_num += 1;
266 }
267
268 Ok(())
269 }
270
271 async fn pop_item(&self) -> Result<Option<RunItem>> {
272 let mut tx = self.pool.begin().await?;
274
275 let row = sqlx::query(
277 r#"
278 SELECT id, item_data
279 FROM sessions
280 WHERE session_id = ?
281 ORDER BY sequence_num DESC
282 LIMIT 1
283 "#,
284 )
285 .bind(&self.session_id)
286 .fetch_optional(&mut *tx)
287 .await?;
288
289 if let Some(row) = row {
290 let id: i64 = row.get("id");
291 let data: String = row.get("item_data");
292
293 sqlx::query("DELETE FROM sessions WHERE id = ?")
295 .bind(id)
296 .execute(&mut *tx)
297 .await?;
298
299 tx.commit().await?;
301
302 Ok(Some(Self::deserialize_item(&data)?))
303 } else {
304 Ok(None)
305 }
306 }
307
308 async fn clear_session(&self) -> Result<()> {
309 sqlx::query("DELETE FROM sessions WHERE session_id = ?")
310 .bind(&self.session_id)
311 .execute(&self.pool)
312 .await?;
313
314 Ok(())
315 }
316}
317
318impl std::fmt::Debug for SqliteSession {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 f.debug_struct("SqliteSession")
321 .field("session_id", &self.session_id)
322 .finish()
323 }
324}
325
326#[derive(Clone)]
328pub struct SqliteSessionStore {
329 pool: SqlitePool,
330}
331
332impl SqliteSessionStore {
333 pub async fn new(db_path: impl AsRef<Path>) -> Result<Self> {
335 let db_url = format!("sqlite:{}", db_path.as_ref().display());
336 let pool = SqlitePool::connect(&db_url).await?;
337 Self::run_migrations(&pool).await?;
338 Ok(Self { pool })
339 }
340
341 pub async fn new_in_memory() -> Result<Self> {
343 let pool = SqlitePool::connect("sqlite::memory:").await?;
344 Self::run_migrations(&pool).await?;
345 Ok(Self { pool })
346 }
347
348 async fn run_migrations(pool: &Pool<Sqlite>) -> Result<()> {
349 sqlx::query(
351 r#"
352 CREATE TABLE IF NOT EXISTS sessions (
353 id INTEGER PRIMARY KEY AUTOINCREMENT,
354 session_id TEXT NOT NULL,
355 item_type TEXT NOT NULL,
356 item_data TEXT NOT NULL,
357 created_at TEXT NOT NULL,
358 sequence_num INTEGER NOT NULL,
359 UNIQUE(session_id, sequence_num)
360 )
361 "#,
362 )
363 .execute(pool)
364 .await?;
365
366 sqlx::query(
367 r#"
368 CREATE INDEX IF NOT EXISTS idx_session_id
369 ON sessions(session_id, sequence_num)
370 "#,
371 )
372 .execute(pool)
373 .await?;
374
375 Ok(())
376 }
377}
378
379impl Service<LoadSession> for SqliteSessionStore {
380 type Response = History;
381 type Error = BoxError;
382 type Future =
383 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
384
385 fn poll_ready(
386 &mut self,
387 _cx: &mut std::task::Context<'_>,
388 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
389 std::task::Poll::Ready(Ok(()))
390 }
391
392 fn call(&mut self, req: LoadSession) -> Self::Future {
393 let pool = self.pool.clone();
394 let sid = req.id.0.clone();
395 Box::pin(async move {
396 let rows = sqlx::query(
397 r#"
398 SELECT item_data
399 FROM sessions
400 WHERE session_id = ?
401 ORDER BY sequence_num ASC
402 "#,
403 )
404 .bind(&sid)
405 .fetch_all(&pool)
406 .await?;
407 let mut items: Vec<RunItem> = Vec::with_capacity(rows.len());
408 for row in rows {
409 let data: String = row.get("item_data");
410 let item: RunItem = serde_json::from_str(&data)?;
411 items.push(item);
412 }
413 let messages = items_to_messages(&items);
414 Ok(messages)
415 })
416 }
417}
418
419impl Service<SaveSession> for SqliteSessionStore {
420 type Response = ();
421 type Error = BoxError;
422 type Future =
423 Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
424
425 fn poll_ready(
426 &mut self,
427 _cx: &mut std::task::Context<'_>,
428 ) -> std::task::Poll<std::result::Result<(), Self::Error>> {
429 std::task::Poll::Ready(Ok(()))
430 }
431
432 fn call(&mut self, req: SaveSession) -> Self::Future {
433 let pool = self.pool.clone();
434 let sid = req.id.0.clone();
435 let history = req.history.clone();
436 Box::pin(async move {
437 let items = messages_to_items(&history).map_err(|e| -> BoxError { e.into() })?;
439
440 sqlx::query("DELETE FROM sessions WHERE session_id = ?")
442 .bind(&sid)
443 .execute(&pool)
444 .await?;
445
446 let mut sequence_num: i64 = 1;
448 for item in items {
449 let item_type = match &item {
450 RunItem::Message(_) => "message",
451 RunItem::ToolCall(_) => "tool_call",
452 RunItem::ToolOutput(_) => "tool_output",
453 RunItem::Handoff(_) => "handoff",
454 };
455 let item_data = serde_json::to_string(&item)?;
456 let created_at = chrono::Utc::now().to_rfc3339();
457 sqlx::query(
458 r#"
459 INSERT INTO sessions (session_id, item_type, item_data, created_at, sequence_num)
460 VALUES (?, ?, ?, ?, ?)
461 "#,
462 )
463 .bind(&sid)
464 .bind(item_type)
465 .bind(item_data)
466 .bind(created_at)
467 .bind(sequence_num)
468 .execute(&pool)
469 .await?;
470 sequence_num += 1;
471 }
472 Ok(())
473 })
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use crate::items::{HandoffItem, MessageItem, Role, ToolCallItem, ToolOutputItem};
481 use chrono::Utc;
482
483 #[tokio::test]
484 async fn test_sqlite_session_basic() {
485 let session = SqliteSession::new_in_memory("test_session").await.unwrap();
486
487 assert_eq!(session.session_id(), "test_session");
488
489 let items = vec![
491 RunItem::Message(MessageItem {
492 id: "1".to_string(),
493 role: Role::User,
494 content: "Hello".to_string(),
495 created_at: Utc::now(),
496 }),
497 RunItem::Message(MessageItem {
498 id: "2".to_string(),
499 role: Role::Assistant,
500 content: "Hi there!".to_string(),
501 created_at: Utc::now(),
502 }),
503 ];
504
505 session.add_items(items.clone()).await.unwrap();
506
507 let retrieved = session.get_items(None).await.unwrap();
509 assert_eq!(retrieved.len(), 2);
510
511 if let RunItem::Message(msg) = &retrieved[0] {
513 assert_eq!(msg.content, "Hello");
514 assert_eq!(msg.role, Role::User);
515 } else {
516 panic!("Expected Message item");
517 }
518 }
519
520 #[tokio::test]
521 async fn test_sqlite_session_with_limit() {
522 let session = SqliteSession::new_in_memory("test_limit").await.unwrap();
523
524 let mut items = vec![];
526 for i in 0..5 {
527 items.push(RunItem::Message(MessageItem {
528 id: format!("{}", i),
529 role: Role::User,
530 content: format!("Message {}", i),
531 created_at: Utc::now(),
532 }));
533 }
534
535 session.add_items(items).await.unwrap();
536
537 let limited = session.get_items(Some(2)).await.unwrap();
539 assert_eq!(limited.len(), 2);
540
541 if let RunItem::Message(msg) = &limited[0] {
543 assert_eq!(msg.content, "Message 3");
544 }
545 if let RunItem::Message(msg) = &limited[1] {
546 assert_eq!(msg.content, "Message 4");
547 }
548 }
549
550 #[tokio::test]
551 async fn test_sqlite_session_pop() {
552 let session = SqliteSession::new_in_memory("test_pop").await.unwrap();
553
554 let items = vec![
555 RunItem::Message(MessageItem {
556 id: "1".to_string(),
557 role: Role::User,
558 content: "First".to_string(),
559 created_at: Utc::now(),
560 }),
561 RunItem::Message(MessageItem {
562 id: "2".to_string(),
563 role: Role::User,
564 content: "Second".to_string(),
565 created_at: Utc::now(),
566 }),
567 ];
568
569 session.add_items(items).await.unwrap();
570
571 let popped = session.pop_item().await.unwrap();
573 assert!(popped.is_some());
574
575 if let Some(RunItem::Message(msg)) = popped {
576 assert_eq!(msg.content, "Second");
577 }
578
579 let remaining = session.get_items(None).await.unwrap();
581 assert_eq!(remaining.len(), 1);
582 }
583
584 #[tokio::test]
585 async fn test_sqlite_session_clear() {
586 let session = SqliteSession::new_in_memory("test_clear").await.unwrap();
587
588 let items = vec![RunItem::Message(MessageItem {
589 id: "1".to_string(),
590 role: Role::User,
591 content: "Test".to_string(),
592 created_at: Utc::now(),
593 })];
594
595 session.add_items(items).await.unwrap();
596
597 session.clear_session().await.unwrap();
599
600 let remaining = session.get_items(None).await.unwrap();
602 assert!(remaining.is_empty());
603 }
604
605 #[tokio::test]
606 async fn test_sqlite_session_complex_items() {
607 let session = SqliteSession::new_in_memory("test_complex").await.unwrap();
608
609 let items = vec![
610 RunItem::Message(MessageItem {
611 id: "1".to_string(),
612 role: Role::User,
613 content: "Calculate something".to_string(),
614 created_at: Utc::now(),
615 }),
616 RunItem::ToolCall(ToolCallItem {
617 id: "2".to_string(),
618 tool_name: "calculator".to_string(),
619 arguments: serde_json::json!({"a": 1, "b": 2}),
620 created_at: Utc::now(),
621 }),
622 RunItem::ToolOutput(ToolOutputItem {
623 id: "3".to_string(),
624 tool_call_id: "2".to_string(),
625 output: serde_json::json!(3),
626 error: None,
627 created_at: Utc::now(),
628 }),
629 RunItem::Handoff(HandoffItem {
630 id: "4".to_string(),
631 from_agent: "Main".to_string(),
632 to_agent: "Specialist".to_string(),
633 reason: Some("Complex calculation".to_string()),
634 created_at: Utc::now(),
635 }),
636 ];
637
638 session.add_items(items.clone()).await.unwrap();
639
640 let retrieved = session.get_items(None).await.unwrap();
641 assert_eq!(retrieved.len(), 4);
642
643 assert!(matches!(retrieved[0], RunItem::Message(_)));
645 assert!(matches!(retrieved[1], RunItem::ToolCall(_)));
646 assert!(matches!(retrieved[2], RunItem::ToolOutput(_)));
647 assert!(matches!(retrieved[3], RunItem::Handoff(_)));
648 }
649
650 #[tokio::test]
651 async fn test_multiple_sessions() {
652 let session1 = SqliteSession::new_in_memory("user1").await.unwrap();
657 let session2 = SqliteSession::new_in_memory("user2").await.unwrap();
658
659 session1
661 .add_items(vec![RunItem::Message(MessageItem {
662 id: "1".to_string(),
663 role: Role::User,
664 content: "Session 1 message".to_string(),
665 created_at: Utc::now(),
666 })])
667 .await
668 .unwrap();
669
670 session2
671 .add_items(vec![RunItem::Message(MessageItem {
672 id: "2".to_string(),
673 role: Role::User,
674 content: "Session 2 message".to_string(),
675 created_at: Utc::now(),
676 })])
677 .await
678 .unwrap();
679
680 let items1 = session1.get_items(None).await.unwrap();
682 let items2 = session2.get_items(None).await.unwrap();
683
684 assert_eq!(items1.len(), 1);
685 assert_eq!(items2.len(), 1);
686
687 if let RunItem::Message(msg) = &items1[0] {
688 assert_eq!(msg.content, "Session 1 message");
689 }
690 if let RunItem::Message(msg) = &items2[0] {
691 assert_eq!(msg.content, "Session 2 message");
692 }
693
694 }
696
697 #[tokio::test]
698 async fn test_sqlite_session_store_load_save_roundtrip() {
699 use async_openai::types::{
700 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
701 ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
702 ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
703 ChatCompletionToolType, FunctionCall,
704 };
705
706 let store = SqliteSessionStore::new_in_memory().await.unwrap();
707 let session_id = crate::sessions::SessionId("s_sqlite".into());
708
709 let sys = ChatCompletionRequestSystemMessageArgs::default()
711 .content("sys")
712 .build()
713 .unwrap();
714 let usr = ChatCompletionRequestUserMessageArgs::default()
715 .content("hi")
716 .build()
717 .unwrap();
718 let tc = ChatCompletionMessageToolCall {
719 id: "c1".to_string(),
720 r#type: ChatCompletionToolType::Function,
721 function: FunctionCall {
722 name: "calc".to_string(),
723 arguments: "{\"a\":1}".to_string(),
724 },
725 };
726 let asst = ChatCompletionRequestAssistantMessageArgs::default()
727 .content("")
728 .tool_calls(vec![tc])
729 .build()
730 .unwrap();
731 let tool = ChatCompletionRequestToolMessageArgs::default()
732 .content("{\"sum\":2}")
733 .tool_call_id("c1")
734 .build()
735 .unwrap();
736
737 let history = vec![
738 ChatCompletionRequestMessage::System(sys),
739 ChatCompletionRequestMessage::User(usr),
740 ChatCompletionRequestMessage::Assistant(asst),
741 ChatCompletionRequestMessage::Tool(tool),
742 ];
743
744 let mut save_store = store.clone();
746 Service::call(
747 &mut save_store,
748 SaveSession {
749 id: session_id.clone(),
750 history: history.clone(),
751 },
752 )
753 .await
754 .unwrap();
755
756 let mut load_store = store.clone();
758 let loaded = Service::call(&mut load_store, LoadSession { id: session_id })
759 .await
760 .unwrap();
761
762 assert_eq!(loaded.len(), history.len());
763 if let ChatCompletionRequestMessage::Tool(t) = &loaded[3] {
765 if let async_openai::types::ChatCompletionRequestToolMessageContent::Text(txt) =
766 &t.content
767 {
768 let v: serde_json::Value = serde_json::from_str(txt).unwrap();
769 assert_eq!(v, serde_json::json!({"sum":2}));
770 } else {
771 panic!("expected text content");
772 }
773 } else {
774 panic!("expected tool message at index 3");
775 }
776 }
777}