1use crate::storage::{
6 Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7 CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8 ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9 StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::{Connection, Database};
14use std::path::Path;
15use std::str::FromStr;
16use tempfile::TempDir;
17use uuid::Uuid;
18
19pub struct LocalStorage {
24 db: Database,
25 _temp_dir: Option<TempDir>,
27}
28
29impl LocalStorage {
30 pub async fn new(db_path: &str) -> Result<Self, StorageError> {
32 let (resolved_path, temp_dir) = if db_path == ":memory:" {
33 let temp_dir = tempfile::tempdir().map_err(|e| {
36 StorageError::Connection(format!("Failed to create temp directory: {}", e))
37 })?;
38 (temp_dir.path().join("local-storage.db"), Some(temp_dir))
39 } else {
40 (Path::new(db_path).to_path_buf(), None)
41 };
42
43 if let Some(parent) = resolved_path.parent() {
45 std::fs::create_dir_all(parent).map_err(|e| {
46 StorageError::Connection(format!("Failed to create database directory: {}", e))
47 })?;
48 }
49
50 let db = libsql::Builder::new_local(&resolved_path)
51 .build()
52 .await
53 .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
54
55 let storage = Self {
56 db,
57 _temp_dir: temp_dir,
58 };
59 storage.configure_database_pragmas().await?;
60 storage.init_schema().await?;
61
62 Ok(storage)
63 }
64
65 pub async fn from_db_and_connection(
70 db: Database,
71 _conn: Connection,
72 ) -> Result<Self, StorageError> {
73 let storage = Self {
74 db,
75 _temp_dir: None,
76 };
77 storage.configure_database_pragmas().await?;
78 storage.init_schema().await?;
79 Ok(storage)
80 }
81
82 async fn configure_database_pragmas(&self) -> Result<(), StorageError> {
87 let conn = self.connect_raw()?;
88 stakpak_shared::sqlite::apply_database_pragmas(&conn)
89 .await
90 .map_err(|e| StorageError::Connection(e.to_string()))?;
91 Ok(())
92 }
93
94 #[deprecated(note = "use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)")]
99 pub async fn from_connection(_conn: Connection) -> Result<Self, StorageError> {
100 Err(StorageError::Connection(
101 "LocalStorage::from_connection is unsupported without the owning libsql::Database; use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)".to_string(),
102 ))
103 }
104
105 async fn init_schema(&self) -> Result<(), StorageError> {
107 let conn = self.connection().await?;
108 super::migrations::run_migrations(&conn)
109 .await
110 .map_err(StorageError::Internal)
111 }
112
113 pub(crate) fn connect_raw(&self) -> Result<Connection, StorageError> {
115 self.db
116 .connect()
117 .map_err(|e| StorageError::Connection(format!("Failed to connect to database: {}", e)))
118 }
119
120 pub(crate) async fn connection(&self) -> Result<Connection, StorageError> {
124 let conn = self.connect_raw()?;
125 stakpak_shared::sqlite::apply_connection_pragmas(&conn)
126 .await
127 .map_err(|e| StorageError::Connection(e.to_string()))?;
128 Ok(conn)
129 }
130
131 async fn get_latest_checkpoint_for_session_inner(
133 conn: &Connection,
134 session_id: Uuid,
135 ) -> Result<Checkpoint, StorageError> {
136 let mut rows = conn
137 .query(
138 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
139 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
140 [session_id.to_string()],
141 )
142 .await
143 .map_err(|e| StorageError::Internal(e.to_string()))?;
144
145 if let Ok(Some(row)) = rows.next().await {
146 let id: String = row
147 .get(0)
148 .map_err(|e| StorageError::Internal(e.to_string()))?;
149 let session_id: String = row
150 .get(1)
151 .map_err(|e| StorageError::Internal(e.to_string()))?;
152 let parent_id: Option<String> = row.get(2).ok();
153 let state: Option<String> = row.get(3).ok();
154 let created_at: String = row
155 .get(4)
156 .map_err(|e| StorageError::Internal(e.to_string()))?;
157 let updated_at: String = row
158 .get(5)
159 .map_err(|e| StorageError::Internal(e.to_string()))?;
160
161 let state: CheckpointState = if let Some(state_str) = state {
162 serde_json::from_str(&state_str).unwrap_or_default()
163 } else {
164 CheckpointState::default()
165 };
166
167 Ok(Checkpoint {
168 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
169 session_id: Uuid::from_str(&session_id)
170 .map_err(|e| StorageError::Internal(e.to_string()))?,
171 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
172 state,
173 created_at: parse_datetime(&created_at)?,
174 updated_at: parse_datetime(&updated_at)?,
175 })
176 } else {
177 Err(StorageError::NotFound(format!(
178 "No checkpoints found for session {}",
179 session_id
180 )))
181 }
182 }
183}
184
185#[async_trait]
186impl SessionStorage for LocalStorage {
187 async fn list_sessions(
188 &self,
189 query: &ListSessionsQuery,
190 ) -> Result<ListSessionsResult, StorageError> {
191 let limit = query.limit.unwrap_or(100);
192 let offset = query.offset.unwrap_or(0);
193
194 let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at,
195 (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
196 (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
197 FROM sessions s WHERE 1=1".to_string();
198
199 if let Some(status) = &query.status {
204 sql.push_str(&format!(" AND s.status = '{}'", status));
205 }
206 if let Some(visibility) = &query.visibility {
207 sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
208 }
209 if query.search.is_some() {
210 sql.push_str(" AND s.title LIKE '%' || ? || '%'");
211 }
212
213 sql.push_str(&format!(
214 " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
215 limit, offset
216 ));
217
218 let conn = self.connection().await?;
219 let mut rows = if let Some(search) = &query.search {
220 conn.query(&sql, [search.as_str()])
221 .await
222 .map_err(|e| StorageError::Internal(e.to_string()))?
223 } else {
224 conn.query(&sql, ())
225 .await
226 .map_err(|e| StorageError::Internal(e.to_string()))?
227 };
228
229 let mut sessions = Vec::new();
230 while let Ok(Some(row)) = rows.next().await {
231 let id: String = row
232 .get(0)
233 .map_err(|e| StorageError::Internal(e.to_string()))?;
234 let title: String = row
235 .get(1)
236 .map_err(|e| StorageError::Internal(e.to_string()))?;
237 let visibility: String = row
238 .get(2)
239 .map_err(|e| StorageError::Internal(e.to_string()))?;
240 let status: String = row
241 .get(3)
242 .map_err(|e| StorageError::Internal(e.to_string()))?;
243 let cwd: Option<String> = row.get(4).ok();
244 let created_at: String = row
245 .get(5)
246 .map_err(|e| StorageError::Internal(e.to_string()))?;
247 let updated_at: String = row
248 .get(6)
249 .map_err(|e| StorageError::Internal(e.to_string()))?;
250 let checkpoint_count: i64 = row.get(7).unwrap_or(0);
251 let active_checkpoint_id: Option<String> = row.get(8).ok();
252
253 sessions.push(SessionSummary {
254 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
255 title,
256 visibility: parse_visibility(&visibility),
257 status: parse_status(&status),
258 cwd,
259 created_at: parse_datetime(&created_at)?,
260 updated_at: parse_datetime(&updated_at)?,
261 message_count: checkpoint_count as u32,
262 active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
263 last_message_at: None,
264 });
265 }
266
267 Ok(ListSessionsResult {
268 sessions,
269 total: None,
270 })
271 }
272
273 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
274 let conn = self.connection().await?;
275 let mut rows = conn
276 .query(
277 "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
278 [session_id.to_string()],
279 )
280 .await
281 .map_err(|e| StorageError::Internal(e.to_string()))?;
282
283 if let Ok(Some(row)) = rows.next().await {
284 let id: String = row
285 .get(0)
286 .map_err(|e| StorageError::Internal(e.to_string()))?;
287 let title: String = row
288 .get(1)
289 .map_err(|e| StorageError::Internal(e.to_string()))?;
290 let visibility: String = row
291 .get(2)
292 .map_err(|e| StorageError::Internal(e.to_string()))?;
293 let status: String = row
294 .get(3)
295 .map_err(|e| StorageError::Internal(e.to_string()))?;
296 let cwd: Option<String> = row.get(4).ok();
297 let created_at: String = row
298 .get(5)
299 .map_err(|e| StorageError::Internal(e.to_string()))?;
300 let updated_at: String = row
301 .get(6)
302 .map_err(|e| StorageError::Internal(e.to_string()))?;
303
304 let active_checkpoint =
306 Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
307 .await
308 .ok();
309
310 Ok(Session {
311 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
312 title,
313 visibility: parse_visibility(&visibility),
314 status: parse_status(&status),
315 cwd,
316 created_at: parse_datetime(&created_at)?,
317 updated_at: parse_datetime(&updated_at)?,
318 active_checkpoint,
319 })
320 } else {
321 Err(StorageError::NotFound(format!(
322 "Session {} not found",
323 session_id
324 )))
325 }
326 }
327
328 async fn create_session(
329 &self,
330 request: &CreateSessionRequest,
331 ) -> Result<CreateSessionResult, StorageError> {
332 let now = Utc::now();
333 let session_id = Uuid::new_v4();
334 let checkpoint_id = Uuid::new_v4();
335
336 let conn = self.connection().await?;
337
338 conn.execute(
340 "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
341 (
342 session_id.to_string(),
343 request.title.as_str(),
344 request.visibility.to_string(),
345 "ACTIVE",
346 request.cwd.as_deref(),
347 now.to_rfc3339(),
348 now.to_rfc3339(),
349 ),
350 )
351 .await
352 .map_err(|e| StorageError::Internal(e.to_string()))?;
353
354 let state_json = serde_json::to_string(&request.initial_state)
356 .map_err(|e| StorageError::Internal(e.to_string()))?;
357
358 conn.execute(
359 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
360 (
361 checkpoint_id.to_string(),
362 session_id.to_string(),
363 None::<String>,
364 state_json,
365 now.to_rfc3339(),
366 now.to_rfc3339(),
367 ),
368 )
369 .await
370 .map_err(|e| StorageError::Internal(e.to_string()))?;
371
372 Ok(CreateSessionResult {
373 session_id,
374 checkpoint: Checkpoint {
375 id: checkpoint_id,
376 session_id,
377 parent_id: None,
378 state: request.initial_state.clone(),
379 created_at: now,
380 updated_at: now,
381 },
382 })
383 }
384
385 async fn update_session(
386 &self,
387 session_id: Uuid,
388 request: &UpdateSessionRequest,
389 ) -> Result<Session, StorageError> {
390 let now = Utc::now();
391
392 {
393 let conn = self.connection().await?;
394
395 if let Some(title) = &request.title {
397 conn.execute(
398 "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
399 (title.as_str(), now.to_rfc3339(), session_id.to_string()),
400 )
401 .await
402 .map_err(|e| StorageError::Internal(e.to_string()))?;
403 }
404 if let Some(visibility) = &request.visibility {
405 conn.execute(
406 "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
407 (
408 visibility.to_string(),
409 now.to_rfc3339(),
410 session_id.to_string(),
411 ),
412 )
413 .await
414 .map_err(|e| StorageError::Internal(e.to_string()))?;
415 }
416 }
417
418 self.get_session(session_id).await
419 }
420
421 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
422 let now = Utc::now();
424 let conn = self.connection().await?;
425 conn.execute(
426 "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
427 (now.to_rfc3339(), session_id.to_string()),
428 )
429 .await
430 .map_err(|e| StorageError::Internal(e.to_string()))?;
431 Ok(())
432 }
433
434 async fn list_checkpoints(
435 &self,
436 session_id: Uuid,
437 query: &ListCheckpointsQuery,
438 ) -> Result<ListCheckpointsResult, StorageError> {
439 let limit = query.limit.unwrap_or(100);
440 let offset = query.offset.unwrap_or(0);
441
442 let sql = format!(
443 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints
444 WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
445 limit, offset
446 );
447
448 let conn = self.connection().await?;
449 let mut rows = conn
450 .query(&sql, [session_id.to_string()])
451 .await
452 .map_err(|e| StorageError::Internal(e.to_string()))?;
453
454 let mut checkpoints = Vec::new();
455 while let Ok(Some(row)) = rows.next().await {
456 let id: String = row
457 .get(0)
458 .map_err(|e| StorageError::Internal(e.to_string()))?;
459 let session_id: String = row
460 .get(1)
461 .map_err(|e| StorageError::Internal(e.to_string()))?;
462 let parent_id: Option<String> = row.get(2).ok();
463 let state: Option<String> = row.get(3).ok();
464 let created_at: String = row
465 .get(4)
466 .map_err(|e| StorageError::Internal(e.to_string()))?;
467 let updated_at: String = row
468 .get(5)
469 .map_err(|e| StorageError::Internal(e.to_string()))?;
470
471 let state: CheckpointState = if let Some(state_str) = state {
472 serde_json::from_str(&state_str).unwrap_or_default()
473 } else {
474 CheckpointState::default()
475 };
476
477 checkpoints.push(CheckpointSummary {
478 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
479 session_id: Uuid::from_str(&session_id)
480 .map_err(|e| StorageError::Internal(e.to_string()))?,
481 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
482 message_count: state.messages.len() as u32,
483 created_at: parse_datetime(&created_at)?,
484 updated_at: parse_datetime(&updated_at)?,
485 });
486 }
487
488 Ok(ListCheckpointsResult {
489 checkpoints,
490 total: None,
491 })
492 }
493
494 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
495 let conn = self.connection().await?;
496 let mut rows = conn
497 .query(
498 "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
499 [checkpoint_id.to_string()],
500 )
501 .await
502 .map_err(|e| StorageError::Internal(e.to_string()))?;
503
504 if let Ok(Some(row)) = rows.next().await {
505 let id: String = row
506 .get(0)
507 .map_err(|e| StorageError::Internal(e.to_string()))?;
508 let session_id: String = row
509 .get(1)
510 .map_err(|e| StorageError::Internal(e.to_string()))?;
511 let parent_id: Option<String> = row.get(2).ok();
512 let state: Option<String> = row.get(3).ok();
513 let created_at: String = row
514 .get(4)
515 .map_err(|e| StorageError::Internal(e.to_string()))?;
516 let updated_at: String = row
517 .get(5)
518 .map_err(|e| StorageError::Internal(e.to_string()))?;
519
520 let state: CheckpointState = if let Some(state_str) = state {
521 serde_json::from_str(&state_str).unwrap_or_default()
522 } else {
523 CheckpointState::default()
524 };
525
526 Ok(Checkpoint {
527 id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
528 session_id: Uuid::from_str(&session_id)
529 .map_err(|e| StorageError::Internal(e.to_string()))?,
530 parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
531 state,
532 created_at: parse_datetime(&created_at)?,
533 updated_at: parse_datetime(&updated_at)?,
534 })
535 } else {
536 Err(StorageError::NotFound(format!(
537 "Checkpoint {} not found",
538 checkpoint_id
539 )))
540 }
541 }
542
543 async fn create_checkpoint(
544 &self,
545 session_id: Uuid,
546 request: &CreateCheckpointRequest,
547 ) -> Result<Checkpoint, StorageError> {
548 let now = Utc::now();
549 let checkpoint_id = Uuid::new_v4();
550
551 let state_json = serde_json::to_string(&request.state)
552 .map_err(|e| StorageError::Internal(e.to_string()))?;
553
554 let conn = self.connection().await?;
555
556 conn.execute(
557 "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
558 (
559 checkpoint_id.to_string(),
560 session_id.to_string(),
561 request.parent_id.map(|id| id.to_string()),
562 state_json,
563 now.to_rfc3339(),
564 now.to_rfc3339(),
565 ),
566 )
567 .await
568 .map_err(|e| StorageError::Internal(e.to_string()))?;
569
570 conn.execute(
572 "UPDATE sessions SET updated_at = ? WHERE id = ?",
573 (now.to_rfc3339(), session_id.to_string()),
574 )
575 .await
576 .map_err(|e| StorageError::Internal(e.to_string()))?;
577
578 Ok(Checkpoint {
579 id: checkpoint_id,
580 session_id,
581 parent_id: request.parent_id,
582 state: request.state.clone(),
583 created_at: now,
584 updated_at: now,
585 })
586 }
587}
588
589fn parse_visibility(s: &str) -> SessionVisibility {
591 match s.to_uppercase().as_str() {
592 "PUBLIC" => SessionVisibility::Public,
593 _ => SessionVisibility::Private,
594 }
595}
596
597fn parse_status(s: &str) -> SessionStatus {
598 match s.to_uppercase().as_str() {
599 "DELETED" => SessionStatus::Deleted,
600 _ => SessionStatus::Active,
601 }
602}
603
604fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
605 DateTime::parse_from_rfc3339(s)
606 .map(|dt| dt.with_timezone(&Utc))
607 .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
608}