1use crate::storage::{
6 Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7 CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8 ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9 StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::Connection;
14use std::path::Path;
15use std::str::FromStr;
16use tokio::sync::Mutex;
17use uuid::Uuid;
18
19pub struct LocalStorage {
24 conn: Mutex<Connection>,
25}
26
27impl LocalStorage {
28 pub async fn new(db_path: &str) -> Result<Self, StorageError> {
30 if let Some(parent) = Path::new(db_path).parent() {
32 std::fs::create_dir_all(parent).map_err(|e| {
33 StorageError::Connection(format!("Failed to create database directory: {}", e))
34 })?;
35 }
36
37 let db = libsql::Builder::new_local(db_path)
38 .build()
39 .await
40 .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
41
42 let conn = db.connect().map_err(|e| {
43 StorageError::Connection(format!("Failed to connect to database: {}", e))
44 })?;
45
46 let storage = Self {
47 conn: Mutex::new(conn),
48 };
49 storage.init_schema().await?;
50
51 Ok(storage)
52 }
53
54 pub async fn from_connection(conn: Connection) -> Result<Self, StorageError> {
58 let storage = Self {
59 conn: Mutex::new(conn),
60 };
61 storage.init_schema().await?;
62 Ok(storage)
63 }
64
65 async fn init_schema(&self) -> Result<(), StorageError> {
67 let conn = self.conn.lock().await;
68 super::migrations::run_migrations(&conn)
69 .await
70 .map_err(StorageError::Internal)
71 }
72
73 pub fn connection(&self) -> &Mutex<Connection> {
77 &self.conn
78 }
79
80 async fn get_latest_checkpoint_for_session_inner(
84 conn: &Connection,
85 session_id: Uuid,
86 ) -> Result<Checkpoint, StorageError> {
87 let mut rows = conn
88 .query(
89 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
90 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
91 [session_id.to_string()],
92 )
93 .await
94 .map_err(|e| StorageError::Internal(e.to_string()))?;
95
96 if let Ok(Some(row)) = rows.next().await {
97 let id: String = row
98 .get(0)
99 .map_err(|e| StorageError::Internal(e.to_string()))?;
100 let session_id: String = row
101 .get(1)
102 .map_err(|e| StorageError::Internal(e.to_string()))?;
103 let parent_id: Option<String> = row.get(2).ok();
104 let state: Option<String> = row.get(3).ok();
105 let created_at: String = row
106 .get(4)
107 .map_err(|e| StorageError::Internal(e.to_string()))?;
108 let updated_at: String = row
109 .get(5)
110 .map_err(|e| StorageError::Internal(e.to_string()))?;
111
112 let state: CheckpointState = if let Some(state_str) = state {
113 serde_json::from_str(&state_str).unwrap_or_default()
114 } else {
115 CheckpointState::default()
116 };
117
118 Ok(Checkpoint {
119 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
120 session_id: Uuid::from_str(&session_id)
121 .map_err(|e| StorageError::Internal(e.to_string()))?,
122 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
123 state,
124 created_at: parse_datetime(&created_at)?,
125 updated_at: parse_datetime(&updated_at)?,
126 })
127 } else {
128 Err(StorageError::NotFound(format!(
129 "No checkpoints found for session {}",
130 session_id
131 )))
132 }
133 }
134}
135
136#[async_trait]
137impl SessionStorage for LocalStorage {
138 async fn list_sessions(
139 &self,
140 query: &ListSessionsQuery,
141 ) -> Result<ListSessionsResult, StorageError> {
142 let limit = query.limit.unwrap_or(100);
143 let offset = query.offset.unwrap_or(0);
144
145 let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at,
146 (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
147 (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
148 FROM sessions s WHERE 1=1".to_string();
149
150 if let Some(status) = &query.status {
155 sql.push_str(&format!(" AND s.status = '{}'", status));
156 }
157 if let Some(visibility) = &query.visibility {
158 sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
159 }
160 if query.search.is_some() {
161 sql.push_str(" AND s.title LIKE '%' || ? || '%'");
162 }
163
164 sql.push_str(&format!(
165 " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
166 limit, offset
167 ));
168
169 let conn = self.conn.lock().await;
170 let mut rows = if let Some(search) = &query.search {
171 conn.query(&sql, [search.as_str()])
172 .await
173 .map_err(|e| StorageError::Internal(e.to_string()))?
174 } else {
175 conn.query(&sql, ())
176 .await
177 .map_err(|e| StorageError::Internal(e.to_string()))?
178 };
179
180 let mut sessions = Vec::new();
181 while let Ok(Some(row)) = rows.next().await {
182 let id: String = row
183 .get(0)
184 .map_err(|e| StorageError::Internal(e.to_string()))?;
185 let title: String = row
186 .get(1)
187 .map_err(|e| StorageError::Internal(e.to_string()))?;
188 let visibility: String = row
189 .get(2)
190 .map_err(|e| StorageError::Internal(e.to_string()))?;
191 let status: String = row
192 .get(3)
193 .map_err(|e| StorageError::Internal(e.to_string()))?;
194 let cwd: Option<String> = row.get(4).ok();
195 let created_at: String = row
196 .get(5)
197 .map_err(|e| StorageError::Internal(e.to_string()))?;
198 let updated_at: String = row
199 .get(6)
200 .map_err(|e| StorageError::Internal(e.to_string()))?;
201 let checkpoint_count: i64 = row.get(7).unwrap_or(0);
202 let active_checkpoint_id: Option<String> = row.get(8).ok();
203
204 sessions.push(SessionSummary {
205 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
206 title,
207 visibility: parse_visibility(&visibility),
208 status: parse_status(&status),
209 cwd,
210 created_at: parse_datetime(&created_at)?,
211 updated_at: parse_datetime(&updated_at)?,
212 message_count: checkpoint_count as u32,
213 active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
214 last_message_at: None,
215 });
216 }
217
218 Ok(ListSessionsResult {
219 sessions,
220 total: None,
221 })
222 }
223
224 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
225 let conn = self.conn.lock().await;
226 let mut rows = conn
227 .query(
228 "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
229 [session_id.to_string()],
230 )
231 .await
232 .map_err(|e| StorageError::Internal(e.to_string()))?;
233
234 if let Ok(Some(row)) = rows.next().await {
235 let id: String = row
236 .get(0)
237 .map_err(|e| StorageError::Internal(e.to_string()))?;
238 let title: String = row
239 .get(1)
240 .map_err(|e| StorageError::Internal(e.to_string()))?;
241 let visibility: String = row
242 .get(2)
243 .map_err(|e| StorageError::Internal(e.to_string()))?;
244 let status: String = row
245 .get(3)
246 .map_err(|e| StorageError::Internal(e.to_string()))?;
247 let cwd: Option<String> = row.get(4).ok();
248 let created_at: String = row
249 .get(5)
250 .map_err(|e| StorageError::Internal(e.to_string()))?;
251 let updated_at: String = row
252 .get(6)
253 .map_err(|e| StorageError::Internal(e.to_string()))?;
254
255 let active_checkpoint =
257 Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
258 .await
259 .ok();
260
261 Ok(Session {
262 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
263 title,
264 visibility: parse_visibility(&visibility),
265 status: parse_status(&status),
266 cwd,
267 created_at: parse_datetime(&created_at)?,
268 updated_at: parse_datetime(&updated_at)?,
269 active_checkpoint,
270 })
271 } else {
272 Err(StorageError::NotFound(format!(
273 "Session {} not found",
274 session_id
275 )))
276 }
277 }
278
279 async fn create_session(
280 &self,
281 request: &CreateSessionRequest,
282 ) -> Result<CreateSessionResult, StorageError> {
283 let now = Utc::now();
284 let session_id = Uuid::new_v4();
285 let checkpoint_id = Uuid::new_v4();
286
287 let conn = self.conn.lock().await;
288
289 conn.execute(
291 "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
292 (
293 session_id.to_string(),
294 request.title.as_str(),
295 request.visibility.to_string(),
296 "ACTIVE",
297 request.cwd.as_deref(),
298 now.to_rfc3339(),
299 now.to_rfc3339(),
300 ),
301 )
302 .await
303 .map_err(|e| StorageError::Internal(e.to_string()))?;
304
305 let state_json = serde_json::to_string(&request.initial_state)
307 .map_err(|e| StorageError::Internal(e.to_string()))?;
308
309 conn.execute(
310 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
311 (
312 checkpoint_id.to_string(),
313 session_id.to_string(),
314 None::<String>,
315 state_json,
316 now.to_rfc3339(),
317 now.to_rfc3339(),
318 ),
319 )
320 .await
321 .map_err(|e| StorageError::Internal(e.to_string()))?;
322
323 Ok(CreateSessionResult {
324 session_id,
325 checkpoint: Checkpoint {
326 id: checkpoint_id,
327 session_id,
328 parent_id: None,
329 state: request.initial_state.clone(),
330 created_at: now,
331 updated_at: now,
332 },
333 })
334 }
335
336 async fn update_session(
337 &self,
338 session_id: Uuid,
339 request: &UpdateSessionRequest,
340 ) -> Result<Session, StorageError> {
341 let now = Utc::now();
342
343 {
344 let conn = self.conn.lock().await;
345
346 if let Some(title) = &request.title {
348 conn.execute(
349 "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
350 (title.as_str(), now.to_rfc3339(), session_id.to_string()),
351 )
352 .await
353 .map_err(|e| StorageError::Internal(e.to_string()))?;
354 }
355 if let Some(visibility) = &request.visibility {
356 conn.execute(
357 "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
358 (
359 visibility.to_string(),
360 now.to_rfc3339(),
361 session_id.to_string(),
362 ),
363 )
364 .await
365 .map_err(|e| StorageError::Internal(e.to_string()))?;
366 }
367 }
368
369 self.get_session(session_id).await
370 }
371
372 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
373 let now = Utc::now();
375 let conn = self.conn.lock().await;
376 conn.execute(
377 "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
378 (now.to_rfc3339(), session_id.to_string()),
379 )
380 .await
381 .map_err(|e| StorageError::Internal(e.to_string()))?;
382 Ok(())
383 }
384
385 async fn list_checkpoints(
386 &self,
387 session_id: Uuid,
388 query: &ListCheckpointsQuery,
389 ) -> Result<ListCheckpointsResult, StorageError> {
390 let limit = query.limit.unwrap_or(100);
391 let offset = query.offset.unwrap_or(0);
392
393 let sql = format!(
394 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
395 WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
396 limit, offset
397 );
398
399 let conn = self.conn.lock().await;
400 let mut rows = conn
401 .query(&sql, [session_id.to_string()])
402 .await
403 .map_err(|e| StorageError::Internal(e.to_string()))?;
404
405 let mut checkpoints = Vec::new();
406 while let Ok(Some(row)) = rows.next().await {
407 let id: String = row
408 .get(0)
409 .map_err(|e| StorageError::Internal(e.to_string()))?;
410 let session_id: String = row
411 .get(1)
412 .map_err(|e| StorageError::Internal(e.to_string()))?;
413 let parent_id: Option<String> = row.get(2).ok();
414 let state: Option<String> = row.get(3).ok();
415 let created_at: String = row
416 .get(4)
417 .map_err(|e| StorageError::Internal(e.to_string()))?;
418 let updated_at: String = row
419 .get(5)
420 .map_err(|e| StorageError::Internal(e.to_string()))?;
421
422 let state: CheckpointState = if let Some(state_str) = state {
423 serde_json::from_str(&state_str).unwrap_or_default()
424 } else {
425 CheckpointState::default()
426 };
427
428 checkpoints.push(CheckpointSummary {
429 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
430 session_id: Uuid::from_str(&session_id)
431 .map_err(|e| StorageError::Internal(e.to_string()))?,
432 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
433 message_count: state.messages.len() as u32,
434 created_at: parse_datetime(&created_at)?,
435 updated_at: parse_datetime(&updated_at)?,
436 });
437 }
438
439 Ok(ListCheckpointsResult {
440 checkpoints,
441 total: None,
442 })
443 }
444
445 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
446 let conn = self.conn.lock().await;
447 let mut rows = conn
448 .query(
449 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
450 [checkpoint_id.to_string()],
451 )
452 .await
453 .map_err(|e| StorageError::Internal(e.to_string()))?;
454
455 if let Ok(Some(row)) = rows.next().await {
456 let id: String = row
457 .get(0)
458 .map_err(|e| StorageError::Internal(e.to_string()))?;
459 let session_id: String = row
460 .get(1)
461 .map_err(|e| StorageError::Internal(e.to_string()))?;
462 let parent_id: Option<String> = row.get(2).ok();
463 let state: Option<String> = row.get(3).ok();
464 let created_at: String = row
465 .get(4)
466 .map_err(|e| StorageError::Internal(e.to_string()))?;
467 let updated_at: String = row
468 .get(5)
469 .map_err(|e| StorageError::Internal(e.to_string()))?;
470
471 let state: CheckpointState = if let Some(state_str) = state {
472 serde_json::from_str(&state_str).unwrap_or_default()
473 } else {
474 CheckpointState::default()
475 };
476
477 Ok(Checkpoint {
478 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
479 session_id: Uuid::from_str(&session_id)
480 .map_err(|e| StorageError::Internal(e.to_string()))?,
481 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
482 state,
483 created_at: parse_datetime(&created_at)?,
484 updated_at: parse_datetime(&updated_at)?,
485 })
486 } else {
487 Err(StorageError::NotFound(format!(
488 "Checkpoint {} not found",
489 checkpoint_id
490 )))
491 }
492 }
493
494 async fn create_checkpoint(
495 &self,
496 session_id: Uuid,
497 request: &CreateCheckpointRequest,
498 ) -> Result<Checkpoint, StorageError> {
499 let now = Utc::now();
500 let checkpoint_id = Uuid::new_v4();
501
502 let state_json = serde_json::to_string(&request.state)
503 .map_err(|e| StorageError::Internal(e.to_string()))?;
504
505 let conn = self.conn.lock().await;
506
507 conn.execute(
508 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
509 (
510 checkpoint_id.to_string(),
511 session_id.to_string(),
512 request.parent_id.map(|id| id.to_string()),
513 state_json,
514 now.to_rfc3339(),
515 now.to_rfc3339(),
516 ),
517 )
518 .await
519 .map_err(|e| StorageError::Internal(e.to_string()))?;
520
521 conn.execute(
523 "UPDATE sessions SET updated_at = ? WHERE id = ?",
524 (now.to_rfc3339(), session_id.to_string()),
525 )
526 .await
527 .map_err(|e| StorageError::Internal(e.to_string()))?;
528
529 Ok(Checkpoint {
530 id: checkpoint_id,
531 session_id,
532 parent_id: request.parent_id,
533 state: request.state.clone(),
534 created_at: now,
535 updated_at: now,
536 })
537 }
538}
539
540fn parse_visibility(s: &str) -> SessionVisibility {
542 match s.to_uppercase().as_str() {
543 "PUBLIC" => SessionVisibility::Public,
544 _ => SessionVisibility::Private,
545 }
546}
547
548fn parse_status(s: &str) -> SessionStatus {
549 match s.to_uppercase().as_str() {
550 "DELETED" => SessionStatus::Deleted,
551 _ => SessionStatus::Active,
552 }
553}
554
555fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
556 DateTime::parse_from_rfc3339(s)
557 .map(|dt| dt.with_timezone(&Utc))
558 .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
559}