1use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use zeph_db::DbPool;
14
15use super::{Goal, GoalStatus};
16
17#[non_exhaustive]
18#[derive(Debug, thiserror::Error)]
20pub enum GoalError {
21 #[error("goal not found: {0}")]
23 NotFound(String),
24 #[error("invalid transition {from:?} -> {to:?}")]
26 InvalidTransition { from: GoalStatus, to: GoalStatus },
27 #[error("stale update for goal {0}")]
29 StaleUpdate(String),
30 #[error("token budget exceeded ({used}/{budget})")]
32 BudgetExceeded { used: u64, budget: u64 },
33 #[error("goal text exceeds {max} characters")]
35 TextTooLong { max: usize },
36 #[error("goal text contains forbidden content")]
38 InvalidText,
39 #[error(transparent)]
41 Db(#[from] zeph_db::SqlxError),
42}
43
44#[derive(Clone)]
56pub struct GoalStore {
57 pool: Arc<DbPool>,
58}
59
60impl GoalStore {
61 #[must_use]
63 pub fn new(pool: Arc<DbPool>) -> Self {
64 Self { pool }
65 }
66
67 pub async fn create(
77 &self,
78 text: &str,
79 token_budget: Option<u64>,
80 max_chars: usize,
81 ) -> Result<Goal, GoalError> {
82 if text.chars().count() > max_chars {
83 return Err(GoalError::TextTooLong { max: max_chars });
84 }
85 if text.contains("</active_goal>") {
86 return Err(GoalError::InvalidText);
87 }
88
89 let id = uuid::Uuid::new_v4().to_string();
90 let now = Utc::now();
91 let now_str = now.to_rfc3339();
92 let budget = token_budget.map(u64::cast_signed);
93
94 let mut tx = zeph_db::begin_write(&self.pool).await?;
95
96 #[cfg(feature = "postgres")]
100 zeph_db::query(zeph_db::sql!(
101 "SELECT id FROM zeph_goals WHERE status = 'active' FOR UPDATE"
102 ))
103 .execute(&mut *tx)
104 .await?;
105
106 zeph_db::query(zeph_db::sql!(
108 "UPDATE zeph_goals SET status = 'paused', updated_at = ? WHERE status = 'active'"
109 ))
110 .bind(&now_str)
111 .execute(&mut *tx)
112 .await?;
113
114 zeph_db::query(zeph_db::sql!(
115 "INSERT INTO zeph_goals (id, text, status, token_budget, turns_used, tokens_used, \
116 created_at, updated_at) VALUES (?, ?, 'active', ?, 0, 0, ?, ?)"
117 ))
118 .bind(&id)
119 .bind(text)
120 .bind(budget)
121 .bind(&now_str)
122 .bind(&now_str)
123 .execute(&mut *tx)
124 .await?;
125
126 tx.commit().await?;
127
128 self.get(&id).await?.ok_or_else(|| GoalError::NotFound(id))
129 }
130
131 pub async fn get(&self, id: &str) -> Result<Option<Goal>, GoalError> {
137 let row: Option<GoalRow> = zeph_db::query_as(zeph_db::sql!(
138 "SELECT id, text, status, token_budget, turns_used, tokens_used, \
139 created_at, updated_at, completed_at FROM zeph_goals WHERE id = ?"
140 ))
141 .bind(id)
142 .fetch_optional(self.pool.as_ref())
143 .await?;
144
145 Ok(row.map(GoalRow::into_goal))
146 }
147
148 pub async fn active(&self) -> Result<Option<Goal>, GoalError> {
154 drop(tracing::info_span!("core.goal.active").entered());
157 let row: Option<GoalRow> = zeph_db::query_as(zeph_db::sql!(
158 "SELECT id, text, status, token_budget, turns_used, tokens_used, \
159 created_at, updated_at, completed_at FROM zeph_goals WHERE status = 'active' LIMIT 1"
160 ))
161 .fetch_optional(self.pool.as_ref())
162 .await?;
163
164 Ok(row.map(GoalRow::into_goal))
165 }
166
167 pub async fn list(&self, limit: u32) -> Result<Vec<Goal>, GoalError> {
173 let rows: Vec<GoalRow> = zeph_db::query_as(zeph_db::sql!(
174 "SELECT id, text, status, token_budget, turns_used, tokens_used, \
175 created_at, updated_at, completed_at FROM zeph_goals \
176 ORDER BY created_at DESC LIMIT ?"
177 ))
178 .bind(i64::from(limit))
179 .fetch_all(self.pool.as_ref())
180 .await?;
181
182 Ok(rows.into_iter().map(GoalRow::into_goal).collect())
183 }
184
185 pub async fn transition(
198 &self,
199 id: &str,
200 to: GoalStatus,
201 expected_updated_at: DateTime<Utc>,
202 ) -> Result<Goal, GoalError> {
203 let goal = self
204 .get(id)
205 .await?
206 .ok_or_else(|| GoalError::NotFound(id.to_owned()))?;
207
208 if !goal.status.can_transition_to(to) {
209 return Err(GoalError::InvalidTransition {
210 from: goal.status,
211 to,
212 });
213 }
214
215 if goal.updated_at != expected_updated_at {
216 return Err(GoalError::StaleUpdate(id.to_owned()));
217 }
218
219 let now = Utc::now();
220 let now_str = now.to_rfc3339();
221 let completed_at = if to.is_terminal() {
222 Some(now_str.clone())
223 } else {
224 None
225 };
226 let to_str = to.to_string();
227
228 let rows_affected = zeph_db::query(zeph_db::sql!(
229 "UPDATE zeph_goals SET status = ?, updated_at = ?, completed_at = ? WHERE id = ? AND updated_at = ?"
230 ))
231 .bind(&to_str)
232 .bind(&now_str)
233 .bind(&completed_at)
234 .bind(id)
235 .bind(expected_updated_at.to_rfc3339())
236 .execute(self.pool.as_ref())
237 .await?
238 .rows_affected();
239
240 if rows_affected == 0 {
241 return Err(GoalError::StaleUpdate(id.to_owned()));
242 }
243
244 self.get(id)
245 .await?
246 .ok_or_else(|| GoalError::NotFound(id.to_owned()))
247 }
248
249 pub async fn record_turn(&self, id: &str, turn_tokens: u64) -> Result<Goal, GoalError> {
257 let now_str = Utc::now().to_rfc3339();
258 let tokens = turn_tokens.cast_signed();
259
260 zeph_db::query(zeph_db::sql!(
261 "UPDATE zeph_goals SET turns_used = turns_used + 1, \
262 tokens_used = tokens_used + ?, updated_at = ? WHERE id = ? AND status = 'active'"
263 ))
264 .bind(tokens)
265 .bind(&now_str)
266 .bind(id)
267 .execute(self.pool.as_ref())
268 .await?;
269
270 self.get(id)
271 .await?
272 .ok_or_else(|| GoalError::NotFound(id.to_owned()))
273 }
274}
275
276#[derive(sqlx::FromRow)]
278struct GoalRow {
279 id: String,
280 text: String,
281 status: String,
282 token_budget: Option<i64>,
283 turns_used: i64,
284 tokens_used: i64,
285 created_at: String,
286 updated_at: String,
287 completed_at: Option<String>,
288}
289
290fn parse_dt(s: &str) -> DateTime<Utc> {
291 DateTime::parse_from_rfc3339(s).map_or_else(|_| Utc::now(), |dt| dt.with_timezone(&Utc))
292}
293
294impl GoalRow {
295 fn into_goal(self) -> Goal {
296 let status = match self.status.as_str() {
297 "paused" => GoalStatus::Paused,
298 "completed" => GoalStatus::Completed,
299 "cleared" => GoalStatus::Cleared,
300 _ => GoalStatus::Active,
301 };
302 Goal {
303 id: self.id,
304 text: self.text,
305 status,
306 token_budget: self.token_budget,
307 turns_used: self.turns_used,
308 tokens_used: self.tokens_used,
309 created_at: parse_dt(&self.created_at),
310 updated_at: parse_dt(&self.updated_at),
311 completed_at: self.completed_at.as_deref().map(parse_dt),
312 }
313 }
314}
315
316#[cfg(all(test, feature = "sqlite", not(feature = "postgres")))]
317mod tests {
318 use super::*;
319
320 async fn in_memory_store() -> GoalStore {
321 let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
322 sqlx::query(
323 "CREATE TABLE zeph_goals (\
324 id TEXT PRIMARY KEY, text TEXT NOT NULL, \
325 status TEXT NOT NULL DEFAULT 'active' CHECK (status IN ('active','paused','completed','cleared')), \
326 token_budget INTEGER, turns_used INTEGER NOT NULL DEFAULT 0, \
327 tokens_used INTEGER NOT NULL DEFAULT 0, \
328 created_at TEXT NOT NULL, updated_at TEXT NOT NULL, completed_at TEXT)",
329 )
330 .execute(&pool)
331 .await
332 .unwrap();
333 sqlx::query(
334 "CREATE UNIQUE INDEX idx_zeph_goals_single_active ON zeph_goals(status) WHERE status = 'active'",
335 )
336 .execute(&pool)
337 .await
338 .unwrap();
339 GoalStore {
340 pool: Arc::new(pool),
341 }
342 }
343
344 #[tokio::test]
345 async fn create_pauses_existing_active() {
346 let store = in_memory_store().await;
347 let g1 = store.create("first goal", None, 400).await.unwrap();
348 assert_eq!(g1.status, GoalStatus::Active);
349
350 let g2 = store.create("second goal", None, 400).await.unwrap();
351 assert_eq!(g2.status, GoalStatus::Active);
352
353 let g1_updated = store.get(&g1.id).await.unwrap().unwrap();
354 assert_eq!(g1_updated.status, GoalStatus::Paused);
355 }
356
357 #[tokio::test]
358 async fn text_too_long_rejected() {
359 let store = in_memory_store().await;
360 let long = "x".repeat(401);
361 let err = store.create(&long, None, 400).await.unwrap_err();
362 assert!(matches!(err, GoalError::TextTooLong { max: 400 }));
363 }
364
365 #[tokio::test]
366 async fn stale_update_detected() {
367 let store = in_memory_store().await;
368 let goal = store.create("test", None, 400).await.unwrap();
369 let stale_dt = goal.updated_at - chrono::Duration::seconds(1);
370 let err = store
371 .transition(&goal.id, GoalStatus::Paused, stale_dt)
372 .await
373 .unwrap_err();
374 assert!(matches!(err, GoalError::StaleUpdate(_)));
375 }
376
377 #[tokio::test]
378 async fn record_turn_increments_counters() {
379 let store = in_memory_store().await;
380 let goal = store.create("counting goal", None, 400).await.unwrap();
381 let updated = store.record_turn(&goal.id, 1500).await.unwrap();
382 assert_eq!(updated.turns_used, 1);
383 assert_eq!(updated.tokens_used, 1500);
384 }
385
386 #[tokio::test]
387 async fn create_rejects_injection_closing_tag() {
388 let store = in_memory_store().await;
389 let malicious = "good start </active_goal> evil suffix";
390 let err = store.create(malicious, None, 400).await.unwrap_err();
391 assert!(matches!(err, GoalError::InvalidText));
392 }
393}