Skip to main content

stakpak_api/local/
storage.rs

1//! Local SQLite storage implementation
2//!
3//! Implements SessionStorage using local SQLite database.
4
5use crate::storage::{
6    Checkpoint, CheckpointState, CheckpointSummary, CreateCheckpointRequest, CreateSessionRequest,
7    CreateSessionResult, ListCheckpointsQuery, ListCheckpointsResult, ListSessionsQuery,
8    ListSessionsResult, Session, SessionStatus, SessionStorage, SessionSummary, SessionVisibility,
9    StorageError, UpdateSessionRequest,
10};
11use async_trait::async_trait;
12use chrono::{DateTime, Utc};
13use libsql::Connection;
14use std::path::Path;
15use std::str::FromStr;
16use tokio::sync::Mutex;
17use uuid::Uuid;
18
19/// Local SQLite storage implementation
20///
21/// Uses a `Mutex<Connection>` to ensure thread-safe access to the
22/// underlying libsql connection, which is not `Send + Sync` by default.
23pub struct LocalStorage {
24    conn: Mutex<Connection>,
25}
26
27impl LocalStorage {
28    /// Create a new local storage instance
29    pub async fn new(db_path: &str) -> Result<Self, StorageError> {
30        // Ensure parent directory exists
31        if let Some(parent) = Path::new(db_path).parent() {
32            std::fs::create_dir_all(parent).map_err(|e| {
33                StorageError::Connection(format!("Failed to create database directory: {}", e))
34            })?;
35        }
36
37        let db = libsql::Builder::new_local(db_path)
38            .build()
39            .await
40            .map_err(|e| StorageError::Connection(format!("Failed to open database: {}", e)))?;
41
42        let conn = db.connect().map_err(|e| {
43            StorageError::Connection(format!("Failed to connect to database: {}", e))
44        })?;
45
46        let storage = Self {
47            conn: Mutex::new(conn),
48        };
49        storage.init_schema().await?;
50
51        Ok(storage)
52    }
53
54    /// Create from existing connection (for use with AgentClient)
55    ///
56    /// This will initialize the database schema if it hasn't been set up yet.
57    pub async fn from_connection(conn: Connection) -> Result<Self, StorageError> {
58        let storage = Self {
59            conn: Mutex::new(conn),
60        };
61        storage.init_schema().await?;
62        Ok(storage)
63    }
64
65    /// Initialize database schema
66    ///
67    /// This creates tables if they don't exist and migrates old schema to new schema.
68    /// Old schema had: sessions(agent_id), checkpoints(status, execution_depth)
69    /// New schema has: sessions(status, cwd), checkpoints(state required)
70    async fn init_schema(&self) -> Result<(), StorageError> {
71        let conn = self.conn.lock().await;
72
73        // Create sessions table (compatible with both old and new schema)
74        conn.execute(
75            "CREATE TABLE IF NOT EXISTS sessions (
76                    id TEXT PRIMARY KEY,
77                    title TEXT NOT NULL,
78                    agent_id TEXT,
79                    visibility TEXT NOT NULL DEFAULT 'PRIVATE',
80                    status TEXT DEFAULT 'ACTIVE',
81                    cwd TEXT,
82                    created_at TEXT NOT NULL,
83                    updated_at TEXT NOT NULL
84                )",
85            (),
86        )
87        .await
88        .map_err(|e| StorageError::Internal(e.to_string()))?;
89
90        // Create checkpoints table (compatible with both old and new schema)
91        conn.execute(
92            "CREATE TABLE IF NOT EXISTS checkpoints (
93                    id TEXT PRIMARY KEY,
94                    session_id TEXT NOT NULL,
95                    status TEXT,
96                    execution_depth INTEGER,
97                    parent_id TEXT,
98                    state TEXT,
99                    created_at TEXT NOT NULL,
100                    updated_at TEXT NOT NULL,
101                    FOREIGN KEY(session_id) REFERENCES sessions(id),
102                    FOREIGN KEY(parent_id) REFERENCES checkpoints(id)
103                )",
104            (),
105        )
106        .await
107        .map_err(|e| StorageError::Internal(e.to_string()))?;
108
109        // Migrate old schema: add missing columns if they don't exist
110        // These will silently fail if columns already exist, which is fine
111        let _ = conn
112            .execute(
113                "ALTER TABLE sessions ADD COLUMN status TEXT DEFAULT 'ACTIVE'",
114                (),
115            )
116            .await;
117        let _ = conn
118            .execute("ALTER TABLE sessions ADD COLUMN cwd TEXT", ())
119            .await;
120
121        // Create index for faster lookups
122        conn.execute(
123            "CREATE INDEX IF NOT EXISTS idx_checkpoints_session_id ON checkpoints(session_id)",
124            (),
125        )
126        .await
127        .map_err(|e| StorageError::Internal(e.to_string()))?;
128
129        Ok(())
130    }
131
132    /// Get connection reference (locked)
133    ///
134    /// Useful for direct SQL access in tests or migrations.
135    pub fn connection(&self) -> &Mutex<Connection> {
136        &self.conn
137    }
138
139    /// Get the latest checkpoint for a session
140    ///
141    /// Expects the caller to already hold the connection lock.
142    async fn get_latest_checkpoint_for_session_inner(
143        conn: &Connection,
144        session_id: Uuid,
145    ) -> Result<Checkpoint, StorageError> {
146        let mut rows = conn
147            .query(
148                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
149                 WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
150                [session_id.to_string()],
151            )
152            .await
153            .map_err(|e| StorageError::Internal(e.to_string()))?;
154
155        if let Ok(Some(row)) = rows.next().await {
156            let id: String = row
157                .get(0)
158                .map_err(|e| StorageError::Internal(e.to_string()))?;
159            let session_id: String = row
160                .get(1)
161                .map_err(|e| StorageError::Internal(e.to_string()))?;
162            let parent_id: Option<String> = row.get(2).ok();
163            let state: Option<String> = row.get(3).ok();
164            let created_at: String = row
165                .get(4)
166                .map_err(|e| StorageError::Internal(e.to_string()))?;
167            let updated_at: String = row
168                .get(5)
169                .map_err(|e| StorageError::Internal(e.to_string()))?;
170
171            let state: CheckpointState = if let Some(state_str) = state {
172                serde_json::from_str(&state_str).unwrap_or_default()
173            } else {
174                CheckpointState::default()
175            };
176
177            Ok(Checkpoint {
178                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
179                session_id: Uuid::from_str(&session_id)
180                    .map_err(|e| StorageError::Internal(e.to_string()))?,
181                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
182                state,
183                created_at: parse_datetime(&created_at)?,
184                updated_at: parse_datetime(&updated_at)?,
185            })
186        } else {
187            Err(StorageError::NotFound(format!(
188                "No checkpoints found for session {}",
189                session_id
190            )))
191        }
192    }
193}
194
195#[async_trait]
196impl SessionStorage for LocalStorage {
197    async fn list_sessions(
198        &self,
199        query: &ListSessionsQuery,
200    ) -> Result<ListSessionsResult, StorageError> {
201        let limit = query.limit.unwrap_or(100);
202        let offset = query.offset.unwrap_or(0);
203
204        let mut sql = "SELECT s.id, s.title, s.visibility, COALESCE(s.status, 'ACTIVE') as status, s.cwd, s.created_at, s.updated_at, 
205            (SELECT COUNT(*) FROM checkpoints c WHERE c.session_id = s.id) as checkpoint_count,
206            (SELECT id FROM checkpoints c WHERE c.session_id = s.id ORDER BY created_at DESC LIMIT 1) as active_checkpoint_id
207            FROM sessions s WHERE 1=1".to_string();
208
209        // Use parameterized values for enum filters (safe because they come from
210        // our own Display impls, but we keep this consistent with the rest of
211        // the codebase).  The search term is the only free-form user input and
212        // is handled with a parameter below.
213        if let Some(status) = &query.status {
214            sql.push_str(&format!(" AND s.status = '{}'", status));
215        }
216        if let Some(visibility) = &query.visibility {
217            sql.push_str(&format!(" AND s.visibility = '{}'", visibility));
218        }
219        if query.search.is_some() {
220            sql.push_str(" AND s.title LIKE '%' || ? || '%'");
221        }
222
223        sql.push_str(&format!(
224            " ORDER BY s.updated_at DESC LIMIT {} OFFSET {}",
225            limit, offset
226        ));
227
228        let conn = self.conn.lock().await;
229        let mut rows = if let Some(search) = &query.search {
230            conn.query(&sql, [search.as_str()])
231                .await
232                .map_err(|e| StorageError::Internal(e.to_string()))?
233        } else {
234            conn.query(&sql, ())
235                .await
236                .map_err(|e| StorageError::Internal(e.to_string()))?
237        };
238
239        let mut sessions = Vec::new();
240        while let Ok(Some(row)) = rows.next().await {
241            let id: String = row
242                .get(0)
243                .map_err(|e| StorageError::Internal(e.to_string()))?;
244            let title: String = row
245                .get(1)
246                .map_err(|e| StorageError::Internal(e.to_string()))?;
247            let visibility: String = row
248                .get(2)
249                .map_err(|e| StorageError::Internal(e.to_string()))?;
250            let status: String = row
251                .get(3)
252                .map_err(|e| StorageError::Internal(e.to_string()))?;
253            let cwd: Option<String> = row.get(4).ok();
254            let created_at: String = row
255                .get(5)
256                .map_err(|e| StorageError::Internal(e.to_string()))?;
257            let updated_at: String = row
258                .get(6)
259                .map_err(|e| StorageError::Internal(e.to_string()))?;
260            let checkpoint_count: i64 = row.get(7).unwrap_or(0);
261            let active_checkpoint_id: Option<String> = row.get(8).ok();
262
263            sessions.push(SessionSummary {
264                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
265                title,
266                visibility: parse_visibility(&visibility),
267                status: parse_status(&status),
268                cwd,
269                created_at: parse_datetime(&created_at)?,
270                updated_at: parse_datetime(&updated_at)?,
271                message_count: checkpoint_count as u32,
272                active_checkpoint_id: active_checkpoint_id.and_then(|id| Uuid::from_str(&id).ok()),
273                last_message_at: None,
274            });
275        }
276
277        Ok(ListSessionsResult {
278            sessions,
279            total: None,
280        })
281    }
282
283    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError> {
284        let conn = self.conn.lock().await;
285        let mut rows = conn
286            .query(
287                "SELECT id, title, visibility, COALESCE(status, 'ACTIVE') as status, cwd, created_at, updated_at FROM sessions WHERE id = ?",
288                [session_id.to_string()],
289            )
290            .await
291            .map_err(|e| StorageError::Internal(e.to_string()))?;
292
293        if let Ok(Some(row)) = rows.next().await {
294            let id: String = row
295                .get(0)
296                .map_err(|e| StorageError::Internal(e.to_string()))?;
297            let title: String = row
298                .get(1)
299                .map_err(|e| StorageError::Internal(e.to_string()))?;
300            let visibility: String = row
301                .get(2)
302                .map_err(|e| StorageError::Internal(e.to_string()))?;
303            let status: String = row
304                .get(3)
305                .map_err(|e| StorageError::Internal(e.to_string()))?;
306            let cwd: Option<String> = row.get(4).ok();
307            let created_at: String = row
308                .get(5)
309                .map_err(|e| StorageError::Internal(e.to_string()))?;
310            let updated_at: String = row
311                .get(6)
312                .map_err(|e| StorageError::Internal(e.to_string()))?;
313
314            // Get the latest checkpoint (reuse the same lock)
315            let active_checkpoint =
316                Self::get_latest_checkpoint_for_session_inner(&conn, session_id)
317                    .await
318                    .ok();
319
320            Ok(Session {
321                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
322                title,
323                visibility: parse_visibility(&visibility),
324                status: parse_status(&status),
325                cwd,
326                created_at: parse_datetime(&created_at)?,
327                updated_at: parse_datetime(&updated_at)?,
328                active_checkpoint,
329            })
330        } else {
331            Err(StorageError::NotFound(format!(
332                "Session {} not found",
333                session_id
334            )))
335        }
336    }
337
338    async fn create_session(
339        &self,
340        request: &CreateSessionRequest,
341    ) -> Result<CreateSessionResult, StorageError> {
342        let now = Utc::now();
343        let session_id = Uuid::new_v4();
344        let checkpoint_id = Uuid::new_v4();
345
346        let conn = self.conn.lock().await;
347
348        // Create session
349        conn.execute(
350            "INSERT INTO sessions (id, title, visibility, status, cwd, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
351            (
352                session_id.to_string(),
353                request.title.as_str(),
354                request.visibility.to_string(),
355                "ACTIVE",
356                request.cwd.as_deref(),
357                now.to_rfc3339(),
358                now.to_rfc3339(),
359            ),
360        )
361        .await
362        .map_err(|e| StorageError::Internal(e.to_string()))?;
363
364        // Create initial checkpoint
365        let state_json = serde_json::to_string(&request.initial_state)
366            .map_err(|e| StorageError::Internal(e.to_string()))?;
367
368        conn.execute(
369            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
370            (
371                checkpoint_id.to_string(),
372                session_id.to_string(),
373                None::<String>,
374                state_json,
375                now.to_rfc3339(),
376                now.to_rfc3339(),
377            ),
378        )
379        .await
380        .map_err(|e| StorageError::Internal(e.to_string()))?;
381
382        Ok(CreateSessionResult {
383            session_id,
384            checkpoint: Checkpoint {
385                id: checkpoint_id,
386                session_id,
387                parent_id: None,
388                state: request.initial_state.clone(),
389                created_at: now,
390                updated_at: now,
391            },
392        })
393    }
394
395    async fn update_session(
396        &self,
397        session_id: Uuid,
398        request: &UpdateSessionRequest,
399    ) -> Result<Session, StorageError> {
400        let now = Utc::now();
401
402        {
403            let conn = self.conn.lock().await;
404
405            // Update fields individually since libsql doesn't support dynamic params easily
406            if let Some(title) = &request.title {
407                conn.execute(
408                    "UPDATE sessions SET title = ?, updated_at = ? WHERE id = ?",
409                    (title.as_str(), now.to_rfc3339(), session_id.to_string()),
410                )
411                .await
412                .map_err(|e| StorageError::Internal(e.to_string()))?;
413            }
414            if let Some(visibility) = &request.visibility {
415                conn.execute(
416                    "UPDATE sessions SET visibility = ?, updated_at = ? WHERE id = ?",
417                    (
418                        visibility.to_string(),
419                        now.to_rfc3339(),
420                        session_id.to_string(),
421                    ),
422                )
423                .await
424                .map_err(|e| StorageError::Internal(e.to_string()))?;
425            }
426        }
427
428        self.get_session(session_id).await
429    }
430
431    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError> {
432        // Mark as deleted instead of actually deleting
433        let now = Utc::now();
434        let conn = self.conn.lock().await;
435        conn.execute(
436            "UPDATE sessions SET status = 'DELETED', updated_at = ? WHERE id = ?",
437            (now.to_rfc3339(), session_id.to_string()),
438        )
439        .await
440        .map_err(|e| StorageError::Internal(e.to_string()))?;
441        Ok(())
442    }
443
444    async fn list_checkpoints(
445        &self,
446        session_id: Uuid,
447        query: &ListCheckpointsQuery,
448    ) -> Result<ListCheckpointsResult, StorageError> {
449        let limit = query.limit.unwrap_or(100);
450        let offset = query.offset.unwrap_or(0);
451
452        let sql = format!(
453            "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints 
454             WHERE session_id = ? ORDER BY created_at ASC LIMIT {} OFFSET {}",
455            limit, offset
456        );
457
458        let conn = self.conn.lock().await;
459        let mut rows = conn
460            .query(&sql, [session_id.to_string()])
461            .await
462            .map_err(|e| StorageError::Internal(e.to_string()))?;
463
464        let mut checkpoints = Vec::new();
465        while let Ok(Some(row)) = rows.next().await {
466            let id: String = row
467                .get(0)
468                .map_err(|e| StorageError::Internal(e.to_string()))?;
469            let session_id: String = row
470                .get(1)
471                .map_err(|e| StorageError::Internal(e.to_string()))?;
472            let parent_id: Option<String> = row.get(2).ok();
473            let state: Option<String> = row.get(3).ok();
474            let created_at: String = row
475                .get(4)
476                .map_err(|e| StorageError::Internal(e.to_string()))?;
477            let updated_at: String = row
478                .get(5)
479                .map_err(|e| StorageError::Internal(e.to_string()))?;
480
481            let state: CheckpointState = if let Some(state_str) = state {
482                serde_json::from_str(&state_str).unwrap_or_default()
483            } else {
484                CheckpointState::default()
485            };
486
487            checkpoints.push(CheckpointSummary {
488                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
489                session_id: Uuid::from_str(&session_id)
490                    .map_err(|e| StorageError::Internal(e.to_string()))?,
491                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
492                message_count: state.messages.len() as u32,
493                created_at: parse_datetime(&created_at)?,
494                updated_at: parse_datetime(&updated_at)?,
495            });
496        }
497
498        Ok(ListCheckpointsResult {
499            checkpoints,
500            total: None,
501        })
502    }
503
504    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError> {
505        let conn = self.conn.lock().await;
506        let mut rows = conn
507            .query(
508                "SELECT id, session_id, parent_id, state, created_at, updated_at FROM checkpoints WHERE id = ?",
509                [checkpoint_id.to_string()],
510            )
511            .await
512            .map_err(|e| StorageError::Internal(e.to_string()))?;
513
514        if let Ok(Some(row)) = rows.next().await {
515            let id: String = row
516                .get(0)
517                .map_err(|e| StorageError::Internal(e.to_string()))?;
518            let session_id: String = row
519                .get(1)
520                .map_err(|e| StorageError::Internal(e.to_string()))?;
521            let parent_id: Option<String> = row.get(2).ok();
522            let state: Option<String> = row.get(3).ok();
523            let created_at: String = row
524                .get(4)
525                .map_err(|e| StorageError::Internal(e.to_string()))?;
526            let updated_at: String = row
527                .get(5)
528                .map_err(|e| StorageError::Internal(e.to_string()))?;
529
530            let state: CheckpointState = if let Some(state_str) = state {
531                serde_json::from_str(&state_str).unwrap_or_default()
532            } else {
533                CheckpointState::default()
534            };
535
536            Ok(Checkpoint {
537                id: Uuid::from_str(&id).map_err(|e| StorageError::Internal(e.to_string()))?,
538                session_id: Uuid::from_str(&session_id)
539                    .map_err(|e| StorageError::Internal(e.to_string()))?,
540                parent_id: parent_id.and_then(|id| Uuid::from_str(&id).ok()),
541                state,
542                created_at: parse_datetime(&created_at)?,
543                updated_at: parse_datetime(&updated_at)?,
544            })
545        } else {
546            Err(StorageError::NotFound(format!(
547                "Checkpoint {} not found",
548                checkpoint_id
549            )))
550        }
551    }
552
553    async fn create_checkpoint(
554        &self,
555        session_id: Uuid,
556        request: &CreateCheckpointRequest,
557    ) -> Result<Checkpoint, StorageError> {
558        let now = Utc::now();
559        let checkpoint_id = Uuid::new_v4();
560
561        let state_json = serde_json::to_string(&request.state)
562            .map_err(|e| StorageError::Internal(e.to_string()))?;
563
564        let conn = self.conn.lock().await;
565
566        conn.execute(
567            "INSERT INTO checkpoints (id, session_id, parent_id, state, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
568            (
569                checkpoint_id.to_string(),
570                session_id.to_string(),
571                request.parent_id.map(|id| id.to_string()),
572                state_json,
573                now.to_rfc3339(),
574                now.to_rfc3339(),
575            ),
576        )
577        .await
578        .map_err(|e| StorageError::Internal(e.to_string()))?;
579
580        // Update session's updated_at
581        conn.execute(
582            "UPDATE sessions SET updated_at = ? WHERE id = ?",
583            (now.to_rfc3339(), session_id.to_string()),
584        )
585        .await
586        .map_err(|e| StorageError::Internal(e.to_string()))?;
587
588        Ok(Checkpoint {
589            id: checkpoint_id,
590            session_id,
591            parent_id: request.parent_id,
592            state: request.state.clone(),
593            created_at: now,
594            updated_at: now,
595        })
596    }
597}
598
599// Helper functions
600fn parse_visibility(s: &str) -> SessionVisibility {
601    match s.to_uppercase().as_str() {
602        "PUBLIC" => SessionVisibility::Public,
603        _ => SessionVisibility::Private,
604    }
605}
606
607fn parse_status(s: &str) -> SessionStatus {
608    match s.to_uppercase().as_str() {
609        "DELETED" => SessionStatus::Deleted,
610        _ => SessionStatus::Active,
611    }
612}
613
614fn parse_datetime(s: &str) -> Result<DateTime<Utc>, StorageError> {
615    DateTime::parse_from_rfc3339(s)
616        .map(|dt| dt.with_timezone(&Utc))
617        .map_err(|e| StorageError::Internal(format!("Failed to parse datetime: {}", e)))
618}