1use sqlx::SqlitePool;
11
12use crate::error::MemoryError;
13
14#[derive(Debug, Clone)]
16pub struct GraphSummary {
17 pub id: String,
18 pub goal: String,
19 pub status: String,
20 pub created_at: String,
21 pub finished_at: Option<String>,
22}
23
24pub trait RawGraphStore: Send + Sync {
29 #[allow(async_fn_in_trait)]
35 async fn save_graph(
36 &self,
37 id: &str,
38 goal: &str,
39 status: &str,
40 graph_json: &str,
41 created_at: &str,
42 finished_at: Option<&str>,
43 ) -> Result<(), MemoryError>;
44
45 #[allow(async_fn_in_trait)]
53 async fn load_graph(&self, id: &str) -> Result<Option<String>, MemoryError>;
54
55 #[allow(async_fn_in_trait)]
61 async fn list_graphs(&self, limit: u32) -> Result<Vec<GraphSummary>, MemoryError>;
62
63 #[allow(async_fn_in_trait)]
71 async fn delete_graph(&self, id: &str) -> Result<bool, MemoryError>;
72}
73
74#[derive(Debug, Clone)]
76pub struct SqliteGraphStore {
77 pool: SqlitePool,
78}
79
80impl SqliteGraphStore {
81 #[must_use]
83 pub fn new(pool: SqlitePool) -> Self {
84 Self { pool }
85 }
86}
87
88impl RawGraphStore for SqliteGraphStore {
89 async fn save_graph(
90 &self,
91 id: &str,
92 goal: &str,
93 status: &str,
94 graph_json: &str,
95 created_at: &str,
96 finished_at: Option<&str>,
97 ) -> Result<(), MemoryError> {
98 sqlx::query(
99 "INSERT INTO task_graphs (id, goal, status, graph_json, created_at, finished_at) \
100 VALUES (?, ?, ?, ?, ?, ?) \
101 ON CONFLICT(id) DO UPDATE SET \
102 goal = excluded.goal, \
103 status = excluded.status, \
104 graph_json = excluded.graph_json, \
105 created_at = excluded.created_at, \
106 finished_at = excluded.finished_at",
107 )
108 .bind(id)
109 .bind(goal)
110 .bind(status)
111 .bind(graph_json)
112 .bind(created_at)
113 .bind(finished_at)
114 .execute(&self.pool)
115 .await
116 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
117 Ok(())
118 }
119
120 async fn load_graph(&self, id: &str) -> Result<Option<String>, MemoryError> {
121 let row: Option<(String,)> =
122 sqlx::query_as("SELECT graph_json FROM task_graphs WHERE id = ?")
123 .bind(id)
124 .fetch_optional(&self.pool)
125 .await
126 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
127 Ok(row.map(|(json,)| json))
128 }
129
130 async fn list_graphs(&self, limit: u32) -> Result<Vec<GraphSummary>, MemoryError> {
131 let rows: Vec<(String, String, String, String, Option<String>)> = sqlx::query_as(
132 "SELECT id, goal, status, created_at, finished_at \
133 FROM task_graphs \
134 ORDER BY created_at DESC \
135 LIMIT ?",
136 )
137 .bind(limit)
138 .fetch_all(&self.pool)
139 .await
140 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
141
142 Ok(rows
143 .into_iter()
144 .map(|(id, goal, status, created_at, finished_at)| GraphSummary {
145 id,
146 goal,
147 status,
148 created_at,
149 finished_at,
150 })
151 .collect())
152 }
153
154 async fn delete_graph(&self, id: &str) -> Result<bool, MemoryError> {
155 let result = sqlx::query("DELETE FROM task_graphs WHERE id = ?")
156 .bind(id)
157 .execute(&self.pool)
158 .await
159 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
160 Ok(result.rows_affected() > 0)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::sqlite::SqliteStore;
168
169 async fn make_store() -> SqliteGraphStore {
170 let sqlite = SqliteStore::new(":memory:").await.expect("SqliteStore");
171 SqliteGraphStore::new(sqlite.pool().clone())
172 }
173
174 #[tokio::test]
175 async fn test_save_and_load_roundtrip() {
176 let store = make_store().await;
177 store
178 .save_graph("id-1", "goal", "created", r#"{"key":"val"}"#, "100", None)
179 .await
180 .expect("save");
181 let loaded = store
182 .load_graph("id-1")
183 .await
184 .expect("load")
185 .expect("should exist");
186 assert_eq!(loaded, r#"{"key":"val"}"#);
187 }
188
189 #[tokio::test]
190 async fn test_load_nonexistent() {
191 let store = make_store().await;
192 let result = store.load_graph("missing-id").await.expect("load");
193 assert!(result.is_none());
194 }
195
196 #[tokio::test]
197 async fn test_list_graphs_ordering() {
198 let store = make_store().await;
199 store
200 .save_graph("id-1", "first", "created", "{}", "100", None)
201 .await
202 .expect("save 1");
203 store
204 .save_graph("id-2", "second", "created", "{}", "200", None)
205 .await
206 .expect("save 2");
207 let list = store.list_graphs(10).await.expect("list");
208 assert_eq!(list.len(), 2);
209 assert_eq!(list[0].id, "id-2");
211 assert_eq!(list[1].id, "id-1");
212 }
213
214 #[tokio::test]
215 async fn test_delete_graph() {
216 let store = make_store().await;
217 store
218 .save_graph("id-del", "goal", "created", "{}", "1", None)
219 .await
220 .expect("save");
221 let deleted = store.delete_graph("id-del").await.expect("delete");
222 assert!(deleted);
223 let loaded = store.load_graph("id-del").await.expect("load");
224 assert!(loaded.is_none());
225 }
226
227 #[tokio::test]
228 async fn test_save_overwrites_existing() {
229 let store = make_store().await;
230 store
231 .save_graph("id-1", "old", "created", r#"{"v":1}"#, "1", None)
232 .await
233 .expect("save 1");
234 store
235 .save_graph("id-1", "new", "running", r#"{"v":2}"#, "1", None)
236 .await
237 .expect("save 2 (upsert)");
238 let loaded = store
239 .load_graph("id-1")
240 .await
241 .expect("load")
242 .expect("exists");
243 assert_eq!(loaded, r#"{"v":2}"#);
244 }
245}