stakpak_api/local/
db.rs

1use crate::models::*;
2use libsql::{Connection, Row};
3use std::str::FromStr;
4use uuid::Uuid;
5
6#[derive(serde::Deserialize)]
7struct SessionRow {
8    id: String,
9    title: String,
10    agent_id: String,
11    visibility: String,
12    created_at: String,
13    updated_at: String,
14}
15
16#[derive(serde::Deserialize)]
17struct CheckpointRow {
18    id: String,
19    session_id: String,
20    status: String,
21    execution_depth: i32,
22    parent_id: Option<String>,
23    created_at: String,
24    updated_at: String,
25    state: Option<String>,
26}
27
28pub async fn init_schema(conn: &Connection) -> Result<(), String> {
29    conn.execute(
30        "CREATE TABLE IF NOT EXISTS sessions (
31            id TEXT PRIMARY KEY,
32            title TEXT NOT NULL,
33            agent_id TEXT NOT NULL,
34            visibility TEXT NOT NULL,
35            created_at TEXT NOT NULL,
36            updated_at TEXT NOT NULL
37        )",
38        (),
39    )
40    .await
41    .map_err(|e| e.to_string())?;
42
43    conn.execute(
44        "CREATE TABLE IF NOT EXISTS checkpoints (
45            id TEXT PRIMARY KEY,
46            session_id TEXT NOT NULL,
47            status TEXT NOT NULL,
48            execution_depth INTEGER NOT NULL,
49            parent_id TEXT,
50            created_at TEXT NOT NULL,
51            updated_at TEXT NOT NULL,
52            state TEXT,
53            FOREIGN KEY(session_id) REFERENCES sessions(id),
54            FOREIGN KEY(parent_id) REFERENCES checkpoints(id)
55        )",
56        (),
57    )
58    .await
59    .map_err(|e| e.to_string())?;
60
61    Ok(())
62}
63
64pub async fn list_sessions(conn: &Connection) -> Result<Vec<AgentSession>, String> {
65    let mut rows = conn
66        .query("SELECT * FROM sessions ORDER BY updated_at DESC", ())
67        .await
68        .map_err(|e| e.to_string())?;
69
70    let mut sessions = Vec::new();
71
72    while let Ok(Some(row)) = rows.next().await {
73        let session_row: SessionRow = libsql::de::from_row(&row).map_err(|e| e.to_string())?;
74        let session_id = Uuid::from_str(&session_row.id).map_err(|e| e.to_string())?;
75
76        let checkpoints = get_session_checkpoints(conn, session_id).await?;
77
78        sessions.push((session_row, checkpoints).try_into()?);
79    }
80
81    Ok(sessions)
82}
83
84pub async fn get_session(conn: &Connection, session_id: Uuid) -> Result<AgentSession, String> {
85    let mut rows = conn
86        .query(
87            "SELECT * FROM sessions WHERE id = ?",
88            [session_id.to_string()],
89        )
90        .await
91        .map_err(|e| e.to_string())?;
92
93    if let Ok(Some(row)) = rows.next().await {
94        let session_row: SessionRow = libsql::de::from_row(&row).map_err(|e| e.to_string())?;
95        let checkpoints = get_session_checkpoints(conn, session_id).await?;
96
97        Ok((session_row, checkpoints).try_into()?)
98    } else {
99        Err("Session not found".to_string())
100    }
101}
102
103async fn get_session_checkpoints(
104    conn: &Connection,
105    session_id: Uuid,
106) -> Result<Vec<AgentCheckpointListItem>, String> {
107    let mut rows = conn
108        .query(
109            "SELECT * FROM checkpoints WHERE session_id = ? ORDER BY created_at ASC",
110            [session_id.to_string()],
111        )
112        .await
113        .map_err(|e| e.to_string())?;
114
115    let mut checkpoints = Vec::new();
116
117    while let Ok(Some(row)) = rows.next().await {
118        checkpoints.push((&row).try_into()?);
119    }
120
121    Ok(checkpoints)
122}
123
124impl TryFrom<CheckpointRow> for AgentCheckpointListItem {
125    type Error = String;
126
127    fn try_from(row: CheckpointRow) -> Result<Self, Self::Error> {
128        Ok(AgentCheckpointListItem {
129            id: Uuid::from_str(&row.id).map_err(|e| e.to_string())?,
130            status: match row.status.as_str() {
131                "RUNNING" => AgentStatus::Running,
132                "COMPLETE" => AgentStatus::Complete,
133                "BLOCKED" => AgentStatus::Blocked,
134                "FAILED" => AgentStatus::Failed,
135                _ => AgentStatus::Failed,
136            },
137            execution_depth: row.execution_depth as usize,
138            parent: row.parent_id.map(|pid| AgentParentCheckpoint {
139                id: Uuid::from_str(&pid).unwrap_or_default(),
140            }),
141            created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
142                .map_err(|e| e.to_string())?
143                .with_timezone(&chrono::Utc),
144            updated_at: chrono::DateTime::parse_from_rfc3339(&row.updated_at)
145                .map_err(|e| e.to_string())?
146                .with_timezone(&chrono::Utc),
147        })
148    }
149}
150
151impl TryFrom<&Row> for AgentCheckpointListItem {
152    type Error = String;
153
154    fn try_from(row: &Row) -> Result<Self, Self::Error> {
155        let checkpoint_row: CheckpointRow = libsql::de::from_row(row).map_err(|e| e.to_string())?;
156        checkpoint_row.try_into()
157    }
158}
159
160impl TryFrom<(SessionRow, Vec<AgentCheckpointListItem>)> for AgentSession {
161    type Error = String;
162
163    fn try_from(
164        (row, checkpoints): (SessionRow, Vec<AgentCheckpointListItem>),
165    ) -> Result<Self, Self::Error> {
166        Ok(AgentSession {
167            id: Uuid::from_str(&row.id).map_err(|e| e.to_string())?,
168            title: row.title,
169            agent_id: AgentID::from_str(&row.agent_id).map_err(|e| e.to_string())?,
170            visibility: match row.visibility.as_str() {
171                "PUBLIC" => AgentSessionVisibility::Public,
172                _ => AgentSessionVisibility::Private,
173            },
174            created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
175                .map_err(|e| e.to_string())?
176                .with_timezone(&chrono::Utc),
177            updated_at: chrono::DateTime::parse_from_rfc3339(&row.updated_at)
178                .map_err(|e| e.to_string())?
179                .with_timezone(&chrono::Utc),
180            checkpoints,
181        })
182    }
183}
184
185pub async fn get_checkpoint(
186    conn: &Connection,
187    checkpoint_id: Uuid,
188) -> Result<RunAgentOutput, String> {
189    let mut rows = conn
190        .query(
191            "SELECT * FROM checkpoints WHERE id = ?",
192            [checkpoint_id.to_string()],
193        )
194        .await
195        .map_err(|e| e.to_string())?;
196
197    if let Ok(Some(row)) = rows.next().await {
198        let checkpoint_row: CheckpointRow =
199            libsql::de::from_row(&row).map_err(|e| e.to_string())?;
200        let checkpoint = (&row).try_into()?;
201
202        let session_id = Uuid::from_str(&checkpoint_row.session_id).map_err(|e| e.to_string())?;
203        let session = get_session(conn, session_id).await?;
204
205        let state = if let Some(s) = checkpoint_row.state {
206            serde_json::from_str(&s).map_err(|e| e.to_string())?
207        } else {
208            return Err("Checkpoint state not found".to_string());
209        };
210
211        Ok(RunAgentOutput {
212            checkpoint,
213            session: session.into(),
214            output: state,
215        })
216    } else {
217        Err("Checkpoint not found".to_string())
218    }
219}
220
221pub async fn get_latest_checkpoint(
222    conn: &Connection,
223    session_id: Uuid,
224) -> Result<RunAgentOutput, String> {
225    let mut rows = conn
226        .query(
227            "SELECT * FROM checkpoints WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
228            [session_id.to_string()],
229        )
230        .await
231        .map_err(|e| e.to_string())?;
232
233    if let Ok(Some(row)) = rows.next().await {
234        let checkpoint_row: CheckpointRow =
235            libsql::de::from_row(&row).map_err(|e| e.to_string())?;
236        let checkpoint_id = Uuid::from_str(&checkpoint_row.id).map_err(|e| e.to_string())?;
237        get_checkpoint(conn, checkpoint_id).await
238    } else {
239        Err("No checkpoints found for session".to_string())
240    }
241}
242
243pub async fn create_session(conn: &Connection, session: &AgentSession) -> Result<(), String> {
244    conn.execute(
245        "INSERT INTO sessions (id, title, agent_id, visibility, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
246        (
247            session.id.to_string(),
248            session.title.as_str(),
249            match session.agent_id {
250                AgentID::PabloV1 => "pablo:v1",
251            },
252            session.visibility.to_string(),
253            session.created_at.to_rfc3339(),
254            session.updated_at.to_rfc3339(),
255        ),
256    )
257    .await
258    .map_err(|e| e.to_string())?;
259    Ok(())
260}
261
262pub async fn create_checkpoint(
263    conn: &Connection,
264    session_id: Uuid,
265    checkpoint: &AgentCheckpointListItem,
266    state: &AgentOutput,
267) -> Result<(), String> {
268    let state_json = serde_json::to_string(state).map_err(|e| e.to_string())?;
269
270    conn.execute(
271        "INSERT INTO checkpoints (id, session_id, status, execution_depth, parent_id, created_at, updated_at, state) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
272        (
273            checkpoint.id.to_string(),
274            session_id.to_string(),
275            checkpoint.status.to_string(),
276            checkpoint.execution_depth as i32,
277            checkpoint.parent.as_ref().map(|p| p.id.to_string()),
278            checkpoint.created_at.to_rfc3339(),
279            checkpoint.updated_at.to_rfc3339(),
280            state_json,
281        ),
282    )
283    .await
284    .map_err(|e| e.to_string())?;
285    Ok(())
286}