1use crate::storage::{
6 Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7 CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8 ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9 StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::{Connection, Database};
14use std::path::Path;
15use std::str::FromStr;
16use tempfile::TempDir;
17use uuid::Uuid;
18
19pub struct LocalStorage {
24 db: Database,
25 _temp_dir: Option<TempDir>,
27}
28
29impl LocalStorage {
30 pub async fn new(db_path: &str) -> Result<Self, StorageError> {
32 let (resolved_path, temp_dir) = if db_path == ":memory:" {
33 let temp_dir = tempfile::tempdir().map_err(|e| {
36 StorageError::Connection(format!("Failed to create temp directory: {}", e))
37 })?;
38 (temp_dir.path().join("local-storage.db"), Some(temp_dir))
39 } else {
40 (Path::new(db_path).to_path_buf(), None)
41 };
42
43 if let Some(parent) = resolved_path.parent() {
45 std::fs::create_dir_all(parent).map_err(|e| {
46 StorageError::Connection(format!("Failed to create database directory: {}", e))
47 })?;
48 }
49
50 let db = libsql::Builder::new_local(&resolved_path)
51 .build()
52 .await
53 .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
54
55 let storage = Self {
56 db,
57 _temp_dir: temp_dir,
58 };
59 storage.init_schema().await?;
60
61 Ok(storage)
62 }
63
64 pub async fn from_db_and_connection(
69 db: Database,
70 _conn: Connection,
71 ) -> Result<Self, StorageError> {
72 let storage = Self {
73 db,
74 _temp_dir: None,
75 };
76 storage.init_schema().await?;
77 Ok(storage)
78 }
79
80 #[deprecated(note = "use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)")]
85 pub async fn from_connection(_conn: Connection) -> Result<Self, StorageError> {
86 Err(StorageError::Connection(
87 "LocalStorage::from_connection is unsupported without the owning libsql::Database; use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)".to_string(),
88 ))
89 }
90
91 async fn init_schema(&self) -> Result<(), StorageError> {
93 let conn = self.connection()?;
94 super::migrations::run_migrations(&conn)
95 .await
96 .map_err(StorageError::Internal)
97 }
98
99 pub(crate) fn connection(&self) -> Result<Connection, StorageError> {
101 self.db
102 .connect()
103 .map_err(|e| StorageError::Connection(format!("Failed to connect to database: {}", e)))
104 }
105
106 async fn get_latest_checkpoint_for_session_inner(
108 conn: &Connection,
109 session_id: Uuid,
110 ) -> Result<Checkpoint, StorageError> {
111 let mut rows = conn
112 .query(
113 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
114 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
115 [session_id.to_string()],
116 )
117 .await
118 .map_err(|e| StorageError::Internal(e.to_string()))?;
119
120 if let Ok(Some(row)) = rows.next().await {
121 let id: String = row
122 .get(0)
123 .map_err(|e| StorageError::Internal(e.to_string()))?;
124 let session_id: String = row
125 .get(1)
126 .map_err(|e| StorageError::Internal(e.to_string()))?;
127 let parent_id: Option<String> = row.get(2).ok();
128 let state: Option<String> = row.get(3).ok();
129 let created_at: String = row
130 .get(4)
131 .map_err(|e| StorageError::Internal(e.to_string()))?;
132 let updated_at: String = row
133 .get(5)
134 .map_err(|e| StorageError::Internal(e.to_string()))?;
135
136 let state: CheckpointState = if let Some(state_str) = state {
137 serde_json::from_str(&state_str).unwrap_or_default()
138 } else {
139 CheckpointState::default()
140 };
141
142 Ok(Checkpoint {
143 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
144 session_id: Uuid::from_str(&session_id)
145 .map_err(|e| StorageError::Internal(e.to_string()))?,
146 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
147 state,
148 created_at: parse_datetime(&created_at)?,
149 updated_at: parse_datetime(&updated_at)?,
150 })
151 } else {
152 Err(StorageError::NotFound(format!(
153 "No checkpoints found for session {}",
154 session_id
155 )))
156 }
157 }
158}
159
160#[async_trait]
161impl SessionStorage for LocalStorage {
162 async fn list_sessions(
163 &self,
164 query: &ListSessionsQuery,
165 ) -> Result<ListSessionsResult, StorageError> {
166 let limit = query.limit.unwrap_or(100);
167 let offset = query.offset.unwrap_or(0);
168
169 let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at,
170 (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
171 (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
172 FROM sessions s WHERE 1=1".to_string();
173
174 if let Some(status) = &query.status {
179 sql.push_str(&format!(" AND s.status = '{}'", status));
180 }
181 if let Some(visibility) = &query.visibility {
182 sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
183 }
184 if query.search.is_some() {
185 sql.push_str(" AND s.title LIKE '%' || ? || '%'");
186 }
187
188 sql.push_str(&format!(
189 " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
190 limit, offset
191 ));
192
193 let conn = self.connection()?;
194 let mut rows = if let Some(search) = &query.search {
195 conn.query(&sql, [search.as_str()])
196 .await
197 .map_err(|e| StorageError::Internal(e.to_string()))?
198 } else {
199 conn.query(&sql, ())
200 .await
201 .map_err(|e| StorageError::Internal(e.to_string()))?
202 };
203
204 let mut sessions = Vec::new();
205 while let Ok(Some(row)) = rows.next().await {
206 let id: String = row
207 .get(0)
208 .map_err(|e| StorageError::Internal(e.to_string()))?;
209 let title: String = row
210 .get(1)
211 .map_err(|e| StorageError::Internal(e.to_string()))?;
212 let visibility: String = row
213 .get(2)
214 .map_err(|e| StorageError::Internal(e.to_string()))?;
215 let status: String = row
216 .get(3)
217 .map_err(|e| StorageError::Internal(e.to_string()))?;
218 let cwd: Option<String> = row.get(4).ok();
219 let created_at: String = row
220 .get(5)
221 .map_err(|e| StorageError::Internal(e.to_string()))?;
222 let updated_at: String = row
223 .get(6)
224 .map_err(|e| StorageError::Internal(e.to_string()))?;
225 let checkpoint_count: i64 = row.get(7).unwrap_or(0);
226 let active_checkpoint_id: Option<String> = row.get(8).ok();
227
228 sessions.push(SessionSummary {
229 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
230 title,
231 visibility: parse_visibility(&visibility),
232 status: parse_status(&status),
233 cwd,
234 created_at: parse_datetime(&created_at)?,
235 updated_at: parse_datetime(&updated_at)?,
236 message_count: checkpoint_count as u32,
237 active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
238 last_message_at: None,
239 });
240 }
241
242 Ok(ListSessionsResult {
243 sessions,
244 total: None,
245 })
246 }
247
248 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
249 let conn = self.connection()?;
250 let mut rows = conn
251 .query(
252 "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
253 [session_id.to_string()],
254 )
255 .await
256 .map_err(|e| StorageError::Internal(e.to_string()))?;
257
258 if let Ok(Some(row)) = rows.next().await {
259 let id: String = row
260 .get(0)
261 .map_err(|e| StorageError::Internal(e.to_string()))?;
262 let title: String = row
263 .get(1)
264 .map_err(|e| StorageError::Internal(e.to_string()))?;
265 let visibility: String = row
266 .get(2)
267 .map_err(|e| StorageError::Internal(e.to_string()))?;
268 let status: String = row
269 .get(3)
270 .map_err(|e| StorageError::Internal(e.to_string()))?;
271 let cwd: Option<String> = row.get(4).ok();
272 let created_at: String = row
273 .get(5)
274 .map_err(|e| StorageError::Internal(e.to_string()))?;
275 let updated_at: String = row
276 .get(6)
277 .map_err(|e| StorageError::Internal(e.to_string()))?;
278
279 let active_checkpoint =
281 Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
282 .await
283 .ok();
284
285 Ok(Session {
286 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
287 title,
288 visibility: parse_visibility(&visibility),
289 status: parse_status(&status),
290 cwd,
291 created_at: parse_datetime(&created_at)?,
292 updated_at: parse_datetime(&updated_at)?,
293 active_checkpoint,
294 })
295 } else {
296 Err(StorageError::NotFound(format!(
297 "Session {} not found",
298 session_id
299 )))
300 }
301 }
302
303 async fn create_session(
304 &self,
305 request: &CreateSessionRequest,
306 ) -> Result<CreateSessionResult, StorageError> {
307 let now = Utc::now();
308 let session_id = Uuid::new_v4();
309 let checkpoint_id = Uuid::new_v4();
310
311 let conn = self.connection()?;
312
313 conn.execute(
315 "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
316 (
317 session_id.to_string(),
318 request.title.as_str(),
319 request.visibility.to_string(),
320 "ACTIVE",
321 request.cwd.as_deref(),
322 now.to_rfc3339(),
323 now.to_rfc3339(),
324 ),
325 )
326 .await
327 .map_err(|e| StorageError::Internal(e.to_string()))?;
328
329 let state_json = serde_json::to_string(&request.initial_state)
331 .map_err(|e| StorageError::Internal(e.to_string()))?;
332
333 conn.execute(
334 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
335 (
336 checkpoint_id.to_string(),
337 session_id.to_string(),
338 None::<String>,
339 state_json,
340 now.to_rfc3339(),
341 now.to_rfc3339(),
342 ),
343 )
344 .await
345 .map_err(|e| StorageError::Internal(e.to_string()))?;
346
347 Ok(CreateSessionResult {
348 session_id,
349 checkpoint: Checkpoint {
350 id: checkpoint_id,
351 session_id,
352 parent_id: None,
353 state: request.initial_state.clone(),
354 created_at: now,
355 updated_at: now,
356 },
357 })
358 }
359
360 async fn update_session(
361 &self,
362 session_id: Uuid,
363 request: &UpdateSessionRequest,
364 ) -> Result<Session, StorageError> {
365 let now = Utc::now();
366
367 {
368 let conn = self.connection()?;
369
370 if let Some(title) = &request.title {
372 conn.execute(
373 "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
374 (title.as_str(), now.to_rfc3339(), session_id.to_string()),
375 )
376 .await
377 .map_err(|e| StorageError::Internal(e.to_string()))?;
378 }
379 if let Some(visibility) = &request.visibility {
380 conn.execute(
381 "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
382 (
383 visibility.to_string(),
384 now.to_rfc3339(),
385 session_id.to_string(),
386 ),
387 )
388 .await
389 .map_err(|e| StorageError::Internal(e.to_string()))?;
390 }
391 }
392
393 self.get_session(session_id).await
394 }
395
396 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
397 let now = Utc::now();
399 let conn = self.connection()?;
400 conn.execute(
401 "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
402 (now.to_rfc3339(), session_id.to_string()),
403 )
404 .await
405 .map_err(|e| StorageError::Internal(e.to_string()))?;
406 Ok(())
407 }
408
409 async fn list_checkpoints(
410 &self,
411 session_id: Uuid,
412 query: &ListCheckpointsQuery,
413 ) -> Result<ListCheckpointsResult, StorageError> {
414 let limit = query.limit.unwrap_or(100);
415 let offset = query.offset.unwrap_or(0);
416
417 let sql = format!(
418 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
419 WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
420 limit, offset
421 );
422
423 let conn = self.connection()?;
424 let mut rows = conn
425 .query(&sql, [session_id.to_string()])
426 .await
427 .map_err(|e| StorageError::Internal(e.to_string()))?;
428
429 let mut checkpoints = Vec::new();
430 while let Ok(Some(row)) = rows.next().await {
431 let id: String = row
432 .get(0)
433 .map_err(|e| StorageError::Internal(e.to_string()))?;
434 let session_id: String = row
435 .get(1)
436 .map_err(|e| StorageError::Internal(e.to_string()))?;
437 let parent_id: Option<String> = row.get(2).ok();
438 let state: Option<String> = row.get(3).ok();
439 let created_at: String = row
440 .get(4)
441 .map_err(|e| StorageError::Internal(e.to_string()))?;
442 let updated_at: String = row
443 .get(5)
444 .map_err(|e| StorageError::Internal(e.to_string()))?;
445
446 let state: CheckpointState = if let Some(state_str) = state {
447 serde_json::from_str(&state_str).unwrap_or_default()
448 } else {
449 CheckpointState::default()
450 };
451
452 checkpoints.push(CheckpointSummary {
453 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
454 session_id: Uuid::from_str(&session_id)
455 .map_err(|e| StorageError::Internal(e.to_string()))?,
456 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
457 message_count: state.messages.len() as u32,
458 created_at: parse_datetime(&created_at)?,
459 updated_at: parse_datetime(&updated_at)?,
460 });
461 }
462
463 Ok(ListCheckpointsResult {
464 checkpoints,
465 total: None,
466 })
467 }
468
469 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
470 let conn = self.connection()?;
471 let mut rows = conn
472 .query(
473 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
474 [checkpoint_id.to_string()],
475 )
476 .await
477 .map_err(|e| StorageError::Internal(e.to_string()))?;
478
479 if let Ok(Some(row)) = rows.next().await {
480 let id: String = row
481 .get(0)
482 .map_err(|e| StorageError::Internal(e.to_string()))?;
483 let session_id: String = row
484 .get(1)
485 .map_err(|e| StorageError::Internal(e.to_string()))?;
486 let parent_id: Option<String> = row.get(2).ok();
487 let state: Option<String> = row.get(3).ok();
488 let created_at: String = row
489 .get(4)
490 .map_err(|e| StorageError::Internal(e.to_string()))?;
491 let updated_at: String = row
492 .get(5)
493 .map_err(|e| StorageError::Internal(e.to_string()))?;
494
495 let state: CheckpointState = if let Some(state_str) = state {
496 serde_json::from_str(&state_str).unwrap_or_default()
497 } else {
498 CheckpointState::default()
499 };
500
501 Ok(Checkpoint {
502 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
503 session_id: Uuid::from_str(&session_id)
504 .map_err(|e| StorageError::Internal(e.to_string()))?,
505 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
506 state,
507 created_at: parse_datetime(&created_at)?,
508 updated_at: parse_datetime(&updated_at)?,
509 })
510 } else {
511 Err(StorageError::NotFound(format!(
512 "Checkpoint {} not found",
513 checkpoint_id
514 )))
515 }
516 }
517
518 async fn create_checkpoint(
519 &self,
520 session_id: Uuid,
521 request: &CreateCheckpointRequest,
522 ) -> Result<Checkpoint, StorageError> {
523 let now = Utc::now();
524 let checkpoint_id = Uuid::new_v4();
525
526 let state_json = serde_json::to_string(&request.state)
527 .map_err(|e| StorageError::Internal(e.to_string()))?;
528
529 let conn = self.connection()?;
530
531 conn.execute(
532 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
533 (
534 checkpoint_id.to_string(),
535 session_id.to_string(),
536 request.parent_id.map(|id| id.to_string()),
537 state_json,
538 now.to_rfc3339(),
539 now.to_rfc3339(),
540 ),
541 )
542 .await
543 .map_err(|e| StorageError::Internal(e.to_string()))?;
544
545 conn.execute(
547 "UPDATE sessions SET updated_at = ? WHERE id = ?",
548 (now.to_rfc3339(), session_id.to_string()),
549 )
550 .await
551 .map_err(|e| StorageError::Internal(e.to_string()))?;
552
553 Ok(Checkpoint {
554 id: checkpoint_id,
555 session_id,
556 parent_id: request.parent_id,
557 state: request.state.clone(),
558 created_at: now,
559 updated_at: now,
560 })
561 }
562}
563
564fn parse_visibility(s: &str) -> SessionVisibility {
566 match s.to_uppercase().as_str() {
567 "PUBLIC" => SessionVisibility::Public,
568 _ => SessionVisibility::Private,
569 }
570}
571
572fn parse_status(s: &str) -> SessionStatus {
573 match s.to_uppercase().as_str() {
574 "DELETED" => SessionStatus::Deleted,
575 _ => SessionStatus::Active,
576 }
577}
578
579fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
580 DateTime::parse_from_rfc3339(s)
581 .map(|dt| dt.with_timezone(&Utc))
582 .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
583}