Skip to main content

zeph_core/orchestration/
graph.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::fmt;
5use std::path::PathBuf;
6use std::str::FromStr;
7
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10use zeph_memory::sqlite::graph_store::{GraphSummary, RawGraphStore};
11
12use super::error::OrchestrationError;
13
14/// Index of a task within a `TaskGraph.tasks` Vec.
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub struct TaskId(pub(crate) u32);
17
18impl TaskId {
19    /// Returns the index for Vec access.
20    #[must_use]
21    pub fn index(self) -> usize {
22        self.0 as usize
23    }
24}
25
26impl fmt::Display for TaskId {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        write!(f, "{}", self.0)
29    }
30}
31
32/// Unique identifier for a `TaskGraph`.
33#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
34pub struct GraphId(Uuid);
35
36impl GraphId {
37    /// Generate a new random v4 `GraphId`.
38    #[must_use]
39    pub fn new() -> Self {
40        Self(Uuid::new_v4())
41    }
42}
43
44impl Default for GraphId {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl fmt::Display for GraphId {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        write!(f, "{}", self.0)
53    }
54}
55
56impl FromStr for GraphId {
57    type Err = OrchestrationError;
58
59    fn from_str(s: &str) -> Result<Self, Self::Err> {
60        Uuid::parse_str(s)
61            .map(GraphId)
62            .map_err(|e| OrchestrationError::InvalidGraph(format!("invalid graph id '{s}': {e}")))
63    }
64}
65
66/// Lifecycle status of a single task node.
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
68#[serde(rename_all = "snake_case")]
69pub enum TaskStatus {
70    Pending,
71    Ready,
72    Running,
73    Completed,
74    Failed,
75    Skipped,
76    Canceled,
77}
78
79impl TaskStatus {
80    /// Returns `true` if the status is a terminal state.
81    #[must_use]
82    pub fn is_terminal(self) -> bool {
83        matches!(
84            self,
85            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Skipped | TaskStatus::Canceled
86        )
87    }
88}
89
90impl fmt::Display for TaskStatus {
91    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92        match self {
93            TaskStatus::Pending => write!(f, "pending"),
94            TaskStatus::Ready => write!(f, "ready"),
95            TaskStatus::Running => write!(f, "running"),
96            TaskStatus::Completed => write!(f, "completed"),
97            TaskStatus::Failed => write!(f, "failed"),
98            TaskStatus::Skipped => write!(f, "skipped"),
99            TaskStatus::Canceled => write!(f, "canceled"),
100        }
101    }
102}
103
104/// Lifecycle status of a `TaskGraph`.
105#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
106#[serde(rename_all = "snake_case")]
107pub enum GraphStatus {
108    Created,
109    Running,
110    Completed,
111    Failed,
112    Canceled,
113    Paused,
114}
115
116impl fmt::Display for GraphStatus {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        match self {
119            GraphStatus::Created => write!(f, "created"),
120            GraphStatus::Running => write!(f, "running"),
121            GraphStatus::Completed => write!(f, "completed"),
122            GraphStatus::Failed => write!(f, "failed"),
123            GraphStatus::Canceled => write!(f, "canceled"),
124            GraphStatus::Paused => write!(f, "paused"),
125        }
126    }
127}
128
129/// What to do when a task fails.
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
131#[serde(rename_all = "snake_case")]
132pub enum FailureStrategy {
133    /// Abort the entire graph.
134    #[default]
135    Abort,
136    /// Retry the task up to `max_retries` times.
137    Retry,
138    /// Skip the task and its dependents.
139    Skip,
140    /// Pause the graph and ask the user.
141    Ask,
142}
143
144impl fmt::Display for FailureStrategy {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        match self {
147            FailureStrategy::Abort => write!(f, "abort"),
148            FailureStrategy::Retry => write!(f, "retry"),
149            FailureStrategy::Skip => write!(f, "skip"),
150            FailureStrategy::Ask => write!(f, "ask"),
151        }
152    }
153}
154
155impl FromStr for FailureStrategy {
156    type Err = OrchestrationError;
157
158    fn from_str(s: &str) -> Result<Self, Self::Err> {
159        match s {
160            "abort" => Ok(FailureStrategy::Abort),
161            "retry" => Ok(FailureStrategy::Retry),
162            "skip" => Ok(FailureStrategy::Skip),
163            "ask" => Ok(FailureStrategy::Ask),
164            other => Err(OrchestrationError::InvalidGraph(format!(
165                "unknown failure strategy '{other}': expected one of abort, retry, skip, ask"
166            ))),
167        }
168    }
169}
170
171/// Output produced by a completed task.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct TaskResult {
174    pub output: String,
175    pub artifacts: Vec<PathBuf>,
176    pub duration_ms: u64,
177    pub agent_id: Option<String>,
178    pub agent_def: Option<String>,
179}
180
181/// A single node in the task DAG.
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct TaskNode {
184    pub id: TaskId,
185    pub title: String,
186    pub description: String,
187    pub agent_hint: Option<String>,
188    pub status: TaskStatus,
189    /// Indices of tasks this node depends on.
190    pub depends_on: Vec<TaskId>,
191    pub result: Option<TaskResult>,
192    pub assigned_agent: Option<String>,
193    pub retry_count: u32,
194    /// Per-task override; `None` means use graph default.
195    pub failure_strategy: Option<FailureStrategy>,
196    pub max_retries: Option<u32>,
197}
198
199impl TaskNode {
200    /// Create a new pending task with the given index.
201    #[must_use]
202    pub fn new(id: u32, title: impl Into<String>, description: impl Into<String>) -> Self {
203        Self {
204            id: TaskId(id),
205            title: title.into(),
206            description: description.into(),
207            agent_hint: None,
208            status: TaskStatus::Pending,
209            depends_on: Vec::new(),
210            result: None,
211            assigned_agent: None,
212            retry_count: 0,
213            failure_strategy: None,
214            max_retries: None,
215        }
216    }
217}
218
219/// A directed acyclic graph of tasks to be executed by the orchestrator.
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct TaskGraph {
222    pub id: GraphId,
223    pub goal: String,
224    pub tasks: Vec<TaskNode>,
225    pub status: GraphStatus,
226    pub default_failure_strategy: FailureStrategy,
227    pub default_max_retries: u32,
228    pub created_at: String,
229    pub finished_at: Option<String>,
230}
231
232impl TaskGraph {
233    /// Create a new graph with `Created` status.
234    #[must_use]
235    pub fn new(goal: impl Into<String>) -> Self {
236        Self {
237            id: GraphId::new(),
238            goal: goal.into(),
239            tasks: Vec::new(),
240            status: GraphStatus::Created,
241            default_failure_strategy: FailureStrategy::default(),
242            default_max_retries: 3,
243            created_at: chrono_now(),
244            finished_at: None,
245        }
246    }
247}
248
249pub(crate) fn chrono_now() -> String {
250    // ISO-8601 UTC timestamp, consistent with the rest of the codebase.
251    // Format: "2026-03-05T22:04:41Z"
252    let secs = std::time::SystemTime::now()
253        .duration_since(std::time::UNIX_EPOCH)
254        .map_or(0, |d| d.as_secs());
255    // Manual formatting: seconds since epoch → ISO-8601 UTC
256    // Days since epoch, then decompose into year/month/day
257    let (y, mo, d, h, mi, s) = epoch_secs_to_datetime(secs);
258    format!("{y:04}-{mo:02}-{d:02}T{h:02}:{mi:02}:{s:02}Z")
259}
260
261/// Convert Unix epoch seconds to (year, month, day, hour, min, sec) UTC.
262fn epoch_secs_to_datetime(secs: u64) -> (u64, u8, u8, u8, u8, u8) {
263    let s = (secs % 60) as u8;
264    let mins = secs / 60;
265    let mi = (mins % 60) as u8;
266    let hours = mins / 60;
267    let h = (hours % 24) as u8;
268    let days = hours / 24; // days since 1970-01-01
269
270    // Gregorian calendar decomposition
271    // 400-year cycle = 146097 days
272    let (mut year, mut remaining_days) = {
273        let cycles = days / 146_097;
274        let rem = days % 146_097;
275        (1970 + cycles * 400, rem)
276    };
277    // 100-year century (36524 days, no leap on century unless /400)
278    let centuries = (remaining_days / 36_524).min(3);
279    year += centuries * 100;
280    remaining_days -= centuries * 36_524;
281    // 4-year cycle (1461 days)
282    let quads = remaining_days / 1_461;
283    year += quads * 4;
284    remaining_days -= quads * 1_461;
285    // remaining years
286    let extra_years = (remaining_days / 365).min(3);
287    year += extra_years;
288    remaining_days -= extra_years * 365;
289
290    let is_leap = (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0);
291    let days_in_month: [u64; 12] = if is_leap {
292        [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
293    } else {
294        [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
295    };
296
297    let mut month = 0u8;
298    for (i, &dim) in days_in_month.iter().enumerate() {
299        if remaining_days < dim {
300            // i is in 0..12, so i+1 fits in u8
301            month = u8::try_from(i + 1).unwrap_or(1);
302            break;
303        }
304        remaining_days -= dim;
305    }
306    // remaining_days is in 0..30, so +1 fits in u8
307    let day = u8::try_from(remaining_days + 1).unwrap_or(1);
308
309    (year, month, day, h, mi, s)
310}
311
312/// Maximum allowed length for a `TaskGraph` goal string.
313const MAX_GOAL_LEN: usize = 1024;
314
315/// Type-safe wrapper around `RawGraphStore` that handles `TaskGraph` serialization.
316///
317/// Consumers in `zeph-core` use this instead of `RawGraphStore` directly, so they
318/// never need to deal with JSON strings.
319///
320/// # Storage layout
321///
322/// The `task_graphs` table stores both metadata columns (`goal`, `status`,
323/// `created_at`, `finished_at`) and the full `graph_json` blob. The metadata
324/// columns are summary/index data used for listing and filtering; `graph_json`
325/// is the authoritative source for full graph reconstruction. On `load`, only
326/// `graph_json` is deserialized — the columns are not consulted.
327pub struct GraphPersistence<S: RawGraphStore> {
328    store: S,
329}
330
331impl<S: RawGraphStore> GraphPersistence<S> {
332    /// Create a new `GraphPersistence` wrapping the given store.
333    pub fn new(store: S) -> Self {
334        Self { store }
335    }
336
337    /// Persist a `TaskGraph` (upsert).
338    ///
339    /// Returns `OrchestrationError::InvalidGraph` if `graph.goal` exceeds
340    /// `MAX_GOAL_LEN` (1024) characters.
341    ///
342    /// # Errors
343    ///
344    /// Returns `OrchestrationError::Persistence` on serialization or database failure.
345    pub async fn save(&self, graph: &TaskGraph) -> Result<(), OrchestrationError> {
346        if graph.goal.len() > MAX_GOAL_LEN {
347            return Err(OrchestrationError::InvalidGraph(format!(
348                "goal exceeds {MAX_GOAL_LEN} character limit ({} chars)",
349                graph.goal.len()
350            )));
351        }
352        let json = serde_json::to_string(graph)
353            .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
354        self.store
355            .save_graph(
356                &graph.id.to_string(),
357                &graph.goal,
358                &graph.status.to_string(),
359                &json,
360                &graph.created_at,
361                graph.finished_at.as_deref(),
362            )
363            .await
364            .map_err(|e| OrchestrationError::Persistence(e.to_string()))
365    }
366
367    /// Load a `TaskGraph` by its `GraphId`.
368    ///
369    /// Returns `None` if not found.
370    ///
371    /// # Errors
372    ///
373    /// Returns `OrchestrationError::Persistence` on database or deserialization failure.
374    pub async fn load(&self, id: &GraphId) -> Result<Option<TaskGraph>, OrchestrationError> {
375        match self
376            .store
377            .load_graph(&id.to_string())
378            .await
379            .map_err(|e| OrchestrationError::Persistence(e.to_string()))?
380        {
381            Some(json) => {
382                let graph = serde_json::from_str(&json)
383                    .map_err(|e| OrchestrationError::Persistence(e.to_string()))?;
384                Ok(Some(graph))
385            }
386            None => Ok(None),
387        }
388    }
389
390    /// List stored graphs (newest first).
391    ///
392    /// # Errors
393    ///
394    /// Returns `OrchestrationError::Persistence` on database failure.
395    pub async fn list(&self, limit: u32) -> Result<Vec<GraphSummary>, OrchestrationError> {
396        self.store
397            .list_graphs(limit)
398            .await
399            .map_err(|e| OrchestrationError::Persistence(e.to_string()))
400    }
401
402    /// Delete a graph by its `GraphId`.
403    ///
404    /// Returns `true` if a row was deleted.
405    ///
406    /// # Errors
407    ///
408    /// Returns `OrchestrationError::Persistence` on database failure.
409    pub async fn delete(&self, id: &GraphId) -> Result<bool, OrchestrationError> {
410        self.store
411            .delete_graph(&id.to_string())
412            .await
413            .map_err(|e| OrchestrationError::Persistence(e.to_string()))
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_taskid_display() {
423        assert_eq!(TaskId(3).to_string(), "3");
424    }
425
426    #[test]
427    fn test_graphid_display_and_new() {
428        let id = GraphId::new();
429        let s = id.to_string();
430        assert_eq!(s.len(), 36, "UUID string should be 36 chars");
431        let parsed: GraphId = s.parse().expect("should parse back");
432        assert_eq!(id, parsed);
433    }
434
435    #[test]
436    fn test_graphid_from_str_invalid() {
437        let err = "not-a-uuid".parse::<GraphId>();
438        assert!(err.is_err());
439    }
440
441    #[test]
442    fn test_task_status_is_terminal() {
443        assert!(TaskStatus::Completed.is_terminal());
444        assert!(TaskStatus::Failed.is_terminal());
445        assert!(TaskStatus::Skipped.is_terminal());
446        assert!(TaskStatus::Canceled.is_terminal());
447
448        assert!(!TaskStatus::Pending.is_terminal());
449        assert!(!TaskStatus::Ready.is_terminal());
450        assert!(!TaskStatus::Running.is_terminal());
451    }
452
453    #[test]
454    fn test_task_status_display() {
455        assert_eq!(TaskStatus::Pending.to_string(), "pending");
456        assert_eq!(TaskStatus::Ready.to_string(), "ready");
457        assert_eq!(TaskStatus::Running.to_string(), "running");
458        assert_eq!(TaskStatus::Completed.to_string(), "completed");
459        assert_eq!(TaskStatus::Failed.to_string(), "failed");
460        assert_eq!(TaskStatus::Skipped.to_string(), "skipped");
461        assert_eq!(TaskStatus::Canceled.to_string(), "canceled");
462    }
463
464    #[test]
465    fn test_failure_strategy_default() {
466        assert_eq!(FailureStrategy::default(), FailureStrategy::Abort);
467    }
468
469    #[test]
470    fn test_failure_strategy_display() {
471        assert_eq!(FailureStrategy::Abort.to_string(), "abort");
472        assert_eq!(FailureStrategy::Retry.to_string(), "retry");
473        assert_eq!(FailureStrategy::Skip.to_string(), "skip");
474        assert_eq!(FailureStrategy::Ask.to_string(), "ask");
475    }
476
477    #[test]
478    fn test_graph_status_display() {
479        assert_eq!(GraphStatus::Created.to_string(), "created");
480        assert_eq!(GraphStatus::Running.to_string(), "running");
481        assert_eq!(GraphStatus::Completed.to_string(), "completed");
482        assert_eq!(GraphStatus::Failed.to_string(), "failed");
483        assert_eq!(GraphStatus::Canceled.to_string(), "canceled");
484        assert_eq!(GraphStatus::Paused.to_string(), "paused");
485    }
486
487    #[test]
488    fn test_task_graph_serde_roundtrip() {
489        let mut graph = TaskGraph::new("test goal");
490        graph.tasks.push(TaskNode::new(0, "task 0", "do something"));
491        let json = serde_json::to_string(&graph).expect("serialize");
492        let restored: TaskGraph = serde_json::from_str(&json).expect("deserialize");
493        assert_eq!(graph.id, restored.id);
494        assert_eq!(graph.goal, restored.goal);
495        assert_eq!(graph.tasks.len(), restored.tasks.len());
496    }
497
498    #[test]
499    fn test_task_node_serde_roundtrip() {
500        let mut node = TaskNode::new(1, "compile", "run cargo build");
501        node.agent_hint = Some("rust-dev".to_string());
502        node.depends_on = vec![TaskId(0)];
503        let json = serde_json::to_string(&node).expect("serialize");
504        let restored: TaskNode = serde_json::from_str(&json).expect("deserialize");
505        assert_eq!(node.id, restored.id);
506        assert_eq!(node.title, restored.title);
507        assert_eq!(node.depends_on, restored.depends_on);
508    }
509
510    #[test]
511    fn test_task_result_serde_roundtrip() {
512        let result = TaskResult {
513            output: "ok".to_string(),
514            artifacts: vec![PathBuf::from("/tmp/out.bin")],
515            duration_ms: 500,
516            agent_id: Some("agent-1".to_string()),
517            agent_def: None,
518        };
519        let json = serde_json::to_string(&result).expect("serialize");
520        let restored: TaskResult = serde_json::from_str(&json).expect("deserialize");
521        assert_eq!(result.output, restored.output);
522        assert_eq!(result.duration_ms, restored.duration_ms);
523        assert_eq!(result.artifacts, restored.artifacts);
524    }
525
526    #[test]
527    fn test_failure_strategy_from_str() {
528        assert_eq!(
529            "abort".parse::<FailureStrategy>().unwrap(),
530            FailureStrategy::Abort
531        );
532        assert_eq!(
533            "retry".parse::<FailureStrategy>().unwrap(),
534            FailureStrategy::Retry
535        );
536        assert_eq!(
537            "skip".parse::<FailureStrategy>().unwrap(),
538            FailureStrategy::Skip
539        );
540        assert_eq!(
541            "ask".parse::<FailureStrategy>().unwrap(),
542            FailureStrategy::Ask
543        );
544        assert!("abort_all".parse::<FailureStrategy>().is_err());
545        assert!("".parse::<FailureStrategy>().is_err());
546    }
547
548    #[test]
549    fn test_chrono_now_iso8601_format() {
550        let ts = chrono_now();
551        // Format: "YYYY-MM-DDTHH:MM:SSZ" — 20 chars
552        assert_eq!(ts.len(), 20, "timestamp should be 20 chars: {ts}");
553        assert!(ts.ends_with('Z'), "should end with Z: {ts}");
554        assert!(ts.contains('T'), "should contain T: {ts}");
555        // Year should be >= 2024
556        let year: u32 = ts[..4].parse().expect("year should be numeric");
557        assert!(year >= 2024, "year should be >= 2024: {year}");
558    }
559
560    #[test]
561    fn test_failure_strategy_serde_snake_case() {
562        assert_eq!(
563            serde_json::to_string(&FailureStrategy::Abort).unwrap(),
564            "\"abort\""
565        );
566        assert_eq!(
567            serde_json::to_string(&FailureStrategy::Retry).unwrap(),
568            "\"retry\""
569        );
570        assert_eq!(
571            serde_json::to_string(&FailureStrategy::Skip).unwrap(),
572            "\"skip\""
573        );
574        assert_eq!(
575            serde_json::to_string(&FailureStrategy::Ask).unwrap(),
576            "\"ask\""
577        );
578    }
579
580    #[test]
581    fn test_graph_persistence_save_rejects_long_goal() {
582        // GraphPersistence::save() is async and requires a real store;
583        // we verify the goal-length guard directly via the const.
584        let long_goal = "x".repeat(MAX_GOAL_LEN + 1);
585        let mut graph = TaskGraph::new(long_goal);
586        graph.goal = "x".repeat(MAX_GOAL_LEN + 1);
587        assert!(
588            graph.goal.len() > MAX_GOAL_LEN,
589            "test setup: goal must exceed limit"
590        );
591        // The check itself lives in GraphPersistence::save(), exercised by
592        // the async persistence tests in zeph-memory; here we verify the constant.
593        assert_eq!(MAX_GOAL_LEN, 1024);
594    }
595}