1use zeph_db::DbPool;
11#[allow(unused_imports)]
12use zeph_db::sql;
13
14use crate::error::MemoryError;
15
16#[derive(Debug, Clone)]
18pub struct GraphSummary {
19 pub id: String,
20 pub goal: String,
21 pub status: String,
22 pub created_at: String,
23 pub finished_at: Option<String>,
24}
25
26pub trait RawGraphStore: Send + Sync {
31 #[allow(async_fn_in_trait)]
37 async fn save_graph(
38 &self,
39 id: &str,
40 goal: &str,
41 status: &str,
42 graph_json: &str,
43 created_at: &str,
44 finished_at: Option<&str>,
45 ) -> Result<(), MemoryError>;
46
47 #[allow(async_fn_in_trait)]
55 async fn load_graph(&self, id: &str) -> Result<Option<String>, MemoryError>;
56
57 #[allow(async_fn_in_trait)]
63 async fn list_graphs(&self, limit: u32) -> Result<Vec<GraphSummary>, MemoryError>;
64
65 #[allow(async_fn_in_trait)]
73 async fn delete_graph(&self, id: &str) -> Result<bool, MemoryError>;
74}
75
76#[derive(Debug, Clone)]
78pub struct TaskGraphStore {
79 pool: DbPool,
80}
81
82impl TaskGraphStore {
83 #[must_use]
85 pub fn new(pool: DbPool) -> Self {
86 Self { pool }
87 }
88}
89
90impl RawGraphStore for TaskGraphStore {
91 #[tracing::instrument(skip_all, name = "memory.graph.save_graph")]
92 async fn save_graph(
93 &self,
94 id: &str,
95 goal: &str,
96 status: &str,
97 graph_json: &str,
98 created_at: &str,
99 finished_at: Option<&str>,
100 ) -> Result<(), MemoryError> {
101 zeph_db::query(sql!(
102 "INSERT INTO task_graphs (id, goal, status, graph_json, created_at, finished_at) \
103 VALUES (?, ?, ?, ?, ?, ?) \
104 ON CONFLICT(id) DO UPDATE SET \
105 goal = excluded.goal, \
106 status = excluded.status, \
107 graph_json = excluded.graph_json, \
108 created_at = excluded.created_at, \
109 finished_at = excluded.finished_at"
110 ))
111 .bind(id)
112 .bind(goal)
113 .bind(status)
114 .bind(graph_json)
115 .bind(created_at)
116 .bind(finished_at)
117 .execute(&self.pool)
118 .await
119 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
120 Ok(())
121 }
122
123 #[cfg(not(feature = "postgres"))]
124 #[tracing::instrument(skip_all, name = "memory.graph.load_graph")]
125 async fn load_graph(&self, id: &str) -> Result<Option<String>, MemoryError> {
126 let row: Option<(String,)> =
127 zeph_db::query_as(sql!("SELECT graph_json FROM task_graphs WHERE id = ?"))
128 .bind(id)
129 .fetch_optional(&self.pool)
130 .await
131 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
132 Ok(row.map(|(json,)| json))
133 }
134
135 #[cfg(not(feature = "postgres"))]
136 async fn list_graphs(&self, limit: u32) -> Result<Vec<GraphSummary>, MemoryError> {
137 let rows: Vec<(String, String, String, String, Option<String>)> = zeph_db::query_as(sql!(
138 "SELECT id, goal, status, created_at, finished_at \
139 FROM task_graphs \
140 ORDER BY created_at DESC \
141 LIMIT ?"
142 ))
143 .bind(limit)
144 .fetch_all(&self.pool)
145 .await
146 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
147
148 Ok(rows
149 .into_iter()
150 .map(|(id, goal, status, created_at, finished_at)| GraphSummary {
151 id,
152 goal,
153 status,
154 created_at,
155 finished_at,
156 })
157 .collect())
158 }
159
160 #[cfg(feature = "postgres")]
161 #[tracing::instrument(skip_all, name = "memory.graph.load_graph")]
162 async fn load_graph(&self, id: &str) -> Result<Option<String>, MemoryError> {
163 let row: Option<String> = sqlx::query_scalar::<sqlx::Postgres, String>(sql!(
164 "SELECT graph_json FROM task_graphs WHERE id = ?"
165 ))
166 .bind(id)
167 .fetch_optional(&self.pool)
168 .await
169 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
170 Ok(row)
171 }
172
173 #[cfg(feature = "postgres")]
174 async fn list_graphs(&self, limit: u32) -> Result<Vec<GraphSummary>, MemoryError> {
175 use sqlx::Row as _;
176
177 let rows = sqlx::query::<sqlx::Postgres>(sql!(
178 "SELECT id, goal, status, created_at, finished_at \
179 FROM task_graphs \
180 ORDER BY created_at DESC \
181 LIMIT ?"
182 ))
183 .bind(i64::from(limit))
184 .fetch_all(&self.pool)
185 .await
186 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
187
188 Ok(rows
189 .into_iter()
190 .map(|row| GraphSummary {
191 id: row.get("id"),
192 goal: row.get("goal"),
193 status: row.get("status"),
194 created_at: row.get("created_at"),
195 finished_at: row.get("finished_at"),
196 })
197 .collect())
198 }
199
200 async fn delete_graph(&self, id: &str) -> Result<bool, MemoryError> {
201 let result = zeph_db::query(sql!("DELETE FROM task_graphs WHERE id = ?"))
202 .bind(id)
203 .execute(&self.pool)
204 .await
205 .map_err(|e| MemoryError::GraphStore(e.to_string()))?;
206 Ok(result.rows_affected() > 0)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::store::DbStore;
214
215 async fn make_store() -> TaskGraphStore {
216 let db = DbStore::new(":memory:").await.expect("DbStore");
217 TaskGraphStore::new(db.pool().clone())
218 }
219
220 #[tokio::test]
221 async fn test_save_and_load_roundtrip() {
222 let store = make_store().await;
223 store
224 .save_graph("id-1", "goal", "created", r#"{"key":"val"}"#, "100", None)
225 .await
226 .expect("save");
227 let loaded = store
228 .load_graph("id-1")
229 .await
230 .expect("load")
231 .expect("should exist");
232 assert_eq!(loaded, r#"{"key":"val"}"#);
233 }
234
235 #[tokio::test]
236 async fn test_load_nonexistent() {
237 let store = make_store().await;
238 let result = store.load_graph("missing-id").await.expect("load");
239 assert!(result.is_none());
240 }
241
242 #[tokio::test]
243 async fn test_list_graphs_ordering() {
244 let store = make_store().await;
245 store
246 .save_graph("id-1", "first", "created", "{}", "100", None)
247 .await
248 .expect("save 1");
249 store
250 .save_graph("id-2", "second", "created", "{}", "200", None)
251 .await
252 .expect("save 2");
253 let list = store.list_graphs(10).await.expect("list");
254 assert_eq!(list.len(), 2);
255 assert_eq!(list[0].id, "id-2");
257 assert_eq!(list[1].id, "id-1");
258 }
259
260 #[tokio::test]
261 async fn test_delete_graph() {
262 let store = make_store().await;
263 store
264 .save_graph("id-del", "goal", "created", "{}", "1", None)
265 .await
266 .expect("save");
267 let deleted = store.delete_graph("id-del").await.expect("delete");
268 assert!(deleted);
269 let loaded = store.load_graph("id-del").await.expect("load");
270 assert!(loaded.is_none());
271 }
272
273 #[tokio::test]
274 async fn test_save_overwrites_existing() {
275 let store = make_store().await;
276 store
277 .save_graph("id-1", "old", "created", r#"{"v":1}"#, "1", None)
278 .await
279 .expect("save 1");
280 store
281 .save_graph("id-1", "new", "running", r#"{"v":2}"#, "1", None)
282 .await
283 .expect("save 2 (upsert)");
284 let loaded = store
285 .load_graph("id-1")
286 .await
287 .expect("load")
288 .expect("exists");
289 assert_eq!(loaded, r#"{"v":2}"#);
290 }
291}