1use std::sync::Arc;
11
12use chrono::{DateTime, Utc};
13use zeph_db::DbPool;
14
15use super::{Goal, GoalStatus};
16
17#[derive(Debug, thiserror::Error)]
19pub enum GoalError {
20 #[error("goal not found: {0}")]
22 NotFound(String),
23 #[error("invalid transition {from:?} -> {to:?}")]
25 InvalidTransition { from: GoalStatus, to: GoalStatus },
26 #[error("stale update for goal {0}")]
28 StaleUpdate(String),
29 #[error("token budget exceeded ({used}/{budget})")]
31 BudgetExceeded { used: u64, budget: u64 },
32 #[error("goal text exceeds {max} characters")]
34 TextTooLong { max: usize },
35 #[error("goal text contains forbidden content")]
37 InvalidText,
38 #[error(transparent)]
40 Db(#[from] zeph_db::SqlxError),
41}
42
43#[derive(Clone)]
55pub struct GoalStore {
56 pool: Arc<DbPool>,
57}
58
59impl GoalStore {
60 #[must_use]
62 pub fn new(pool: Arc<DbPool>) -> Self {
63 Self { pool }
64 }
65
66 pub async fn create(
76 &self,
77 text: &str,
78 token_budget: Option<u64>,
79 max_chars: usize,
80 ) -> Result<Goal, GoalError> {
81 if text.chars().count() > max_chars {
82 return Err(GoalError::TextTooLong { max: max_chars });
83 }
84 if text.contains("</active_goal>") {
85 return Err(GoalError::InvalidText);
86 }
87
88 let id = uuid::Uuid::new_v4().to_string();
89 let now = Utc::now();
90 let now_str = now.to_rfc3339();
91 let budget = token_budget.map(u64::cast_signed);
92
93 let mut tx = zeph_db::begin_write(&self.pool).await?;
94
95 #[cfg(feature = "postgres")]
99 zeph_db::query(zeph_db::sql!(
100 "SELECT id FROM zeph_goals WHERE status = 'active' FOR UPDATE"
101 ))
102 .execute(&mut *tx)
103 .await?;
104
105 zeph_db::query(zeph_db::sql!(
107 "UPDATE zeph_goals SET status = 'paused', updated_at = ? WHERE status = 'active'"
108 ))
109 .bind(&now_str)
110 .execute(&mut *tx)
111 .await?;
112
113 zeph_db::query(zeph_db::sql!(
114 "INSERT INTO zeph_goals (id, text, status, token_budget, turns_used, tokens_used, \
115 created_at, updated_at) VALUES (?, ?, 'active', ?, 0, 0, ?, ?)"
116 ))
117 .bind(&id)
118 .bind(text)
119 .bind(budget)
120 .bind(&now_str)
121 .bind(&now_str)
122 .execute(&mut *tx)
123 .await?;
124
125 tx.commit().await?;
126
127 self.get(&id).await?.ok_or_else(|| GoalError::NotFound(id))
128 }
129
130 pub async fn get(&self, id: &str) -> Result<Option<Goal>, GoalError> {
136 let row: Option<GoalRow> = zeph_db::query_as(zeph_db::sql!(
137 "SELECT id, text, status, token_budget, turns_used, tokens_used, \
138 created_at, updated_at, completed_at FROM zeph_goals WHERE id = ?"
139 ))
140 .bind(id)
141 .fetch_optional(self.pool.as_ref())
142 .await?;
143
144 Ok(row.map(GoalRow::into_goal))
145 }
146
147 pub async fn active(&self) -> Result<Option<Goal>, GoalError> {
153 drop(tracing::info_span!("core.goal.active").entered());
156 let row: Option<GoalRow> = zeph_db::query_as(zeph_db::sql!(
157 "SELECT id, text, status, token_budget, turns_used, tokens_used, \
158 created_at, updated_at, completed_at FROM zeph_goals WHERE status = 'active' LIMIT 1"
159 ))
160 .fetch_optional(self.pool.as_ref())
161 .await?;
162
163 Ok(row.map(GoalRow::into_goal))
164 }
165
166 pub async fn list(&self, limit: u32) -> Result<Vec<Goal>, GoalError> {
172 let rows: Vec<GoalRow> = zeph_db::query_as(zeph_db::sql!(
173 "SELECT id, text, status, token_budget, turns_used, tokens_used, \
174 created_at, updated_at, completed_at FROM zeph_goals \
175 ORDER BY created_at DESC LIMIT ?"
176 ))
177 .bind(i64::from(limit))
178 .fetch_all(self.pool.as_ref())
179 .await?;
180
181 Ok(rows.into_iter().map(GoalRow::into_goal).collect())
182 }
183
184 pub async fn transition(
197 &self,
198 id: &str,
199 to: GoalStatus,
200 expected_updated_at: DateTime<Utc>,
201 ) -> Result<Goal, GoalError> {
202 let goal = self
203 .get(id)
204 .await?
205 .ok_or_else(|| GoalError::NotFound(id.to_owned()))?;
206
207 if !goal.status.can_transition_to(to) {
208 return Err(GoalError::InvalidTransition {
209 from: goal.status,
210 to,
211 });
212 }
213
214 if goal.updated_at != expected_updated_at {
215 return Err(GoalError::StaleUpdate(id.to_owned()));
216 }
217
218 let now = Utc::now();
219 let now_str = now.to_rfc3339();
220 let completed_at = if to.is_terminal() {
221 Some(now_str.clone())
222 } else {
223 None
224 };
225 let to_str = to.to_string();
226
227 let rows_affected = zeph_db::query(zeph_db::sql!(
228 "UPDATE zeph_goals SET status = ?, updated_at = ?, completed_at = ? WHERE id = ? AND updated_at = ?"
229 ))
230 .bind(&to_str)
231 .bind(&now_str)
232 .bind(&completed_at)
233 .bind(id)
234 .bind(expected_updated_at.to_rfc3339())
235 .execute(self.pool.as_ref())
236 .await?
237 .rows_affected();
238
239 if rows_affected == 0 {
240 return Err(GoalError::StaleUpdate(id.to_owned()));
241 }
242
243 self.get(id)
244 .await?
245 .ok_or_else(|| GoalError::NotFound(id.to_owned()))
246 }
247
248 pub async fn record_turn(&self, id: &str, turn_tokens: u64) -> Result<Goal, GoalError> {
256 let now_str = Utc::now().to_rfc3339();
257 let tokens = turn_tokens.cast_signed();
258
259 zeph_db::query(zeph_db::sql!(
260 "UPDATE zeph_goals SET turns_used = turns_used + 1, \
261 tokens_used = tokens_used + ?, updated_at = ? WHERE id = ? AND status = 'active'"
262 ))
263 .bind(tokens)
264 .bind(&now_str)
265 .bind(id)
266 .execute(self.pool.as_ref())
267 .await?;
268
269 self.get(id)
270 .await?
271 .ok_or_else(|| GoalError::NotFound(id.to_owned()))
272 }
273}
274
275#[derive(sqlx::FromRow)]
277struct GoalRow {
278 id: String,
279 text: String,
280 status: String,
281 token_budget: Option<i64>,
282 turns_used: i64,
283 tokens_used: i64,
284 created_at: String,
285 updated_at: String,
286 completed_at: Option<String>,
287}
288
289fn parse_dt(s: &str) -> DateTime<Utc> {
290 DateTime::parse_from_rfc3339(s).map_or_else(|_| Utc::now(), |dt| dt.with_timezone(&Utc))
291}
292
293impl GoalRow {
294 fn into_goal(self) -> Goal {
295 let status = match self.status.as_str() {
296 "paused" => GoalStatus::Paused,
297 "completed" => GoalStatus::Completed,
298 "cleared" => GoalStatus::Cleared,
299 _ => GoalStatus::Active,
300 };
301 Goal {
302 id: self.id,
303 text: self.text,
304 status,
305 token_budget: self.token_budget,
306 turns_used: self.turns_used,
307 tokens_used: self.tokens_used,
308 created_at: parse_dt(&self.created_at),
309 updated_at: parse_dt(&self.updated_at),
310 completed_at: self.completed_at.as_deref().map(parse_dt),
311 }
312 }
313}
314
315#[cfg(all(test, feature = "sqlite", not(feature = "postgres")))]
316mod tests {
317 use super::*;
318
319 async fn in_memory_store() -> GoalStore {
320 let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
321 sqlx::query(
322 "CREATE TABLE zeph_goals (\
323 id TEXT PRIMARY KEY, text TEXT NOT NULL, \
324 status TEXT NOT NULL DEFAULT 'active' CHECK (status IN ('active','paused','completed','cleared')), \
325 token_budget INTEGER, turns_used INTEGER NOT NULL DEFAULT 0, \
326 tokens_used INTEGER NOT NULL DEFAULT 0, \
327 created_at TEXT NOT NULL, updated_at TEXT NOT NULL, completed_at TEXT)",
328 )
329 .execute(&pool)
330 .await
331 .unwrap();
332 sqlx::query(
333 "CREATE UNIQUE INDEX idx_zeph_goals_single_active ON zeph_goals(status) WHERE status = 'active'",
334 )
335 .execute(&pool)
336 .await
337 .unwrap();
338 GoalStore {
339 pool: Arc::new(pool),
340 }
341 }
342
343 #[tokio::test]
344 async fn create_pauses_existing_active() {
345 let store = in_memory_store().await;
346 let g1 = store.create("first goal", None, 400).await.unwrap();
347 assert_eq!(g1.status, GoalStatus::Active);
348
349 let g2 = store.create("second goal", None, 400).await.unwrap();
350 assert_eq!(g2.status, GoalStatus::Active);
351
352 let g1_updated = store.get(&g1.id).await.unwrap().unwrap();
353 assert_eq!(g1_updated.status, GoalStatus::Paused);
354 }
355
356 #[tokio::test]
357 async fn text_too_long_rejected() {
358 let store = in_memory_store().await;
359 let long = "x".repeat(401);
360 let err = store.create(&long, None, 400).await.unwrap_err();
361 assert!(matches!(err, GoalError::TextTooLong { max: 400 }));
362 }
363
364 #[tokio::test]
365 async fn stale_update_detected() {
366 let store = in_memory_store().await;
367 let goal = store.create("test", None, 400).await.unwrap();
368 let stale_dt = goal.updated_at - chrono::Duration::seconds(1);
369 let err = store
370 .transition(&goal.id, GoalStatus::Paused, stale_dt)
371 .await
372 .unwrap_err();
373 assert!(matches!(err, GoalError::StaleUpdate(_)));
374 }
375
376 #[tokio::test]
377 async fn record_turn_increments_counters() {
378 let store = in_memory_store().await;
379 let goal = store.create("counting goal", None, 400).await.unwrap();
380 let updated = store.record_turn(&goal.id, 1500).await.unwrap();
381 assert_eq!(updated.turns_used, 1);
382 assert_eq!(updated.tokens_used, 1500);
383 }
384
385 #[tokio::test]
386 async fn create_rejects_injection_closing_tag() {
387 let store = in_memory_store().await;
388 let malicious = "good start </active_goal> evil suffix";
389 let err = store.create(malicious, None, 400).await.unwrap_err();
390 assert!(matches!(err, GoalError::InvalidText));
391 }
392}