Skip to main content

stakpak_api/local/
storage.rs

1//! Local SQLite storage implementation
2//!
3//! Implements SessionStorage using local SQLite database.
4
5use crate::storage::{
6    BackendInfo, Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest,
7    CreateSessionRequest, CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult,
8    ListSessionsQuery, ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary,
9    SessionVisibility, StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::{Connection, Database};
14use std::path::Path;
15use std::str::FromStr;
16use tempfile::TempDir;
17use uuid::Uuid;
18
19/// Local SQLite storage implementation
20///
21/// Uses a shared `Database` handle and opens a fresh `Connection` per
22/// operation to avoid libsql connection-sharing hazards under concurrency.
23pub struct LocalStorage {
24    db: Database,
25    backend_info: BackendInfo,
26    /// Owns temporary backing storage for in-memory mode and cleans it on drop.
27    _temp_dir: Option<TempDir>,
28}
29
30impl LocalStorage {
31    /// Create a new local storage instance
32    pub async fn new(db_path: &str) -> Result<Self, StorageError> {
33        let (resolved_path, temp_dir) = if db_path == ":memory:" {
34            // libsql in-memory databases are connection-scoped; use a temporary
35            // directory and clean it automatically when storage is dropped.
36            let temp_dir = tempfile::tempdir().map_err(|e| {
37                StorageError::Connection(format!("Failed to create temp directory: {}", e))
38            })?;
39            (temp_dir.path().join("local-storage.db"), Some(temp_dir))
40        } else {
41            (Path::new(db_path).to_path_buf(), None)
42        };
43
44        // Ensure parent directory exists.
45        if let Some(parent) = resolved_path.parent() {
46            std::fs::create_dir_all(parent).map_err(|e| {
47                StorageError::Connection(format!("Failed to create database directory: {}", e))
48            })?;
49        }
50
51        let db = libsql::Builder::new_local(&resolved_path)
52            .build()
53            .await
54            .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
55
56        let storage = Self {
57            db,
58            backend_info: BackendInfo::local(db_path.to_string()),
59            _temp_dir: temp_dir,
60        };
61        storage.configure_database_pragmas().await?;
62        storage.init_schema().await?;
63
64        Ok(storage)
65    }
66
67    /// Create from an existing database + connection pair.
68    ///
69    /// The provided connection is intentionally ignored; a fresh connection is
70    /// opened per operation.
71    pub async fn from_db_and_connection(
72        db: Database,
73        _conn: Connection,
74    ) -> Result<Self, StorageError> {
75        let storage = Self {
76            db,
77            backend_info: BackendInfo::local(":memory:"),
78            _temp_dir: None,
79        };
80        storage.configure_database_pragmas().await?;
81        storage.init_schema().await?;
82        Ok(storage)
83    }
84
85    /// Set database-level PRAGMAs that persist across connections.
86    ///
87    /// Per-connection PRAGMAs (busy_timeout, synchronous) are applied in
88    /// `connection()` on every fresh connection instead.
89    async fn configure_database_pragmas(&self) -> Result<(), StorageError> {
90        let conn = self.connect_raw()?;
91        stakpak_shared::sqlite::apply_database_pragmas(&conn)
92            .await
93            .map_err(|e| StorageError::Connection(e.to_string()))?;
94        Ok(())
95    }
96
97    /// Create from an existing connection.
98    ///
99    /// Deprecated because a bare `Connection` does not guarantee that its owning
100    /// `libsql::Database` outlives the connection, which can cause crashes.
101    #[deprecated(note = "use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)")]
102    pub async fn from_connection(_conn: Connection) -> Result<Self, StorageError> {
103        Err(StorageError::Connection(
104            "LocalStorage::from_connection is unsupported without the owning libsql::Database; use LocalStorage::new(...) or LocalStorage::from_db_and_connection(...)".to_string(),
105        ))
106    }
107
108    /// Initialize database schema by running migrations
109    async fn init_schema(&self) -> Result<(), StorageError> {
110        let conn = self.connection().await?;
111        super::migrations::run_migrations(&conn)
112            .await
113            .map_err(StorageError::Internal)
114    }
115
116    /// Open a raw connection without per-connection PRAGMAs.
117    pub(crate) fn connect_raw(&self) -> Result<Connection, StorageError> {
118        self.db
119            .connect()
120            .map_err(|e| StorageError::Connection(format!("Failed to connect to database: {}", e)))
121    }
122
123    /// Open a fresh connection with per-connection PRAGMAs applied.
124    ///
125    /// See [`stakpak_shared::sqlite::apply_connection_pragmas`] for details.
126    pub(crate) async fn connection(&self) -> Result<Connection, StorageError> {
127        let conn = self.connect_raw()?;
128        stakpak_shared::sqlite::apply_connection_pragmas(&conn)
129            .await
130            .map_err(|e| StorageError::Connection(e.to_string()))?;
131        Ok(conn)
132    }
133
134    /// Get the latest checkpoint for a session using a caller-provided connection.
135    async fn get_latest_checkpoint_for_session_inner(
136        conn: &Connection,
137        session_id: Uuid,
138    ) -> Result<Checkpoint, StorageError> {
139        let mut rows = conn
140            .query(
141                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
142                 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
143                [session_id.to_string()],
144            )
145            .await
146            .map_err(|e| StorageError::Internal(e.to_string()))?;
147
148        if let Ok(Some(row)) = rows.next().await {
149            let id: String = row
150                .get(0)
151                .map_err(|e| StorageError::Internal(e.to_string()))?;
152            let session_id: String = row
153                .get(1)
154                .map_err(|e| StorageError::Internal(e.to_string()))?;
155            let parent_id: Option<String> = row.get(2).ok();
156            let state: Option<String> = row.get(3).ok();
157            let created_at: String = row
158                .get(4)
159                .map_err(|e| StorageError::Internal(e.to_string()))?;
160            let updated_at: String = row
161                .get(5)
162                .map_err(|e| StorageError::Internal(e.to_string()))?;
163
164            let state: CheckpointState = if let Some(state_str) = state {
165                serde_json::from_str(&state_str).unwrap_or_default()
166            } else {
167                CheckpointState::default()
168            };
169
170            Ok(Checkpoint {
171                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
172                session_id: Uuid::from_str(&session_id)
173                    .map_err(|e| StorageError::Internal(e.to_string()))?,
174                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
175                state,
176                created_at: parse_datetime(&created_at)?,
177                updated_at: parse_datetime(&updated_at)?,
178            })
179        } else {
180            Err(StorageError::NotFound(format!(
181                "No checkpoints found for session {}",
182                session_id
183            )))
184        }
185    }
186}
187
188#[async_trait]
189impl SessionStorage for LocalStorage {
190    fn backend_info(&self) -> BackendInfo {
191        self.backend_info.clone()
192    }
193
194    async fn list_sessions(
195        &self,
196        query: &ListSessionsQuery,
197    ) -> Result<ListSessionsResult, StorageError> {
198        let limit = query.limit.unwrap_or(100);
199        let offset = query.offset.unwrap_or(0);
200
201        // `message_count` is derived from the latest checkpoint's state JSON using
202        // `json_array_length($.messages)`, so it reflects actual messages rather than
203        // checkpoint revisions (which are not user-facing).
204        let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at,
205            COALESCE((
206                SELECT json_array_length(c.state, '$.messages')
207                FROM checkpoints c
208                WHERE c.session_id = s.id
209                ORDER BY c.created_at DESC
210                LIMIT 1
211            ), 0) as message_count,
212            (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY c.created_at DESC LIMIT 1) as active_checkpoint_id
213            FROM sessions s WHERE 1=1".to_string();
214
215        // Use parameterized values for enum filters (safe because they come from
216        // our own Display impls, but we keep this consistent with the rest of
217        // the codebase).  The search term is the only free-form user input and
218        // is handled with a parameter below.
219        if let Some(status) = &query.status {
220            sql.push_str(&format!(" AND s.status = '{}'", status));
221        }
222        if let Some(visibility) = &query.visibility {
223            sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
224        }
225        if query.search.is_some() {
226            sql.push_str(" AND s.title LIKE '%' || ? || '%'");
227        }
228
229        sql.push_str(&format!(
230            " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
231            limit, offset
232        ));
233
234        let conn = self.connection().await?;
235        let mut rows = if let Some(search) = &query.search {
236            conn.query(&sql, [search.as_str()])
237                .await
238                .map_err(|e| StorageError::Internal(e.to_string()))?
239        } else {
240            conn.query(&sql, ())
241                .await
242                .map_err(|e| StorageError::Internal(e.to_string()))?
243        };
244
245        let mut sessions = Vec::new();
246        while let Ok(Some(row)) = rows.next().await {
247            let id: String = row
248                .get(0)
249                .map_err(|e| StorageError::Internal(e.to_string()))?;
250            let title: String = row
251                .get(1)
252                .map_err(|e| StorageError::Internal(e.to_string()))?;
253            let visibility: String = row
254                .get(2)
255                .map_err(|e| StorageError::Internal(e.to_string()))?;
256            let status: String = row
257                .get(3)
258                .map_err(|e| StorageError::Internal(e.to_string()))?;
259            let cwd: Option<String> = row.get(4).ok();
260            let created_at: String = row
261                .get(5)
262                .map_err(|e| StorageError::Internal(e.to_string()))?;
263            let updated_at: String = row
264                .get(6)
265                .map_err(|e| StorageError::Internal(e.to_string()))?;
266            let message_count: i64 = row.get(7).unwrap_or(0);
267            let active_checkpoint_id: Option<String> = row.get(8).ok();
268
269            sessions.push(SessionSummary {
270                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
271                title,
272                visibility: parse_visibility(&visibility),
273                status: parse_status(&status),
274                cwd,
275                created_at: parse_datetime(&created_at)?,
276                updated_at: parse_datetime(&updated_at)?,
277                message_count: message_count.max(0) as u32,
278                active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
279                last_message_at: None,
280            });
281        }
282
283        Ok(ListSessionsResult {
284            sessions,
285            total: None,
286        })
287    }
288
289    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
290        let conn = self.connection().await?;
291        let mut rows = conn
292            .query(
293                "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
294                [session_id.to_string()],
295            )
296            .await
297            .map_err(|e| StorageError::Internal(e.to_string()))?;
298
299        if let Ok(Some(row)) = rows.next().await {
300            let id: String = row
301                .get(0)
302                .map_err(|e| StorageError::Internal(e.to_string()))?;
303            let title: String = row
304                .get(1)
305                .map_err(|e| StorageError::Internal(e.to_string()))?;
306            let visibility: String = row
307                .get(2)
308                .map_err(|e| StorageError::Internal(e.to_string()))?;
309            let status: String = row
310                .get(3)
311                .map_err(|e| StorageError::Internal(e.to_string()))?;
312            let cwd: Option<String> = row.get(4).ok();
313            let created_at: String = row
314                .get(5)
315                .map_err(|e| StorageError::Internal(e.to_string()))?;
316            let updated_at: String = row
317                .get(6)
318                .map_err(|e| StorageError::Internal(e.to_string()))?;
319
320            // Get the latest checkpoint (reuse the same lock)
321            let active_checkpoint =
322                Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
323                    .await
324                    .ok();
325
326            Ok(Session {
327                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
328                title,
329                visibility: parse_visibility(&visibility),
330                status: parse_status(&status),
331                cwd,
332                created_at: parse_datetime(&created_at)?,
333                updated_at: parse_datetime(&updated_at)?,
334                active_checkpoint,
335            })
336        } else {
337            Err(StorageError::NotFound(format!(
338                "Session {} not found",
339                session_id
340            )))
341        }
342    }
343
344    async fn create_session(
345        &self,
346        request: &CreateSessionRequest,
347    ) -> Result<CreateSessionResult, StorageError> {
348        let now = Utc::now();
349        let session_id = Uuid::new_v4();
350        let checkpoint_id = Uuid::new_v4();
351
352        let conn = self.connection().await?;
353
354        // Create session
355        conn.execute(
356            "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
357            (
358                session_id.to_string(),
359                request.title.as_str(),
360                request.visibility.to_string(),
361                "ACTIVE",
362                request.cwd.as_deref(),
363                now.to_rfc3339(),
364                now.to_rfc3339(),
365            ),
366        )
367        .await
368        .map_err(|e| StorageError::Internal(e.to_string()))?;
369
370        // Create initial checkpoint
371        let state_json = serde_json::to_string(&request.initial_state)
372            .map_err(|e| StorageError::Internal(e.to_string()))?;
373
374        conn.execute(
375            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
376            (
377                checkpoint_id.to_string(),
378                session_id.to_string(),
379                None::<String>,
380                state_json,
381                now.to_rfc3339(),
382                now.to_rfc3339(),
383            ),
384        )
385        .await
386        .map_err(|e| StorageError::Internal(e.to_string()))?;
387
388        Ok(CreateSessionResult {
389            session_id,
390            checkpoint: Checkpoint {
391                id: checkpoint_id,
392                session_id,
393                parent_id: None,
394                state: request.initial_state.clone(),
395                created_at: now,
396                updated_at: now,
397            },
398        })
399    }
400
401    async fn update_session(
402        &self,
403        session_id: Uuid,
404        request: &UpdateSessionRequest,
405    ) -> Result<Session, StorageError> {
406        let now = Utc::now();
407
408        {
409            let conn = self.connection().await?;
410
411            // Update fields individually since libsql doesn't support dynamic params easily
412            if let Some(title) = &request.title {
413                conn.execute(
414                    "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
415                    (title.as_str(), now.to_rfc3339(), session_id.to_string()),
416                )
417                .await
418                .map_err(|e| StorageError::Internal(e.to_string()))?;
419            }
420            if let Some(visibility) = &request.visibility {
421                conn.execute(
422                    "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
423                    (
424                        visibility.to_string(),
425                        now.to_rfc3339(),
426                        session_id.to_string(),
427                    ),
428                )
429                .await
430                .map_err(|e| StorageError::Internal(e.to_string()))?;
431            }
432        }
433
434        self.get_session(session_id).await
435    }
436
437    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
438        // Mark as deleted instead of actually deleting
439        let now = Utc::now();
440        let conn = self.connection().await?;
441        conn.execute(
442            "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
443            (now.to_rfc3339(), session_id.to_string()),
444        )
445        .await
446        .map_err(|e| StorageError::Internal(e.to_string()))?;
447        Ok(())
448    }
449
450    async fn list_checkpoints(
451        &self,
452        session_id: Uuid,
453        query: &ListCheckpointsQuery,
454    ) -> Result<ListCheckpointsResult, StorageError> {
455        let limit = query.limit.unwrap_or(100);
456        let offset = query.offset.unwrap_or(0);
457
458        let sql = format!(
459            "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
460             WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
461            limit, offset
462        );
463
464        let conn = self.connection().await?;
465        let mut rows = conn
466            .query(&sql, [session_id.to_string()])
467            .await
468            .map_err(|e| StorageError::Internal(e.to_string()))?;
469
470        let mut checkpoints = Vec::new();
471        while let Ok(Some(row)) = rows.next().await {
472            let id: String = row
473                .get(0)
474                .map_err(|e| StorageError::Internal(e.to_string()))?;
475            let session_id: String = row
476                .get(1)
477                .map_err(|e| StorageError::Internal(e.to_string()))?;
478            let parent_id: Option<String> = row.get(2).ok();
479            let state: Option<String> = row.get(3).ok();
480            let created_at: String = row
481                .get(4)
482                .map_err(|e| StorageError::Internal(e.to_string()))?;
483            let updated_at: String = row
484                .get(5)
485                .map_err(|e| StorageError::Internal(e.to_string()))?;
486
487            let state: CheckpointState = if let Some(state_str) = state {
488                serde_json::from_str(&state_str).unwrap_or_default()
489            } else {
490                CheckpointState::default()
491            };
492
493            checkpoints.push(CheckpointSummary {
494                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
495                session_id: Uuid::from_str(&session_id)
496                    .map_err(|e| StorageError::Internal(e.to_string()))?,
497                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
498                message_count: state.messages.len() as u32,
499                created_at: parse_datetime(&created_at)?,
500                updated_at: parse_datetime(&updated_at)?,
501            });
502        }
503
504        Ok(ListCheckpointsResult {
505            checkpoints,
506            total: None,
507        })
508    }
509
510    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
511        let conn = self.connection().await?;
512        let mut rows = conn
513            .query(
514                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
515                [checkpoint_id.to_string()],
516            )
517            .await
518            .map_err(|e| StorageError::Internal(e.to_string()))?;
519
520        if let Ok(Some(row)) = rows.next().await {
521            let id: String = row
522                .get(0)
523                .map_err(|e| StorageError::Internal(e.to_string()))?;
524            let session_id: String = row
525                .get(1)
526                .map_err(|e| StorageError::Internal(e.to_string()))?;
527            let parent_id: Option<String> = row.get(2).ok();
528            let state: Option<String> = row.get(3).ok();
529            let created_at: String = row
530                .get(4)
531                .map_err(|e| StorageError::Internal(e.to_string()))?;
532            let updated_at: String = row
533                .get(5)
534                .map_err(|e| StorageError::Internal(e.to_string()))?;
535
536            let state: CheckpointState = if let Some(state_str) = state {
537                serde_json::from_str(&state_str).unwrap_or_default()
538            } else {
539                CheckpointState::default()
540            };
541
542            Ok(Checkpoint {
543                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
544                session_id: Uuid::from_str(&session_id)
545                    .map_err(|e| StorageError::Internal(e.to_string()))?,
546                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
547                state,
548                created_at: parse_datetime(&created_at)?,
549                updated_at: parse_datetime(&updated_at)?,
550            })
551        } else {
552            Err(StorageError::NotFound(format!(
553                "Checkpoint {} not found",
554                checkpoint_id
555            )))
556        }
557    }
558
559    async fn create_checkpoint(
560        &self,
561        session_id: Uuid,
562        request: &CreateCheckpointRequest,
563    ) -> Result<Checkpoint, StorageError> {
564        let now = Utc::now();
565        let checkpoint_id = Uuid::new_v4();
566
567        let state_json = serde_json::to_string(&request.state)
568            .map_err(|e| StorageError::Internal(e.to_string()))?;
569
570        let conn = self.connection().await?;
571
572        conn.execute(
573            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
574            (
575                checkpoint_id.to_string(),
576                session_id.to_string(),
577                request.parent_id.map(|id| id.to_string()),
578                state_json,
579                now.to_rfc3339(),
580                now.to_rfc3339(),
581            ),
582        )
583        .await
584        .map_err(|e| StorageError::Internal(e.to_string()))?;
585
586        // Update session's updated_at
587        conn.execute(
588            "UPDATE sessions SET updated_at = ? WHERE id = ?",
589            (now.to_rfc3339(), session_id.to_string()),
590        )
591        .await
592        .map_err(|e| StorageError::Internal(e.to_string()))?;
593
594        Ok(Checkpoint {
595            id: checkpoint_id,
596            session_id,
597            parent_id: request.parent_id,
598            state: request.state.clone(),
599            created_at: now,
600            updated_at: now,
601        })
602    }
603}
604
605// Helper functions
606fn parse_visibility(s: &str) -> SessionVisibility {
607    match s.to_uppercase().as_str() {
608        "PUBLIC" => SessionVisibility::Public,
609        _ => SessionVisibility::Private,
610    }
611}
612
613fn parse_status(s: &str) -> SessionStatus {
614    match s.to_uppercase().as_str() {
615        "DELETED" => SessionStatus::Deleted,
616        _ => SessionStatus::Active,
617    }
618}
619
620fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
621    DateTime::parse_from_rfc3339(s)
622        .map(|dt| dt.with_timezone(&Utc))
623        .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
624}