Skip to main content

stakpak_api/local/
storage.rs

1//! Local SQLite storage implementation
2//!
3//! Implements SessionStorage using local SQLite database.
4
5use crate::storage::{
6    Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7    CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8    ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9    StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::{Connection, Database};
14use std::path::Path;
15use std::str::FromStr;
16use tempfile::TempDir;
17use uuid::Uuid;
18
19/// Local SQLite storage implementation
20///
21/// Uses a shared `Database` handle and opens a fresh `Connection` per
22/// operation to avoid libsql connection-sharing hazards under concurrency.
23pub struct LocalStorage {
24    db: Database,
25    /// Owns temporary backing storage for in-memory mode and cleans it on drop.
26    _temp_dir: Option<TempDir>,
27}
28
29impl LocalStorage {
30    /// Create a new local storage instance
31    pub async fn new(db_path: &str) -> Result<Self, StorageError> {
32        let (resolved_path, temp_dir) = if db_path == ":memory:" {
33            // libsql in-memory databases are connection-scoped; use a temporary
34            // directory and clean it automatically when storage is dropped.
35            let temp_dir = tempfile::tempdir().map_err(|e| {
36                StorageError::Connection(format!("Failed to create temp directory: {}", e))
37            })?;
38            (temp_dir.path().join("local-storage.db"), Some(temp_dir))
39        } else {
40            (Path::new(db_path).to_path_buf(), None)
41        };
42
43        // Ensure parent directory exists.
44        if let Some(parent) = resolved_path.parent() {
45            std::fs::create_dir_all(parent).map_err(|e| {
46                StorageError::Connection(format!("Failed to create database directory: {}", e))
47            })?;
48        }
49
50        let db = libsql::Builder::new_local(&resolved_path)
51            .build()
52            .await
53            .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
54
55        let storage = Self {
56            db,
57            _temp_dir: temp_dir,
58        };
59        storage.init_schema().await?;
60
61        Ok(storage)
62    }
63
64    /// Create from an existing database + connection pair.
65    ///
66    /// The provided connection is intentionally ignored; a fresh connection is
67    /// opened per operation.
68    pub async fn from_db_and_connection(
69        db: Database,
70        _conn: Connection,
71    ) -> Result<Self, StorageError> {
72        let storage = Self {
73            db,
74            _temp_dir: None,
75        };
76        storage.init_schema().await?;
77        Ok(storage)
78    }
79
80    /// Create from an existing connection.
81    ///
82    /// Deprecated because a bare `Connection` does not guarantee that its owning
83    /// `libsql::Database` outlives the connection, which can cause crashes.
84    #[deprecated(note = "use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)")]
85    pub async fn from_connection(_conn: Connection) -> Result<Self, StorageError> {
86        Err(StorageError::Connection(
87            "LocalStorage::from_connection is unsupported without the owning libsql::Database; use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)".to_string(),
88        ))
89    }
90
91    /// Initialize database schema by running migrations
92    async fn init_schema(&self) -> Result<(), StorageError> {
93        let conn = self.connection()?;
94        super::migrations::run_migrations(&conn)
95            .await
96            .map_err(StorageError::Internal)
97    }
98
99    /// Open a fresh connection for a single storage operation.
100    pub(crate) fn connection(&self) -> Result<Connection, StorageError> {
101        self.db
102            .connect()
103            .map_err(|e| StorageError::Connection(format!("Failed to connect to database: {}", e)))
104    }
105
106    /// Get the latest checkpoint for a session using a caller-provided connection.
107    async fn get_latest_checkpoint_for_session_inner(
108        conn: &Connection,
109        session_id: Uuid,
110    ) -> Result<Checkpoint, StorageError> {
111        let mut rows = conn
112            .query(
113                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
114                 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
115                [session_id.to_string()],
116            )
117            .await
118            .map_err(|e| StorageError::Internal(e.to_string()))?;
119
120        if let Ok(Some(row)) = rows.next().await {
121            let id: String = row
122                .get(0)
123                .map_err(|e| StorageError::Internal(e.to_string()))?;
124            let session_id: String = row
125                .get(1)
126                .map_err(|e| StorageError::Internal(e.to_string()))?;
127            let parent_id: Option<String> = row.get(2).ok();
128            let state: Option<String> = row.get(3).ok();
129            let created_at: String = row
130                .get(4)
131                .map_err(|e| StorageError::Internal(e.to_string()))?;
132            let updated_at: String = row
133                .get(5)
134                .map_err(|e| StorageError::Internal(e.to_string()))?;
135
136            let state: CheckpointState = if let Some(state_str) = state {
137                serde_json::from_str(&state_str).unwrap_or_default()
138            } else {
139                CheckpointState::default()
140            };
141
142            Ok(Checkpoint {
143                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
144                session_id: Uuid::from_str(&session_id)
145                    .map_err(|e| StorageError::Internal(e.to_string()))?,
146                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
147                state,
148                created_at: parse_datetime(&created_at)?,
149                updated_at: parse_datetime(&updated_at)?,
150            })
151        } else {
152            Err(StorageError::NotFound(format!(
153                "No checkpoints found for session {}",
154                session_id
155            )))
156        }
157    }
158}
159
160#[async_trait]
161impl SessionStorage for LocalStorage {
162    async fn list_sessions(
163        &self,
164        query: &ListSessionsQuery,
165    ) -> Result<ListSessionsResult, StorageError> {
166        let limit = query.limit.unwrap_or(100);
167        let offset = query.offset.unwrap_or(0);
168
169        let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at, 
170            (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
171            (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
172            FROM sessions s WHERE 1=1".to_string();
173
174        // Use parameterized values for enum filters (safe because they come from
175        // our own Display impls, but we keep this consistent with the rest of
176        // the codebase).  The search term is the only free-form user input and
177        // is handled with a parameter below.
178        if let Some(status) = &query.status {
179            sql.push_str(&format!(" AND s.status = '{}'", status));
180        }
181        if let Some(visibility) = &query.visibility {
182            sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
183        }
184        if query.search.is_some() {
185            sql.push_str(" AND s.title LIKE '%' || ? || '%'");
186        }
187
188        sql.push_str(&format!(
189            " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
190            limit, offset
191        ));
192
193        let conn = self.connection()?;
194        let mut rows = if let Some(search) = &query.search {
195            conn.query(&sql, [search.as_str()])
196                .await
197                .map_err(|e| StorageError::Internal(e.to_string()))?
198        } else {
199            conn.query(&sql, ())
200                .await
201                .map_err(|e| StorageError::Internal(e.to_string()))?
202        };
203
204        let mut sessions = Vec::new();
205        while let Ok(Some(row)) = rows.next().await {
206            let id: String = row
207                .get(0)
208                .map_err(|e| StorageError::Internal(e.to_string()))?;
209            let title: String = row
210                .get(1)
211                .map_err(|e| StorageError::Internal(e.to_string()))?;
212            let visibility: String = row
213                .get(2)
214                .map_err(|e| StorageError::Internal(e.to_string()))?;
215            let status: String = row
216                .get(3)
217                .map_err(|e| StorageError::Internal(e.to_string()))?;
218            let cwd: Option<String> = row.get(4).ok();
219            let created_at: String = row
220                .get(5)
221                .map_err(|e| StorageError::Internal(e.to_string()))?;
222            let updated_at: String = row
223                .get(6)
224                .map_err(|e| StorageError::Internal(e.to_string()))?;
225            let checkpoint_count: i64 = row.get(7).unwrap_or(0);
226            let active_checkpoint_id: Option<String> = row.get(8).ok();
227
228            sessions.push(SessionSummary {
229                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
230                title,
231                visibility: parse_visibility(&visibility),
232                status: parse_status(&status),
233                cwd,
234                created_at: parse_datetime(&created_at)?,
235                updated_at: parse_datetime(&updated_at)?,
236                message_count: checkpoint_count as u32,
237                active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
238                last_message_at: None,
239            });
240        }
241
242        Ok(ListSessionsResult {
243            sessions,
244            total: None,
245        })
246    }
247
248    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
249        let conn = self.connection()?;
250        let mut rows = conn
251            .query(
252                "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
253                [session_id.to_string()],
254            )
255            .await
256            .map_err(|e| StorageError::Internal(e.to_string()))?;
257
258        if let Ok(Some(row)) = rows.next().await {
259            let id: String = row
260                .get(0)
261                .map_err(|e| StorageError::Internal(e.to_string()))?;
262            let title: String = row
263                .get(1)
264                .map_err(|e| StorageError::Internal(e.to_string()))?;
265            let visibility: String = row
266                .get(2)
267                .map_err(|e| StorageError::Internal(e.to_string()))?;
268            let status: String = row
269                .get(3)
270                .map_err(|e| StorageError::Internal(e.to_string()))?;
271            let cwd: Option<String> = row.get(4).ok();
272            let created_at: String = row
273                .get(5)
274                .map_err(|e| StorageError::Internal(e.to_string()))?;
275            let updated_at: String = row
276                .get(6)
277                .map_err(|e| StorageError::Internal(e.to_string()))?;
278
279            // Get the latest checkpoint (reuse the same lock)
280            let active_checkpoint =
281                Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
282                    .await
283                    .ok();
284
285            Ok(Session {
286                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
287                title,
288                visibility: parse_visibility(&visibility),
289                status: parse_status(&status),
290                cwd,
291                created_at: parse_datetime(&created_at)?,
292                updated_at: parse_datetime(&updated_at)?,
293                active_checkpoint,
294            })
295        } else {
296            Err(StorageError::NotFound(format!(
297                "Session {} not found",
298                session_id
299            )))
300        }
301    }
302
303    async fn create_session(
304        &self,
305        request: &CreateSessionRequest,
306    ) -> Result<CreateSessionResult, StorageError> {
307        let now = Utc::now();
308        let session_id = Uuid::new_v4();
309        let checkpoint_id = Uuid::new_v4();
310
311        let conn = self.connection()?;
312
313        // Create session
314        conn.execute(
315            "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
316            (
317                session_id.to_string(),
318                request.title.as_str(),
319                request.visibility.to_string(),
320                "ACTIVE",
321                request.cwd.as_deref(),
322                now.to_rfc3339(),
323                now.to_rfc3339(),
324            ),
325        )
326        .await
327        .map_err(|e| StorageError::Internal(e.to_string()))?;
328
329        // Create initial checkpoint
330        let state_json = serde_json::to_string(&request.initial_state)
331            .map_err(|e| StorageError::Internal(e.to_string()))?;
332
333        conn.execute(
334            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
335            (
336                checkpoint_id.to_string(),
337                session_id.to_string(),
338                None::<String>,
339                state_json,
340                now.to_rfc3339(),
341                now.to_rfc3339(),
342            ),
343        )
344        .await
345        .map_err(|e| StorageError::Internal(e.to_string()))?;
346
347        Ok(CreateSessionResult {
348            session_id,
349            checkpoint: Checkpoint {
350                id: checkpoint_id,
351                session_id,
352                parent_id: None,
353                state: request.initial_state.clone(),
354                created_at: now,
355                updated_at: now,
356            },
357        })
358    }
359
360    async fn update_session(
361        &self,
362        session_id: Uuid,
363        request: &UpdateSessionRequest,
364    ) -> Result<Session, StorageError> {
365        let now = Utc::now();
366
367        {
368            let conn = self.connection()?;
369
370            // Update fields individually since libsql doesn't support dynamic params easily
371            if let Some(title) = &request.title {
372                conn.execute(
373                    "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
374                    (title.as_str(), now.to_rfc3339(), session_id.to_string()),
375                )
376                .await
377                .map_err(|e| StorageError::Internal(e.to_string()))?;
378            }
379            if let Some(visibility) = &request.visibility {
380                conn.execute(
381                    "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
382                    (
383                        visibility.to_string(),
384                        now.to_rfc3339(),
385                        session_id.to_string(),
386                    ),
387                )
388                .await
389                .map_err(|e| StorageError::Internal(e.to_string()))?;
390            }
391        }
392
393        self.get_session(session_id).await
394    }
395
396    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
397        // Mark as deleted instead of actually deleting
398        let now = Utc::now();
399        let conn = self.connection()?;
400        conn.execute(
401            "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
402            (now.to_rfc3339(), session_id.to_string()),
403        )
404        .await
405        .map_err(|e| StorageError::Internal(e.to_string()))?;
406        Ok(())
407    }
408
409    async fn list_checkpoints(
410        &self,
411        session_id: Uuid,
412        query: &ListCheckpointsQuery,
413    ) -> Result<ListCheckpointsResult, StorageError> {
414        let limit = query.limit.unwrap_or(100);
415        let offset = query.offset.unwrap_or(0);
416
417        let sql = format!(
418            "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
419             WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
420            limit, offset
421        );
422
423        let conn = self.connection()?;
424        let mut rows = conn
425            .query(&sql, [session_id.to_string()])
426            .await
427            .map_err(|e| StorageError::Internal(e.to_string()))?;
428
429        let mut checkpoints = Vec::new();
430        while let Ok(Some(row)) = rows.next().await {
431            let id: String = row
432                .get(0)
433                .map_err(|e| StorageError::Internal(e.to_string()))?;
434            let session_id: String = row
435                .get(1)
436                .map_err(|e| StorageError::Internal(e.to_string()))?;
437            let parent_id: Option<String> = row.get(2).ok();
438            let state: Option<String> = row.get(3).ok();
439            let created_at: String = row
440                .get(4)
441                .map_err(|e| StorageError::Internal(e.to_string()))?;
442            let updated_at: String = row
443                .get(5)
444                .map_err(|e| StorageError::Internal(e.to_string()))?;
445
446            let state: CheckpointState = if let Some(state_str) = state {
447                serde_json::from_str(&state_str).unwrap_or_default()
448            } else {
449                CheckpointState::default()
450            };
451
452            checkpoints.push(CheckpointSummary {
453                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
454                session_id: Uuid::from_str(&session_id)
455                    .map_err(|e| StorageError::Internal(e.to_string()))?,
456                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
457                message_count: state.messages.len() as u32,
458                created_at: parse_datetime(&created_at)?,
459                updated_at: parse_datetime(&updated_at)?,
460            });
461        }
462
463        Ok(ListCheckpointsResult {
464            checkpoints,
465            total: None,
466        })
467    }
468
469    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
470        let conn = self.connection()?;
471        let mut rows = conn
472            .query(
473                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
474                [checkpoint_id.to_string()],
475            )
476            .await
477            .map_err(|e| StorageError::Internal(e.to_string()))?;
478
479        if let Ok(Some(row)) = rows.next().await {
480            let id: String = row
481                .get(0)
482                .map_err(|e| StorageError::Internal(e.to_string()))?;
483            let session_id: String = row
484                .get(1)
485                .map_err(|e| StorageError::Internal(e.to_string()))?;
486            let parent_id: Option<String> = row.get(2).ok();
487            let state: Option<String> = row.get(3).ok();
488            let created_at: String = row
489                .get(4)
490                .map_err(|e| StorageError::Internal(e.to_string()))?;
491            let updated_at: String = row
492                .get(5)
493                .map_err(|e| StorageError::Internal(e.to_string()))?;
494
495            let state: CheckpointState = if let Some(state_str) = state {
496                serde_json::from_str(&state_str).unwrap_or_default()
497            } else {
498                CheckpointState::default()
499            };
500
501            Ok(Checkpoint {
502                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
503                session_id: Uuid::from_str(&session_id)
504                    .map_err(|e| StorageError::Internal(e.to_string()))?,
505                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
506                state,
507                created_at: parse_datetime(&created_at)?,
508                updated_at: parse_datetime(&updated_at)?,
509            })
510        } else {
511            Err(StorageError::NotFound(format!(
512                "Checkpoint {} not found",
513                checkpoint_id
514            )))
515        }
516    }
517
518    async fn create_checkpoint(
519        &self,
520        session_id: Uuid,
521        request: &CreateCheckpointRequest,
522    ) -> Result<Checkpoint, StorageError> {
523        let now = Utc::now();
524        let checkpoint_id = Uuid::new_v4();
525
526        let state_json = serde_json::to_string(&request.state)
527            .map_err(|e| StorageError::Internal(e.to_string()))?;
528
529        let conn = self.connection()?;
530
531        conn.execute(
532            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
533            (
534                checkpoint_id.to_string(),
535                session_id.to_string(),
536                request.parent_id.map(|id| id.to_string()),
537                state_json,
538                now.to_rfc3339(),
539                now.to_rfc3339(),
540            ),
541        )
542        .await
543        .map_err(|e| StorageError::Internal(e.to_string()))?;
544
545        // Update session's updated_at
546        conn.execute(
547            "UPDATE sessions SET updated_at = ? WHERE id = ?",
548            (now.to_rfc3339(), session_id.to_string()),
549        )
550        .await
551        .map_err(|e| StorageError::Internal(e.to_string()))?;
552
553        Ok(Checkpoint {
554            id: checkpoint_id,
555            session_id,
556            parent_id: request.parent_id,
557            state: request.state.clone(),
558            created_at: now,
559            updated_at: now,
560        })
561    }
562}
563
564// Helper functions
565fn parse_visibility(s: &str) -> SessionVisibility {
566    match s.to_uppercase().as_str() {
567        "PUBLIC" => SessionVisibility::Public,
568        _ => SessionVisibility::Private,
569    }
570}
571
572fn parse_status(s: &str) -> SessionStatus {
573    match s.to_uppercase().as_str() {
574        "DELETED" => SessionStatus::Deleted,
575        _ => SessionStatus::Active,
576    }
577}
578
579fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
580    DateTime::parse_from_rfc3339(s)
581        .map(|dt| dt.with_timezone(&Utc))
582        .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
583}