1use crate::storage::{
6 Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7 CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8 ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9 StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::Connection;
14use std::path::Path;
15use std::str::FromStr;
16use tokio::sync::Mutex;
17use uuid::Uuid;
18
19pub struct LocalStorage {
24 conn: Mutex<Connection>,
25}
26
27impl LocalStorage {
28 pub async fn new(db_path: &str) -> Result<Self, StorageError> {
30 if let Some(parent) = Path::new(db_path).parent() {
32 std::fs::create_dir_all(parent).map_err(|e| {
33 StorageError::Connection(format!("Failed to create database directory: {}", e))
34 })?;
35 }
36
37 let db = libsql::Builder::new_local(db_path)
38 .build()
39 .await
40 .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
41
42 let conn = db.connect().map_err(|e| {
43 StorageError::Connection(format!("Failed to connect to database: {}", e))
44 })?;
45
46 let storage = Self {
47 conn: Mutex::new(conn),
48 };
49 storage.init_schema().await?;
50
51 Ok(storage)
52 }
53
54 pub async fn from_connection(conn: Connection) -> Result<Self, StorageError> {
58 let storage = Self {
59 conn: Mutex::new(conn),
60 };
61 storage.init_schema().await?;
62 Ok(storage)
63 }
64
65 async fn init_schema(&self) -> Result<(), StorageError> {
71 let conn = self.conn.lock().await;
72
73 conn.execute(
75 "CREATE TABLE IF NOT EXISTS sessions (
76 id TEXT PRIMARY KEY,
77 title TEXT NOT NULL,
78 agent_id TEXT,
79 visibility TEXT NOT NULL DEFAULT 'PRIVATE',
80 status TEXT DEFAULT 'ACTIVE',
81 cwd TEXT,
82 created_at TEXT NOT NULL,
83 updated_at TEXT NOT NULL
84 )",
85 (),
86 )
87 .await
88 .map_err(|e| StorageError::Internal(e.to_string()))?;
89
90 conn.execute(
92 "CREATE TABLE IF NOT EXISTS checkpoints (
93 id TEXT PRIMARY KEY,
94 session_id TEXT NOT NULL,
95 status TEXT,
96 execution_depth INTEGER,
97 parent_id TEXT,
98 state TEXT,
99 created_at TEXT NOT NULL,
100 updated_at TEXT NOT NULL,
101 FOREIGN KEY(session_id) REFERENCES sessions(id),
102 FOREIGN KEY(parent_id) REFERENCES checkpoints(id)
103 )",
104 (),
105 )
106 .await
107 .map_err(|e| StorageError::Internal(e.to_string()))?;
108
109 let _ = conn
112 .execute(
113 "ALTER TABLE sessions ADD COLUMN status TEXT DEFAULT 'ACTIVE'",
114 (),
115 )
116 .await;
117 let _ = conn
118 .execute("ALTER TABLE sessions ADD COLUMN cwd TEXT", ())
119 .await;
120
121 conn.execute(
123 "CREATE INDEX IF NOT EXISTS idx_checkpoints_session_id ON checkpoints(session_id)",
124 (),
125 )
126 .await
127 .map_err(|e| StorageError::Internal(e.to_string()))?;
128
129 Ok(())
130 }
131
132 pub fn connection(&self) -> &Mutex<Connection> {
136 &self.conn
137 }
138
139 async fn get_latest_checkpoint_for_session_inner(
143 conn: &Connection,
144 session_id: Uuid,
145 ) -> Result<Checkpoint, StorageError> {
146 let mut rows = conn
147 .query(
148 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
149 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
150 [session_id.to_string()],
151 )
152 .await
153 .map_err(|e| StorageError::Internal(e.to_string()))?;
154
155 if let Ok(Some(row)) = rows.next().await {
156 let id: String = row
157 .get(0)
158 .map_err(|e| StorageError::Internal(e.to_string()))?;
159 let session_id: String = row
160 .get(1)
161 .map_err(|e| StorageError::Internal(e.to_string()))?;
162 let parent_id: Option<String> = row.get(2).ok();
163 let state: Option<String> = row.get(3).ok();
164 let created_at: String = row
165 .get(4)
166 .map_err(|e| StorageError::Internal(e.to_string()))?;
167 let updated_at: String = row
168 .get(5)
169 .map_err(|e| StorageError::Internal(e.to_string()))?;
170
171 let state: CheckpointState = if let Some(state_str) = state {
172 serde_json::from_str(&state_str).unwrap_or_default()
173 } else {
174 CheckpointState::default()
175 };
176
177 Ok(Checkpoint {
178 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
179 session_id: Uuid::from_str(&session_id)
180 .map_err(|e| StorageError::Internal(e.to_string()))?,
181 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
182 state,
183 created_at: parse_datetime(&created_at)?,
184 updated_at: parse_datetime(&updated_at)?,
185 })
186 } else {
187 Err(StorageError::NotFound(format!(
188 "No checkpoints found for session {}",
189 session_id
190 )))
191 }
192 }
193}
194
195#[async_trait]
196impl SessionStorage for LocalStorage {
197 async fn list_sessions(
198 &self,
199 query: &ListSessionsQuery,
200 ) -> Result<ListSessionsResult, StorageError> {
201 let limit = query.limit.unwrap_or(100);
202 let offset = query.offset.unwrap_or(0);
203
204 let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at,
205 (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
206 (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
207 FROM sessions s WHERE 1=1".to_string();
208
209 if let Some(status) = &query.status {
214 sql.push_str(&format!(" AND s.status = '{}'", status));
215 }
216 if let Some(visibility) = &query.visibility {
217 sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
218 }
219 if query.search.is_some() {
220 sql.push_str(" AND s.title LIKE '%' || ? || '%'");
221 }
222
223 sql.push_str(&format!(
224 " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
225 limit, offset
226 ));
227
228 let conn = self.conn.lock().await;
229 let mut rows = if let Some(search) = &query.search {
230 conn.query(&sql, [search.as_str()])
231 .await
232 .map_err(|e| StorageError::Internal(e.to_string()))?
233 } else {
234 conn.query(&sql, ())
235 .await
236 .map_err(|e| StorageError::Internal(e.to_string()))?
237 };
238
239 let mut sessions = Vec::new();
240 while let Ok(Some(row)) = rows.next().await {
241 let id: String = row
242 .get(0)
243 .map_err(|e| StorageError::Internal(e.to_string()))?;
244 let title: String = row
245 .get(1)
246 .map_err(|e| StorageError::Internal(e.to_string()))?;
247 let visibility: String = row
248 .get(2)
249 .map_err(|e| StorageError::Internal(e.to_string()))?;
250 let status: String = row
251 .get(3)
252 .map_err(|e| StorageError::Internal(e.to_string()))?;
253 let cwd: Option<String> = row.get(4).ok();
254 let created_at: String = row
255 .get(5)
256 .map_err(|e| StorageError::Internal(e.to_string()))?;
257 let updated_at: String = row
258 .get(6)
259 .map_err(|e| StorageError::Internal(e.to_string()))?;
260 let checkpoint_count: i64 = row.get(7).unwrap_or(0);
261 let active_checkpoint_id: Option<String> = row.get(8).ok();
262
263 sessions.push(SessionSummary {
264 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
265 title,
266 visibility: parse_visibility(&visibility),
267 status: parse_status(&status),
268 cwd,
269 created_at: parse_datetime(&created_at)?,
270 updated_at: parse_datetime(&updated_at)?,
271 message_count: checkpoint_count as u32,
272 active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
273 last_message_at: None,
274 });
275 }
276
277 Ok(ListSessionsResult {
278 sessions,
279 total: None,
280 })
281 }
282
283 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
284 let conn = self.conn.lock().await;
285 let mut rows = conn
286 .query(
287 "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
288 [session_id.to_string()],
289 )
290 .await
291 .map_err(|e| StorageError::Internal(e.to_string()))?;
292
293 if let Ok(Some(row)) = rows.next().await {
294 let id: String = row
295 .get(0)
296 .map_err(|e| StorageError::Internal(e.to_string()))?;
297 let title: String = row
298 .get(1)
299 .map_err(|e| StorageError::Internal(e.to_string()))?;
300 let visibility: String = row
301 .get(2)
302 .map_err(|e| StorageError::Internal(e.to_string()))?;
303 let status: String = row
304 .get(3)
305 .map_err(|e| StorageError::Internal(e.to_string()))?;
306 let cwd: Option<String> = row.get(4).ok();
307 let created_at: String = row
308 .get(5)
309 .map_err(|e| StorageError::Internal(e.to_string()))?;
310 let updated_at: String = row
311 .get(6)
312 .map_err(|e| StorageError::Internal(e.to_string()))?;
313
314 let active_checkpoint =
316 Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
317 .await
318 .ok();
319
320 Ok(Session {
321 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
322 title,
323 visibility: parse_visibility(&visibility),
324 status: parse_status(&status),
325 cwd,
326 created_at: parse_datetime(&created_at)?,
327 updated_at: parse_datetime(&updated_at)?,
328 active_checkpoint,
329 })
330 } else {
331 Err(StorageError::NotFound(format!(
332 "Session {} not found",
333 session_id
334 )))
335 }
336 }
337
338 async fn create_session(
339 &self,
340 request: &CreateSessionRequest,
341 ) -> Result<CreateSessionResult, StorageError> {
342 let now = Utc::now();
343 let session_id = Uuid::new_v4();
344 let checkpoint_id = Uuid::new_v4();
345
346 let conn = self.conn.lock().await;
347
348 conn.execute(
350 "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
351 (
352 session_id.to_string(),
353 request.title.as_str(),
354 request.visibility.to_string(),
355 "ACTIVE",
356 request.cwd.as_deref(),
357 now.to_rfc3339(),
358 now.to_rfc3339(),
359 ),
360 )
361 .await
362 .map_err(|e| StorageError::Internal(e.to_string()))?;
363
364 let state_json = serde_json::to_string(&request.initial_state)
366 .map_err(|e| StorageError::Internal(e.to_string()))?;
367
368 conn.execute(
369 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
370 (
371 checkpoint_id.to_string(),
372 session_id.to_string(),
373 None::<String>,
374 state_json,
375 now.to_rfc3339(),
376 now.to_rfc3339(),
377 ),
378 )
379 .await
380 .map_err(|e| StorageError::Internal(e.to_string()))?;
381
382 Ok(CreateSessionResult {
383 session_id,
384 checkpoint: Checkpoint {
385 id: checkpoint_id,
386 session_id,
387 parent_id: None,
388 state: request.initial_state.clone(),
389 created_at: now,
390 updated_at: now,
391 },
392 })
393 }
394
395 async fn update_session(
396 &self,
397 session_id: Uuid,
398 request: &UpdateSessionRequest,
399 ) -> Result<Session, StorageError> {
400 let now = Utc::now();
401
402 {
403 let conn = self.conn.lock().await;
404
405 if let Some(title) = &request.title {
407 conn.execute(
408 "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
409 (title.as_str(), now.to_rfc3339(), session_id.to_string()),
410 )
411 .await
412 .map_err(|e| StorageError::Internal(e.to_string()))?;
413 }
414 if let Some(visibility) = &request.visibility {
415 conn.execute(
416 "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
417 (
418 visibility.to_string(),
419 now.to_rfc3339(),
420 session_id.to_string(),
421 ),
422 )
423 .await
424 .map_err(|e| StorageError::Internal(e.to_string()))?;
425 }
426 }
427
428 self.get_session(session_id).await
429 }
430
431 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
432 let now = Utc::now();
434 let conn = self.conn.lock().await;
435 conn.execute(
436 "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
437 (now.to_rfc3339(), session_id.to_string()),
438 )
439 .await
440 .map_err(|e| StorageError::Internal(e.to_string()))?;
441 Ok(())
442 }
443
444 async fn list_checkpoints(
445 &self,
446 session_id: Uuid,
447 query: &ListCheckpointsQuery,
448 ) -> Result<ListCheckpointsResult, StorageError> {
449 let limit = query.limit.unwrap_or(100);
450 let offset = query.offset.unwrap_or(0);
451
452 let sql = format!(
453 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
454 WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
455 limit, offset
456 );
457
458 let conn = self.conn.lock().await;
459 let mut rows = conn
460 .query(&sql, [session_id.to_string()])
461 .await
462 .map_err(|e| StorageError::Internal(e.to_string()))?;
463
464 let mut checkpoints = Vec::new();
465 while let Ok(Some(row)) = rows.next().await {
466 let id: String = row
467 .get(0)
468 .map_err(|e| StorageError::Internal(e.to_string()))?;
469 let session_id: String = row
470 .get(1)
471 .map_err(|e| StorageError::Internal(e.to_string()))?;
472 let parent_id: Option<String> = row.get(2).ok();
473 let state: Option<String> = row.get(3).ok();
474 let created_at: String = row
475 .get(4)
476 .map_err(|e| StorageError::Internal(e.to_string()))?;
477 let updated_at: String = row
478 .get(5)
479 .map_err(|e| StorageError::Internal(e.to_string()))?;
480
481 let state: CheckpointState = if let Some(state_str) = state {
482 serde_json::from_str(&state_str).unwrap_or_default()
483 } else {
484 CheckpointState::default()
485 };
486
487 checkpoints.push(CheckpointSummary {
488 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
489 session_id: Uuid::from_str(&session_id)
490 .map_err(|e| StorageError::Internal(e.to_string()))?,
491 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
492 message_count: state.messages.len() as u32,
493 created_at: parse_datetime(&created_at)?,
494 updated_at: parse_datetime(&updated_at)?,
495 });
496 }
497
498 Ok(ListCheckpointsResult {
499 checkpoints,
500 total: None,
501 })
502 }
503
504 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
505 let conn = self.conn.lock().await;
506 let mut rows = conn
507 .query(
508 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
509 [checkpoint_id.to_string()],
510 )
511 .await
512 .map_err(|e| StorageError::Internal(e.to_string()))?;
513
514 if let Ok(Some(row)) = rows.next().await {
515 let id: String = row
516 .get(0)
517 .map_err(|e| StorageError::Internal(e.to_string()))?;
518 let session_id: String = row
519 .get(1)
520 .map_err(|e| StorageError::Internal(e.to_string()))?;
521 let parent_id: Option<String> = row.get(2).ok();
522 let state: Option<String> = row.get(3).ok();
523 let created_at: String = row
524 .get(4)
525 .map_err(|e| StorageError::Internal(e.to_string()))?;
526 let updated_at: String = row
527 .get(5)
528 .map_err(|e| StorageError::Internal(e.to_string()))?;
529
530 let state: CheckpointState = if let Some(state_str) = state {
531 serde_json::from_str(&state_str).unwrap_or_default()
532 } else {
533 CheckpointState::default()
534 };
535
536 Ok(Checkpoint {
537 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
538 session_id: Uuid::from_str(&session_id)
539 .map_err(|e| StorageError::Internal(e.to_string()))?,
540 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
541 state,
542 created_at: parse_datetime(&created_at)?,
543 updated_at: parse_datetime(&updated_at)?,
544 })
545 } else {
546 Err(StorageError::NotFound(format!(
547 "Checkpoint {} not found",
548 checkpoint_id
549 )))
550 }
551 }
552
553 async fn create_checkpoint(
554 &self,
555 session_id: Uuid,
556 request: &CreateCheckpointRequest,
557 ) -> Result<Checkpoint, StorageError> {
558 let now = Utc::now();
559 let checkpoint_id = Uuid::new_v4();
560
561 let state_json = serde_json::to_string(&request.state)
562 .map_err(|e| StorageError::Internal(e.to_string()))?;
563
564 let conn = self.conn.lock().await;
565
566 conn.execute(
567 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
568 (
569 checkpoint_id.to_string(),
570 session_id.to_string(),
571 request.parent_id.map(|id| id.to_string()),
572 state_json,
573 now.to_rfc3339(),
574 now.to_rfc3339(),
575 ),
576 )
577 .await
578 .map_err(|e| StorageError::Internal(e.to_string()))?;
579
580 conn.execute(
582 "UPDATE sessions SET updated_at = ? WHERE id = ?",
583 (now.to_rfc3339(), session_id.to_string()),
584 )
585 .await
586 .map_err(|e| StorageError::Internal(e.to_string()))?;
587
588 Ok(Checkpoint {
589 id: checkpoint_id,
590 session_id,
591 parent_id: request.parent_id,
592 state: request.state.clone(),
593 created_at: now,
594 updated_at: now,
595 })
596 }
597}
598
599fn parse_visibility(s: &str) -> SessionVisibility {
601 match s.to_uppercase().as_str() {
602 "PUBLIC" => SessionVisibility::Public,
603 _ => SessionVisibility::Private,
604 }
605}
606
607fn parse_status(s: &str) -> SessionStatus {
608 match s.to_uppercase().as_str() {
609 "DELETED" => SessionStatus::Deleted,
610 _ => SessionStatus::Active,
611 }
612}
613
614fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
615 DateTime::parse_from_rfc3339(s)
616 .map(|dt| dt.with_timezone(&Utc))
617 .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
618}