Skip to main content

stakpak_api/local/
storage.rs

1//! Local SQLite storage implementation
2//!
3//! Implements SessionStorage using local SQLite database.
4
5use crate::storage::{
6    Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7    CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8    ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9    StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::{Connection, Database};
14use std::path::Path;
15use std::str::FromStr;
16use tempfile::TempDir;
17use uuid::Uuid;
18
19/// Local SQLite storage implementation
20///
21/// Uses a shared `Database` handle and opens a fresh `Connection` per
22/// operation to avoid libsql connection-sharing hazards under concurrency.
23pub struct LocalStorage {
24    db: Database,
25    /// Owns temporary backing storage for in-memory mode and cleans it on drop.
26    _temp_dir: Option<TempDir>,
27}
28
29impl LocalStorage {
30    /// Create a new local storage instance
31    pub async fn new(db_path: &str) -> Result<Self, StorageError> {
32        let (resolved_path, temp_dir) = if db_path == ":memory:" {
33            // libsql in-memory databases are connection-scoped; use a temporary
34            // directory and clean it automatically when storage is dropped.
35            let temp_dir = tempfile::tempdir().map_err(|e| {
36                StorageError::Connection(format!("Failed to create temp directory: {}", e))
37            })?;
38            (temp_dir.path().join("local-storage.db"), Some(temp_dir))
39        } else {
40            (Path::new(db_path).to_path_buf(), None)
41        };
42
43        // Ensure parent directory exists.
44        if let Some(parent) = resolved_path.parent() {
45            std::fs::create_dir_all(parent).map_err(|e| {
46                StorageError::Connection(format!("Failed to create database directory: {}", e))
47            })?;
48        }
49
50        let db = libsql::Builder::new_local(&resolved_path)
51            .build()
52            .await
53            .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
54
55        let storage = Self {
56            db,
57            _temp_dir: temp_dir,
58        };
59        storage.configure_database_pragmas().await?;
60        storage.init_schema().await?;
61
62        Ok(storage)
63    }
64
65    /// Create from an existing database + connection pair.
66    ///
67    /// The provided connection is intentionally ignored; a fresh connection is
68    /// opened per operation.
69    pub async fn from_db_and_connection(
70        db: Database,
71        _conn: Connection,
72    ) -> Result<Self, StorageError> {
73        let storage = Self {
74            db,
75            _temp_dir: None,
76        };
77        storage.configure_database_pragmas().await?;
78        storage.init_schema().await?;
79        Ok(storage)
80    }
81
82    /// Set database-level PRAGMAs that persist across connections.
83    ///
84    /// Per-connection PRAGMAs (busy_timeout, synchronous) are applied in
85    /// `connection()` on every fresh connection instead.
86    async fn configure_database_pragmas(&self) -> Result<(), StorageError> {
87        let conn = self.connect_raw()?;
88        stakpak_shared::sqlite::apply_database_pragmas(&conn)
89            .await
90            .map_err(|e| StorageError::Connection(e.to_string()))?;
91        Ok(())
92    }
93
94    /// Create from an existing connection.
95    ///
96    /// Deprecated because a bare `Connection` does not guarantee that its owning
97    /// `libsql::Database` outlives the connection, which can cause crashes.
98    #[deprecated(note = "use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)")]
99    pub async fn from_connection(_conn: Connection) -> Result<Self, StorageError> {
100        Err(StorageError::Connection(
101            "LocalStorage::from_connection is unsupported without the owning libsql::Database; use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)".to_string(),
102        ))
103    }
104
105    /// Initialize database schema by running migrations
106    async fn init_schema(&self) -> Result<(), StorageError> {
107        let conn = self.connection().await?;
108        super::migrations::run_migrations(&conn)
109            .await
110            .map_err(StorageError::Internal)
111    }
112
113    /// Open a raw connection without per-connection PRAGMAs.
114    pub(crate) fn connect_raw(&self) -> Result<Connection, StorageError> {
115        self.db
116            .connect()
117            .map_err(|e| StorageError::Connection(format!("Failed to connect to database: {}", e)))
118    }
119
120    /// Open a fresh connection with per-connection PRAGMAs applied.
121    ///
122    /// See [`stakpak_shared::sqlite::apply_connection_pragmas`] for details.
123    pub(crate) async fn connection(&self) -> Result<Connection, StorageError> {
124        let conn = self.connect_raw()?;
125        stakpak_shared::sqlite::apply_connection_pragmas(&conn)
126            .await
127            .map_err(|e| StorageError::Connection(e.to_string()))?;
128        Ok(conn)
129    }
130
131    /// Get the latest checkpoint for a session using a caller-provided connection.
132    async fn get_latest_checkpoint_for_session_inner(
133        conn: &Connection,
134        session_id: Uuid,
135    ) -> Result<Checkpoint, StorageError> {
136        let mut rows = conn
137            .query(
138                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
139                 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
140                [session_id.to_string()],
141            )
142            .await
143            .map_err(|e| StorageError::Internal(e.to_string()))?;
144
145        if let Ok(Some(row)) = rows.next().await {
146            let id: String = row
147                .get(0)
148                .map_err(|e| StorageError::Internal(e.to_string()))?;
149            let session_id: String = row
150                .get(1)
151                .map_err(|e| StorageError::Internal(e.to_string()))?;
152            let parent_id: Option<String> = row.get(2).ok();
153            let state: Option<String> = row.get(3).ok();
154            let created_at: String = row
155                .get(4)
156                .map_err(|e| StorageError::Internal(e.to_string()))?;
157            let updated_at: String = row
158                .get(5)
159                .map_err(|e| StorageError::Internal(e.to_string()))?;
160
161            let state: CheckpointState = if let Some(state_str) = state {
162                serde_json::from_str(&state_str).unwrap_or_default()
163            } else {
164                CheckpointState::default()
165            };
166
167            Ok(Checkpoint {
168                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
169                session_id: Uuid::from_str(&session_id)
170                    .map_err(|e| StorageError::Internal(e.to_string()))?,
171                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
172                state,
173                created_at: parse_datetime(&created_at)?,
174                updated_at: parse_datetime(&updated_at)?,
175            })
176        } else {
177            Err(StorageError::NotFound(format!(
178                "No checkpoints found for session {}",
179                session_id
180            )))
181        }
182    }
183}
184
185#[async_trait]
186impl SessionStorage for LocalStorage {
187    async fn list_sessions(
188        &self,
189        query: &ListSessionsQuery,
190    ) -> Result<ListSessionsResult, StorageError> {
191        let limit = query.limit.unwrap_or(100);
192        let offset = query.offset.unwrap_or(0);
193
194        let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at, 
195            (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
196            (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
197            FROM sessions s WHERE 1=1".to_string();
198
199        // Use parameterized values for enum filters (safe because they come from
200        // our own Display impls, but we keep this consistent with the rest of
201        // the codebase).  The search term is the only free-form user input and
202        // is handled with a parameter below.
203        if let Some(status) = &query.status {
204            sql.push_str(&format!(" AND s.status = '{}'", status));
205        }
206        if let Some(visibility) = &query.visibility {
207            sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
208        }
209        if query.search.is_some() {
210            sql.push_str(" AND s.title LIKE '%' || ? || '%'");
211        }
212
213        sql.push_str(&format!(
214            " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
215            limit, offset
216        ));
217
218        let conn = self.connection().await?;
219        let mut rows = if let Some(search) = &query.search {
220            conn.query(&sql, [search.as_str()])
221                .await
222                .map_err(|e| StorageError::Internal(e.to_string()))?
223        } else {
224            conn.query(&sql, ())
225                .await
226                .map_err(|e| StorageError::Internal(e.to_string()))?
227        };
228
229        let mut sessions = Vec::new();
230        while let Ok(Some(row)) = rows.next().await {
231            let id: String = row
232                .get(0)
233                .map_err(|e| StorageError::Internal(e.to_string()))?;
234            let title: String = row
235                .get(1)
236                .map_err(|e| StorageError::Internal(e.to_string()))?;
237            let visibility: String = row
238                .get(2)
239                .map_err(|e| StorageError::Internal(e.to_string()))?;
240            let status: String = row
241                .get(3)
242                .map_err(|e| StorageError::Internal(e.to_string()))?;
243            let cwd: Option<String> = row.get(4).ok();
244            let created_at: String = row
245                .get(5)
246                .map_err(|e| StorageError::Internal(e.to_string()))?;
247            let updated_at: String = row
248                .get(6)
249                .map_err(|e| StorageError::Internal(e.to_string()))?;
250            let checkpoint_count: i64 = row.get(7).unwrap_or(0);
251            let active_checkpoint_id: Option<String> = row.get(8).ok();
252
253            sessions.push(SessionSummary {
254                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
255                title,
256                visibility: parse_visibility(&visibility),
257                status: parse_status(&status),
258                cwd,
259                created_at: parse_datetime(&created_at)?,
260                updated_at: parse_datetime(&updated_at)?,
261                message_count: checkpoint_count as u32,
262                active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
263                last_message_at: None,
264            });
265        }
266
267        Ok(ListSessionsResult {
268            sessions,
269            total: None,
270        })
271    }
272
273    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
274        let conn = self.connection().await?;
275        let mut rows = conn
276            .query(
277                "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
278                [session_id.to_string()],
279            )
280            .await
281            .map_err(|e| StorageError::Internal(e.to_string()))?;
282
283        if let Ok(Some(row)) = rows.next().await {
284            let id: String = row
285                .get(0)
286                .map_err(|e| StorageError::Internal(e.to_string()))?;
287            let title: String = row
288                .get(1)
289                .map_err(|e| StorageError::Internal(e.to_string()))?;
290            let visibility: String = row
291                .get(2)
292                .map_err(|e| StorageError::Internal(e.to_string()))?;
293            let status: String = row
294                .get(3)
295                .map_err(|e| StorageError::Internal(e.to_string()))?;
296            let cwd: Option<String> = row.get(4).ok();
297            let created_at: String = row
298                .get(5)
299                .map_err(|e| StorageError::Internal(e.to_string()))?;
300            let updated_at: String = row
301                .get(6)
302                .map_err(|e| StorageError::Internal(e.to_string()))?;
303
304            // Get the latest checkpoint (reuse the same lock)
305            let active_checkpoint =
306                Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
307                    .await
308                    .ok();
309
310            Ok(Session {
311                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
312                title,
313                visibility: parse_visibility(&visibility),
314                status: parse_status(&status),
315                cwd,
316                created_at: parse_datetime(&created_at)?,
317                updated_at: parse_datetime(&updated_at)?,
318                active_checkpoint,
319            })
320        } else {
321            Err(StorageError::NotFound(format!(
322                "Session {} not found",
323                session_id
324            )))
325        }
326    }
327
328    async fn create_session(
329        &self,
330        request: &CreateSessionRequest,
331    ) -> Result<CreateSessionResult, StorageError> {
332        let now = Utc::now();
333        let session_id = Uuid::new_v4();
334        let checkpoint_id = Uuid::new_v4();
335
336        let conn = self.connection().await?;
337
338        // Create session
339        conn.execute(
340            "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
341            (
342                session_id.to_string(),
343                request.title.as_str(),
344                request.visibility.to_string(),
345                "ACTIVE",
346                request.cwd.as_deref(),
347                now.to_rfc3339(),
348                now.to_rfc3339(),
349            ),
350        )
351        .await
352        .map_err(|e| StorageError::Internal(e.to_string()))?;
353
354        // Create initial checkpoint
355        let state_json = serde_json::to_string(&request.initial_state)
356            .map_err(|e| StorageError::Internal(e.to_string()))?;
357
358        conn.execute(
359            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
360            (
361                checkpoint_id.to_string(),
362                session_id.to_string(),
363                None::<String>,
364                state_json,
365                now.to_rfc3339(),
366                now.to_rfc3339(),
367            ),
368        )
369        .await
370        .map_err(|e| StorageError::Internal(e.to_string()))?;
371
372        Ok(CreateSessionResult {
373            session_id,
374            checkpoint: Checkpoint {
375                id: checkpoint_id,
376                session_id,
377                parent_id: None,
378                state: request.initial_state.clone(),
379                created_at: now,
380                updated_at: now,
381            },
382        })
383    }
384
385    async fn update_session(
386        &self,
387        session_id: Uuid,
388        request: &UpdateSessionRequest,
389    ) -> Result<Session, StorageError> {
390        let now = Utc::now();
391
392        {
393            let conn = self.connection().await?;
394
395            // Update fields individually since libsql doesn't support dynamic params easily
396            if let Some(title) = &request.title {
397                conn.execute(
398                    "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
399                    (title.as_str(), now.to_rfc3339(), session_id.to_string()),
400                )
401                .await
402                .map_err(|e| StorageError::Internal(e.to_string()))?;
403            }
404            if let Some(visibility) = &request.visibility {
405                conn.execute(
406                    "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
407                    (
408                        visibility.to_string(),
409                        now.to_rfc3339(),
410                        session_id.to_string(),
411                    ),
412                )
413                .await
414                .map_err(|e| StorageError::Internal(e.to_string()))?;
415            }
416        }
417
418        self.get_session(session_id).await
419    }
420
421    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
422        // Mark as deleted instead of actually deleting
423        let now = Utc::now();
424        let conn = self.connection().await?;
425        conn.execute(
426            "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
427            (now.to_rfc3339(), session_id.to_string()),
428        )
429        .await
430        .map_err(|e| StorageError::Internal(e.to_string()))?;
431        Ok(())
432    }
433
434    async fn list_checkpoints(
435        &self,
436        session_id: Uuid,
437        query: &ListCheckpointsQuery,
438    ) -> Result<ListCheckpointsResult, StorageError> {
439        let limit = query.limit.unwrap_or(100);
440        let offset = query.offset.unwrap_or(0);
441
442        let sql = format!(
443            "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
444             WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
445            limit, offset
446        );
447
448        let conn = self.connection().await?;
449        let mut rows = conn
450            .query(&sql, [session_id.to_string()])
451            .await
452            .map_err(|e| StorageError::Internal(e.to_string()))?;
453
454        let mut checkpoints = Vec::new();
455        while let Ok(Some(row)) = rows.next().await {
456            let id: String = row
457                .get(0)
458                .map_err(|e| StorageError::Internal(e.to_string()))?;
459            let session_id: String = row
460                .get(1)
461                .map_err(|e| StorageError::Internal(e.to_string()))?;
462            let parent_id: Option<String> = row.get(2).ok();
463            let state: Option<String> = row.get(3).ok();
464            let created_at: String = row
465                .get(4)
466                .map_err(|e| StorageError::Internal(e.to_string()))?;
467            let updated_at: String = row
468                .get(5)
469                .map_err(|e| StorageError::Internal(e.to_string()))?;
470
471            let state: CheckpointState = if let Some(state_str) = state {
472                serde_json::from_str(&state_str).unwrap_or_default()
473            } else {
474                CheckpointState::default()
475            };
476
477            checkpoints.push(CheckpointSummary {
478                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
479                session_id: Uuid::from_str(&session_id)
480                    .map_err(|e| StorageError::Internal(e.to_string()))?,
481                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
482                message_count: state.messages.len() as u32,
483                created_at: parse_datetime(&created_at)?,
484                updated_at: parse_datetime(&updated_at)?,
485            });
486        }
487
488        Ok(ListCheckpointsResult {
489            checkpoints,
490            total: None,
491        })
492    }
493
494    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
495        let conn = self.connection().await?;
496        let mut rows = conn
497            .query(
498                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
499                [checkpoint_id.to_string()],
500            )
501            .await
502            .map_err(|e| StorageError::Internal(e.to_string()))?;
503
504        if let Ok(Some(row)) = rows.next().await {
505            let id: String = row
506                .get(0)
507                .map_err(|e| StorageError::Internal(e.to_string()))?;
508            let session_id: String = row
509                .get(1)
510                .map_err(|e| StorageError::Internal(e.to_string()))?;
511            let parent_id: Option<String> = row.get(2).ok();
512            let state: Option<String> = row.get(3).ok();
513            let created_at: String = row
514                .get(4)
515                .map_err(|e| StorageError::Internal(e.to_string()))?;
516            let updated_at: String = row
517                .get(5)
518                .map_err(|e| StorageError::Internal(e.to_string()))?;
519
520            let state: CheckpointState = if let Some(state_str) = state {
521                serde_json::from_str(&state_str).unwrap_or_default()
522            } else {
523                CheckpointState::default()
524            };
525
526            Ok(Checkpoint {
527                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
528                session_id: Uuid::from_str(&session_id)
529                    .map_err(|e| StorageError::Internal(e.to_string()))?,
530                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
531                state,
532                created_at: parse_datetime(&created_at)?,
533                updated_at: parse_datetime(&updated_at)?,
534            })
535        } else {
536            Err(StorageError::NotFound(format!(
537                "Checkpoint {} not found",
538                checkpoint_id
539            )))
540        }
541    }
542
543    async fn create_checkpoint(
544        &self,
545        session_id: Uuid,
546        request: &CreateCheckpointRequest,
547    ) -> Result<Checkpoint, StorageError> {
548        let now = Utc::now();
549        let checkpoint_id = Uuid::new_v4();
550
551        let state_json = serde_json::to_string(&request.state)
552            .map_err(|e| StorageError::Internal(e.to_string()))?;
553
554        let conn = self.connection().await?;
555
556        conn.execute(
557            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
558            (
559                checkpoint_id.to_string(),
560                session_id.to_string(),
561                request.parent_id.map(|id| id.to_string()),
562                state_json,
563                now.to_rfc3339(),
564                now.to_rfc3339(),
565            ),
566        )
567        .await
568        .map_err(|e| StorageError::Internal(e.to_string()))?;
569
570        // Update session's updated_at
571        conn.execute(
572            "UPDATE sessions SET updated_at = ? WHERE id = ?",
573            (now.to_rfc3339(), session_id.to_string()),
574        )
575        .await
576        .map_err(|e| StorageError::Internal(e.to_string()))?;
577
578        Ok(Checkpoint {
579            id: checkpoint_id,
580            session_id,
581            parent_id: request.parent_id,
582            state: request.state.clone(),
583            created_at: now,
584            updated_at: now,
585        })
586    }
587}
588
589// Helper functions
590fn parse_visibility(s: &str) -> SessionVisibility {
591    match s.to_uppercase().as_str() {
592        "PUBLIC" => SessionVisibility::Public,
593        _ => SessionVisibility::Private,
594    }
595}
596
597fn parse_status(s: &str) -> SessionStatus {
598    match s.to_uppercase().as_str() {
599        "DELETED" => SessionStatus::Deleted,
600        _ => SessionStatus::Active,
601    }
602}
603
604fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
605    DateTime::parse_from_rfc3339(s)
606        .map(|dt| dt.with_timezone(&Utc))
607        .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
608}