1use crate::models::*;
2use libsql::{Connection, Row};
3use std::str::FromStr;
4use uuid::Uuid;
5
6#[derive(serde::Deserialize)]
7struct SessionRow {
8 id: String,
9 title: String,
10 agent_id: String,
11 visibility: String,
12 created_at: String,
13 updated_at: String,
14}
15
16#[derive(serde::Deserialize)]
17struct CheckpointRow {
18 id: String,
19 session_id: String,
20 status: String,
21 execution_depth: i32,
22 parent_id: Option<String>,
23 created_at: String,
24 updated_at: String,
25 state: Option<String>,
26}
27
28pub async fn init_schema(conn: &Connection) -> Result<(), String> {
29 conn.execute(
30 "CREATE TABLE IF NOT EXISTS sessions (
31 id TEXT PRIMARY KEY,
32 title TEXT NOT NULL,
33 agent_id TEXT NOT NULL,
34 visibility TEXT NOT NULL,
35 created_at TEXT NOT NULL,
36 updated_at TEXT NOT NULL
37 )",
38 (),
39 )
40 .await
41 .map_err(|e| e.to_string())?;
42
43 conn.execute(
44 "CREATE TABLE IF NOT EXISTS checkpoints (
45 id TEXT PRIMARY KEY,
46 session_id TEXT NOT NULL,
47 status TEXT NOT NULL,
48 execution_depth INTEGER NOT NULL,
49 parent_id TEXT,
50 created_at TEXT NOT NULL,
51 updated_at TEXT NOT NULL,
52 state TEXT,
53 FOREIGN KEY(session_id) REFERENCES sessions(id),
54 FOREIGN KEY(parent_id) REFERENCES checkpoints(id)
55 )",
56 (),
57 )
58 .await
59 .map_err(|e| e.to_string())?;
60
61 Ok(())
62}
63
64pub async fn list_sessions(conn: &Connection) -> Result<Vec<AgentSession>, String> {
65 let mut rows = conn
66 .query("SELECT * FROM sessions ORDER BY updated_at DESC", ())
67 .await
68 .map_err(|e| e.to_string())?;
69
70 let mut sessions = Vec::new();
71
72 while let Ok(Some(row)) = rows.next().await {
73 let session_row: SessionRow = libsql::de::from_row(&row).map_err(|e| e.to_string())?;
74 let session_id = Uuid::from_str(&session_row.id).map_err(|e| e.to_string())?;
75
76 let checkpoints = get_session_checkpoints(conn, session_id).await?;
77
78 sessions.push((session_row, checkpoints).try_into()?);
79 }
80
81 Ok(sessions)
82}
83
84pub async fn get_session(conn: &Connection, session_id: Uuid) -> Result<AgentSession, String> {
85 let mut rows = conn
86 .query(
87 "SELECT * FROM sessions WHERE id = ?",
88 [session_id.to_string()],
89 )
90 .await
91 .map_err(|e| e.to_string())?;
92
93 if let Ok(Some(row)) = rows.next().await {
94 let session_row: SessionRow = libsql::de::from_row(&row).map_err(|e| e.to_string())?;
95 let checkpoints = get_session_checkpoints(conn, session_id).await?;
96
97 Ok((session_row, checkpoints).try_into()?)
98 } else {
99 Err("Session not found".to_string())
100 }
101}
102
103async fn get_session_checkpoints(
104 conn: &Connection,
105 session_id: Uuid,
106) -> Result<Vec<AgentCheckpointListItem>, String> {
107 let mut rows = conn
108 .query(
109 "SELECT * FROM checkpoints WHERE session_id = ? ORDER BY created_at ASC",
110 [session_id.to_string()],
111 )
112 .await
113 .map_err(|e| e.to_string())?;
114
115 let mut checkpoints = Vec::new();
116
117 while let Ok(Some(row)) = rows.next().await {
118 checkpoints.push((&row).try_into()?);
119 }
120
121 Ok(checkpoints)
122}
123
124impl TryFrom<CheckpointRow> for AgentCheckpointListItem {
125 type Error = String;
126
127 fn try_from(row: CheckpointRow) -> Result<Self, Self::Error> {
128 Ok(AgentCheckpointListItem {
129 id: Uuid::from_str(&row.id).map_err(|e| e.to_string())?,
130 status: match row.status.as_str() {
131 "RUNNING" => AgentStatus::Running,
132 "COMPLETE" => AgentStatus::Complete,
133 "BLOCKED" => AgentStatus::Blocked,
134 "FAILED" => AgentStatus::Failed,
135 _ => AgentStatus::Failed,
136 },
137 execution_depth: row.execution_depth as usize,
138 parent: row.parent_id.map(|pid| AgentParentCheckpoint {
139 id: Uuid::from_str(&pid).unwrap_or_default(),
140 }),
141 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
142 .map_err(|e| e.to_string())?
143 .with_timezone(&chrono::Utc),
144 updated_at: chrono::DateTime::parse_from_rfc3339(&row.updated_at)
145 .map_err(|e| e.to_string())?
146 .with_timezone(&chrono::Utc),
147 })
148 }
149}
150
151impl TryFrom<&Row> for AgentCheckpointListItem {
152 type Error = String;
153
154 fn try_from(row: &Row) -> Result<Self, Self::Error> {
155 let checkpoint_row: CheckpointRow = libsql::de::from_row(row).map_err(|e| e.to_string())?;
156 checkpoint_row.try_into()
157 }
158}
159
160impl TryFrom<(SessionRow, Vec<AgentCheckpointListItem>)> for AgentSession {
161 type Error = String;
162
163 fn try_from(
164 (row, checkpoints): (SessionRow, Vec<AgentCheckpointListItem>),
165 ) -> Result<Self, Self::Error> {
166 Ok(AgentSession {
167 id: Uuid::from_str(&row.id).map_err(|e| e.to_string())?,
168 title: row.title,
169 agent_id: AgentID::from_str(&row.agent_id).map_err(|e| e.to_string())?,
170 visibility: match row.visibility.as_str() {
171 "PUBLIC" => AgentSessionVisibility::Public,
172 _ => AgentSessionVisibility::Private,
173 },
174 created_at: chrono::DateTime::parse_from_rfc3339(&row.created_at)
175 .map_err(|e| e.to_string())?
176 .with_timezone(&chrono::Utc),
177 updated_at: chrono::DateTime::parse_from_rfc3339(&row.updated_at)
178 .map_err(|e| e.to_string())?
179 .with_timezone(&chrono::Utc),
180 checkpoints,
181 })
182 }
183}
184
185pub async fn get_checkpoint(
186 conn: &Connection,
187 checkpoint_id: Uuid,
188) -> Result<RunAgentOutput, String> {
189 let mut rows = conn
190 .query(
191 "SELECT * FROM checkpoints WHERE id = ?",
192 [checkpoint_id.to_string()],
193 )
194 .await
195 .map_err(|e| e.to_string())?;
196
197 if let Ok(Some(row)) = rows.next().await {
198 let checkpoint_row: CheckpointRow =
199 libsql::de::from_row(&row).map_err(|e| e.to_string())?;
200 let checkpoint = (&row).try_into()?;
201
202 let session_id = Uuid::from_str(&checkpoint_row.session_id).map_err(|e| e.to_string())?;
203 let session = get_session(conn, session_id).await?;
204
205 let state = if let Some(s) = checkpoint_row.state {
206 serde_json::from_str(&s).map_err(|e| e.to_string())?
207 } else {
208 return Err("Checkpoint state not found".to_string());
209 };
210
211 Ok(RunAgentOutput {
212 checkpoint,
213 session: session.into(),
214 output: state,
215 })
216 } else {
217 Err("Checkpoint not found".to_string())
218 }
219}
220
221pub async fn get_latest_checkpoint(
222 conn: &Connection,
223 session_id: Uuid,
224) -> Result<RunAgentOutput, String> {
225 let mut rows = conn
226 .query(
227 "SELECT * FROM checkpoints WHERE session_id = ? ORDER BY created_at DESC LIMIT 1",
228 [session_id.to_string()],
229 )
230 .await
231 .map_err(|e| e.to_string())?;
232
233 if let Ok(Some(row)) = rows.next().await {
234 let checkpoint_row: CheckpointRow =
235 libsql::de::from_row(&row).map_err(|e| e.to_string())?;
236 let checkpoint_id = Uuid::from_str(&checkpoint_row.id).map_err(|e| e.to_string())?;
237 get_checkpoint(conn, checkpoint_id).await
238 } else {
239 Err("No checkpoints found for session".to_string())
240 }
241}
242
243pub async fn create_session(conn: &Connection, session: &AgentSession) -> Result<(), String> {
244 conn.execute(
245 "INSERT INTO sessions (id, title, agent_id, visibility, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)",
246 (
247 session.id.to_string(),
248 session.title.as_str(),
249 match session.agent_id {
250 AgentID::PabloV1 => "pablo:v1",
251 },
252 session.visibility.to_string(),
253 session.created_at.to_rfc3339(),
254 session.updated_at.to_rfc3339(),
255 ),
256 )
257 .await
258 .map_err(|e| e.to_string())?;
259 Ok(())
260}
261
262pub async fn create_checkpoint(
263 conn: &Connection,
264 session_id: Uuid,
265 checkpoint: &AgentCheckpointListItem,
266 state: &AgentOutput,
267) -> Result<(), String> {
268 let state_json = serde_json::to_string(state).map_err(|e| e.to_string())?;
269
270 conn.execute(
271 "INSERT INTO checkpoints (id, session_id, status, execution_depth, parent_id, created_at, updated_at, state) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
272 (
273 checkpoint.id.to_string(),
274 session_id.to_string(),
275 checkpoint.status.to_string(),
276 checkpoint.execution_depth as i32,
277 checkpoint.parent.as_ref().map(|p| p.id.to_string()),
278 checkpoint.created_at.to_rfc3339(),
279 checkpoint.updated_at.to_rfc3339(),
280 state_json,
281 ),
282 )
283 .await
284 .map_err(|e| e.to_string())?;
285 Ok(())
286}