Skip to main content

stakpak_api/local/
storage.rs

1//! Local SQLite storage implementation
2//!
3//! Implements SessionStorage using local SQLite database.
4
5use crate::storage::{
6    Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7    CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8    ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9    StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::Connection;
14use std::path::Path;
15use std::str::FromStr;
16use tokio::sync::Mutex;
17use uuid::Uuid;
18
19/// Local SQLite storage implementation
20///
21/// Uses a `Mutex<Connection>` to ensure thread-safe access to the
22/// underlying libsql connection, which is not `Send + Sync` by default.
23pub struct LocalStorage {
24    conn: Mutex<Connection>,
25}
26
27impl LocalStorage {
28    /// Create a new local storage instance
29    pub async fn new(db_path: &str) -> Result<Self, StorageError> {
30        // Ensure parent directory exists
31        if let Some(parent) = Path::new(db_path).parent() {
32            std::fs::create_dir_all(parent).map_err(|e| {
33                StorageError::Connection(format!("Failed to create database directory: {}", e))
34            })?;
35        }
36
37        let db = libsql::Builder::new_local(db_path)
38            .build()
39            .await
40            .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
41
42        let conn = db.connect().map_err(|e| {
43            StorageError::Connection(format!("Failed to connect to database: {}", e))
44        })?;
45
46        let storage = Self {
47            conn: Mutex::new(conn),
48        };
49        storage.init_schema().await?;
50
51        Ok(storage)
52    }
53
54    /// Create from existing connection (for use with AgentClient)
55    ///
56    /// This will initialize the database schema if it hasn't been set up yet.
57    pub async fn from_connection(conn: Connection) -> Result<Self, StorageError> {
58        let storage = Self {
59            conn: Mutex::new(conn),
60        };
61        storage.init_schema().await?;
62        Ok(storage)
63    }
64
65    /// Initialize database schema by running migrations
66    async fn init_schema(&self) -> Result<(), StorageError> {
67        let conn = self.conn.lock().await;
68        super::migrations::run_migrations(&conn)
69            .await
70            .map_err(StorageError::Internal)
71    }
72
73    /// Get connection reference (locked)
74    ///
75    /// Useful for direct SQL access in tests or migrations.
76    pub fn connection(&self) -> &Mutex<Connection> {
77        &self.conn
78    }
79
80    /// Get the latest checkpoint for a session
81    ///
82    /// Expects the caller to already hold the connection lock.
83    async fn get_latest_checkpoint_for_session_inner(
84        conn: &Connection,
85        session_id: Uuid,
86    ) -> Result<Checkpoint, StorageError> {
87        let mut rows = conn
88            .query(
89                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
90                 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
91                [session_id.to_string()],
92            )
93            .await
94            .map_err(|e| StorageError::Internal(e.to_string()))?;
95
96        if let Ok(Some(row)) = rows.next().await {
97            let id: String = row
98                .get(0)
99                .map_err(|e| StorageError::Internal(e.to_string()))?;
100            let session_id: String = row
101                .get(1)
102                .map_err(|e| StorageError::Internal(e.to_string()))?;
103            let parent_id: Option<String> = row.get(2).ok();
104            let state: Option<String> = row.get(3).ok();
105            let created_at: String = row
106                .get(4)
107                .map_err(|e| StorageError::Internal(e.to_string()))?;
108            let updated_at: String = row
109                .get(5)
110                .map_err(|e| StorageError::Internal(e.to_string()))?;
111
112            let state: CheckpointState = if let Some(state_str) = state {
113                serde_json::from_str(&state_str).unwrap_or_default()
114            } else {
115                CheckpointState::default()
116            };
117
118            Ok(Checkpoint {
119                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
120                session_id: Uuid::from_str(&session_id)
121                    .map_err(|e| StorageError::Internal(e.to_string()))?,
122                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
123                state,
124                created_at: parse_datetime(&created_at)?,
125                updated_at: parse_datetime(&updated_at)?,
126            })
127        } else {
128            Err(StorageError::NotFound(format!(
129                "No checkpoints found for session {}",
130                session_id
131            )))
132        }
133    }
134}
135
136#[async_trait]
137impl SessionStorage for LocalStorage {
138    async fn list_sessions(
139        &self,
140        query: &ListSessionsQuery,
141    ) -> Result<ListSessionsResult, StorageError> {
142        let limit = query.limit.unwrap_or(100);
143        let offset = query.offset.unwrap_or(0);
144
145        let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at, 
146            (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
147            (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
148            FROM sessions s WHERE 1=1".to_string();
149
150        // Use parameterized values for enum filters (safe because they come from
151        // our own Display impls, but we keep this consistent with the rest of
152        // the codebase).  The search term is the only free-form user input and
153        // is handled with a parameter below.
154        if let Some(status) = &query.status {
155            sql.push_str(&format!(" AND s.status = '{}'", status));
156        }
157        if let Some(visibility) = &query.visibility {
158            sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
159        }
160        if query.search.is_some() {
161            sql.push_str(" AND s.title LIKE '%' || ? || '%'");
162        }
163
164        sql.push_str(&format!(
165            " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
166            limit, offset
167        ));
168
169        let conn = self.conn.lock().await;
170        let mut rows = if let Some(search) = &query.search {
171            conn.query(&sql, [search.as_str()])
172                .await
173                .map_err(|e| StorageError::Internal(e.to_string()))?
174        } else {
175            conn.query(&sql, ())
176                .await
177                .map_err(|e| StorageError::Internal(e.to_string()))?
178        };
179
180        let mut sessions = Vec::new();
181        while let Ok(Some(row)) = rows.next().await {
182            let id: String = row
183                .get(0)
184                .map_err(|e| StorageError::Internal(e.to_string()))?;
185            let title: String = row
186                .get(1)
187                .map_err(|e| StorageError::Internal(e.to_string()))?;
188            let visibility: String = row
189                .get(2)
190                .map_err(|e| StorageError::Internal(e.to_string()))?;
191            let status: String = row
192                .get(3)
193                .map_err(|e| StorageError::Internal(e.to_string()))?;
194            let cwd: Option<String> = row.get(4).ok();
195            let created_at: String = row
196                .get(5)
197                .map_err(|e| StorageError::Internal(e.to_string()))?;
198            let updated_at: String = row
199                .get(6)
200                .map_err(|e| StorageError::Internal(e.to_string()))?;
201            let checkpoint_count: i64 = row.get(7).unwrap_or(0);
202            let active_checkpoint_id: Option<String> = row.get(8).ok();
203
204            sessions.push(SessionSummary {
205                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
206                title,
207                visibility: parse_visibility(&visibility),
208                status: parse_status(&status),
209                cwd,
210                created_at: parse_datetime(&created_at)?,
211                updated_at: parse_datetime(&updated_at)?,
212                message_count: checkpoint_count as u32,
213                active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
214                last_message_at: None,
215            });
216        }
217
218        Ok(ListSessionsResult {
219            sessions,
220            total: None,
221        })
222    }
223
224    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
225        let conn = self.conn.lock().await;
226        let mut rows = conn
227            .query(
228                "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
229                [session_id.to_string()],
230            )
231            .await
232            .map_err(|e| StorageError::Internal(e.to_string()))?;
233
234        if let Ok(Some(row)) = rows.next().await {
235            let id: String = row
236                .get(0)
237                .map_err(|e| StorageError::Internal(e.to_string()))?;
238            let title: String = row
239                .get(1)
240                .map_err(|e| StorageError::Internal(e.to_string()))?;
241            let visibility: String = row
242                .get(2)
243                .map_err(|e| StorageError::Internal(e.to_string()))?;
244            let status: String = row
245                .get(3)
246                .map_err(|e| StorageError::Internal(e.to_string()))?;
247            let cwd: Option<String> = row.get(4).ok();
248            let created_at: String = row
249                .get(5)
250                .map_err(|e| StorageError::Internal(e.to_string()))?;
251            let updated_at: String = row
252                .get(6)
253                .map_err(|e| StorageError::Internal(e.to_string()))?;
254
255            // Get the latest checkpoint (reuse the same lock)
256            let active_checkpoint =
257                Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
258                    .await
259                    .ok();
260
261            Ok(Session {
262                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
263                title,
264                visibility: parse_visibility(&visibility),
265                status: parse_status(&status),
266                cwd,
267                created_at: parse_datetime(&created_at)?,
268                updated_at: parse_datetime(&updated_at)?,
269                active_checkpoint,
270            })
271        } else {
272            Err(StorageError::NotFound(format!(
273                "Session {} not found",
274                session_id
275            )))
276        }
277    }
278
279    async fn create_session(
280        &self,
281        request: &CreateSessionRequest,
282    ) -> Result<CreateSessionResult, StorageError> {
283        let now = Utc::now();
284        let session_id = Uuid::new_v4();
285        let checkpoint_id = Uuid::new_v4();
286
287        let conn = self.conn.lock().await;
288
289        // Create session
290        conn.execute(
291            "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
292            (
293                session_id.to_string(),
294                request.title.as_str(),
295                request.visibility.to_string(),
296                "ACTIVE",
297                request.cwd.as_deref(),
298                now.to_rfc3339(),
299                now.to_rfc3339(),
300            ),
301        )
302        .await
303        .map_err(|e| StorageError::Internal(e.to_string()))?;
304
305        // Create initial checkpoint
306        let state_json = serde_json::to_string(&request.initial_state)
307            .map_err(|e| StorageError::Internal(e.to_string()))?;
308
309        conn.execute(
310            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
311            (
312                checkpoint_id.to_string(),
313                session_id.to_string(),
314                None::<String>,
315                state_json,
316                now.to_rfc3339(),
317                now.to_rfc3339(),
318            ),
319        )
320        .await
321        .map_err(|e| StorageError::Internal(e.to_string()))?;
322
323        Ok(CreateSessionResult {
324            session_id,
325            checkpoint: Checkpoint {
326                id: checkpoint_id,
327                session_id,
328                parent_id: None,
329                state: request.initial_state.clone(),
330                created_at: now,
331                updated_at: now,
332            },
333        })
334    }
335
336    async fn update_session(
337        &self,
338        session_id: Uuid,
339        request: &UpdateSessionRequest,
340    ) -> Result<Session, StorageError> {
341        let now = Utc::now();
342
343        {
344            let conn = self.conn.lock().await;
345
346            // Update fields individually since libsql doesn't support dynamic params easily
347            if let Some(title) = &request.title {
348                conn.execute(
349                    "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
350                    (title.as_str(), now.to_rfc3339(), session_id.to_string()),
351                )
352                .await
353                .map_err(|e| StorageError::Internal(e.to_string()))?;
354            }
355            if let Some(visibility) = &request.visibility {
356                conn.execute(
357                    "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
358                    (
359                        visibility.to_string(),
360                        now.to_rfc3339(),
361                        session_id.to_string(),
362                    ),
363                )
364                .await
365                .map_err(|e| StorageError::Internal(e.to_string()))?;
366            }
367        }
368
369        self.get_session(session_id).await
370    }
371
372    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
373        // Mark as deleted instead of actually deleting
374        let now = Utc::now();
375        let conn = self.conn.lock().await;
376        conn.execute(
377            "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
378            (now.to_rfc3339(), session_id.to_string()),
379        )
380        .await
381        .map_err(|e| StorageError::Internal(e.to_string()))?;
382        Ok(())
383    }
384
385    async fn list_checkpoints(
386        &self,
387        session_id: Uuid,
388        query: &ListCheckpointsQuery,
389    ) -> Result<ListCheckpointsResult, StorageError> {
390        let limit = query.limit.unwrap_or(100);
391        let offset = query.offset.unwrap_or(0);
392
393        let sql = format!(
394            "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
395             WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
396            limit, offset
397        );
398
399        let conn = self.conn.lock().await;
400        let mut rows = conn
401            .query(&sql, [session_id.to_string()])
402            .await
403            .map_err(|e| StorageError::Internal(e.to_string()))?;
404
405        let mut checkpoints = Vec::new();
406        while let Ok(Some(row)) = rows.next().await {
407            let id: String = row
408                .get(0)
409                .map_err(|e| StorageError::Internal(e.to_string()))?;
410            let session_id: String = row
411                .get(1)
412                .map_err(|e| StorageError::Internal(e.to_string()))?;
413            let parent_id: Option<String> = row.get(2).ok();
414            let state: Option<String> = row.get(3).ok();
415            let created_at: String = row
416                .get(4)
417                .map_err(|e| StorageError::Internal(e.to_string()))?;
418            let updated_at: String = row
419                .get(5)
420                .map_err(|e| StorageError::Internal(e.to_string()))?;
421
422            let state: CheckpointState = if let Some(state_str) = state {
423                serde_json::from_str(&state_str).unwrap_or_default()
424            } else {
425                CheckpointState::default()
426            };
427
428            checkpoints.push(CheckpointSummary {
429                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
430                session_id: Uuid::from_str(&session_id)
431                    .map_err(|e| StorageError::Internal(e.to_string()))?,
432                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
433                message_count: state.messages.len() as u32,
434                created_at: parse_datetime(&created_at)?,
435                updated_at: parse_datetime(&updated_at)?,
436            });
437        }
438
439        Ok(ListCheckpointsResult {
440            checkpoints,
441            total: None,
442        })
443    }
444
445    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
446        let conn = self.conn.lock().await;
447        let mut rows = conn
448            .query(
449                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
450                [checkpoint_id.to_string()],
451            )
452            .await
453            .map_err(|e| StorageError::Internal(e.to_string()))?;
454
455        if let Ok(Some(row)) = rows.next().await {
456            let id: String = row
457                .get(0)
458                .map_err(|e| StorageError::Internal(e.to_string()))?;
459            let session_id: String = row
460                .get(1)
461                .map_err(|e| StorageError::Internal(e.to_string()))?;
462            let parent_id: Option<String> = row.get(2).ok();
463            let state: Option<String> = row.get(3).ok();
464            let created_at: String = row
465                .get(4)
466                .map_err(|e| StorageError::Internal(e.to_string()))?;
467            let updated_at: String = row
468                .get(5)
469                .map_err(|e| StorageError::Internal(e.to_string()))?;
470
471            let state: CheckpointState = if let Some(state_str) = state {
472                serde_json::from_str(&state_str).unwrap_or_default()
473            } else {
474                CheckpointState::default()
475            };
476
477            Ok(Checkpoint {
478                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
479                session_id: Uuid::from_str(&session_id)
480                    .map_err(|e| StorageError::Internal(e.to_string()))?,
481                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
482                state,
483                created_at: parse_datetime(&created_at)?,
484                updated_at: parse_datetime(&updated_at)?,
485            })
486        } else {
487            Err(StorageError::NotFound(format!(
488                "Checkpoint {} not found",
489                checkpoint_id
490            )))
491        }
492    }
493
494    async fn create_checkpoint(
495        &self,
496        session_id: Uuid,
497        request: &CreateCheckpointRequest,
498    ) -> Result<Checkpoint, StorageError> {
499        let now = Utc::now();
500        let checkpoint_id = Uuid::new_v4();
501
502        let state_json = serde_json::to_string(&request.state)
503            .map_err(|e| StorageError::Internal(e.to_string()))?;
504
505        let conn = self.conn.lock().await;
506
507        conn.execute(
508            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
509            (
510                checkpoint_id.to_string(),
511                session_id.to_string(),
512                request.parent_id.map(|id| id.to_string()),
513                state_json,
514                now.to_rfc3339(),
515                now.to_rfc3339(),
516            ),
517        )
518        .await
519        .map_err(|e| StorageError::Internal(e.to_string()))?;
520
521        // Update session's updated_at
522        conn.execute(
523            "UPDATE sessions SET updated_at = ? WHERE id = ?",
524            (now.to_rfc3339(), session_id.to_string()),
525        )
526        .await
527        .map_err(|e| StorageError::Internal(e.to_string()))?;
528
529        Ok(Checkpoint {
530            id: checkpoint_id,
531            session_id,
532            parent_id: request.parent_id,
533            state: request.state.clone(),
534            created_at: now,
535            updated_at: now,
536        })
537    }
538}
539
540// Helper functions
541fn parse_visibility(s: &str) -> SessionVisibility {
542    match s.to_uppercase().as_str() {
543        "PUBLIC" => SessionVisibility::Public,
544        _ => SessionVisibility::Private,
545    }
546}
547
548fn parse_status(s: &str) -> SessionStatus {
549    match s.to_uppercase().as_str() {
550        "DELETED" => SessionStatus::Deleted,
551        _ => SessionStatus::Active,
552    }
553}
554
555fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
556    DateTime::parse_from_rfc3339(s)
557        .map(|dt| dt.with_timezone(&Utc))
558        .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
559}